diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index c58e834d25..c38e855b86 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -160,8 +160,6 @@ class SamplingMetadata: apply_top_p_top_k: bool apply_penalty: bool apply_bias: bool - # mapping for - logprob_index_map: dict[int, Tuple[bool, int, int]] has_logprob: bool sampling_tensors: SamplingTensors sampling_params: List[SamplingParams] @@ -194,12 +192,10 @@ def from_sampling_params( batch_size = len(sampling_params) # index 0 is for non-logprob requests has_logprob = False - logprob_index_map = {} list_mask_top_logprob = np.full( ((LOGPROB_TOP_K_MAX) + 1, batch_size), False, dtype=bool ) logit_bias_maxlen = 0 - idxs_logprob = [-1] * ((LOGPROB_TOP_K_MAX) + 1) for batch_idx, param in enumerate(sampling_params): # Prepare temperature # NOTE: Zero temperature means deterministic sampling @@ -221,15 +217,8 @@ def from_sampling_params( idx_greedy += 1 # param.top_logprobs is zero if logprob is not used - list_mask_top_logprob[param.top_logprobs] = True + list_mask_top_logprob[param.top_logprobs][batch_idx] = param.logprobs has_logprob |= param.logprobs - idxs_logprob[param.top_logprobs] += 1 - logprob_index_map[batch_idx] = ( - param.logprobs, - param.top_logprobs, - idxs_logprob[param.top_logprobs], - ) - apply_penalty |= ( abs(param.presence_penalty) >= SAMPLING_EPS or abs(param.frequency_penalty) >= SAMPLING_EPS @@ -297,7 +286,6 @@ def from_sampling_params( apply_top_p_top_k, apply_penalty, apply_bias, - logprob_index_map, has_logprob, sampling_tensors, sampling_params, @@ -415,10 +403,6 @@ def _is_safe_to_sample(prob_like): (batch_size,), fill_value=None, dtype=RawLogprobsInfo ) if sampling_metadata.has_logprob: - all_top_logprobs = np.empty((LOGPROB_TOP_K_MAX + 1, batch_size), dtype=np.int64) - - all_top_logprobs = [torch.tensor([], device=logits.device)] - all_top_tokens = [torch.tensor([], device=logits.device)] # If everything is random sampling, save one extra softmax if not sampling_metadata.has_greedy: assert probs_random is not None @@ -426,27 +410,30 @@ def _is_safe_to_sample(prob_like): else: logprobs = torch.log_softmax(logits, dim=-1) - for lp_idx in range(LOGPROB_TOP_K_MAX): - logprob_topk = lp_idx + 1 - mask_t = sampling_metadata.sampling_tensors.mask_top_logprob[logprob_topk] - top_logprobs, top_tokens = torch.topk( - logprobs[mask_t], k=logprob_topk, dim=-1, largest=True, sorted=True - ) - all_top_logprobs.append(top_logprobs) - all_top_tokens.append(top_tokens) + # Redudandant but vectorized + extended_logprobs = logprobs.repeat((LOGPROB_TOP_K_MAX + 1, 1, 1)) + all_top_logprobs, all_top_tokens = torch.topk( + extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True + ) - # recover original batch order - # TODO: Can we vectorize this? for batch_idx in range(batch_size): - is_logprobs, logprob_topk, idx = sampling_metadata.logprob_index_map[ - batch_idx - ] next_token = next_tokens[batch_idx] + is_logprobs = sampling_metadata.sampling_params[batch_idx].logprobs + mask = sampling_metadata.sampling_tensors.mask_top_logprob + top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs + print( + all_top_tokens.shape, + top_k, + all_top_tokens[mask][batch_idx].shape, + all_top_tokens[mask][batch_idx][:top_k], + all_top_tokens[mask][batch_idx][:top_k].shape, + ) if is_logprobs: logprob_infos[batch_idx] = RawLogprobsInfo( current_token_id=next_token, current_logprob=logprobs[batch_idx][next_token], - top_token_ids=all_top_tokens[logprob_topk][idx], - top_logprobs=all_top_logprobs[logprob_topk][idx], + top_token_ids=all_top_tokens[mask][batch_idx][:top_k], + top_logprobs=all_top_logprobs[mask][batch_idx][:top_k], ) + return SamplingOutput(next_tokens, logprob_infos) diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index baeb53fdb4..579d40f086 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -247,11 +247,16 @@ def _test_logprobs( engine, num_requests=10, ): - prompt = "hi could you please implement merge sort?" + prompts = [ + "Hi could you please implement merge sort?", + "What is the best city in the world?", + "Can you write a poem for Seattle?", + "Describe lion for kids.", + ] requests = [ create_request( idx=str(n), - prompt=prompt, + prompt=random.choice(prompts), temp=0, freq_pen=0, pre_pen=0,