Skip to content

Commit

Permalink
Merge branch 'batch-serving' into parallel-sampling-eviction
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 1, 2024
2 parents a4d6e01 + 583bb4b commit 7360392
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 23 deletions.
9 changes: 9 additions & 0 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def get_num_cache_blocks(
)


def get_logprob_infos(
i: int,
logprob_infos: Optional[RawLogprobsInfos],
) -> Optional[RawLogprobsInfos]:
if logprob_infos is None or logprob_infos[i] is None:
return None
return [logprob_infos[i]]


def get_raw_logprob_info(
logits,
token_id,
Expand Down
22 changes: 7 additions & 15 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
sample,
prepare_inputs,
prepare_multi_query_decode_inputs,
get_logprob_infos,
get_num_cache_blocks,
)

Expand Down Expand Up @@ -207,15 +208,6 @@ def profile_memory_usage(self, seq_lens):

return self.get_used_memory()

def get_logprob_infos(
self,
i: int,
logprob_infos: Optional[RawLogprobsInfos],
) -> Optional[RawLogprobsInfos]:
if logprob_infos is None or logprob_infos[i] is None:
return None
return [logprob_infos[i]]

def sample_from_logits(
self,
logits: Union[tvm.nd.NDArray, torch.Tensor],
Expand Down Expand Up @@ -248,7 +240,7 @@ def sample_from_logits(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=[new_token],
error=None,
logprob_info=self.get_logprob_infos(i, logprob_infos),
logprob_info=get_logprob_infos(i, logprob_infos),
)
)
else:
Expand All @@ -257,7 +249,7 @@ def sample_from_logits(
sequence_id=sequence_id,
generated_tokens=[new_token],
error=None,
logprob_info=self.get_logprob_infos(i, logprob_infos),
logprob_info=get_logprob_infos(i, logprob_infos),
)
)

Expand Down Expand Up @@ -298,7 +290,7 @@ def sample_from_logits(
),
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=self.get_logprob_infos(
logprob_info=get_logprob_infos(
0, logprob_infos
),
)
Expand All @@ -309,7 +301,7 @@ def sample_from_logits(
sequence_id=sequence_id,
generated_tokens=[new_token], # type: ignore
error=None,
logprob_info=self.get_logprob_infos(0, logprob_infos),
logprob_info=get_logprob_infos(0, logprob_infos),
)
)
else:
Expand All @@ -323,7 +315,7 @@ def sample_from_logits(
),
generated_tokens=[],
error=err_msg,
logprob_info=self.get_logprob_infos(
logprob_info=get_logprob_infos(
0, logprob_infos
),
)
Expand All @@ -334,7 +326,7 @@ def sample_from_logits(
sequence_id=sequence_id,
generated_tokens=[],
error=err_msg,
logprob_info=self.get_logprob_infos(0, logprob_infos),
logprob_info=get_logprob_infos(0, logprob_infos),
)
)

Expand Down
26 changes: 18 additions & 8 deletions serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,20 +342,30 @@ def _test_penalty(
def _test_logprobs(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
top_logprobs=3,
max_num_batched_tokens=2048
):
prompt = "hi"
prompt = "hi could you please implement merge sort?"
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,
max_num_batched_tokens,
)
s = 113
requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, top_logprobs=top_logprobs, logprobs=True) for n in range(s, s+num_requests)]
requests = [
create_request(
idx=str(n),
prompt=prompt,
temp=0,
freq_pen=0,
pre_pen=0,
max_tokens=300,
stop=None,
ignore_eos=True,
top_logprobs=top_logprobs,
logprobs=True
) for n in range(num_requests)
]
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand All @@ -366,7 +376,7 @@ def _test_logprobs(
assert len(res.sequences) == 1
seq = res.sequences[0]

assert seq.finish_reason is not None or len(list(seq.logprobs.content[0]["top_logprobs"])) == top_logprobs
assert seq.finish_reason is not None or len(seq.logprob_info[0].top_logprobs) == top_logprobs

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
Expand Down

0 comments on commit 7360392

Please sign in to comment.