From 5f905c7607783d5b7b8c8ff2da9b0c38bd081a41 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Wed, 28 Aug 2024 22:04:22 -0700 Subject: [PATCH] Address comments #1 --- shark_turbine/kernel/wave/constraints.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/shark_turbine/kernel/wave/constraints.py b/shark_turbine/kernel/wave/constraints.py index 0306db983..45ab8d1f7 100644 --- a/shark_turbine/kernel/wave/constraints.py +++ b/shark_turbine/kernel/wave/constraints.py @@ -116,13 +116,13 @@ def compute_access_pattern_using_vector_shapes( def apply( self, - mma_index: int, + constraint_index: int, dim: IndexSymbol, elements_per_thread: int | IndexSymbol, ) -> IndexSequence: if self.vector_shapes is not None: return self.compute_access_pattern_using_vector_shapes( - dim, mma_index, elements_per_thread + dim, constraint_index, elements_per_thread ) lane = self.linearized_thread_id match self.mma_type: @@ -146,7 +146,9 @@ def apply( 1, # K ] return IndexSequence( - offset[mma_index], size[mma_index], stride[mma_index] + offset[constraint_index], + size[constraint_index], + stride[constraint_index], )