Skip to content

Commit

Permalink
[BE]: Apply ruff FURB 118. (pytorch#124743)
Browse files Browse the repository at this point in the history
Replaces various lambdas with operator.itemgetter which is more efficient (as it's a builtin function). Particularly useful for when lambdas are used as 'key' functions.

Pull Request resolved: pytorch#124743
Approved by: https://github.com/albanD, https://github.com/malfet
  • Loading branch information
Skylion007 authored and pytorchmergebot committed Apr 26, 2024
1 parent fc2aa23 commit 2f3b0be
Show file tree
Hide file tree
Showing 25 changed files with 66 additions and 46 deletions.
3 changes: 2 additions & 1 deletion .github/scripts/get_workflow_job_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import argparse
import json
import operator
import os
import re
import sys
Expand Down Expand Up @@ -126,7 +127,7 @@ def find_job_id_name(args: Any) -> Tuple[str, str]:

# Sort the jobs list by start time, in descending order. We want to get the most
# recently scheduled job on the runner.
jobs.sort(key=lambda job: job["started_at"], reverse=True)
jobs.sort(key=operator.itemgetter("started_at"), reverse=True)

for job in jobs:
if job["runner_name"] == args.runner_name:
Expand Down
10 changes: 5 additions & 5 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2630,12 +2630,12 @@ def replace_use(a):
n.metadata.clear()

# Stage 4: Aggregate values.
sorted_tensor_values = dict(sorted(graph.tensor_values.items(), key=lambda x: x[0]))
sorted_tensor_values = dict(sorted(graph.tensor_values.items(), key=operator.itemgetter(0)))
sorted_sym_int_values = dict(
sorted(graph.sym_int_values.items(), key=lambda x: x[0])
sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
)
sorted_sym_bool_values = dict(
sorted(graph.sym_bool_values.items(), key=lambda x: x[0])
sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
)

# Stage 5: Recurse in subgraphs.
Expand Down Expand Up @@ -2683,8 +2683,8 @@ def canonicalize(ep: ExportedProgram) -> ExportedProgram:
"""
ep = copy.deepcopy(ep)

opset_version = dict(sorted(ep.opset_version.items(), key=lambda x: x[0]))
range_constraints = dict(sorted(ep.range_constraints.items(), key=lambda x: x[0]))
opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0)))
range_constraints = dict(sorted(ep.range_constraints.items(), key=operator.itemgetter(0)))
module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn)
signature = ep.graph_module.signature
graph = ep.graph_module.graph
Expand Down
3 changes: 2 additions & 1 deletion torch/_functorch/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import json
import operator
import os
import time

Expand Down Expand Up @@ -94,7 +95,7 @@ def get_sorted_gpu_events(events):
if not is_gpu_compute_event(event):
continue
sorted_gpu_events.append(event)
return sorted(sorted_gpu_events, key=lambda x: x["ts"])
return sorted(sorted_gpu_events, key=operator.itemgetter("ts"))


def get_duration(sorted_gpu_events):
Expand Down
6 changes: 3 additions & 3 deletions torch/_functorch/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _count_ops(graph):
for node in graph.nodes:
if node.op == "call_function":
cnt[node.target.__name__] += 1
print(sorted(cnt.items(), key=lambda x: x[1], reverse=True))
print(sorted(cnt.items(), key=operator.itemgetter(1), reverse=True))


@functools.lru_cache(None)
Expand All @@ -432,7 +432,7 @@ def sort_depths(args, depth_map):
arg_depths = {
arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node)
}
return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True)
return sorted(arg_depths.items(), key=operator.itemgetter(1), reverse=True)


def reordering_to_mimic_autograd_engine(gm):
Expand Down Expand Up @@ -1315,7 +1315,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int:
)
print(
"Count of Ops Rematerialized: ",
sorted(counts.items(), key=lambda x: x[1], reverse=True),
sorted(counts.items(), key=operator.itemgetter(1), reverse=True),
)
return fw_module, bw_module

Expand Down
4 changes: 3 additions & 1 deletion torch/_functorch/top_operators_github_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
From https://docs.google.com/spreadsheets/d/12R3nCOLskxPYjjiNkdqy4OdQ65eQp_htebXGODsjSeA/edit#gid=0
Try to keep this list in sync with that.
"""
import operator

top_torch = [
("t", 6837449),
("tensor", 585786),
Expand Down Expand Up @@ -618,7 +620,7 @@ def get_nn_functional_top_list():
top_nn_functional_[functional_name] += count

top_nn_functional_ = list(top_nn_functional_.items())
top_nn_functional_.sort(key=lambda x: x[1], reverse=True)
top_nn_functional_.sort(key=operator.itemgetter(1), reverse=True)
return top_nn_functional_


Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def run_node(self, node) -> Any:
for n in output_nodes
if isinstance(n, torch.fx.Node)
]
last_node = min(indices, key=lambda tup: tup[0])[1]
last_node = min(indices, key=operator.itemgetter(0))[1]

def percolate_tags(node, recompute_tag, input_stops):
queue = [node]
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2290,7 +2290,7 @@ def get_possible_fusions_with_highest_priority(self, possible_fusions):
)
# return the possible fusions with highest priority
possible_fusions_with_highest_priority = min(
possible_fusions_group_by_priority.items(), key=lambda item: item[0]
possible_fusions_group_by_priority.items(), key=operator.itemgetter(0)
)[1]
assert len(possible_fusions_with_highest_priority) > 0
return possible_fusions_with_highest_priority
Expand Down
2 changes: 1 addition & 1 deletion torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3597,7 +3597,7 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:

# derive permute order by sorting urtensor strides
enumerated_stride = list(enumerate(urtensor_stride))
enumerated_stride.sort(key=lambda item: item[1], reverse=True)
enumerated_stride.sort(key=operator.itemgetter(1), reverse=True)
permute_order, sorted_stride = zip(*enumerated_stride)

# add new and expand dimensions according to urtensor
Expand Down
4 changes: 3 additions & 1 deletion torch/_refs/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_nam
)


import operator

# Utilities should come BEFORE this import
from torch._decomp import register_decomposition
from torch._decomp.decompositions import pw_cast_for_opmath
Expand Down Expand Up @@ -165,7 +167,7 @@ def _backshift_permutation(dim0, dim1, ndim):

def _inverse_permutation(perm):
# Given a permutation, returns its inverse. It's equivalent to argsort on an array
return [i for i, j in sorted(enumerate(perm), key=lambda i_j: i_j[1])]
return [i for i, j in sorted(enumerate(perm), key=operator.itemgetter(1))]


# CompositeImplicitAutograd
Expand Down
3 changes: 2 additions & 1 deletion torch/ao/quantization/fx/_equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
maybe_get_next_module,
node_arg_is_weight,
)
import operator

CUSTOM_MODULE_SUPP_LIST: List[Any] = []

Expand Down Expand Up @@ -810,7 +811,7 @@ def get_equalization_qconfig_dict(

# Sort the layer_sqnr_dictionary values and get the layers with the lowest
# SQNR values (aka highest quantization errors)
layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=lambda item: item[1])
layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1))
layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize]

# Constructs an equalization_qconfig_dict that specifies to only equalize
Expand Down
3 changes: 2 additions & 1 deletion torch/cuda/_memory_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from itertools import groupby
import base64
import warnings
import operator

cache = lru_cache(None)

Expand Down Expand Up @@ -492,7 +493,7 @@ def free(alloc, device):
# create the final snapshot state
blocks_at_end = [(to_device(tensor_key.device), event['addr'], event['size'], event['frames'])
for (tensor_key, version), event in kv_to_elem.items()]
for device, blocks in groupby(sorted(blocks_at_end), key=lambda x: x[0]):
for device, blocks in groupby(sorted(blocks_at_end), key=operator.itemgetter(0)):
seg = snapshot['segments'][device] # type: ignore[index]
last_addr = seg['address']
for _, addr, size, frames in blocks:
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_shard/sharding_spec/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
from torch.distributed._shard.op_registry_utils import _decorator_func
import operator

if TYPE_CHECKING:
# Only include ShardedTensor when do type checking, exclude it
Expand Down Expand Up @@ -214,7 +215,7 @@ def _infer_sharding_spec_from_shards_metadata(shards_metadata):
if chunk_sharding_dim is not None:
# Ensure we infer the correct placement order from offsets
placements = [
x for _, x in sorted(zip(chunk_offset_list, placements), key=lambda e: e[0])
x for _, x in sorted(zip(chunk_offset_list, placements), key=operator.itemgetter(0))
]

from .chunk_sharding_spec import ChunkShardingSpec
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_tools/memory_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.nn as nn
from torch.utils.hooks import RemovableHandle
from torch.utils._python_dispatch import TorchDispatchMode
import operator


BYTES_PER_MB = 1024 * 1024.0
Expand Down Expand Up @@ -148,7 +149,7 @@ def summary(self, top: int = 20) -> None:
print("------------------------------------------------")
print(f"The number of cuda retries are: {self._num_cuda_retries}")
print(f"Top {top} ops that generates memory are:")
for k, v in sorted(op_diff.items(), key=lambda item: item[1], reverse=True)[
for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[
:top
]:
print(f"{k}: {v}MB")
Expand Down
5 changes: 3 additions & 2 deletions torch/distributed/checkpoint/filesystem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import dataclasses
import io
import operator
import os
import pickle
import queue
Expand Down Expand Up @@ -177,7 +178,7 @@ def start_loading(self) -> None:
if self.started:
return
self.started = True
self.items.sort(key=lambda x: x[0])
self.items.sort(key=operator.itemgetter(0))
self._refill()

def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
Expand Down Expand Up @@ -218,7 +219,7 @@ def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[Writ

for wi in tensor_w:
# TODO replace with headq
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0]
buckets[idx].append(wi)
bucket_sizes[idx] += _item_size(wi)

Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/checkpoint/planner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import io
import operator
from dataclasses import dataclass
from enum import auto, Enum
from functools import reduce
Expand Down Expand Up @@ -67,7 +68,7 @@ def tensor_storage_size(self) -> Optional[int]:
if self.tensor_data is None:
return None

numels = reduce(lambda x, y: x * y, self.tensor_data.size, 1)
numels = reduce(operator.mul, self.tensor_data.size, 1)
dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype)
return numels * dtype_size

Expand Down
2 changes: 1 addition & 1 deletion torch/export/unflatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def _reorder_submodules(
_reorder_submodules(child, fqn_order, prefix=fqn + ".")
delattr(parent, name)
children.append((fqn_order[fqn], name, child))
children.sort(key=lambda x: x[0])
children.sort(key=operator.itemgetter(0))
for _, name, child in children:
parent.register_module(name, child)

Expand Down
2 changes: 1 addition & 1 deletion torch/fx/experimental/accelerator_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def find_device_for(partition: Partition):
# Find devices for all the partitions without a device
found_device = True
for partition in no_device_partitions:
device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=lambda item: item[1]))
device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1)))
found_device = find_device_for(partition)
if not found_device:
break
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .utils import _toposort, groupby
from .variadic import isvariadic
import operator

__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature",
"edge", "ordering"]
Expand Down Expand Up @@ -111,7 +112,7 @@ def ordering(signatures):
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(lambda x: x[0], edges)
edges = groupby(operator.itemgetter(0), edges)
for s in signatures:
if s not in edges:
edges[s] = []
Expand Down
3 changes: 2 additions & 1 deletion torch/profiler/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import operator
import re
from collections import deque
from dataclasses import dataclass
Expand Down Expand Up @@ -316,7 +317,7 @@ def rank_events(self, length):
event
for _, event in sorted(
zip(heuristic_score_list, event_list),
key=lambda x: x[0],
key=operator.itemgetter(0),
reverse=True,
)
]
Expand Down
Loading

0 comments on commit 2f3b0be

Please sign in to comment.