diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index c38e855b86..826569a392 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -161,6 +161,7 @@ class SamplingMetadata: apply_penalty: bool apply_bias: bool has_logprob: bool + logprob_batch_indices: List[int] sampling_tensors: SamplingTensors sampling_params: List[SamplingParams] @@ -192,6 +193,7 @@ def from_sampling_params( batch_size = len(sampling_params) # index 0 is for non-logprob requests has_logprob = False + logprob_batch_indices = [] list_mask_top_logprob = np.full( ((LOGPROB_TOP_K_MAX) + 1, batch_size), False, dtype=bool ) @@ -216,9 +218,12 @@ def from_sampling_params( list_mask_random.append(False) idx_greedy += 1 - # param.top_logprobs is zero if logprob is not used - list_mask_top_logprob[param.top_logprobs][batch_idx] = param.logprobs - has_logprob |= param.logprobs + if param.logprobs: + logprob_batch_indices.append(batch_idx) + # param.top_logprobs is zero if logprob is not used + list_mask_top_logprob[param.top_logprobs][batch_idx] = param.logprobs + has_logprob |= True + apply_penalty |= ( abs(param.presence_penalty) >= SAMPLING_EPS or abs(param.frequency_penalty) >= SAMPLING_EPS @@ -287,6 +292,7 @@ def from_sampling_params( apply_penalty, apply_bias, has_logprob, + logprob_batch_indices, sampling_tensors, sampling_params, ) @@ -366,8 +372,8 @@ class SamplingOutput: def sample( logits: torch.Tensor, - sampling_metadata, - check_safety=False, + sampling_metadata: SamplingMetadata, + check_safety: bool = False, ) -> SamplingOutput: def _is_safe_to_sample(prob_like): return ( @@ -415,25 +421,21 @@ def _is_safe_to_sample(prob_like): all_top_logprobs, all_top_tokens = torch.topk( extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True ) - - for batch_idx in range(batch_size): + mask = sampling_metadata.sampling_tensors.mask_top_logprob + top_tokens = all_top_tokens[mask] + top_logprobs = all_top_logprobs[mask] + for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices): next_token = next_tokens[batch_idx] is_logprobs = sampling_metadata.sampling_params[batch_idx].logprobs - mask = sampling_metadata.sampling_tensors.mask_top_logprob + top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs - print( - all_top_tokens.shape, - top_k, - all_top_tokens[mask][batch_idx].shape, - all_top_tokens[mask][batch_idx][:top_k], - all_top_tokens[mask][batch_idx][:top_k].shape, - ) + if is_logprobs: logprob_infos[batch_idx] = RawLogprobsInfo( current_token_id=next_token, current_logprob=logprobs[batch_idx][next_token], - top_token_ids=all_top_tokens[mask][batch_idx][:top_k], - top_logprobs=all_top_logprobs[mask][batch_idx][:top_k], + top_token_ids=top_tokens[idx][:top_k], + top_logprobs=top_logprobs[idx][:top_k], ) return SamplingOutput(next_tokens, logprob_infos)