From e653630e71542ada1158afe9044797abf64b5bf3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 29 Feb 2024 19:02:32 +0000 Subject: [PATCH] [Model] Update Mixtral to have well-formed TIR Inside a `T.block`, loop variables may not be used, and access to them must be done through the corresponding `T.axis.remap` output. --- mlc_llm/relax_model/mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/mixtral.py b/mlc_llm/relax_model/mixtral.py index ba50af9c6d..c9caa1ffdd 100644 --- a/mlc_llm/relax_model/mixtral.py +++ b/mlc_llm/relax_model/mixtral.py @@ -350,14 +350,14 @@ def top2_softmax_func( for j in T.unroll(2): with T.block("cast"): vj = T.axis.remap("S", [j]) - local_top_k_f32[vj] = T.cast(local_top_k[j], "float32") + local_top_k_f32[vj] = T.cast(local_top_k[vj], "float32") with T.block("max"): local_top_k_max[0] = T.max(local_top_k_f32[0], local_top_k_f32[1]) for j in T.unroll(2): with T.block("output"): vj = T.axis.remap("S", [j]) out[vi, vj] = T.cast( - T.exp(local_top_k_f32[j] - local_top_k_max[0]) + T.exp(local_top_k_f32[vj] - local_top_k_max[0]) / ( T.exp(local_top_k_f32[0] - local_top_k_max[0]) + T.exp(local_top_k_f32[1] - local_top_k_max[0])