Skip to content

Commit

Permalink
Address Martin's commments, round iree-org#2
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Jul 16, 2024
1 parent bb4918a commit 8107d73
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 20 deletions.
4 changes: 3 additions & 1 deletion lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def run(func: Callable[[], None]) -> Callable[[], None]:
return func


def get_read_nodes(graph: fx.Graph) -> list[fx.Node]:
def get_read_nodes(graph: fx.Graph) -> list[CustomOp]:
custom_nodes: list[CustomOp] = [get_custom(node) for node in graph.nodes]
return [node for node in custom_nodes if isinstance(node, Read)]

Expand Down Expand Up @@ -154,6 +154,8 @@ def test_gemm():
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), f16, SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction
# CHECK-NEXT: %write
# CHECK-SAME: (%reduction, %c, 4)

# Reduction subgraph:
# CHECK: %acc
Expand Down
24 changes: 11 additions & 13 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,32 +197,30 @@ def update_arg(self, idx_or_name: int | str, value: CustomOp | fx.Node):
else:
raise IndexError("Index out of range")

def copy(self, new_name: Optional[str] = None) -> Self:
"""Returns a duplicate of this node."""
self.graph.inserting_after(self.fx_node)
new_node = self.graph.node_copy(self.fx_node)
new_node.tkw_op = self
if new_name:
new_node.name = new_name
return get_custom(new_node)

def copy_to_new_graph(
self, new_graph: fx.Graph, new_name: Optional[str] = None
def copy(
self, new_name: Optional[str] = None, new_graph: Optional[fx.Graph] = None
) -> Self:
"""Returns a duplicate of this node."""
new_node = new_graph.node_copy(self.fx_node)
graph = new_graph
if new_graph is None:
graph = self.graph
graph.inserting_after(self.fx_node)
new_node = graph.node_copy(self.fx_node)
new_node.tkw_op = self
if new_name:
new_node.name = new_name
return get_custom(new_node)

def replace_all_uses_with(self, new_node: CustomOp):
def replace_all_uses_with(self, new_node: CustomOp | fx.Node):
"""Replace all uses of the current node with the new node."""
for user in self.users:
user.update_arg(user.node_args.index(self), new_node)

def erase(self):
"""Erase the current node from the graph where it exists."""
assert (
not self.fx_node.users
), f"Attempting to erase {self.fx_node} which has {len(self.fx.users)} users!"
self.graph.erase_node(self.fx_node)

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions shark_turbine/kernel/wave/hoisting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
logger = get_logger("turbine.wave.hoisting")


def get_allocs_(graph: fx.Graph) -> list[CustomOp]:
def get_allocs(graph: fx.Graph) -> list[CustomOp]:
return [
custom_node
for node in graph.nodes
Expand All @@ -24,8 +24,8 @@ def hoist_allocs(trace: CapturedTrace):
case Reduction():
with root_graph.inserting_before(custom_node.fx_node):
subgraph = trace.get_subgraph(custom_node.subgraph_name)
allocs = get_allocs_(subgraph)
allocs = get_allocs(subgraph)
for alloc in allocs:
new_alloc = alloc.copy_to_new_graph(root_graph)
new_alloc = alloc.copy(new_graph=root_graph)
alloc.replace_all_uses_with(new_alloc)
alloc.erase()
Empty file.
8 changes: 5 additions & 3 deletions shark_turbine/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
logger = get_logger("turbine.wave.promotion")


def apply_promotion_pattern_(custom_node: Read | Write, allocate_node: Allocate):
def apply_promotion_pattern(custom_node: Read | Write, allocate_node: Allocate):
match custom_node:
case Read(
memory, elements_per_thread
Expand All @@ -21,9 +21,11 @@ def apply_promotion_pattern_(custom_node: Read | Write, allocate_node: Allocate)
Write(
custom_node.fx_node, allocate_node.fx_node, elements_per_thread
).add_to_graph(custom_node.graph)
case _:
logger.error(f"Attempted to promoted unsupported operator {custom_node}")


def promote_node(node: CustomOp, address_space: IndexSymbol):
def promote_node(node: Read | Write, address_space: IndexSymbol):
"""Promotes the given operand in the provided graph
to the specified address space.
Expand All @@ -38,4 +40,4 @@ def promote_node(node: CustomOp, address_space: IndexSymbol):
node.type.symbolic_shape, node.type.dtype, address_space
)
allocate_node.add_to_graph(node.graph)
apply_promotion_pattern_(node, allocate_node)
apply_promotion_pattern(node, allocate_node)

0 comments on commit 8107d73

Please sign in to comment.