Skip to content

Commit

Permalink
[BE] Enable ruff's UP rules and autoformat torchgen/ (pytorch#105423)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#105423
Approved by: https://github.com/Skylion007
  • Loading branch information
justinchuby authored and pytorchmergebot committed Jul 18, 2023
1 parent 6ca3d7e commit 964d29f
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 30 deletions.
16 changes: 8 additions & 8 deletions torchgen/api/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def from_outputs(
outputs=outputs,
)
elif size > 1:
if any((not a.type.is_tensor_like() for a in outputs)):
if any(not a.type.is_tensor_like() for a in outputs):
raise RuntimeError(f"Unsupported output type: {outputs}")
return PythonOutArgument(
name="out",
Expand Down Expand Up @@ -882,10 +882,10 @@ def topt_default_init(name: str) -> Optional[str]:


def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
if len(returns) <= 1 or all((r.name is None for r in returns)):
if len(returns) <= 1 or all(r.name is None for r in returns):
return []
else:
if any((r.name is None for r in returns)):
if any(r.name is None for r in returns):
# When building on Windows, `PyStructSequence_UnnamedField` could not be
# resolved by the linker for some reason, which cause error in building:
#
Expand Down Expand Up @@ -1163,7 +1163,7 @@ def dispatch_lambda_return_str(f: NativeFunction) -> str:
# mutable reference to temporary. Maybe we could assign it to a
# variable itself.)
returns_without_annotation = tuple(
(Return(r.name, r.type, None) for r in f.func.returns)
Return(r.name, r.type, None) for r in f.func.returns
)
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
if return_str not in SUPPORTED_RETURN_TYPES:
Expand Down Expand Up @@ -1195,7 +1195,7 @@ def cpp_dispatch_exprs(
exprs: Tuple[str, ...] = tuple()
if not isinstance(python_signature, PythonSignatureDeprecated):
# By default the exprs are consistent with the C++ signature.
exprs = tuple((a.name for a in cpp_args))
exprs = tuple(a.name for a in cpp_args)
else:
# For deprecated python signature we may need fill in some constants.
exprs = tuple(
Expand Down Expand Up @@ -1426,7 +1426,7 @@ def dispatch_lambda_exprs(
f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
)
if not all(
(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS.keys())
a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS.keys()
):
raise RuntimeError(
f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
Expand Down Expand Up @@ -1454,7 +1454,7 @@ def dispatch_lambda_exprs(
raise RuntimeError(
f"{f.func}: dtype in tensor_options_args without output arg"
)
if not all((a in tensor_options_args_names for a in ("layout", "device"))):
if not all(a in tensor_options_args_names for a in ("layout", "device")):
raise RuntimeError(
f"{f.func}: incomplete tensor options for output check"
)
Expand All @@ -1473,6 +1473,6 @@ def dispatch_lambda_exprs(
)

return DispatchLambdaArgumentExprs(
exprs=tuple((lambda_args_exprs[a.name] for a in lambda_args)),
exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
inits=inits,
)
2 changes: 1 addition & 1 deletion torchgen/code_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class CodeTemplate:

@staticmethod
def from_file(filename: str) -> "CodeTemplate":
with open(filename, "r") as f:
with open(filename) as f:
return CodeTemplate(f.read(), filename)

def __init__(self, pattern: str, filename: str = "") -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchgen/executorch/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def parse_et_yaml(
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
of fields to persist from native_functions.yaml to functions.yaml
"""
with open(path, "r") as f:
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)

et_kernel = extract_kernel_fields(es)
Expand Down
4 changes: 2 additions & 2 deletions torchgen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
def parse_tags_yaml(path: str) -> Set[str]:
global _GLOBAL_PARSE_TAGS_YAML_CACHE
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
with open(path, "r") as f:
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)
_GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)

Expand All @@ -233,7 +233,7 @@ def parse_native_yaml(

# if a loaded yaml is provided, use that instead of reading from path
if loaded_yaml is None:
with open(path, "r") as f:
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)
else:
es = loaded_yaml
Expand Down
6 changes: 3 additions & 3 deletions torchgen/gen_backend_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def parse_backend_yaml(
)
}

with open(backend_yaml_path, "r") as f:
with open(backend_yaml_path) as f:
yaml_values = yaml.load(f, Loader=YamlLoader)
assert isinstance(yaml_values, dict)

Expand Down Expand Up @@ -253,9 +253,9 @@ def error_on_missing_kernels(
full_codegen: Optional[List[OperatorName]] = None,
) -> None:
try:
with open(kernel_defn_file_path, "r") as f:
with open(kernel_defn_file_path) as f:
backend_defns = f.read()
except IOError as e:
except OSError as e:
raise AssertionError(
f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
) from e
Expand Down
6 changes: 3 additions & 3 deletions torchgen/gen_executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def translate_native_yaml(
None
"""
if use_aten_lib:
with open(aten_yaml_path, "r") as aten_yaml:
with open(aten_yaml_path) as aten_yaml:
out_file.writelines(aten_yaml.readlines())
return

Expand Down Expand Up @@ -604,7 +604,7 @@ def translate_native_yaml(
or os.stat(native_yaml_path).st_size == 0
):
return
with open(native_yaml_path, "r") as native_yaml:
with open(native_yaml_path) as native_yaml:
native_es = yaml.load(native_yaml, Loader=LineLoader)
if not native_es:
return
Expand Down Expand Up @@ -641,7 +641,7 @@ def parse_yaml(
Union[Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ETKernelIndex],
]:
if path and os.path.exists(path) and os.stat(path).st_size > 0:
with open(path, "r") as f:
with open(path) as f:
es = yaml.load(f, Loader=LineLoader)

# Check for kernel index structure
Expand Down
6 changes: 3 additions & 3 deletions torchgen/gen_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def parse_native_functions_keys(
)
}

with open(backend_yaml_path, "r") as f:
with open(backend_yaml_path) as f:
yaml_values = yaml.load(f, Loader=YamlLoader)
assert isinstance(yaml_values, dict)

Expand All @@ -134,10 +134,10 @@ def validate_shape_inference_header(
shape_inference_hdr: str, expected_shape_infr_decls: List[str]
) -> None:
try:
with open(shape_inference_hdr, "r") as f:
with open(shape_inference_hdr) as f:
shape_infr_decls = f.read()
shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
except IOError as e:
except OSError as e:
raise AssertionError(
f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
) from e
Expand Down
2 changes: 1 addition & 1 deletion torchgen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Location:
line: int

def __str__(self) -> str:
return "{}:{}".format(self.file, self.line)
return f"{self.file}:{self.line}"


# Valid values of the 'variants' field in native_functions.yaml
Expand Down
2 changes: 1 addition & 1 deletion torchgen/selective_build/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def from_yaml_dict(
if "debug_info" in op_info:
di_list = op_info["debug_info"]
assert isinstance(di_list, list)
debug_info = tuple((str(x) for x in di_list))
debug_info = tuple(str(x) for x in di_list)

return SelectiveBuildOperator(
name=op_name,
Expand Down
4 changes: 2 additions & 2 deletions torchgen/selective_build/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def from_yaml_dict(data: Dict[str, object]) -> "SelectiveBuilder":
di_list = data["debug_info"]
assert isinstance(di_list, list)

debug_info = tuple((str(x) for x in di_list))
debug_info = tuple(str(x) for x in di_list)

operators = {}
operators_dict = data.get("operators", {})
Expand Down Expand Up @@ -141,7 +141,7 @@ def from_yaml_str(config_contents: str) -> "SelectiveBuilder":

@staticmethod
def from_yaml_path(config_path: str) -> "SelectiveBuilder":
with open(config_path, "r") as f:
with open(config_path) as f:
contents = yaml.safe_load(f)
return SelectiveBuilder.from_yaml_dict(contents)

Expand Down
10 changes: 5 additions & 5 deletions torchgen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def context(msg_fn: Callable[[], str]) -> Iterator[None]:
# for getting mypy to do exhaustiveness checking
# TODO: put this somewhere else, maybe
def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
raise AssertionError(f"Unhandled type: {type(x).__name__}")


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -137,9 +137,9 @@ def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
def _write_if_changed(self, filename: str, contents: str) -> None:
old_contents: Optional[str]
try:
with open(filename, "r") as f:
with open(filename) as f:
old_contents = f.read()
except IOError:
except OSError:
old_contents = None
if contents != old_contents:
# Create output directory if it doesn't exist
Expand All @@ -157,7 +157,7 @@ def substitute_with_template(
# TODO: Update the comment reference to the correct location
if "generated_comment" not in env:
comment = "@" + "generated by torchgen/gen.py"
comment += " from {}".format(os.path.basename(template_path))
comment += f" from {os.path.basename(template_path)}"
env["generated_comment"] = comment
template = _read_template(template_path)
return template.substitute(env)
Expand All @@ -172,7 +172,7 @@ def write_with_template(
template_fn: str,
env_callable: Callable[[], Union[str, Dict[str, Any]]],
) -> None:
filename = "{}/{}".format(self.install_dir, filename)
filename = f"{self.install_dir}/{filename}"
assert filename not in self.filenames, "duplicate file write {filename}"
self.filenames.add(filename)
if not self.dry_run:
Expand Down

0 comments on commit 964d29f

Please sign in to comment.