Skip to content

Commit

Permalink
[Llama] Add descriptive names for symbolic variables (#180)
Browse files Browse the repository at this point in the history
Because the symbolic variables may appear in many places throughout
the resulting module, they should have more descriptive names, which
can be understood outside of their original contexts.
  • Loading branch information
Lunderberg authored Jan 30, 2024
1 parent 253da78 commit 34c7137
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
26 changes: 13 additions & 13 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def __init__(

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
cache_len = te.var("cache_len", "int64")
cache_len = te.var("cached_rotary_embedding_len", "int64")
self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached")
self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached")
############ End ############
Expand Down Expand Up @@ -907,7 +907,7 @@ def create_embed_func(
) -> None:
func_name = "embed"

seq_len = tvm.tir.SizeVar("m", "int64")
seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
with bb.function(func_name):
model = LlamaEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64"))
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)
Expand All @@ -934,8 +934,8 @@ def create_prefill_func_for_single_seq(
func_name = "prefill_with_embed" if sep_embed else "prefill"

bsz = 1
seq_len = tvm.tir.SizeVar("n", "int64")
all_seq_len = tvm.tir.SizeVar("m", "int64")
seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -980,8 +980,8 @@ def create_prefill_func_for_batching(
) -> None:
func_name = "prefill_with_embed"

bsz = tir.SizeVar("nseq", "int64")
total_seq_len = tvm.tir.SizeVar("m", "int64")
bsz = tir.SizeVar("batch_size", "int64")
total_seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def create_decoding_func_for_single_seq(
func_name = "decode"

bsz = 1
all_seq_len = tvm.tir.SizeVar("m", "int64")
all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")

with bb.function(func_name):
model = LlamaForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"))
Expand Down Expand Up @@ -1058,7 +1058,7 @@ def create_decoding_func_for_batching(
) -> None:
func_name = "decode_with_embed"

bsz = tir.SizeVar("nseq", "int64")
bsz = tir.SizeVar("batch_size", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def create_softmax_func_for_single_seq(bb: relax.BlockBuilder, config: LlamaConf

def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
with bb.function("softmax_with_temperature"):
bsz = tvm.tir.SizeVar("nseq", "int64")
bsz = tvm.tir.SizeVar("batch_size", "int64")
logits = nn.Placeholder(
(bsz, 1, tvm.tir.SizeVar("vocab_size", "int64")),
dtype="float32",
Expand Down Expand Up @@ -1192,8 +1192,8 @@ def kv_cache_transpose_append(
var_pos2seqidx: T.handle,
layer_id: T.int64,
):
nseq = T.int64()
ntoken = T.SizeVar("ntoken", "int64")
nseq = T.SizeVar("batch_size", "int64")
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
npage = T.int64()
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
Expand Down Expand Up @@ -1493,10 +1493,10 @@ def get_model(args, hf_config):
mod = bb.get()

tir_bound_map = dict()
tir_bound_map["n"] = (
tir_bound_map["num_tokens_excluding_cache"] = (
args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length
)
tir_bound_map["m"] = config.max_sequence_length
tir_bound_map["num_tokens_including_cache"] = config.max_sequence_length
tir_bound_map["vocab_size"] = args.max_vocab_size
if enable_batching:
tir_bound_map["nseq"] = args.max_batch_size
Expand Down
18 changes: 9 additions & 9 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def __init__(

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
cache_len = te.var("cache_len", "int64")
cache_len = te.var("cached_rotary_embedding_len", "int64")
self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached")
self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached")
############ End ############
Expand Down Expand Up @@ -762,8 +762,8 @@ def create_evaluate_func(
"""Evaluate logits for the last token in each sequence. Same as prefill but without KV cache."""
func_name = "evaluate"

num_query_token = tvm.tir.SizeVar("num_query_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_query_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

with bb.function(func_name):
model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), sep_embed)
Expand Down Expand Up @@ -815,8 +815,8 @@ def create_encoding_func(
"""
func_name = "prefill_with_embed" if sep_embed else "prefill"

num_query_token = tvm.tir.SizeVar("num_query_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_query_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

num_inputs = 5

Expand Down Expand Up @@ -885,7 +885,7 @@ def create_decoding_func(
"""Batched decoding with vLLM paged KV cache."""
func_name = "decode"

num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

func_names = ["decode"]

Expand Down Expand Up @@ -952,9 +952,9 @@ def create_evaluate_multi_query_func(
) -> None:
func_name = "evaluate_multi_query"

num_query_token = tvm.tir.SizeVar("num_query_token", "int64")
num_past_token = tvm.tir.SizeVar("num_past_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_query_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_past_token = tvm.tir.SizeVar("num_tokens_in_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")
seq_lens_sum = tvm.tir.SizeVar("seq_lens_sum", "int64")

num_inputs = 8
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_chat/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def tir_kv_cache_transpose_append(
var_position_map: T.handle,
):
T.func_attr({"tir.noalias": T.bool(True)})
ntoken = T.SizeVar("ntoken", "int64")
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)
Expand Down Expand Up @@ -181,7 +181,7 @@ def tir_kv_cache_debug_get_kv(
layer_id: T.int64,
):
T.func_attr({"tir.noalias": T.bool(True)})
seqlen = T.SizeVar("seqlen", "int64")
seqlen = T.SizeVar("num_tokens_including_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)
Expand Down

0 comments on commit 34c7137

Please sign in to comment.