Skip to content

Commit

Permalink
[Inductor] Optimize read write merging in FusedSchedulerNode ctor (py…
Browse files Browse the repository at this point in the history
…torch#105693)

Reduced optimizer compilation time by half, I think it will improve it in general as well.

Pull Request resolved: pytorch#105693
Approved by: https://github.com/jansel
  • Loading branch information
mlazos authored and pytorchmergebot committed Jul 21, 2023
1 parent 842616b commit 72b223c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
13 changes: 13 additions & 0 deletions torch/_inductor/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,19 @@ def merge(self, other: "ReadWrites"):
op_counts = other.op_counts
return ReadWrites(reads - writes, writes, index_exprs, op_counts=op_counts)

@staticmethod
def merge_list(read_writes: List["ReadWrites"]):
all_writes = set.union(*[rw.writes for rw in read_writes])
all_reads = set.union(*[rw.reads for rw in read_writes]) - all_writes
all_index_exprs = set.union(*[rw.index_exprs for rw in read_writes])

op_counts = collections.Counter()
for rw in read_writes:
if rw.op_counts is not None:
op_counts.update(rw.op_counts)

return ReadWrites(all_reads, all_writes, all_index_exprs, op_counts=op_counts)

def remove_reads(self, rem_reads):
return ReadWrites(
self.reads - rem_reads,
Expand Down
4 changes: 1 addition & 3 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,9 +511,7 @@ def __init__(self, scheduler: "Scheduler", snodes: List[SchedulerNode]):
)

self.set_read_writes(
functools.reduce(
dependencies.ReadWrites.merge, [x.read_writes for x in snodes]
)
dependencies.ReadWrites.merge_list([x.read_writes for x in snodes])
)

self.unmet_dependencies = {
Expand Down

0 comments on commit 72b223c

Please sign in to comment.