Skip to content

Commit

Permalink
Hoist only once
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Feb 4, 2025
1 parent 1eb66d2 commit a34f8ee
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
3 changes: 3 additions & 0 deletions iree/turbine/kernel/_support/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def _subs(
return value.subs(map)
return value

def __hash__(self):
return hash((self.start))

def subs(self, map: dict[IndexExpr, IndexExpr]):
start = self._subs(self.start, map)
size = self._subs(self.size, map)
Expand Down
35 changes: 32 additions & 3 deletions iree/turbine/kernel/wave/hoisting.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,43 @@ def remove_unused_captured_vars(reduction: CustomOp, subgraph: fx.Graph):
get_custom(captured_vars[captured_idx]).erase()
# Order of captured_vars in subgraph do not necessarily match order of root
# implicit_capture. Especially if we introduce instruction reoderings.
root_capture_idx = new_implicit_captures.index(
captured_vars[captured_idx].meta["lifted"]
)
lifted = captured_vars[captured_idx].meta["lifted"]
root_capture_idx = new_implicit_captures.index(lifted)
new_implicit_captures.pop(root_capture_idx)
reduction.update_arg("implicit_captures", new_implicit_captures)


def remove_redundant_hoists(
hoisted_ops: list[CustomOp], reduction: CustomOp, subgraph: fx.Graph
):
redundant_hoists = dict()
already_hoisted = dict()
for hoisted_op in hoisted_ops:
key = (type(hoisted_op), tuple(hoisted_op.index.values()))
print(hoisted_op.fx_node, hoisted_op.index, type(hoisted_op))
if key in already_hoisted and not isinstance(hoisted_op, Allocate):
redundant_hoists[already_hoisted[key]] = hoisted_op
hoisted_op.replace_all_uses_with(already_hoisted[key])
continue
already_hoisted[
(type(hoisted_op), tuple(hoisted_op.index.values()))
] = hoisted_op.fx_node

captured_vars = reduction.captured_vars(subgraph)
for captured_idx in reversed(range(len(captured_vars))):
# Order of captured_vars in subgraph do not necessarily match order of root
# implicit_capture. Especially if we introduce instruction reoderings.
lifted = captured_vars[captured_idx].meta["lifted"]
if lifted in redundant_hoists:
captured_vars[captured_idx].meta["lifted"] = redundant_hoists[lifted]

return redundant_hoists


def hoist_loop_invariant_ops(trace: CapturedTrace, constraints: list[Constraint]):
"""Hoists ops that are loop-invariant from reduction subgraphs to outer root graph."""
root_graph = trace.get_root_graph()
hoisted_ops = []
for node in root_graph.nodes:
custom_node = get_custom(node)
match custom_node:
Expand All @@ -101,6 +128,8 @@ def hoist_loop_invariant_ops(trace: CapturedTrace, constraints: list[Constraint]
if isinstance(hoistable_op, Read):
root_var = hoistable_op.memory.meta["lifted"]
new_op.update_arg("memory", root_var)
hoisted_ops.append(new_op)
remove_redundant_hoists(hoisted_ops, custom_node, subgraph)
# Clear/Remove unused captured var to correct codegen. Ops inside
# scf.for will be indexing/loading from the wrong bindings otherwise.
remove_unused_captured_vars(custom_node, subgraph)

0 comments on commit a34f8ee

Please sign in to comment.