Skip to content

Commit

Permalink
Merge branch 'batch-serving' into flash-decoding-engine
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 30, 2024
2 parents 2eff7b0 + 34c7137 commit c51c2a4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 31 deletions.
26 changes: 13 additions & 13 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def __init__(

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
cache_len = te.var("cache_len", "int64")
cache_len = te.var("cached_rotary_embedding_len", "int64")
self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached")
self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached")
############ End ############
Expand Down Expand Up @@ -907,7 +907,7 @@ def create_embed_func(
) -> None:
func_name = "embed"

seq_len = tvm.tir.SizeVar("m", "int64")
seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
with bb.function(func_name):
model = LlamaEmbedTokensWrapper(config, tvm.tir.SizeVar("vocab_size", "int64"))
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)
Expand All @@ -934,8 +934,8 @@ def create_prefill_func_for_single_seq(
func_name = "prefill_with_embed" if sep_embed else "prefill"

bsz = 1
seq_len = tvm.tir.SizeVar("n", "int64")
all_seq_len = tvm.tir.SizeVar("m", "int64")
seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -980,8 +980,8 @@ def create_prefill_func_for_batching(
) -> None:
func_name = "prefill_with_embed"

bsz = tir.SizeVar("nseq", "int64")
total_seq_len = tvm.tir.SizeVar("m", "int64")
bsz = tir.SizeVar("batch_size", "int64")
total_seq_len = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def create_decoding_func_for_single_seq(
func_name = "decode"

bsz = 1
all_seq_len = tvm.tir.SizeVar("m", "int64")
all_seq_len = tvm.tir.SizeVar("num_tokens_including_cache", "int64")

with bb.function(func_name):
model = LlamaForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"))
Expand Down Expand Up @@ -1058,7 +1058,7 @@ def create_decoding_func_for_batching(
) -> None:
func_name = "decode_with_embed"

bsz = tir.SizeVar("nseq", "int64")
bsz = tir.SizeVar("batch_size", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = LlamaForCausalLM(
Expand Down Expand Up @@ -1157,7 +1157,7 @@ def create_softmax_func_for_single_seq(bb: relax.BlockBuilder, config: LlamaConf

def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
with bb.function("softmax_with_temperature"):
bsz = tvm.tir.SizeVar("nseq", "int64")
bsz = tvm.tir.SizeVar("batch_size", "int64")
logits = nn.Placeholder(
(bsz, 1, tvm.tir.SizeVar("vocab_size", "int64")),
dtype="float32",
Expand Down Expand Up @@ -1192,8 +1192,8 @@ def kv_cache_transpose_append(
var_pos2seqidx: T.handle,
layer_id: T.int64,
):
nseq = T.int64()
ntoken = T.SizeVar("ntoken", "int64")
nseq = T.SizeVar("batch_size", "int64")
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
npage = T.int64()
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
Expand Down Expand Up @@ -1493,10 +1493,10 @@ def get_model(args, hf_config):
mod = bb.get()

tir_bound_map = dict()
tir_bound_map["n"] = (
tir_bound_map["num_tokens_excluding_cache"] = (
args.prefill_chunk_size if args.prefill_chunk_size > 0 else config.max_sequence_length
)
tir_bound_map["m"] = config.max_sequence_length
tir_bound_map["num_tokens_including_cache"] = config.max_sequence_length
tir_bound_map["vocab_size"] = args.max_vocab_size
if enable_batching:
tir_bound_map["nseq"] = args.max_batch_size
Expand Down
18 changes: 9 additions & 9 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ def __init__(

# Set the cached sin/cos to the maximum of 2048 and max seq len.
# This will be eliminated further with online rotary embedding calculation.
cache_len = te.var("cache_len", "int64")
cache_len = te.var("cached_rotary_embedding_len", "int64")
self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached")
self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached")
############ End ############
Expand Down Expand Up @@ -762,8 +762,8 @@ def create_evaluate_func(
"""Evaluate logits for the last token in each sequence. Same as prefill but without KV cache."""
func_name = "evaluate"

num_query_token = tvm.tir.SizeVar("num_query_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_query_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

with bb.function(func_name):
model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), sep_embed)
Expand Down Expand Up @@ -815,8 +815,8 @@ def create_encoding_func(
"""
func_name = "prefill_with_embed" if sep_embed else "prefill"

num_query_token = tvm.tir.SizeVar("num_query_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_query_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

num_inputs = 5

Expand Down Expand Up @@ -885,7 +885,7 @@ def create_decoding_func(
"""Batched decoding with vLLM paged KV cache."""
func_name = "decode"

num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")

func_names = ["decode"]

Expand Down Expand Up @@ -952,9 +952,9 @@ def create_evaluate_multi_query_func(
) -> None:
func_name = "evaluate_multi_query"

num_query_token = tvm.tir.SizeVar("num_query_token", "int64")
num_past_token = tvm.tir.SizeVar("num_past_token", "int64")
num_seq = tvm.tir.SizeVar("num_seq", "int64")
num_query_token = tvm.tir.SizeVar("num_tokens_excluding_cache", "int64")
num_past_token = tvm.tir.SizeVar("num_tokens_in_cache", "int64")
num_seq = tvm.tir.SizeVar("batch_size", "int64")
seq_lens_sum = tvm.tir.SizeVar("seq_lens_sum", "int64")

num_inputs = 8
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_chat/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def tir_kv_cache_transpose_append(
var_position_map: T.handle,
):
T.func_attr({"tir.noalias": T.bool(True)})
ntoken = T.SizeVar("ntoken", "int64")
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)
Expand Down Expand Up @@ -181,7 +181,7 @@ def tir_kv_cache_debug_get_kv(
layer_id: T.int64,
):
T.func_attr({"tir.noalias": T.bool(True)})
seqlen = T.SizeVar("seqlen", "int64")
seqlen = T.SizeVar("num_tokens_including_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)
Expand Down
20 changes: 13 additions & 7 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,19 @@ def _is_safe_to_sample(prob_like):
logits = torch.from_dlpack(logits)
num_seq = len(sampling_params)

mask_random = torch.tensor(
mask_random_cpu = torch.tensor(
[p.sampling_type == SamplingType.RANDOM for p in sampling_params],
dtype=torch.bool,
)
mask_greedy = torch.logical_not(mask_random)
mask_greedy_cpu = torch.logical_not(mask_random_cpu)
if logits.device == torch.device("cpu"):
mask_random_dvc = mask_random_cpu
mask_greedy_dvc = mask_greedy_cpu
else: # gpu
mask_random_dvc = mask_random_cpu.to(logits.device)
mask_greedy_dvc = mask_greedy_cpu.to(logits.device)

logits_greedy = logits[mask_greedy]
logits_greedy = logits[mask_greedy_dvc]

if logits_greedy.shape[0] > 0:
res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy()
Expand Down Expand Up @@ -141,7 +147,7 @@ def _is_safe_to_sample(prob_like):
.to(device=logits.device)
)

logits_random = logits[mask_random]
logits_random = logits[mask_random_dvc]

if divide_by_temperature:
t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device)
Expand All @@ -156,17 +162,17 @@ def _is_safe_to_sample(prob_like):
torch.cuda.nvtx.range_pop()
return None

res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0]
res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy()

if logits_random.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_random

res = np.empty((num_seq,), dtype=np.int32)
res[mask_random] = res_random
res[mask_random_cpu] = res_random

if logits_greedy.shape[0] > 0:
res[mask_greedy] = res_greedy
res[mask_greedy_cpu] = res_greedy

torch.cuda.nvtx.range_pop()
return res
Expand Down

0 comments on commit c51c2a4

Please sign in to comment.