diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index b21553abd1..145c522810 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -849,7 +849,7 @@ def __init__( # Set the cached sin/cos to the maximum of 2048 and max seq len. # This will be eliminated further with online rotary embedding calculation. - cache_len = te.var("cache_len", "int64") + cache_len = te.var("cached_rotary_embedding_len", "int64") self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached") self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached") ############ End ############ @@ -907,7 +907,7 @@ def create_embed_func( ) -> None: func_name = "embed" - seq_len = tvm.tir.SizeVar("m", "int64") + seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") with bb.function(func_name): model = LlamaEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) @@ -934,8 +934,8 @@ def create_prefill_func_for_single_seq( func_name = "prefill_with_embed" if sep_embed else "prefill" bsz = 1 - seq_len = tvm.tir.SizeVar("n", "int64") - all_seq_len = tvm.tir.SizeVar("m", "int64") + seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") + all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64") hidden_size = config.hidden_size with bb.function(func_name): model = LlamaForCausalLM( @@ -980,8 +980,8 @@ def create_prefill_func_for_batching( ) -> None: func_name = "prefill_with_embed" - bsz = tir.SizeVar("nseq", "int64") - total_seq_len = tvm.tir.SizeVar("m", "int64") + bsz = tir.SizeVar("batch_size", "int64") + total_seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") hidden_size = config.hidden_size with bb.function(func_name): model = LlamaForCausalLM( @@ -1019,7 +1019,7 @@ def create_decoding_func_for_single_seq( func_name = "decode" bsz = 1 - all_seq_len = tvm.tir.SizeVar("m", "int64") + all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64") with bb.function(func_name): model = LlamaForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64")) @@ -1058,7 +1058,7 @@ def create_decoding_func_for_batching( ) -> None: func_name = "decode_with_embed" - bsz = tir.SizeVar("nseq", "int64") + bsz = tir.SizeVar("batch_size", "int64") hidden_size = config.hidden_size with bb.function(func_name): model = LlamaForCausalLM( @@ -1157,7 +1157,7 @@ def create_softmax_func_for_single_seq(bb: relax.BlockBuilder, config: LlamaConf def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig) -> None: with bb.function("softmax_with_temperature"): - bsz = tvm.tir.SizeVar("nseq", "int64") + bsz = tvm.tir.SizeVar("batch_size", "int64") logits = nn.Placeholder( (bsz, 1, tvm.tir.SizeVar("vocab_size", "int64")), dtype="float32", @@ -1192,8 +1192,8 @@ def kv_cache_transpose_append( var_pos2seqidx: T.handle, layer_id: T.int64, ): - nseq = T.int64() - ntoken = T.SizeVar("ntoken", "int64") + nseq = T.SizeVar("batch_size", "int64") + ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") npage = T.int64() page_size = T.SizeVar("page_size", "int64") num_pages = T.int64() @@ -1493,10 +1493,10 @@ def get_model(args, hf_config): mod = bb.get() tir_bound_map = dict() - tir_bound_map["n"] = ( + tir_bound_map["num_tokens_excluding_cache"] = ( args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length ) - tir_bound_map["m"] = config.max_sequence_length + tir_bound_map["num_tokens_including_cache"] = config.max_sequence_length tir_bound_map["vocab_size"] = args.max_vocab_size if enable_batching: tir_bound_map["nseq"] = args.max_batch_size diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 97e7656c9b..1dbd91dd73 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -558,7 +558,7 @@ def __init__( # Set the cached sin/cos to the maximum of 2048 and max seq len. # This will be eliminated further with online rotary embedding calculation. - cache_len = te.var("cache_len", "int64") + cache_len = te.var("cached_rotary_embedding_len", "int64") self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached") self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached") ############ End ############ @@ -762,8 +762,8 @@ def create_evaluate_func( """Evaluate logits for the last token in each sequence. Same as prefill but without KV cache.""" func_name = "evaluate" - num_query_token = tvm.tir.SizeVar("num_query_token", "int64") - num_seq = tvm.tir.SizeVar("num_seq", "int64") + num_query_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") + num_seq = tvm.tir.SizeVar("batch_size", "int64") with bb.function(func_name): model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), sep_embed) @@ -815,8 +815,8 @@ def create_encoding_func( """ func_name = "prefill_with_embed" if sep_embed else "prefill" - num_query_token = tvm.tir.SizeVar("num_query_token", "int64") - num_seq = tvm.tir.SizeVar("num_seq", "int64") + num_query_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") + num_seq = tvm.tir.SizeVar("batch_size", "int64") num_inputs = 5 @@ -885,7 +885,7 @@ def create_decoding_func( """Batched decoding with vLLM paged KV cache.""" func_name = "decode" - num_seq = tvm.tir.SizeVar("num_seq", "int64") + num_seq = tvm.tir.SizeVar("batch_size", "int64") func_names = ["decode"] @@ -952,9 +952,9 @@ def create_evaluate_multi_query_func( ) -> None: func_name = "evaluate_multi_query" - num_query_token = tvm.tir.SizeVar("num_query_token", "int64") - num_past_token = tvm.tir.SizeVar("num_past_token", "int64") - num_seq = tvm.tir.SizeVar("num_seq", "int64") + num_query_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64") + num_past_token = tvm.tir.SizeVar("num_tokens_in_cache", "int64") + num_seq = tvm.tir.SizeVar("batch_size", "int64") seq_lens_sum = tvm.tir.SizeVar("seq_lens_sum", "int64") num_inputs = 8 diff --git a/python/mlc_chat/nn/kv_cache.py b/python/mlc_chat/nn/kv_cache.py index 1629da82fa..92bc4d325b 100644 --- a/python/mlc_chat/nn/kv_cache.py +++ b/python/mlc_chat/nn/kv_cache.py @@ -141,7 +141,7 @@ def tir_kv_cache_transpose_append( var_position_map: T.handle, ): T.func_attr({"tir.noalias": T.bool(True)}) - ntoken = T.SizeVar("ntoken", "int64") + ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") page_size = T.SizeVar("page_size", "int64") num_pages = T.int64() pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype) @@ -181,7 +181,7 @@ def tir_kv_cache_debug_get_kv( layer_id: T.int64, ): T.func_attr({"tir.noalias": T.bool(True)}) - seqlen = T.SizeVar("seqlen", "int64") + seqlen = T.SizeVar("num_tokens_including_cache", "int64") page_size = T.SizeVar("page_size", "int64") num_pages = T.int64() pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)