From 1ea3b72a29869496a6d4e4d0ca8edaf16461dc63 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 24 Jan 2024 00:14:42 +0000 Subject: [PATCH] [Mixtral] Simplify indices calculation in moe layer --- mlc_llm/relax_model/mixtral.py | 77 ++++++++++------------------------ 1 file changed, 21 insertions(+), 56 deletions(-) diff --git a/mlc_llm/relax_model/mixtral.py b/mlc_llm/relax_model/mixtral.py index ea4b112780..ce40fac1b8 100644 --- a/mlc_llm/relax_model/mixtral.py +++ b/mlc_llm/relax_model/mixtral.py @@ -210,66 +210,31 @@ def get_indices( from tvm import relax from tvm.script import tir as T - @T.prim_func - def get_flattened_expert_indices_scheduled( - var_cumsum_colwise_flattened: T.handle, - var_expert_indices: T.handle, - var_flattened_expert_indices: T.handle, - ): - T.func_attr({"tir.is_scheduled": 1}) - batch_size = T.SizeVar("batch_size", "int32") - cumsum_flattened_length = T.SizeVar("cumsum_flattened_length", "int32") + TX = 1024 + experts_per_tok = T.int32(self.num_experts_per_tok) - cumsum_colwise_flattened = T.match_buffer( - var_cumsum_colwise_flattened, shape=[cumsum_flattened_length], dtype="int32" - ) - expert_indices = T.match_buffer( - var_expert_indices, shape=[batch_size, self.num_experts_per_tok], dtype="int32" - ) - flattened_expert_indices = T.match_buffer( - var_flattened_expert_indices, - shape=[batch_size * self.num_experts_per_tok], - dtype="int32", - ) + @T.prim_func(private=True) + def _func(var_cumsum: T.handle, var_expert_indices: T.handle, var_indices: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True}) + batch_size = T.SizeVar("batch_size", "int32") + cumsum_len = T.SizeVar("cumsum_len", "int32") # [experts_per_tok * batch_size] + cumsum = T.match_buffer(var_cumsum, [cumsum_len], "int32") + expert_indices = T.match_buffer(var_expert_indices, [batch_size, experts_per_tok], "int32") + indices = T.match_buffer(var_indices, [batch_size * experts_per_tok], "int32") + for bj_o in T.thread_binding(0, T.ceildiv(batch_size * experts_per_tok, TX), "blockIdx.x"): + for bj_i in T.thread_binding(0, TX, "threadIdx.x"): + with T.block("indices"): + T.reads(expert_indices[:, :], cumsum[:]) + T.writes(indices[:]) + if bj_o * TX + bj_i < batch_size * experts_per_tok: + b: T.int32 = T.floordiv(bj_o * TX + bj_i, experts_per_tok) + j: T.int32 = T.floormod(bj_o * TX + bj_i, experts_per_tok) + e: T.int32 = expert_indices[b, j] + indices[cumsum[e * batch_size + b] - 1] = b * experts_per_tok + j - for io in T.thread_binding( - 0, T.ceildiv(cumsum_flattened_length, T.int32(1024)), "blockIdx.x" - ): - for ii in T.thread_binding( - 0, T.min(cumsum_flattened_length, T.int32(1024)), "threadIdx.x" - ): - with T.block("get_indices"): - vi = T.axis.spatial(cumsum_flattened_length, io * T.int32(1024) + ii) - T.where(io * T.int32(1024) + ii < cumsum_flattened_length) - T.reads( - cumsum_colwise_flattened[vi - 1 : vi - 1 + 2], - expert_indices[:, 0 : self.num_experts_per_tok], - ) - T.writes(flattened_expert_indices[:]) - expert_idx = T.alloc_buffer(shape=(), dtype="int32", scope="local") - if cumsum_colwise_flattened[vi] > T.if_then_else( - vi == 0, T.int32(0), cumsum_colwise_flattened[vi - 1] - ): - idx: T.SizeVar("idx", "int32") = cumsum_colwise_flattened[vi] - 1 - instance_id: T.SizeVar("instance_id", "int32") = T.truncmod( - vi, batch_size - ) - expert_id: T.SizeVar("expert_id", "int32") = T.truncdiv(vi, batch_size) - for j in T.serial(0, self.num_experts_per_tok): - with T.block("select_expert"): - vj = T.axis.spatial(self.num_experts_per_tok, j) - vinstance_id = T.axis.spatial(batch_size, instance_id) - vexpert_id = T.axis.spatial( - T.truncdiv(cumsum_flattened_length, batch_size), expert_id - ) - if expert_indices[vinstance_id, vj] == vexpert_id: - expert_idx[()] = vj - flattened_expert_indices[idx] = ( - instance_id * self.num_experts_per_tok + expert_idx[()] - ) bb = relax.BlockBuilder.current() - gvar = bb.add_func(get_flattened_expert_indices_scheduled, "get_flattened_expert_indices") + gvar = bb.add_func(_func, "get_flattened_expert_indices") return bb.emit( relax.call_tir( gvar,