diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 2e54e9a7de7f5c..cd3db5b97fb35d 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -160,6 +160,20 @@ case "$image" in TRITON=yes INDUCTOR_BENCHMARKS=yes ;; + pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9) + CUDA_VERSION=12.1.0 + CUDNN_VERSION=8 + ANACONDA_PYTHON_VERSION=3.10 + GCC_VERSION=9 + PROTOBUF=yes + DB=yes + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + CONDA_CMAKE=yes + TRITON=yes + ;; pytorch-linux-focal-py3-clang7-asan) ANACONDA_PYTHON_VERSION=3.9 CLANG_VERSION=7 diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index f3e3c95213592e..52af279e34691b 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -105,5 +105,13 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then pip_install -r /opt/conda/requirements-docs.txt fi + # HACK HACK HACK + # gcc-9 for ubuntu-18.04 from http://ppa.launchpad.net/ubuntu-toolchain-r/test/ubuntu + # Pulls llibstdc++6 13.1.0-8ubuntu1~18.04 which is too new for conda + # So remove libstdc++6.so.3.29 installed by https://anaconda.org/anaconda/libstdcxx-ng/files?version=11.2.0 + if grep 18.04.6 /etc/issue >/dev/null; then + rm /opt/conda/envs/py_$ANACONDA_PYTHON_VERSION/lib/libstdc++.so.6 + fi + popd fi diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 17e06fa6c8ac31..0c6593ffd00c30 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -75,10 +75,10 @@ librosa>=0.6.2 ; python_version < "3.11" #Pinned versions: #test that import: -mypy==0.960 +mypy==1.4.1 # Pin MyPy version because new errors are likely to appear with each release #Description: linter -#Pinned versions: 0.960 +#Pinned versions: 1.4.1 #test that import: test_typing.py, test_type_hints.py networkx==2.8.8 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 1dc711481358c1..82fab435fe5c37 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -40,6 +40,7 @@ jobs: - docker-image-name: pytorch-linux-bionic-cuda11.8-cudnn8-py3-gcc7-inductor-benchmarks - docker-image-name: pytorch-linux-bionic-py3.8-clang9 - docker-image-name: pytorch-linux-bionic-py3.11-clang9 + - docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 - docker-image-name: pytorch-linux-focal-rocm-n-1-py3 - docker-image-name: pytorch-linux-focal-rocm-n-py3 - docker-image-name: pytorch-linux-jammy-cuda11.8-cudnn8-py3.8-clang12 diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 279782edf572f3..3279d47f1e52b0 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -305,12 +305,12 @@ jobs: { config: "default", shard: 1, num_shards: 1, runner: "linux.4xlarge" }, ]} - linux-bionic-cuda12_1-py3_10-gcc9-bazel-test: - name: linux-bionic-cuda12.1-py3.10-gcc9-bazel-test + linux-focal-cuda12_1-py3_10-gcc9-bazel-test: + name: linux-focal-cuda12.1-py3.10-gcc9-bazel-test uses: ./.github/workflows/_bazel-build-test.yml with: - build-environment: linux-bionic-cuda12.1-py3.10-gcc9-bazel-test - docker-image-name: pytorch-linux-bionic-cuda12.1-cudnn8-py3-gcc9 + build-environment: linux-focal-cuda12.1-py3.10-gcc9-bazel-test + docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9 cuda-version: "12.1" test-matrix: | { include: [ diff --git a/.lintrunner.toml b/.lintrunner.toml index 2ebf50e039fa9d..abfe157b2cb31f 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -152,7 +152,7 @@ init_command = [ '--dry-run={{DRYRUN}}', 'numpy==1.24.3', 'expecttest==0.1.3', - 'mypy==0.960', + 'mypy==1.4.1', 'types-requests==2.27.25', 'types-PyYAML==6.0.7', 'types-tabulate==0.8.8', diff --git a/caffe2/contrib/aten/gen_op.py b/caffe2/contrib/aten/gen_op.py index dad452b5216a3d..5189af6f2c5369 100755 --- a/caffe2/contrib/aten/gen_op.py +++ b/caffe2/contrib/aten/gen_op.py @@ -49,7 +49,7 @@ # use faster C loader if available from yaml import CSafeLoader as Loader except ImportError: - from yaml import SafeLoader as Loader # type: ignore[misc] + from yaml import SafeLoader as Loader # type: ignore[assignment, misc] def write(filename, s): diff --git a/caffe2/contrib/tensorboard/tensorboard.py b/caffe2/contrib/tensorboard/tensorboard.py index e086a74f879cd7..4b89964627d261 100644 --- a/caffe2/contrib/tensorboard/tensorboard.py +++ b/caffe2/contrib/tensorboard/tensorboard.py @@ -106,7 +106,7 @@ def graph_def_to_event(step, graph_def): wall_time=step, step=step, graph_def=graph_def.SerializeToString()) -@cli.command("tensorboard-graphs") +@cli.command("tensorboard-graphs") # type: ignore[arg-type, attr-defined] @click.option("--c2-netdef", type=click.Path(exists=True, dir_okay=False), multiple=True) @click.option("--tf-dir", type=click.Path(exists=True)) @@ -129,7 +129,7 @@ def parse_net_def(path): log.info("Wrote %s graphs to logdir %s", len(events), tf_dir) -@cli.command("tensorboard-events") +@cli.command("tensorboard-events") # type: ignore[arg-type, attr-defined] @click.option("--c2-dir", type=click.Path(exists=True, file_okay=False), help="Root directory of the Caffe2 run") @click.option("--tf-dir", type=click.Path(writable=True), @@ -209,4 +209,4 @@ def event(step, values): if __name__ == "__main__": - cli() + cli() # type: ignore[misc] diff --git a/test/test_futures.py b/test/test_futures.py index 8d0f429de08dd0..33814eda41eafb 100644 --- a/test/test_futures.py +++ b/test/test_futures.py @@ -21,7 +21,7 @@ def test_set_exception(self) -> None: error_msg = "Intentional Value Error" value_error = ValueError(error_msg) - f = Future[T]() + f = Future[T]() # type: ignore[valid-type] # Set exception f.set_exception(value_error) # Exception should throw on wait @@ -29,7 +29,7 @@ def test_set_exception(self) -> None: f.wait() # Exception should also throw on value - f = Future[T]() + f = Future[T]() # type: ignore[valid-type] f.set_exception(value_error) with self.assertRaisesRegex(ValueError, "Intentional"): f.value() @@ -37,7 +37,7 @@ def test_set_exception(self) -> None: def cb(fut): fut.value() - f = Future[T]() + f = Future[T]() # type: ignore[valid-type] f.set_exception(value_error) with self.assertRaisesRegex(RuntimeError, "Got the following error"): @@ -54,7 +54,7 @@ def wait_future(f): with self.assertRaisesRegex(ValueError, "Intentional"): f.wait() - f = Future[T]() + f = Future[T]() # type: ignore[valid-type] t = threading.Thread(target=wait_future, args=(f, )) t.start() f.set_exception(value_error) @@ -68,7 +68,7 @@ def then_future(f): with self.assertRaisesRegex(RuntimeError, "Got the following error"): fut.wait() - f = Future[T]() + f = Future[T]() # type: ignore[valid-type] t = threading.Thread(target=then_future, args=(f, )) t.start() f.set_exception(value_error) diff --git a/test/test_torch.py b/test/test_torch.py index 1c96c6d393ff45..7fbd358dc80917 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -29,7 +29,7 @@ from functools import partial from torch import multiprocessing as mp from torch.testing import make_tensor -from torch.testing._internal.common_utils import ( +from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] TEST_WITH_TORCHINDUCTOR, TestCase, TEST_WITH_ROCM, run_tests, IS_JETSON, IS_WINDOWS, IS_FILESYSTEM_UTF8_ENCODING, NO_MULTIPROCESSING_SPAWN, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, skipIfTorchInductor, load_tests, slowTest, @@ -8490,7 +8490,7 @@ def _spawn_method(self, method, arg): except RuntimeError: pass with mp.Pool(1) as pool: - out: list = pool.map(method, [arg]) + out = pool.map(method, [arg]) self.assertTrue(out[0]) def _test_multinomial_invalid_probs(probs): diff --git a/tools/code_coverage/oss_coverage.py b/tools/code_coverage/oss_coverage.py index f04fbd3b63b0f3..15b3104a13d9f3 100644 --- a/tools/code_coverage/oss_coverage.py +++ b/tools/code_coverage/oss_coverage.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 import time -from package.oss.cov_json import get_json_report -from package.oss.init import initialization -from package.tool.summarize_jsons import summarize_jsons -from package.util.setting import TestPlatform -from package.util.utils import print_time +from package.oss.cov_json import get_json_report # type: ignore[import] +from package.oss.init import initialization # type: ignore[import] +from package.tool.summarize_jsons import summarize_jsons # type: ignore[import] +from package.util.setting import TestPlatform # type: ignore[import] +from package.util.utils import print_time # type: ignore[import] def report_coverage() -> None: diff --git a/tools/code_coverage/package/tool/summarize_jsons.py b/tools/code_coverage/package/tool/summarize_jsons.py index 22dae3e9e24abc..7c5d8891ea83d3 100644 --- a/tools/code_coverage/package/tool/summarize_jsons.py +++ b/tools/code_coverage/package/tool/summarize_jsons.py @@ -45,7 +45,7 @@ def transform_file_name( return file_path[file_path.find(folder) :] # remove pytorch base folder path if platform == TestPlatform.OSS: - from package.oss.utils import get_pytorch_folder + from package.oss.utils import get_pytorch_folder # type: ignore[import] pytorch_foler = get_pytorch_folder() assert file_path.startswith(pytorch_foler) diff --git a/tools/code_coverage/package/util/utils.py b/tools/code_coverage/package/util/utils.py index dee9cbc1a1b6d5..e0b4befb578b9c 100644 --- a/tools/code_coverage/package/util/utils.py +++ b/tools/code_coverage/package/util/utils.py @@ -89,7 +89,9 @@ def get_raw_profiles_folder() -> str: def detect_compiler_type(platform: TestPlatform) -> CompilerType: if platform == TestPlatform.OSS: - from package.oss.utils import detect_compiler_type # type: ignore[misc] + from package.oss.utils import ( # type: ignore[assignment, import, misc] + detect_compiler_type, + ) cov_type = detect_compiler_type() # type: ignore[call-arg] else: @@ -100,7 +102,7 @@ def detect_compiler_type(platform: TestPlatform) -> CompilerType: cov_type = detect_compiler_type() check_compiler_type(cov_type) - return cov_type + return cov_type # type: ignore[no-any-return] def get_test_name_from_whole_path(path: str) -> str: diff --git a/tools/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py index 9269f39cda4cc3..02c9c39270b107 100644 --- a/tools/gen_vulkan_spv.py +++ b/tools/gen_vulkan_spv.py @@ -20,7 +20,7 @@ try: from yaml import CLoader as Loader except ImportError: - from yaml import Loader # type: ignore[misc] + from yaml import Loader # type: ignore[assignment, misc] H_NAME = "spv.h" CPP_NAME = "spv.cpp" diff --git a/tools/linter/adapters/workflow_consistency_linter.py b/tools/linter/adapters/workflow_consistency_linter.py index 0359a52f1055d4..b856880aa001d2 100644 --- a/tools/linter/adapters/workflow_consistency_linter.py +++ b/tools/linter/adapters/workflow_consistency_linter.py @@ -16,7 +16,7 @@ try: from yaml import CSafeLoader as Loader except ImportError: - from yaml import SafeLoader as Loader # type: ignore[misc] + from yaml import SafeLoader as Loader # type: ignore[assignment, misc] class LintSeverity(str, Enum): diff --git a/tools/lite_interpreter/gen_selected_mobile_ops_header.py b/tools/lite_interpreter/gen_selected_mobile_ops_header.py index b46c7b675d13aa..18e09ddecd130b 100644 --- a/tools/lite_interpreter/gen_selected_mobile_ops_header.py +++ b/tools/lite_interpreter/gen_selected_mobile_ops_header.py @@ -11,7 +11,7 @@ try: from yaml import CSafeLoader as Loader except ImportError: - from yaml import SafeLoader as Loader # type: ignore[misc] + from yaml import SafeLoader as Loader # type: ignore[assignment, misc] if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) { diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index e1824a5e3fc091..c03fd87f25b6aa 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -10,7 +10,7 @@ # use faster C loader if available from yaml import CSafeLoader as YamlLoader except ImportError: - from yaml import SafeLoader as YamlLoader # type: ignore[misc] + from yaml import SafeLoader as YamlLoader # type: ignore[assignment, misc] NATIVE_FUNCTIONS_PATH = "aten/src/ATen/native/native_functions.yaml" TAGS_PATH = "aten/src/ATen/native/tags.yaml" diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index f777aa088e0381..dcd2f623c6ed70 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -429,6 +429,7 @@ def _jit_is_script_object(obj: Any) -> _bool: ... def _last_executed_optimized_graph() -> Graph: ... def parse_type_comment(comment: str) -> Decl: ... def _get_upgraders_map_size() -> _int: ... +def _get_upgraders_entry_map() -> Dict[str, str]: ... def _dump_upgraders_map() -> Dict[str, str]: ... def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ... def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ... diff --git a/torch/__init__.py b/torch/__init__.py index e083cf2b924ef2..236e6db366c0ab 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -417,7 +417,7 @@ def sym_int(a): if isinstance(a, SymInt): return a elif isinstance(a, SymFloat): - return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type] + return math.floor(a) if a >= 0 else math.ceil(a) # type: ignore[arg-type, call-overload] return py_int(a) # type: ignore[operator] def sym_max(a, b): @@ -1320,7 +1320,7 @@ def manager_path(): # Some type signatures pulled in from _VariableFunctions here clash with # signatures already imported. For now these clashes are ignored; see # PR #43339 for details. - from torch._C._VariableFunctions import * # type: ignore[misc] # noqa: F403 + from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403 # Fixup segment_reduce visibility _segment_reduce = segment_reduce del segment_reduce diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 359566222df4e9..4c1d74d2c26e40 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -10,20 +10,20 @@ class _Union: @classmethod def create(cls, **kwargs): assert len(kwargs) == 1 - return cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) + return cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type] def __post_init__(self): - assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1 + assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1 # type: ignore[arg-type, misc] @property def value(self): - val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None) + val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type] assert val is not None return val @property def type(self): - val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None) + val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None) # type: ignore[arg-type] assert val_type is not None return val_type diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index d281b45d974056..8ee57694457371 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -9,7 +9,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import sympy @@ -749,7 +749,7 @@ def __init__(self): self.module = torch.nn.Module() @contextmanager - def save_graph_module(self) -> None: + def save_graph_module(self) -> Iterator[None]: saved = self.graph, self.module, self.serialized_name_to_node, self.serialized_name_to_meta self.graph = torch.fx.Graph() self.module = torch.nn.Module() @@ -773,7 +773,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: if vr := self.symbol_name_to_range.get(val.expr_str): symbolic_shapes._constrain_symbol_range( - self.shape_env, sym, vr.lower, vr.upper + self.shape_env, sym, vr.lower, vr.upper # type: ignore[arg-type] ) return self.shape_env.create_symintnode(sym, hint=val.hint) @@ -855,6 +855,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: output_node.meta["val"] = tuple( arg.meta["val"] for arg in output_node.args[0] ) + return output_node def deserialize_node(self, serialized_node: Node, target: Callable) -> None: if target.__module__ == "_operator": # TODO(zhxchen17) Follow up on this. @@ -1050,7 +1051,7 @@ def deserialize_multiple_outputs(self, serialized_node: Node, fx_node: torch.fx. self.serialized_name_to_node[fx_node.name] = fx_node def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: - ret = {} + ret: Dict[str, Any] = {} if stack_trace := metadata.get("stack_trace"): ret["stack_trace"] = stack_trace diff --git a/torch/_export/serde/upgrade.py b/torch/_export/serde/upgrade.py index 25aaf46e4b5a3a..3e8e69141cd936 100644 --- a/torch/_export/serde/upgrade.py +++ b/torch/_export/serde/upgrade.py @@ -33,7 +33,7 @@ def get_upgraders() -> Dict[str, Tuple[str, str]]: """Getting upgraders entry map and operator version map and merge them into one dict.""" upgraders = torch._C._get_upgraders_entry_map() op_version_map = torch._C._get_operator_version_map() - output = defaultdict(tuple) + output: Dict[str, Tuple[str, str]] = defaultdict(tuple) # type: ignore[arg-type] for opname, entry_list in op_version_map.items(): if not entry_list: raise RuntimeError(f"Op version map has an empty entry for opname {opname}") diff --git a/torch/_functorch/functional_call.py b/torch/_functorch/functional_call.py index bccc03e3ae98ed..53e30ee9acb601 100644 --- a/torch/_functorch/functional_call.py +++ b/torch/_functorch/functional_call.py @@ -1,5 +1,5 @@ from collections import Counter -from typing import Any, Dict, List, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -12,7 +12,7 @@ def functional_call( module: "torch.nn.Module", parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]], args: Union[Any, Tuple], - kwargs: Dict[str, Any] = None, + kwargs: Optional[Dict[str, Any]] = None, *, tie_weights: bool = True, strict: bool = False, diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 2387cd9df1677e..f419ac04b005b8 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1004,7 +1004,7 @@ def _linalg_svd_meta( A: Tensor, full_matrices: bool = False, compute_uv: bool = True, - driver: str = None, + driver: Optional[str] = None, ): checkIsMatrix(A, "linalg.svd") checkFloatingOrComplex(A, "linalg.svd") @@ -1147,7 +1147,7 @@ def linalg_solve_triangular_meta( upper: bool, left: bool = True, unitriangular: bool = False, - out: Tensor = None, + out: Optional[Tensor] = None, ) -> Tensor: if out is None: out = A.new_empty([0]) @@ -4695,8 +4695,8 @@ def upsample_nearest2d_backward( grad_output: Tensor, output_size: Sequence[Union[int, torch.types.SymInt]], input_size: Sequence[Union[int, torch.types.SymInt]], - scales_h: float = None, - scales_w: float = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, ): full_output_size = upsample_common_check( input_size, output_size, num_spatial_dims=2 diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 65a408086afd6b..7e8a37da76b06d 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -331,7 +331,7 @@ class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum): def _elementwise_meta( *args, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, - args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None, + args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None, ) -> FakeTensor: """ Meta function for elementwise operations that produce outputs in the same dtype diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 7cb3d50c87ffd9..cab005c13ba664 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -1,6 +1,6 @@ import functools from contextlib import nullcontext -from typing import Any, Callable, Dict, Sequence +from typing import Any, Callable, Dict, Optional, Sequence from warnings import warn import torch @@ -111,7 +111,7 @@ def __torch_function__( orig_func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Dict = None, + kwargs: Optional[Dict] = None, ): if kwargs is None: kwargs = {} @@ -161,7 +161,7 @@ def __torch_function__( orig_func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Dict = None, + kwargs: Optional[Dict] = None, ): if kwargs is None: kwargs = {} @@ -374,7 +374,7 @@ def __torch_function__( orig_func: Callable, types: Sequence, args: Sequence[Any] = (), - kwargs: Dict = None, + kwargs: Optional[Dict] = None, ): if kwargs is None: kwargs = {} diff --git a/torch/_prims/debug_prims.py b/torch/_prims/debug_prims.py index 2ddd23ddbea2aa..cf8d9caacb4cd1 100644 --- a/torch/_prims/debug_prims.py +++ b/torch/_prims/debug_prims.py @@ -27,7 +27,7 @@ def load_tensor_reader(loc): def register_debug_prims(): @custom_op("debugprims::load_tensor") - def load_tensor( + def load_tensor( # type: ignore[empty-body] name: str, size: Sequence[int], stride: Sequence[int], diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 041fb764c8b1c1..938465cac36318 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -10,7 +10,7 @@ import torch._prims_common as utils from torch.utils._pytree import tree_flatten, tree_unflatten -from typing import Callable, Sequence, Tuple, NamedTuple, overload +from typing import Callable, Sequence, Tuple, NamedTuple, Optional, overload import inspect from functools import wraps import warnings @@ -97,7 +97,7 @@ def __init__( self, *, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, - type_promoting_args: Sequence[str] = None, + type_promoting_args: Optional[Sequence[str]] = None, ): self.type_promoting_arg_names = type_promoting_args self.type_promotion_kind = type_promotion_kind diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index d36a28df60a289..611c6e3b9a625a 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -1832,7 +1832,7 @@ def clamp( @out_wrapper() def clamp_min( self: TensorLikeType, - min: TensorOrNumberLikeType = None, + min: Optional[TensorOrNumberLikeType] = None, ) -> TensorLikeType: return torch.clamp(self, min=min) # type: ignore[arg-type] @@ -1841,7 +1841,7 @@ def clamp_min( @out_wrapper() def clamp_max( self: TensorLikeType, - max: TensorOrNumberLikeType = None, + max: Optional[TensorOrNumberLikeType] = None, ) -> TensorLikeType: return torch.clamp(self, max=max) # type: ignore[arg-type] @@ -4654,7 +4654,7 @@ def logspace( ret = torch.linspace( start, end, - steps, + steps, # type: ignore[arg-type] dtype=torch.float64, layout=layout, device=device, diff --git a/torch/ao/nn/quantizable/modules/activation.py b/torch/ao/nn/quantizable/modules/activation.py index d94c18eda309ff..b7ba9dd8dc72c2 100644 --- a/torch/ao/nn/quantizable/modules/activation.py +++ b/torch/ao/nn/quantizable/modules/activation.py @@ -63,7 +63,7 @@ class MultiheadAttention(nn.MultiheadAttention): def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0., bias: bool = True, add_bias_kv: bool = False, add_zero_attn: bool = False, - kdim: int = None, vdim: int = None, batch_first: bool = False, + kdim: Optional[int] = None, vdim: Optional[int] = None, batch_first: bool = False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__(embed_dim, num_heads, dropout, diff --git a/torch/ao/ns/_numeric_suite_fx.py b/torch/ao/ns/_numeric_suite_fx.py index eed6415f83b3b5..b26dbadc006823 100644 --- a/torch/ao/ns/_numeric_suite_fx.py +++ b/torch/ao/ns/_numeric_suite_fx.py @@ -856,7 +856,7 @@ def prepare_n_shadows_model( create_n_transformed_and_logged_copies_of_subgraph( mt, subgraph_idx, match_name, nodes_in_this_subgraph, qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig, - custom_prepare_fn, custom_prepare_kwargs + custom_prepare_fn, custom_prepare_kwargs # type: ignore[arg-type] ) return mt diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index dac75e60bde876..fa328deb0f592a 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -443,7 +443,7 @@ def create_one_transformed_and_logged_copy_of_subgraph( example_inputs: Any, last_added_shadow_node_list: List[Optional[Node]], custom_prepare_fn: Optional[Callable] = None, - custom_prepare_kwargs: Dict[str, Any] = None, + custom_prepare_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ Given a subgraph in `mt` and a subgraph candidate idx, inserts the @@ -575,7 +575,7 @@ def create_n_transformed_and_logged_copies_of_subgraph( qconfig_mappings: List[QConfigMapping], list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]], custom_prepare_fn: Optional[Callable] = None, - custom_prepare_kwargs: Dict[str, Any] = None, + custom_prepare_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ Given a model `mt` and a subgraph_idx, creates the needed copies @@ -756,7 +756,7 @@ def _get_subgraph_containing_node(node, subgraphs_dedup): create_n_transformed_and_logged_copies_of_subgraph( model, cur_subgraph_idx, match_name, maybe_subgraph, [qconfig_mapping], [node_name_to_qconfig], - None, None + None, None # type: ignore[arg-type] ) # find the created shadow module and record it so we # can find it easily in step 2 diff --git a/torch/ao/ns/fx/weight_utils.py b/torch/ao/ns/fx/weight_utils.py index 870b183acc61a5..d375694b88b300 100644 --- a/torch/ao/ns/fx/weight_utils.py +++ b/torch/ao/ns/fx/weight_utils.py @@ -78,7 +78,7 @@ def get_lstm_mod_weights(mod: nn.Module) -> List[torch.Tensor]: res.append(param_value) return res else: - assert isinstance(mod, nnqd.LSTM), f"type {type(res)} not handled yet" + assert isinstance(mod, nnqd.LSTM), f"type {type(mod)} not handled yet" res = [] for weight_value in mod._all_weight_values: res.append(weight_value.param.__getstate__()[0][4][0].__getstate__()[0][0]) diff --git a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py index 88ccc8cfc41088..c336799c622562 100644 --- a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py +++ b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List +from typing import Any, Dict, List, Optional import torch from collections import defaultdict from torch import nn @@ -205,7 +205,7 @@ def register_layer(self, layer: nn.Module, aggregate_fn=None, reduce_fn=None, # or sparsify_hook() self.data_groups[name]['hook_state'] = "aggregate" # aggregate hook is attached - def get_mask(self, name: str = None, layer: nn.Module = None): + def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None): """ Returns mask associated to the layer. diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py index d26b2161dcedf9..31600118f662d1 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/evaluate_model_metrics.py @@ -3,7 +3,7 @@ from dlrm_s_pytorch import unpack_batch # type: ignore[import] import numpy as np # type: ignore[import] import sklearn # type: ignore[import] -from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model +from dlrm_utils import make_test_data_loader, dlrm_wrap, fetch_model # type: ignore[import] import pandas as pd # type: ignore[import] import argparse diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py index beba25ed3b8a9b..28f32bbabe1760 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -1,7 +1,7 @@ import torch from torch.nn import functional as F from functools import reduce -from typing import Tuple, Any, List +from typing import Any, List, Optional, Tuple from .base_data_sparsifier import BaseDataSparsifier @@ -31,9 +31,9 @@ class DataNormSparsifier(BaseDataSparsifier): arguments and could be overriden by the configuration provided in the `add_data` step. """ - def __init__(self, data_list: List[Tuple[str, Any]] = None, sparsity_level: float = 0.5, + def __init__(self, data_list: Optional[List[Tuple[str, Any]]] = None, sparsity_level: float = 0.5, sparse_block_shape: Tuple[int, int] = (1, 4), - zeros_per_block: int = None, norm: str = 'L1'): + zeros_per_block: Optional[int] = None, norm: str = 'L1'): if zeros_per_block is None: zeros_per_block = reduce((lambda x, y: x * y), sparse_block_shape) diff --git a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py index 1a2791c359b6e5..1e76cfc345ac5f 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.ao.pruning.sparsifier.utils import module_to_fqn, fqn_to_module -from typing import Dict, List +from typing import Dict, List, Optional SUPPORTED_MODULES = { nn.Embedding, @@ -28,7 +28,7 @@ def _fetch_all_embeddings(model): def post_training_sparse_quantize(model, data_sparsifier_class, sparsify_first=True, - select_embeddings: List[nn.Module] = None, + select_embeddings: Optional[List[nn.Module]] = None, **sparse_config): """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. diff --git a/torch/ao/quantization/experimental/quantizer.py b/torch/ao/quantization/experimental/quantizer.py index 1d8845cd2b6548..e7e6048fb00e08 100644 --- a/torch/ao/quantization/experimental/quantizer.py +++ b/torch/ao/quantization/experimental/quantizer.py @@ -44,7 +44,7 @@ def quantize(self, tensor2quantize: Tensor): from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT - result = TensorAPoT(self, tensor2quantize) + result = TensorAPoT(self, tensor2quantize) # type: ignore[assignment] return result @@ -83,7 +83,7 @@ def dequantize(self, apot_tensor) -> Tensor: def quant_dequant(self, tensor2quantize: Tensor) -> Tensor: levels_lst = list(self.quantization_levels) - result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst)) + result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst)) # type: ignore[call-arg] return result diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py index 77a7f3079906a0..cf287db8c52454 100644 --- a/torch/ao/quantization/fx/match_utils.py +++ b/torch/ao/quantization/fx/match_utils.py @@ -85,9 +85,9 @@ def _find_matches( modules: Dict[str, torch.nn.Module], patterns: Dict[Pattern, QuantizeHandler], root_node_getter_mapping: Dict[Pattern, Callable], - standalone_module_names: List[str] = None, - standalone_module_classes: List[Type] = None, - custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]: + standalone_module_names: Optional[List[str]] = None, + standalone_module_classes: Optional[List[Type]] = None, + custom_module_classes: Optional[List[Any]] = None) -> Dict[str, _MatchResult]: """ Matches the nodes in the input graph to quantization patterns, and outputs the information needed to quantize them in future steps. diff --git a/torch/ao/quantization/fx/quantize_handler.py b/torch/ao/quantization/fx/quantize_handler.py index 57e3c97411a506..e98bc334f0060b 100644 --- a/torch/ao/quantization/fx/quantize_handler.py +++ b/torch/ao/quantization/fx/quantize_handler.py @@ -18,7 +18,7 @@ ) from abc import ABC -from typing import Callable, Dict, List, Type +from typing import Callable, Dict, List, Type, Optional __all__ = [ "QuantizeHandler", @@ -52,7 +52,7 @@ def __init__( self, node_pattern: NodePattern, modules: Dict[str, torch.nn.Module], - root_node_getter: Callable = None, + root_node_getter: Optional[Callable] = None, is_custom_module=False, is_standalone_module=False): """ Records pattern information in __init__, which will be used @@ -113,7 +113,7 @@ def __init__( self, node_pattern: NodePattern, modules: Dict[str, torch.nn.Module], - root_node_getter: Callable = None): + root_node_getter: Optional[Callable] = None): super().__init__(node_pattern, modules, root_node_getter) if num_tensor_args_to_observation_type: assert self.num_tensor_args in num_tensor_args_to_observation_type, \ diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index 4693d2794406f8..dda784a8ed6ea1 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -458,8 +458,10 @@ def _get_new_edge_or_node(edge_or_node: EdgeOrNode): if isinstance(edge_or_node, Node): _node = edge_or_node return original_to_replacement_node.get(_node, _node) - elif isinstance(edge_or_node, Tuple[Node, Node]): - src, dest = edge_or_node + # TODO: It's really should be + # isinstance(edge_or_node, tuple) and len(edge_or_node) == 2 and all(isinstance(x, Node) for x in edge_or_node) + elif isinstance(edge_or_node, Tuple[Node, Node]): # type: ignore[arg-type] + src, dest = edge_or_node # type: ignore[misc] return ( original_to_replacement_node.get(src, src), original_to_replacement_node.get(dest, dest), diff --git a/torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py index 147c8cc983ad94..29a250ec7e257e 100644 --- a/torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/pt2e/quantizer/x86_inductor_quantizer.py @@ -58,13 +58,13 @@ def _supported_quantized_operators() -> Dict[str, List[OperatorPatternType]]: for conv_op, add_op, relu_op in conv_add_relu_options: if add_op is None: # Append Conv ReLU - supported_operators["conv2d"].append([conv_op, relu_op]) + supported_operators["conv2d"].append([conv_op, relu_op]) # type: ignore[list-item] elif relu_op is None: # Append Conv Add - supported_operators["conv2d"].append([conv_op, add_op]) + supported_operators["conv2d"].append([conv_op, add_op]) # type: ignore[list-item] else: # Append Conv Add ReLU - supported_operators["conv2d"].append([conv_op, add_op, relu_op]) + supported_operators["conv2d"].append([conv_op, add_op, relu_op]) # type: ignore[list-item] return copy.deepcopy(supported_operators) @@ -222,17 +222,17 @@ def _get_input_idx_for_binary_node( """ conv_gemm_node_idx = None extra_input_node_idx = None - if (binary_node.args[0].op == "call_function") and ( + if (binary_node.args[0].op == "call_function") and ( # type: ignore[union-attr] binary_node.args[0] == conv_gemm_node ): conv_gemm_node_idx = 0 extra_input_node_idx = 1 - elif (binary_node.args[1].op == "call_function") and ( + elif (binary_node.args[1].op == "call_function") and ( # type: ignore[union-attr] binary_node.args[1] == conv_gemm_node ): conv_gemm_node_idx = 1 extra_input_node_idx = 0 - extra_input_node = binary_node.args[extra_input_node_idx] + extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] assert isinstance(extra_input_node, Node) return conv_gemm_node_idx, extra_input_node_idx diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index 7b3de95f5f697b..7ae9a2724494a8 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -119,9 +119,9 @@ def _reference_dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, qu def reference_representation_rewrite(model: GraphModule) -> GraphModule: remove_tensor_overload_for_qdq_ops(model) for example_inputs, pattern, replacement in _EXAMPLE_INPUTS_PATTERN_AND_REPLACEMENTS: - pattern = get_aten_graph_module(pattern, example_inputs) - remove_tensor_overload_for_qdq_ops(pattern) - replacement = get_aten_graph_module(replacement, example_inputs) - remove_tensor_overload_for_qdq_ops(replacement) + pattern = get_aten_graph_module(pattern, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] + replacement = get_aten_graph_module(replacement, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] matches = replace_pattern(model, pattern, replacement) return model diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index eecb29a4237226..f32fc361c28f64 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -455,7 +455,7 @@ def allocate(size, tensor_key, version, during_trace=True): device = to_device(tensor_key.device) addr = tensor_key.storage.ptr - seg = snapshot['segments'][device] + seg = snapshot['segments'][device] # type: ignore[index] if seg['address'] is None or seg['address'] > addr: seg['address'] = addr seg['total_size'] = max(seg['total_size'], addr + size) # record max addr for now, we will make it the size later @@ -465,12 +465,12 @@ def allocate(size, tensor_key, version, during_trace=True): stack = [{'filename': 'none', 'line': 0, 'name': p.name} for p in stack] r = {'action': 'alloc', 'addr': addr, 'size': size, 'stream': 0, 'frames': stack, 'category': category} if during_trace: - snapshot['device_traces'][device].append(r) + snapshot['device_traces'][device].append(r) # type: ignore[index] return r def free(alloc, device): for e in ('free_requested', 'free_completed'): - snapshot['device_traces'][device].append({'action': e, + snapshot['device_traces'][device].append({'action': e, # type: ignore[index] 'addr': alloc['addr'], 'size': alloc['size'], 'stream': 0, @@ -499,7 +499,7 @@ def free(alloc, device): blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames']) for (tensor_key, version), event in kv_to_elem.items()] for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]): - seg = snapshot['segments'][device] + seg = snapshot['segments'][device] # type: ignore[index] last_addr = seg['address'] for _, addr, size, frames in blocks: if last_addr < addr: @@ -510,8 +510,8 @@ def free(alloc, device): if last_addr < seg['total_size']: seg['blocks'].append({'size': seg['total_size'] - last_addr, 'state': 'inactive'}) - snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] - for seg in snapshot['segments']: + snapshot['segments'] = [seg for seg in snapshot['segments'] if seg['blocks']] # type: ignore[attr-defined] + for seg in snapshot['segments']: # type: ignore[attr-defined, name-defined, no-redef] seg['total_size'] -= seg['address'] if not seg['blocks']: seg['blocks'].append({'size': seg['total_size'], 'state': 'inactive'}) diff --git a/torch/cuda/random.py b/torch/cuda/random.py index 5d6e5e6ca37ce6..d55f147b244014 100644 --- a/torch/cuda/random.py +++ b/torch/cuda/random.py @@ -1,5 +1,5 @@ import torch -from typing import cast, Iterable, List, Union +from typing import Iterable, List, Union from . import _lazy_init, _lazy_call, device_count, current_device from .. import Tensor @@ -56,7 +56,7 @@ def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'cu device = torch.device('cuda', device) def cb(): - idx = cast(torch.device, device).index + idx = device.index if idx is None: idx = current_device() default_generator = torch.cuda.default_generators[idx] diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 67b1af3f134808..acab7081c91c25 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -732,7 +732,7 @@ def _init_from_local_tensor( local_tensor: torch.Tensor, sharding_spec: shard_spec.ShardingSpec, *global_size: Sequence[int], - process_group: dist.ProcessGroup = None, + process_group: Optional[dist.ProcessGroup] = None, init_rrefs=False, ) -> "ShardedTensor": """ diff --git a/torch/distributed/_spmd/api.py b/torch/distributed/_spmd/api.py index 04fdbaeb77bc36..4a7e7b99acfdea 100644 --- a/torch/distributed/_spmd/api.py +++ b/torch/distributed/_spmd/api.py @@ -389,7 +389,7 @@ def swap(fqn_prefix: str, module: torch.nn.Module) -> None: # can trace operations applied to them. def stateless_func(func, params, buffers, named_states, args, kwargs): with stateless._reparametrize_module( - cast(nn.Module, mod), {**params, **buffers} + mod, {**params, **buffers} ), _rematerialize_optimizer( opt, named_states, params ) if opt else nullcontext(): diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index 97efb0c9837727..01c99ebc3e9538 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -210,7 +210,7 @@ def full( def zeros( *size, requires_grad: bool = False, - dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, layout: torch.layout = torch.strided, device_mesh: Optional[DeviceMesh] = None, placements: Optional[Sequence[Placement]] = None, diff --git a/torch/distributed/_tensor/examples/checkpoint_example.py b/torch/distributed/_tensor/examples/checkpoint_example.py index a6672f21f9fde5..fb5141b17a027c 100644 --- a/torch/distributed/_tensor/examples/checkpoint_example.py +++ b/torch/distributed/_tensor/examples/checkpoint_example.py @@ -133,7 +133,7 @@ def output_fn(outputs, device_mesh): ) -def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module: +def checkpoint(model: nn.Module, mesh: DeviceMesh) -> nn.Module: # type: ignore[empty-body] """ checkpoint save/load models with DTensor parameters """ diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py index 99ef7a8a9e14f8..724d919452efc5 100644 --- a/torch/distributed/_tensor/ops/utils.py +++ b/torch/distributed/_tensor/ops/utils.py @@ -40,7 +40,7 @@ def wrapper(impl): def as_list( x: Union[List[object], object] # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. -) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: +) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type] # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args, # which is an object but treated as a list by the tracer. Therefore, keep # `immutable_list` intact here as well. diff --git a/torch/distributed/algorithms/_comm_hooks/default_hooks.py b/torch/distributed/algorithms/_comm_hooks/default_hooks.py index 52acea85e9d662..ccdbf66049bdec 100644 --- a/torch/distributed/algorithms/_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/_comm_hooks/default_hooks.py @@ -1,6 +1,7 @@ import functools import torch import torch.distributed as dist +from typing import Optional class DefaultState: @@ -127,7 +128,7 @@ def _low_precision_hook(prec: torch.dtype, state: LowPrecisionState, grad: torch allreduce_hook(state, grad) _decompress(state, grad) -def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None): +def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None): r""" This FSDP communication hook implements a simple gradient compression approach that casts ``grad`` to half-precision floating-point format (``torch.float16``). @@ -144,7 +145,7 @@ def fp16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: tor fp16_hook = functools.partial(_low_precision_hook, torch.float16) return fp16_hook(state, grad, output) -def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: torch.Tensor = None): +def bf16_compress_hook(state: LowPrecisionState, grad: torch.Tensor, output: Optional[torch.Tensor] = None): r""" This FSDP communication hook implements a simple gradient compression approach that casts ``grad`` to half-precision floating-point format (``torch.float16``). diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index 7ee063cad68e8c..ec370367bc5e2d 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -20,7 +20,7 @@ def load_state_dict( process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, no_dist: bool = False, - planner: LoadPlanner = None, + planner: Optional[LoadPlanner] = None, ) -> None: """ Loads a distributed ``state_dict`` in SPMD style. diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 1cf6e2d064e110..a99cd129aeb637 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -22,7 +22,7 @@ def save_state_dict( process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, no_dist: bool = False, - planner: SavePlanner = None, + planner: Optional[SavePlanner] = None, ) -> Metadata: """ Saves a distributed model in SPMD style. diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index dc34548e1d28e2..ed6c93078299a4 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -58,7 +58,7 @@ def broadcast( raise AssertionError("Data or Function is expected to be None if not successful") payload: Optional[T] = None - exception : Exception = None + exception : Optional[Exception] = None # if no pg is passed then execute if rank is 0 if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): # determine if it is an executable function or data payload only @@ -119,7 +119,7 @@ def all_gather( >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) """ payload: Optional[T] = None - exception : Exception = None + exception : Optional[Exception] = None success = True # determine if it is an executable function or data payload only if callable(data_or_fn): @@ -143,7 +143,7 @@ def all_gather( total_list = [None] * dist.get_world_size(pg) all_gather_object_enforce_type(pg, total_list, sync_obj) # Each rank will throw RuntimeError in case of failure on any rank. - stage_name: Optional[str] = cast(SyncPayload[T], total_list[0]).stage_name + stage_name = cast(SyncPayload[T], total_list[0]).stage_name exception_list: List[Tuple[int, Exception]] = [] ret_list: List[T] = [] error_msg: str = "" @@ -160,7 +160,7 @@ def all_gather( ret_list.append(sp.payload) if len(exception_list) > 0: - raise RuntimeError( + raise RuntimeError( # type: ignore[misc] error_msg, exception_list) from exception_list[0] return ret_list else: @@ -168,7 +168,7 @@ def all_gather( raise RuntimeError( f"all_gather failed with exception {sync_obj.exception}", ) from sync_obj.exception - return [sync_obj.payload] + return [sync_obj.payload] # type: ignore[list-item] # Note: use Any for typing for now so users can pass in diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 0a9f297de4fbcd..2ad55683c33da9 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -68,7 +68,7 @@ def add_value(self, metric_name: str, metric_value: int): # pyre-fixme[9]: group has type `str`; used as `None`. -def configure(handler: MetricHandler, group: str = None): +def configure(handler: MetricHandler, group: Optional[str] = None): if group is None: global _default_metrics_handler # pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index aed35f916ad072..deca745fbaad4f 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -158,7 +158,7 @@ def __init__( file_path: str, max_interval: float = 10, daemon: bool = True, - log_event: Callable[[str, Optional[FileTimerRequest]], None] = None + log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None ) -> None: self._file_path = file_path self._max_interval = max_interval diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index a5f9511223c4a3..5b1619890e3690 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -476,7 +476,7 @@ def _flatten_optim_state_dict( len(fqns) == 1 ), f"use_orig_params is True but there are multiple FQNs, {fqns}." if optim is not None: # NamedOptimizer or KeyedOptimizer case. - state = optim.state.get(param, None) + state = optim.state.get(param, None) # type: ignore[call-overload] if state is not None: flat_osd_state[key] = copy.deepcopy(state) else: diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 53dc5d18f2e76a..93a46acbe4f54d 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -1433,7 +1433,7 @@ def _register_post_backward_reshard_only_hooks( hook_handle = register_multi_grad_hook( inp_tensors, functools.partial(_post_backward_reshard, state, handle) ) - handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined] + handle.flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment] @no_type_check @@ -1468,11 +1468,11 @@ def _wait_for_computation_stream( For example, this should be called in the FSDP root's pre-forward to respect optimizer step computation. """ - unshard_stream.wait_stream(computation_stream) + unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] # Having the pre-all-gather stream wait for the current stream even if we # do not leverage the pre-all-gather stream is tolerable since this only # runs once per iteration - pre_unshard_stream.wait_stream(computation_stream) + pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] def _reset_flat_param_grad_info_if_needed( diff --git a/torch/distributed/fsdp/_trace_utils.py b/torch/distributed/fsdp/_trace_utils.py index cb2ca8ad44a307..42d569f1ecb774 100644 --- a/torch/distributed/fsdp/_trace_utils.py +++ b/torch/distributed/fsdp/_trace_utils.py @@ -167,7 +167,7 @@ def _patched_create_proxy( kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, - proxy_factory_fn: Callable[[torch.fx.Node], torch.fx.Proxy] = None, + proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None, ) -> torch.fx.Proxy: """ Overrides ``create_proxy`` to save execution information to diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 85db607e742716..718b723d1fb52d 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -96,7 +96,7 @@ def _auto_wrap( ) recursive_wrap_kwargs["auto_wrap_policy"] = policy _warn_on_overridden_mixed_precision(overridden_module_classes) - _recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs) + _recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs) # type: ignore[arg-type] def _check_nested_wrapping(root_module: nn.Module): diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index d55b8be41b7fd0..0f6c1fed68b0ed 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -129,8 +129,8 @@ def __init__( self, remote_device: str, module_cls: Type[nn.Module], - args: Tuple = None, - kwargs: Dict[str, Any] = None, + args: Optional[Tuple] = None, + kwargs: Optional[Dict[str, Any]] = None, _module_interface_cls: Any = None, ): """ @@ -358,7 +358,7 @@ def half(self: T) -> T: # type: ignore[return] def bfloat16(self: T) -> T: # type: ignore[return] _raise_not_supported(self.bfloat16.__name__) - def to(self, *args, **kwargs) -> T: # type: ignore[return] + def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var] _raise_not_supported(self.to.__name__) def register_backward_hook( # type: ignore[return] @@ -377,7 +377,7 @@ def register_forward_pre_hook( # type: ignore[return] ) -> RemovableHandle: _raise_not_supported(self.register_forward_pre_hook.__name__) - def register_forward_hook( # type: ignore[return] + def register_forward_hook( # type: ignore[return, override] self, hook: Union[ Callable[[T, Tuple[Any, ...], Any], Optional[Any]], @@ -685,8 +685,8 @@ def __init__( self, remote_device: str, module_cls: Type[nn.Module], - args: Tuple = None, - kwargs: Dict[str, Any] = None, + args: Optional[Tuple] = None, + kwargs: Optional[Dict[str, Any]] = None, ): super().__init__(remote_device, module_cls, args, kwargs) diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index 575e54a92607af..70072b81fd680d 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -2,7 +2,7 @@ import warnings from copy import deepcopy -from typing import Any, Collection, Dict, List, Mapping, Union +from typing import Any, Collection, Dict, List, Mapping, Optional, Union import torch import torch.nn as nn @@ -63,8 +63,8 @@ def __init__( self, named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]], optimizer_class: optim.Optimizer, - param_groups: Collection[Mapping[str, Any]] = None, - module: nn.Module = None, + param_groups: Optional[Collection[Mapping[str, Any]]] = None, + module: Optional[nn.Module] = None, *args, **kwargs, ) -> None: @@ -154,7 +154,7 @@ def step(self, closure: Any = None) -> None: self._optimizer.step(closure=closure) @property - def state(self) -> Mapping[torch.Tensor, Any]: + def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override] return self._optimizer.state def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: diff --git a/torch/distributed/pipeline/sync/utils.py b/torch/distributed/pipeline/sync/utils.py index 37cac7e6e9c2a5..6f18a6e61abbed 100644 --- a/torch/distributed/pipeline/sync/utils.py +++ b/torch/distributed/pipeline/sync/utils.py @@ -1,12 +1,12 @@ from torch import nn -from typing import List +from typing import List, Optional __all__ = ["partition_model"] def partition_model( module: nn.Sequential, balance: List[int], - devices: List[int] = None): + devices: Optional[List[int]] = None): """ Given an :class:`nn.Sequential ` module, partitions the model across multiple GPU devices according the provided ``balance`` diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index be965f058b4dab..960d5935304dc4 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -49,7 +49,7 @@ def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: Tuple[Any, ...]) -> N def input_reshard_backward_hook(_: torch.nn.Module, _i: Tuple[Any, ...], _o: Any) -> Any: nonlocal cx - cx.__exit__() # type: ignore[name-defined] + cx.__exit__() # type: ignore[name-defined, union-attr] if input_reshard_dim is None: return module diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index 46b0349f2043be..0277d22997197c 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -1,7 +1,7 @@ import math import warnings from numbers import Number -from typing import Union +from typing import Optional, Union import torch from torch import nan @@ -72,9 +72,9 @@ class Wishart(ExponentialFamily): def __init__(self, df: Union[torch.Tensor, Number], - covariance_matrix: torch.Tensor = None, - precision_matrix: torch.Tensor = None, - scale_tril: torch.Tensor = None, + covariance_matrix: Optional[torch.Tensor] = None, + precision_matrix: Optional[torch.Tensor] = None, + scale_tril: Optional[torch.Tensor] = None, validate_args=None): assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \ "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 54612b7c64bb15..b5010c1e509fab 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -23,7 +23,7 @@ def __init__( root: torch.nn.Module, graph: torch.fx.Graph, const_subgraph: Optional[torch.fx.Graph] = None, - fx_const_folded_attrs_name: str = None, + fx_const_folded_attrs_name: Optional[str] = None, device_for_folded_attrs: str = "cuda", ): super().__init__(root, graph) diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py index 4718c504e00620..be19e7b93ac8b8 100644 --- a/torch/fx/experimental/meta_tracer.py +++ b/torch/fx/experimental/meta_tracer.py @@ -259,7 +259,7 @@ def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], - meta_args : Dict[str, torch.Tensor] = None, + meta_args : Optional[Dict[str, torch.Tensor]] = None, concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule: tracer = MetaTracer() graph = tracer.trace(root, meta_args, concrete_args) diff --git a/torch/fx/experimental/unification/__init__.py b/torch/fx/experimental/unification/__init__.py index 4041a20fc51def..31446d0e61253d 100644 --- a/torch/fx/experimental/unification/__init__.py +++ b/torch/fx/experimental/unification/__init__.py @@ -1,4 +1,4 @@ -# type: ignore[attr-defined] +# mypy: disable-error-code=attr-defined from .core import unify, reify # noqa: F403 from .more import unifiable # noqa: F403 from .variable import var, isvar, vars, variables, Var # noqa: F403 diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 053ce93780d995..7693f528af56e1 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -15,7 +15,7 @@ logger.setLevel(logging.WARNING) class Partition: - def __init__(self, id: int = None, nodes: Iterable[Node] = None): + def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None): self.id = id self.nodes: Set[Node] = set(nodes) if nodes is not None else set() diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index a35d9af09ae6ad..37c31fdff19b6c 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -1,6 +1,6 @@ from functools import wraps from inspect import unwrap -from typing import Callable, List +from typing import Callable, List, Optional import logging logger = logging.getLogger(__name__) @@ -76,7 +76,7 @@ def wrapped_fn(gm): -def loop_pass(base_pass: Callable, n_iter: int = None, predicate: Callable = None): +def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): """ Convenience wrapper for passes which need to be applied multiple times. diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 2181c171e4a89c..48db6ff99e5d89 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -202,7 +202,7 @@ def replace_pattern_with_filters( gm: GraphModule, pattern: Union[Callable, GraphModule], replacement: Union[Callable, GraphModule], - match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined] + match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined] ignore_literals: bool = False, ) -> List[ReplacedPatterns]: """ @@ -222,7 +222,7 @@ def _replace_pattern( gm: GraphModule, pattern: Union[Callable, GraphModule], replacement: Union[Callable, GraphModule], - match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined] + match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, # type: ignore[name-defined] ignore_literals: bool = False, ) -> List[ReplacedPatterns]: diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 9839330b260ac0..21434d95c77ae9 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1302,7 +1302,7 @@ def amin( @_apply_docstring_templates def argmax( input: Union[Tensor, MaskedTensor], - dim: int = None, + dim: Optional[int] = None, *, keepdim: Optional[bool] = False, dtype: Optional[DType] = None, @@ -1328,7 +1328,7 @@ def argmax( @_apply_docstring_templates def argmin( input: Union[Tensor, MaskedTensor], - dim: int = None, + dim: Optional[int] = None, *, keepdim: Optional[bool] = False, dtype: Optional[DType] = None, diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 37c0731fcf0e67..03ce028b7256cf 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -103,7 +103,7 @@ def __init__(self, *args): for idx, module in enumerate(args): self.add_module(str(idx), module) - def _get_item_by_idx(self, iterator, idx) -> T: + def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var] """Get the idx-th item of the iterator""" size = len(self) idx = operator.index(idx) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index bace244553e008..bb5931b76b442e 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -51,7 +51,7 @@ class _ConvNd(Module): 'out_channels', 'kernel_size'] __annotations__ = {'bias': Optional[torch.Tensor]} - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body] ... in_channels: int diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index 72311e180ad67a..6d934905e77baa 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -1,7 +1,7 @@ import contextlib import warnings from collections import defaultdict -from typing import Any, Dict, Iterator, Set, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union import torch from torch import Tensor @@ -148,7 +148,7 @@ def functional_call( module: "torch.nn.Module", parameters_and_buffers: Dict[str, Tensor], args: Union[Any, Tuple], - kwargs: Dict[str, Any] = None, + kwargs: Optional[Dict[str, Any]] = None, *, tie_weights: bool = True, strict: bool = False, @@ -233,7 +233,7 @@ def _functional_call( module: "torch.nn.Module", parameters_and_buffers: Dict[str, Tensor], args: Union[Any, Tuple], - kwargs: Dict[str, Any] = None, + kwargs: Optional[Dict[str, Any]] = None, *, tie_weights: bool = True, strict: bool = False, diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index e6be0393d4a0e0..aae7eb6f8e3804 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -340,7 +340,7 @@ def serialize( ) -> None: import onnx - if not isinstance(export_output.model_proto, onnx.ModelProto): + if not isinstance(export_output.model_proto, onnx.ModelProto): # type: ignore[attr-defined] raise ValueError("export_output.ModelProto is not an onnx.ModelProto") destination.write(export_output.model_proto.SerializeToString()) @@ -348,7 +348,7 @@ def serialize( class ExportOutput: """An in-memory representation of a PyTorch model that has been exported to ONNX.""" - _model_proto: Final[onnx.ModelProto] + _model_proto: Final[onnx.ModelProto] # type: ignore[name-defined] _input_adapter: Final[io_adapter.InputAdapter] _output_adapter: Final[io_adapter.OutputAdapter] _diagnostic_context: Final[infra.DiagnosticContext] @@ -356,7 +356,7 @@ class ExportOutput: @_beartype.beartype def __init__( self, - model_proto: onnx.ModelProto, + model_proto: onnx.ModelProto, # type: ignore[name-defined] input_adapter: io_adapter.InputAdapter, output_adapter: io_adapter.OutputAdapter, diagnostic_context: infra.DiagnosticContext, @@ -367,7 +367,7 @@ def __init__( self._diagnostic_context = diagnostic_context @property - def model_proto(self) -> onnx.ModelProto: + def model_proto(self) -> onnx.ModelProto: # type: ignore[name-defined] """The exported ONNX model as an ``onnx.ModelProto``.""" return self._model_proto diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index da084c7d46e9f6..1b4a2d250eadd7 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -193,7 +193,7 @@ def generate_fx( wrapped_model, *model_args, tracing_mode=fx_mode, - fake_mode=fake_mode, + fake_mode=fake_mode, # type: ignore[arg-type] **model_kwargs, ) del graph_guard # Unused diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py index 48055c7431705b..cb51daa294105e 100644 --- a/torch/onnx/_internal/fx/serialization.py +++ b/torch/onnx/_internal/fx/serialization.py @@ -15,7 +15,7 @@ @_beartype.beartype def _create_tensor_proto_with_external_data( tensor: torch.Tensor, name: str, location: str, basepath: str -) -> onnx.TensorProto: +) -> onnx.TensorProto: # type: ignore[name-defined] """Create a TensorProto with external data from a PyTorch tensor. The external data is saved to os.path.join(basepath, location). @@ -38,13 +38,13 @@ def _create_tensor_proto_with_external_data( # FIXME: Avoid importing onnx into torch.onnx. import onnx - tensor_proto = onnx.TensorProto() + tensor_proto = onnx.TensorProto() # type: ignore[attr-defined] tensor_proto.name = name tensor_proto.data_type = jit_type_utils.JitScalarType.from_dtype( tensor.dtype ).onnx_type() tensor_proto.dims.extend(tensor.shape) - tensor_proto.data_location = onnx.TensorProto.EXTERNAL + tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined] # Settings for saving one tensor per file. # Offset is zero because there is no other tensor in the same file. @@ -86,7 +86,7 @@ def save_model_with_external_data( model_location: str, initializer_location: str, torch_load_paths: Tuple[Union[str, io.BytesIO], ...], - onnx_model: onnx.ModelProto, + onnx_model: onnx.ModelProto, # type: ignore[name-defined] rename_initializer: bool = False, ) -> None: """Load PyTorch tensors from files and add to "onnx_model" as external initializers. @@ -121,7 +121,7 @@ def save_model_with_external_data( # FIXME: Avoid importing onnx into torch.onnx. import onnx - onnx_model_with_initializers = onnx.ModelProto() + onnx_model_with_initializers = onnx.ModelProto() # type: ignore[attr-defined] onnx_model_with_initializers.CopyFrom(onnx_model) onnx_input_names = [input.name for input in onnx_model.graph.input] @@ -159,4 +159,4 @@ def save_model_with_external_data( onnx_model_with_initializers.graph.initializer.append(tensor_proto) # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx". - onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) + onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location)) # type: ignore[attr-defined] diff --git a/torch/onnx/_internal/onnx_proto_utils.py b/torch/onnx/_internal/onnx_proto_utils.py index 92cc91ee56bf72..0b1e476a08e201 100644 --- a/torch/onnx/_internal/onnx_proto_utils.py +++ b/torch/onnx/_internal/onnx_proto_utils.py @@ -62,7 +62,7 @@ def export_as_test_case( shutil.rmtree(data_set_dir) os.makedirs(data_set_dir) - proto = onnx.load_model_from_string(model_bytes) + proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined] for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)): export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb")) @@ -112,12 +112,12 @@ def load_test_case(dir: str) -> Tuple[bytes, Any, Any]: inputs = {} input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb")) for input_file in input_files: - tensor = onnx.load_tensor(input_file) + tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined] inputs[tensor.name] = numpy_helper.to_array(tensor) outputs = {} output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb")) for output_file in output_files: - tensor = onnx.load_tensor(output_file) + tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined] outputs[tensor.name] = numpy_helper.to_array(tensor) return model_bytes, inputs, outputs @@ -227,7 +227,7 @@ def _add_onnxscript_fn( # size > 2GB, and if it for some reason did not, the model would fail on # serialization anyway in terms of the protobuf limitation. So we don't # need to worry about > 2GB model getting here. - model_proto = onnx.load_model_from_string(model_bytes) + model_proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined] # Iterate graph nodes to insert only the included custom # function_proto into model_proto diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index 7169dd298dd21c..abfa4677eb21cb 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -206,14 +206,14 @@ def _ort_session( def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]): try: import onnx - from onnx import reference as onnx_reference + from onnx import reference as onnx_reference # type: ignore[attr-defined] except ImportError: raise ImportError("onnx >= 1.13 is required for reference evaluator.") proto = ( - onnx.load(model) + onnx.load(model) # type: ignore[attr-defined] if isinstance(model, str) - else onnx.load_model_from_string(model.getvalue()) + else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined] ) onnx_session = onnx_reference.ReferenceEvaluator(proto) return onnx_session diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index f8c40a3aee5d9a..34d27bdaca6058 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -6,7 +6,7 @@ import functools import math -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Tuple, Optional from torch import Tensor import torch.utils.hooks as hooks @@ -208,7 +208,7 @@ def __init__(self, params, defaults): "an iterable of Tensors or dicts, but got " + torch.typename(params)) - self.state = defaultdict(dict) + self.state: Dict[int, Any] = defaultdict(dict) self.param_groups = [] param_groups = list(params) @@ -340,8 +340,8 @@ def _patch_step_function(self): self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__) hooked = getattr(self.__class__.step, "hooked", None) if not hooked: - self.__class__.step = self.profile_hook_step(self.__class__.step) - self.__class__.step.hooked = True + self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[method-assign] + self.__class__.step.hooked = True # type: ignore[attr-defined] def register_step_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: r"""Register an optimizer step pre hook which will be called before @@ -418,14 +418,15 @@ def pack_group(group): } @staticmethod - def _process_value_according_to_param_policy(param: Tensor, value: Tensor, param_id: int = None, - param_groups: List[Dict[Any, Any]] = None, key=None) -> Tensor: + def _process_value_according_to_param_policy(param: Tensor, value: Tensor, param_id: Optional[int] = None, + param_groups: Optional[List[Dict[Any, Any]]] = None, key=None) -> Tensor: # Floating-point types are a bit special here. They are the only ones # that are assumed to always match the type of params. # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 # UNLESS fused or capturable, see note [special device hosting for step] fused = False capturable = False + assert param_groups is not None for pg in param_groups: if param_id in pg["params"]: fused = pg["fused"] if "fused" in pg else False @@ -477,14 +478,14 @@ def cast(param, value, param_id=None, param_groups=None, key=None): elif isinstance(value, dict): return {k: cast(param, v, param_id=param_id, param_groups=param_groups, key=k) for k, v in value.items()} elif isinstance(value, container_abcs.Iterable): - return type(value)(cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) + return type(value)(cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] else: return value # Copy state assigned to params (and cast tensors to appropriate types). # State that is not assigned to params is copied as is (needed for # backward compatibility). - state = defaultdict(dict) + state: Dict[Any, Dict[Any, Any]] = defaultdict(dict) for k, v in state_dict['state'].items(): if k in id_map: param = id_map[k] @@ -521,7 +522,7 @@ def zero_grad(self, set_to_none: bool = True): if not hasattr(self, "_zero_grad_profile_name"): self._patch_step_function() if foreach: - per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + per_device_and_dtype_grads: Dict[Any, Dict[Any, List[Any]]] = defaultdict(lambda: defaultdict(list)) with torch.autograd.profiler.record_function(self._zero_grad_profile_name): for group in self.param_groups: for p in group['params']: diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py index 196af20c934dae..ae95faf0d2bae7 100644 --- a/torch/profiler/_pattern_matcher.py +++ b/torch/profiler/_pattern_matcher.py @@ -603,7 +603,7 @@ def input_dtypes(event: _ProfilerEvent): def report_all_anti_patterns(prof, should_benchmark: bool = False, print_enable: bool = True, - json_report_dir: str = None): + json_report_dir: Optional[str] = None): report_dict: Dict = {} anti_patterns = [ ExtraCUDACopyPattern(prof, should_benchmark), diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 6a8ff17ee4fb04..608bd9b23cb889 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -242,7 +242,7 @@ def _memory_profile(self) -> MemoryProfile: assert self.profiler is not None and self.profiler.kineto_results is not None return MemoryProfile(self.profiler.kineto_results) - def export_memory_timeline(self, path: str, device: str = None) -> None: + def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: """Extract the memory information from the memory profile collected tree for a given device, and export a timeline plot consisting of [times, [sizes by category]], where times are timestamps and sizes diff --git a/torch/serialization.py b/torch/serialization.py index 611553b9810f06..ca362851687147 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -861,7 +861,7 @@ def load( pickle_module: Any = None, *, weights_only: bool = False, - mmap: bool = None, + mmap: Optional[bool] = None, **pickle_load_args: Any ) -> Any: # Reference: https://github.com/pytorch/pytorch/issues/54354 diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index 3886497a1d5d30..1ddfff96228927 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -422,9 +422,9 @@ def kaiser( def hamming(M: int, *, sym: bool = True, - dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, layout: torch.layout = torch.strided, - device: torch.device = None, + device: Optional[torch.device] = None, requires_grad: bool = False) -> Tensor: return general_hamming(M, sym=sym, dtype=dtype, layout=layout, device=device, requires_grad=requires_grad) @@ -469,9 +469,9 @@ def hamming(M: int, def hann(M: int, *, sym: bool = True, - dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, layout: torch.layout = torch.strided, - device: torch.device = None, + device: Optional[torch.device] = None, requires_grad: bool = False) -> Tensor: return general_hamming(M, alpha=0.5, @@ -521,9 +521,9 @@ def hann(M: int, def blackman(M: int, *, sym: bool = True, - dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, layout: torch.layout = torch.strided, - device: torch.device = None, + device: Optional[torch.device] = None, requires_grad: bool = False) -> Tensor: if dtype is None: dtype = torch.get_default_dtype() @@ -575,9 +575,9 @@ def blackman(M: int, def bartlett(M: int, *, sym: bool = True, - dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, layout: torch.layout = torch.strided, - device: torch.device = None, + device: Optional[torch.device] = None, requires_grad: bool = False) -> Tensor: if dtype is None: dtype = torch.get_default_dtype() @@ -644,9 +644,9 @@ def bartlett(M: int, def general_cosine(M, *, a: Iterable, sym: bool = True, - dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, layout: torch.layout = torch.strided, - device: torch.device = None, + device: Optional[torch.device] = None, requires_grad: bool = False) -> Tensor: if dtype is None: dtype = torch.get_default_dtype() @@ -721,9 +721,9 @@ def general_hamming(M, *, alpha: float = 0.54, sym: bool = True, - dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, layout: torch.layout = torch.strided, - device: torch.device = None, + device: Optional[torch.device] = None, requires_grad: bool = False) -> Tensor: return general_cosine(M, a=[alpha, 1. - alpha], diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index e36df73f278bff..0e4c217a50aafb 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -199,7 +199,7 @@ def __init__( self.compressed_tensor = compressed_tensor self.transposed = transposed - def __repr__(self) -> str: + def __repr__(self) -> str: # type: ignore[override] """Return string representation of SparseSemiStructuredTensor Returns: diff --git a/torch/storage.py b/torch/storage.py index c15bf548f0499f..43c4b89a2279de 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -3,7 +3,7 @@ import torch from ._utils import _type, _cuda, _hpu from torch.types import Storage -from typing import Any, TypeVar, Type, Union, cast, Dict as _Dict +from typing import cast, Any, Dict as _Dict, Optional as _Optional, TypeVar, Type, Union import copy import collections from functools import lru_cache @@ -27,51 +27,51 @@ class _StorageBase: device: torch.device def __init__(self, *args, **kwargs): ... # noqa: E704 - def __len__(self) -> int: ... # noqa: E704 + def __len__(self) -> int: ... # type: ignore[empty-body] # noqa: E704 def __getitem__(self, idx): ... # noqa: E704 def __setitem__(self, *args, **kwargs): ... # noqa: E704 - def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704 - def new(self) -> T: ... # noqa: E704 - def nbytes(self) -> int: ... # noqa: E704 + def copy_(self, source: T, non_blocking: _Optional[bool] = None) -> T: ... # type: ignore[empty-body] # noqa: E704 + def new(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 + def nbytes(self) -> int: ... # type: ignore[empty-body] # noqa: E704 def size(self) -> int: return self.nbytes() - def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704 - def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704 - def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704 - def element_size(self) -> int: ... # noqa: E704 + def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 + def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 + def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 + def element_size(self) -> int: ... # type: ignore[empty-body, type-var] # noqa: E704 def get_device(self) -> int: return self.device.index - def data_ptr(self) -> int: ... # noqa: E704 + def data_ptr(self) -> int: ... # type: ignore[empty-body] # noqa: E704 # Defined in torch/csrc/generic/StorageSharing.cpp def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704 def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704 @classmethod - def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704 + def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704 @classmethod - def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704 + def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # type: ignore[empty-body] # noqa: E704 @classmethod - def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704 + def from_buffer(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704 @classmethod - def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704 + def _new_shared_filename_cpu(cls: Type[T], manager, obj, size, *, device=None, dtype=None) -> T: ... # type: ignore[empty-body] # noqa: E704 @classmethod - def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704 + def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704 @classmethod - def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704 - def _shared_decref(self) -> T: ... # noqa: E704 + def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704 + def _shared_decref(self) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 def _write_file(self, *args, **kwargs): ... # noqa: E704 def resize_(self, size: int): ... # noqa: E704 - def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704 + def _weak_ref(self, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 def _set_from_file(self, *args, **kwargs): ... # noqa: E704 def _set_cdata(self, *args, **kwargs): ... # noqa: E704 def _share_cuda_(self, *args, **kwargs): ... # noqa: E704 - def is_shared(self) -> bool: ... # noqa: E704 + def is_shared(self) -> bool: ... # type: ignore[empty-body] # noqa: E704 @classmethod - def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704 + def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T: ... # type: ignore[empty-body] # noqa: E704 def _shared_incref(self, *args, **kwargs): ... # noqa: E704 @classmethod def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704 @@ -80,9 +80,9 @@ def is_cuda(self): ... # noqa: E704 @property def is_hpu(self): ... # noqa: E704 @classmethod - def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704 + def from_file(cls, filename, shared, nbytes) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 @classmethod - def _expired(cls, *args, **kwargs) -> T: ... # noqa: E704 + def _expired(cls, *args, **kwargs) -> T: ... # type: ignore[empty-body, misc, type-var] # noqa: E704 def _byteswap(self, *args, **kwargs): ... # noqa: E704 def __str__(self): @@ -712,12 +712,12 @@ def _getitem(self, idx): tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self) return tmp_tensor[idx_wrapped].item() - def copy_(self, source: T, non_blocking: bool = None): + def copy_(self, source: T, non_blocking: _Optional[bool] = None): _warn_typed_storage_removal() if isinstance(source, TypedStorage): - self._untyped_storage.copy_(source._untyped_storage, non_blocking) + self._untyped_storage.copy_(source._untyped_storage, non_blocking) # type: ignore[arg-type] else: - self._untyped_storage.copy_(source, non_blocking) + self._untyped_storage.copy_(source, non_blocking) # type: ignore[arg-type] return self def nbytes(self): @@ -728,7 +728,7 @@ def nbytes(self): def _nbytes(self): return self._untyped_storage.nbytes() - def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]: + def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> Union[T, str]: _warn_typed_storage_removal() if dtype is None: legacy_class = self._get_legacy_storage_class() @@ -741,14 +741,14 @@ def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]: else: return self._untyped_storage.type(dtype, non_blocking) - def cuda(self, device=None, non_blocking=False, **kwargs) -> T: + def cuda(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var] _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError("Cannot create CUDA storage with quantized dtype") cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs) return self._new_wrapped_storage(cuda_storage) - def hpu(self, device=None, non_blocking=False, **kwargs) -> T: + def hpu(self, device=None, non_blocking=False, **kwargs) -> T: # type: ignore[misc, type-var] _warn_typed_storage_removal() if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]: raise RuntimeError("Cannot create HPU storage with quantized dtype") diff --git a/torch/types.py b/torch/types.py index 612196d6ff8814..56db67669d8cf2 100644 --- a/torch/types.py +++ b/torch/types.py @@ -43,37 +43,37 @@ class Storage: dtype: torch.dtype _torch_load_uninitialized: bool - def __deepcopy__(self, memo) -> 'Storage': + def __deepcopy__(self, memo) -> 'Storage': # type: ignore[empty-body] ... - def _new_shared(self, int) -> 'Storage': + def _new_shared(self, int) -> 'Storage': # type: ignore[empty-body] ... def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int) -> None: ... - def element_size(self) -> int: + def element_size(self) -> int: # type: ignore[empty-body] ... - def is_shared(self) -> bool: + def is_shared(self) -> bool: # type: ignore[empty-body] ... - def share_memory_(self) -> 'Storage': + def share_memory_(self) -> 'Storage': # type: ignore[empty-body] ... - def nbytes(self) -> int: + def nbytes(self) -> int: # type: ignore[empty-body] ... - def cpu(self) -> 'Storage': + def cpu(self) -> 'Storage': # type: ignore[empty-body] ... - def data_ptr(self) -> int: + def data_ptr(self) -> int: # type: ignore[empty-body] ... - def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage': + def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> 'Storage': # type: ignore[empty-body] ... - def _new_with_file(self, f: Any, element_size: int) -> 'Storage': + def _new_with_file(self, f: Any, element_size: int) -> 'Storage': # type: ignore[empty-body] ... ... diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py index 0adb5221aafde5..68b565b718ead8 100644 --- a/torch/utils/_traceback.py +++ b/torch/utils/_traceback.py @@ -100,7 +100,7 @@ def report_compile_source_on_error(): # specifically _PyCode_InitAddressRange, reveals that # this iterator is initialized from co_linetable and # co_firstfileno. So copy these we must! - code = code.replace( + code = code.replace( # type: ignore[call-arg] co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined] co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined] ) diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index 1914d5d66a0abf..83edab479578c7 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -184,7 +184,7 @@ def wrap_module_to(self: torch.nn.modules.module.T, def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str, - unsupported_dtype: List[torch.dtype] = None) -> None: + unsupported_dtype: Optional[List[torch.dtype]] = None) -> None: # Attribute is registered in the _StorageBase class # and UntypedStorage obtains through inheritance. @property # type: ignore[misc] @@ -255,7 +255,7 @@ def wrap_typed_storage_to(self: torch.storage.TypedStorage, def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True, for_storage: bool = False, - unsupported_dtype: List[torch.dtype] = None) -> None: + unsupported_dtype: Optional[List[torch.dtype]] = None) -> None: r""" generate_methods_for_privateuse1_backend(for_tensor, for_module, for_storage, unsupported_dtype) -> None diff --git a/torch/utils/benchmark/utils/compile.py b/torch/utils/benchmark/utils/compile.py index 35e480958a88ec..dcee32ace4031a 100644 --- a/torch/utils/benchmark/utils/compile.py +++ b/torch/utils/benchmark/utils/compile.py @@ -37,8 +37,8 @@ def bench_loop( model: Union[torch.nn.Module, Callable], sample_input: Union[torch.Tensor, Any], num_iters: int = 5, - optimizer: torch.optim.Optimizer = None, - loss_fn: Callable = None, + optimizer: Optional[torch.optim.Optimizer] = None, + loss_fn: Optional[Callable] = None, ): # Define the statement and setup for the benchmark if optimizer and loss_fn: @@ -73,8 +73,8 @@ def benchmark_compile( num_iters: int = 5, backend: Optional[str] = None, mode: Optional[str] = "default", - optimizer: torch.optim.Optimizer = None, - loss_fn : Union[torch.nn.Module, Callable] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + loss_fn : Union[torch.nn.Module, Callable, None] = None, ): """ Use this utility to benchmark torch.compile @@ -117,7 +117,7 @@ def bench_all( sample_input: Union[torch.Tensor, Any], num_iters : int = 5, optimizer: Optional[torch.optim.Optimizer] = None, - loss_fn : Union[torch.nn.Module, Callable] = None, + loss_fn : Union[torch.nn.Module, Callable, None] = None, ): """ This is a simple utility that can be used to benchmark torch.compile diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 6b3ab027137020..ec86f778023ba6 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -1066,7 +1066,7 @@ def __init__(self, loader): # pin_memory_thread once it is started. self._pin_memory_thread = pin_memory_thread else: - self._data_queue = self._worker_result_queue + self._data_queue = self._worker_result_queue # type: ignore[assignment] # In some rare cases, persistent workers (daemonic processes) # would be terminated before `__del__` of iterator is invoked diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index 62ac1107500678..06029e07851685 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -415,7 +415,7 @@ def __getattr__(self, attrname): # ? @functional_datapipe('trace_as_dataframe') -class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): +class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc] source_datapipe = None # TODO(VitalyFedyunin): Must implement all special functions of datapipes diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index d47c3b1a01125c..445400ecb59c32 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -376,7 +376,7 @@ def __iter__(self) -> "_IterDataPipeSerializationWrapper": self._datapipe_iter = iter(self._datapipe) return self - def __next__(self) -> T_co: + def __next__(self) -> T_co: # type: ignore[type-var] assert self._datapipe_iter is not None return next(self._datapipe_iter) diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 67bff8708064de..7c76e986b230d4 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -421,7 +421,7 @@ def __init__(self, datapipe: IterDataPipe[T_co], num_instances: int, self.main_datapipe_exhausted = False self._child_stop: List[bool] = [True for _ in range(num_instances)] - def _find_next(self, instance_id: int) -> T_co: + def _find_next(self, instance_id: int) -> T_co: # type: ignore[type-var] while True: if self.main_datapipe_exhausted or self._child_stop[instance_id]: raise StopIteration diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py index 9bf18802ff0807..85146f8345cbdc 100644 --- a/torch/utils/data/datapipes/map/combining.py +++ b/torch/utils/data/datapipes/map/combining.py @@ -40,7 +40,7 @@ def __init__(self, *datapipes: MapDataPipe): raise TypeError("Expected all inputs to be `Sized`") self.datapipes = datapipes # type: ignore[assignment] - def __getitem__(self, index) -> T_co: + def __getitem__(self, index) -> T_co: # type: ignore[type-var] offset = 0 for dp in self.datapipes: if index - offset < len(dp): diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index a7cb850aaf0b41..06028a679509ac 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -418,5 +418,5 @@ def random_split(dataset: Dataset[T], lengths: Sequence[Union[int, float]], if sum(lengths) != len(dataset): # type: ignore[arg-type] raise ValueError("Sum of input lengths does not equal the length of the input dataset!") - indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload] + indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)] diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index b7324eeadb9a9d..bbddc6735a1882 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -236,7 +236,7 @@ def __init__( mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None, depth: int = 2, display: bool = True, - custom_mapping: Dict[Any, Any] = None): + custom_mapping: Optional[Dict[Any, Any]] = None): self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int)) self.depth = depth self.parents = ["Global"] diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index edfeff8aadfb22..a64d5e9c35830a 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -1,6 +1,6 @@ import gc import sys -from typing import NamedTuple, Tuple, List, Optional +from typing import Any, Dict, List, NamedTuple, Optional, Tuple import types import weakref import json @@ -100,7 +100,7 @@ def annotated_references(obj): need for a list. Descriptions are currently strings. """ - references = {} + references: Dict[int, List[str]] = {} def add_reference(name, obj): references.setdefault(id(obj), []).append(name) @@ -272,7 +272,7 @@ def create_graph(objects, *, context=None, filter=None): filter = is_cuda_tensor nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects] - node_referrers = [[] for obj in objects] + node_referrers: List[List[int]] = [[] for obj in objects] id_to_node = {id(obj): i for i, obj in enumerate(objects)} for obj in objects: @@ -299,8 +299,8 @@ def create_graph(objects, *, context=None, filter=None): to_keep.add(idx) referrers = node_referrers[idx] to_search.extend(referrers) - id_to_filtered_id = {} - filtered = [] + id_to_filtered_id: Dict[int, int] = {} + filtered: List[Any] = [] for i, n in enumerate(nodes): if i in to_keep: id_to_filtered_id[i] = len(id_to_filtered_id) diff --git a/torch/utils/weak.py b/torch/utils/weak.py index 12e77f1e27ab35..2a7d597c4f2a06 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -50,7 +50,7 @@ def __init__(self, key, callback=None): # cache the id of the key as we know this is definitely the hash # method self._id = id(key) - super().__init__(key, callback) + super().__init__(key, callback) # type: ignore[call-arg] def __call__(self): r = super().__call__() diff --git a/torchgen/yaml_utils.py b/torchgen/yaml_utils.py index 5869f74caf0e20..0278af84bf633a 100644 --- a/torchgen/yaml_utils.py +++ b/torchgen/yaml_utils.py @@ -2,12 +2,12 @@ try: from yaml import CSafeLoader as Loader except ImportError: - from yaml import SafeLoader as Loader # type: ignore[misc] + from yaml import SafeLoader as Loader # type: ignore[assignment, misc] try: from yaml import CSafeDumper as Dumper except ImportError: - from yaml import SafeDumper as Dumper # type: ignore[misc] + from yaml import SafeDumper as Dumper # type: ignore[assignment, misc] YamlDumper = Dumper