Skip to content

Commit

Permalink
[Mixtral] Simplify indices calculation in moe layer
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jan 24, 2024
1 parent 66a2e53 commit 1ea3b72
Showing 1 changed file with 21 additions and 56 deletions.
77 changes: 21 additions & 56 deletions mlc_llm/relax_model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,66 +210,31 @@ def get_indices(
from tvm import relax
from tvm.script import tir as T

@T.prim_func
def get_flattened_expert_indices_scheduled(
var_cumsum_colwise_flattened: T.handle,
var_expert_indices: T.handle,
var_flattened_expert_indices: T.handle,
):
T.func_attr({"tir.is_scheduled": 1})
batch_size = T.SizeVar("batch_size", "int32")
cumsum_flattened_length = T.SizeVar("cumsum_flattened_length", "int32")
TX = 1024
experts_per_tok = T.int32(self.num_experts_per_tok)

cumsum_colwise_flattened = T.match_buffer(
var_cumsum_colwise_flattened, shape=[cumsum_flattened_length], dtype="int32"
)
expert_indices = T.match_buffer(
var_expert_indices, shape=[batch_size, self.num_experts_per_tok], dtype="int32"
)
flattened_expert_indices = T.match_buffer(
var_flattened_expert_indices,
shape=[batch_size * self.num_experts_per_tok],
dtype="int32",
)
@T.prim_func(private=True)
def _func(var_cumsum: T.handle, var_expert_indices: T.handle, var_indices: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": True})
batch_size = T.SizeVar("batch_size", "int32")
cumsum_len = T.SizeVar("cumsum_len", "int32") # [experts_per_tok * batch_size]
cumsum = T.match_buffer(var_cumsum, [cumsum_len], "int32")
expert_indices = T.match_buffer(var_expert_indices, [batch_size, experts_per_tok], "int32")
indices = T.match_buffer(var_indices, [batch_size * experts_per_tok], "int32")
for bj_o in T.thread_binding(0, T.ceildiv(batch_size * experts_per_tok, TX), "blockIdx.x"):
for bj_i in T.thread_binding(0, TX, "threadIdx.x"):
with T.block("indices"):
T.reads(expert_indices[:, :], cumsum[:])
T.writes(indices[:])
if bj_o * TX + bj_i < batch_size * experts_per_tok:
b: T.int32 = T.floordiv(bj_o * TX + bj_i, experts_per_tok)
j: T.int32 = T.floormod(bj_o * TX + bj_i, experts_per_tok)
e: T.int32 = expert_indices[b, j]
indices[cumsum[e * batch_size + b] - 1] = b * experts_per_tok + j

for io in T.thread_binding(
0, T.ceildiv(cumsum_flattened_length, T.int32(1024)), "blockIdx.x"
):
for ii in T.thread_binding(
0, T.min(cumsum_flattened_length, T.int32(1024)), "threadIdx.x"
):
with T.block("get_indices"):
vi = T.axis.spatial(cumsum_flattened_length, io * T.int32(1024) + ii)
T.where(io * T.int32(1024) + ii < cumsum_flattened_length)
T.reads(
cumsum_colwise_flattened[vi - 1 : vi - 1 + 2],
expert_indices[:, 0 : self.num_experts_per_tok],
)
T.writes(flattened_expert_indices[:])
expert_idx = T.alloc_buffer(shape=(), dtype="int32", scope="local")
if cumsum_colwise_flattened[vi] > T.if_then_else(
vi == 0, T.int32(0), cumsum_colwise_flattened[vi - 1]
):
idx: T.SizeVar("idx", "int32") = cumsum_colwise_flattened[vi] - 1
instance_id: T.SizeVar("instance_id", "int32") = T.truncmod(
vi, batch_size
)
expert_id: T.SizeVar("expert_id", "int32") = T.truncdiv(vi, batch_size)
for j in T.serial(0, self.num_experts_per_tok):
with T.block("select_expert"):
vj = T.axis.spatial(self.num_experts_per_tok, j)
vinstance_id = T.axis.spatial(batch_size, instance_id)
vexpert_id = T.axis.spatial(
T.truncdiv(cumsum_flattened_length, batch_size), expert_id
)
if expert_indices[vinstance_id, vj] == vexpert_id:
expert_idx[()] = vj
flattened_expert_indices[idx] = (
instance_id * self.num_experts_per_tok + expert_idx[()]
)

bb = relax.BlockBuilder.current()
gvar = bb.add_func(get_flattened_expert_indices_scheduled, "get_flattened_expert_indices")
gvar = bb.add_func(_func, "get_flattened_expert_indices")
return bb.emit(
relax.call_tir(
gvar,
Expand Down

0 comments on commit 1ea3b72

Please sign in to comment.