Skip to content

Commit

Permalink
better
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Feb 7, 2024
1 parent ced3c12 commit 8fb9f33
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class SamplingMetadata:
apply_penalty: bool
apply_bias: bool
has_logprob: bool
logprob_batch_indices: List[int]
sampling_tensors: SamplingTensors
sampling_params: List[SamplingParams]

Expand Down Expand Up @@ -192,6 +193,7 @@ def from_sampling_params(
batch_size = len(sampling_params)
# index 0 is for non-logprob requests
has_logprob = False
logprob_batch_indices = []
list_mask_top_logprob = np.full(
((LOGPROB_TOP_K_MAX) + 1, batch_size), False, dtype=bool
)
Expand All @@ -216,9 +218,12 @@ def from_sampling_params(
list_mask_random.append(False)
idx_greedy += 1

# param.top_logprobs is zero if logprob is not used
list_mask_top_logprob[param.top_logprobs][batch_idx] = param.logprobs
has_logprob |= param.logprobs
if param.logprobs:
logprob_batch_indices.append(batch_idx)
# param.top_logprobs is zero if logprob is not used
list_mask_top_logprob[param.top_logprobs][batch_idx] = param.logprobs
has_logprob |= True

apply_penalty |= (
abs(param.presence_penalty) >= SAMPLING_EPS
or abs(param.frequency_penalty) >= SAMPLING_EPS
Expand Down Expand Up @@ -287,6 +292,7 @@ def from_sampling_params(
apply_penalty,
apply_bias,
has_logprob,
logprob_batch_indices,
sampling_tensors,
sampling_params,
)
Expand Down Expand Up @@ -366,8 +372,8 @@ class SamplingOutput:

def sample(
logits: torch.Tensor,
sampling_metadata,
check_safety=False,
sampling_metadata: SamplingMetadata,
check_safety: bool = False,
) -> SamplingOutput:
def _is_safe_to_sample(prob_like):
return (
Expand Down Expand Up @@ -415,25 +421,21 @@ def _is_safe_to_sample(prob_like):
all_top_logprobs, all_top_tokens = torch.topk(
extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True
)

for batch_idx in range(batch_size):
mask = sampling_metadata.sampling_tensors.mask_top_logprob
top_tokens = all_top_tokens[mask]
top_logprobs = all_top_logprobs[mask]
for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices):
next_token = next_tokens[batch_idx]
is_logprobs = sampling_metadata.sampling_params[batch_idx].logprobs
mask = sampling_metadata.sampling_tensors.mask_top_logprob

top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs
print(
all_top_tokens.shape,
top_k,
all_top_tokens[mask][batch_idx].shape,
all_top_tokens[mask][batch_idx][:top_k],
all_top_tokens[mask][batch_idx][:top_k].shape,
)

if is_logprobs:
logprob_infos[batch_idx] = RawLogprobsInfo(
current_token_id=next_token,
current_logprob=logprobs[batch_idx][next_token],
top_token_ids=all_top_tokens[mask][batch_idx][:top_k],
top_logprobs=all_top_logprobs[mask][batch_idx][:top_k],
top_token_ids=top_tokens[idx][:top_k],
top_logprobs=top_logprobs[idx][:top_k],
)

return SamplingOutput(next_tokens, logprob_infos)

0 comments on commit 8fb9f33

Please sign in to comment.