Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Feb 7, 2024
1 parent e08842f commit ced3c12
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 34 deletions.
51 changes: 19 additions & 32 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ class SamplingMetadata:
apply_top_p_top_k: bool
apply_penalty: bool
apply_bias: bool
# mapping for <batch index, Tuple(logprob, top-k, top-k index)>
logprob_index_map: dict[int, Tuple[bool, int, int]]
has_logprob: bool
sampling_tensors: SamplingTensors
sampling_params: List[SamplingParams]
Expand Down Expand Up @@ -194,12 +192,10 @@ def from_sampling_params(
batch_size = len(sampling_params)
# index 0 is for non-logprob requests
has_logprob = False
logprob_index_map = {}
list_mask_top_logprob = np.full(
((LOGPROB_TOP_K_MAX) + 1, batch_size), False, dtype=bool
)
logit_bias_maxlen = 0
idxs_logprob = [-1] * ((LOGPROB_TOP_K_MAX) + 1)
for batch_idx, param in enumerate(sampling_params):
# Prepare temperature
# NOTE: Zero temperature means deterministic sampling
Expand All @@ -221,15 +217,8 @@ def from_sampling_params(
idx_greedy += 1

# param.top_logprobs is zero if logprob is not used
list_mask_top_logprob[param.top_logprobs] = True
list_mask_top_logprob[param.top_logprobs][batch_idx] = param.logprobs
has_logprob |= param.logprobs
idxs_logprob[param.top_logprobs] += 1
logprob_index_map[batch_idx] = (
param.logprobs,
param.top_logprobs,
idxs_logprob[param.top_logprobs],
)

apply_penalty |= (
abs(param.presence_penalty) >= SAMPLING_EPS
or abs(param.frequency_penalty) >= SAMPLING_EPS
Expand Down Expand Up @@ -297,7 +286,6 @@ def from_sampling_params(
apply_top_p_top_k,
apply_penalty,
apply_bias,
logprob_index_map,
has_logprob,
sampling_tensors,
sampling_params,
Expand Down Expand Up @@ -415,38 +403,37 @@ def _is_safe_to_sample(prob_like):
(batch_size,), fill_value=None, dtype=RawLogprobsInfo
)
if sampling_metadata.has_logprob:
all_top_logprobs = np.empty((LOGPROB_TOP_K_MAX + 1, batch_size), dtype=np.int64)

all_top_logprobs = [torch.tensor([], device=logits.device)]
all_top_tokens = [torch.tensor([], device=logits.device)]
# If everything is random sampling, save one extra softmax
if not sampling_metadata.has_greedy:
assert probs_random is not None
logprobs = torch.log(probs_random)
else:
logprobs = torch.log_softmax(logits, dim=-1)

for lp_idx in range(LOGPROB_TOP_K_MAX):
logprob_topk = lp_idx + 1
mask_t = sampling_metadata.sampling_tensors.mask_top_logprob[logprob_topk]
top_logprobs, top_tokens = torch.topk(
logprobs[mask_t], k=logprob_topk, dim=-1, largest=True, sorted=True
)
all_top_logprobs.append(top_logprobs)
all_top_tokens.append(top_tokens)
# Redudandant but vectorized
extended_logprobs = logprobs.repeat((LOGPROB_TOP_K_MAX + 1, 1, 1))
all_top_logprobs, all_top_tokens = torch.topk(
extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True
)

# recover original batch order
# TODO: Can we vectorize this?
for batch_idx in range(batch_size):
is_logprobs, logprob_topk, idx = sampling_metadata.logprob_index_map[
batch_idx
]
next_token = next_tokens[batch_idx]
is_logprobs = sampling_metadata.sampling_params[batch_idx].logprobs
mask = sampling_metadata.sampling_tensors.mask_top_logprob
top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs
print(
all_top_tokens.shape,
top_k,
all_top_tokens[mask][batch_idx].shape,
all_top_tokens[mask][batch_idx][:top_k],
all_top_tokens[mask][batch_idx][:top_k].shape,
)
if is_logprobs:
logprob_infos[batch_idx] = RawLogprobsInfo(
current_token_id=next_token,
current_logprob=logprobs[batch_idx][next_token],
top_token_ids=all_top_tokens[logprob_topk][idx],
top_logprobs=all_top_logprobs[logprob_topk][idx],
top_token_ids=all_top_tokens[mask][batch_idx][:top_k],
top_logprobs=all_top_logprobs[mask][batch_idx][:top_k],
)

return SamplingOutput(next_tokens, logprob_infos)
9 changes: 7 additions & 2 deletions serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,16 @@ def _test_logprobs(
engine,
num_requests=10,
):
prompt = "hi could you please implement merge sort?"
prompts = [
"Hi could you please implement merge sort?",
"What is the best city in the world?",
"Can you write a poem for Seattle?",
"Describe lion for kids.",
]
requests = [
create_request(
idx=str(n),
prompt=prompt,
prompt=random.choice(prompts),
temp=0,
freq_pen=0,
pre_pen=0,
Expand Down

0 comments on commit ced3c12

Please sign in to comment.