Skip to content

Commit

Permalink
wip:logprob
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Feb 2, 2024
1 parent a60f6f8 commit b4d8129
Showing 1 changed file with 54 additions and 75 deletions.
129 changes: 54 additions & 75 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,47 +100,6 @@ def get_raw_logprob_info(
)


def get_logprob_indices(
sampling_params: List[SamplingParams],
num_seq: int,
) -> Tuple[List[Tuple[int, int, int]], List[Tuple[int, int, int]]]:
lgp_inds_greedy: List[Tuple[int, int, int]] = []
lgp_inds_random: List[Tuple[int, int, int]] = []

g_ind = 0
r_ind = 0
for i in range(num_seq):
sampling_param = sampling_params[i]
if sampling_param.sampling_type == SamplingType.RANDOM:
if sampling_param.logprobs:
lgp_inds_random.append((i, r_ind, sampling_param.top_logprobs))
r_ind = r_ind + 1
else:
if sampling_param.logprobs:
lgp_inds_greedy.append((i, g_ind, sampling_param.top_logprobs))
g_ind = g_ind + 1

return lgp_inds_greedy, lgp_inds_random


def get_raw_logprob_infos(
logprob_infos, #: RawLogprobsInfos,
indices: List[Tuple[int, int, int]],
logits: torch.Tensor,
token_ids: torch.Tensor,
): # -> RawLogprobsInfos:
for i, ind, top_logprobs in indices:
# ind : batch suence id
# i : greedy id
logprob_infos[i] = get_raw_logprob_info(
logits[ind],
token_ids[ind],
top_logprobs,
)

return logprob_infos


def check_logprob_infos(
logprob_infos, #: RawLogprobsInfos,
): # -> Optional[RawLogprobsInfos]:
Expand All @@ -154,11 +113,11 @@ def check_logprob_infos(
return None


# TODO: Add logprob
@dataclass
class SamplingTensors:
mask_random: torch.Tensor
mask_greedy: torch.Tensor
mask_top_logprob: torch.Tensor
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
Expand All @@ -175,6 +134,7 @@ def from_lists(
dtype,
dev,
list_mask_random: List[bool],
list_mask_top_logprob: List[List[bool]],
list_temperatures: List[float],
list_top_ps: List[float],
list_top_ks: List[int],
Expand All @@ -195,6 +155,8 @@ def from_lists(
mask_greedy = torch.logical_not(
mask_random,
)
# `mask_top_logprob` will be on cpu
mask_top_logprob = torch.from_numpy(list_mask_top_logprob)
temp = torch.tensor(
list_temperatures,
dtype=dtype,
Expand Down Expand Up @@ -249,9 +211,11 @@ def from_lists(
device="cpu",
pin_memory=True,
)

return cls(
mask_random,
mask_greedy,
mask_top_logprob,
temp.to(device=dev, non_blocking=True),
top_ps.to(device=dev, non_blocking=True),
top_ks.to(device=dev, non_blocking=True),
Expand All @@ -266,13 +230,16 @@ def from_lists(

@dataclass
class SamplingMetadata:
# mapping for <batch index, tuple(sampling type, sampling index)>
# mapping for <batch index, tuple(sampling type, greedy/random sampling index)>
index_map: dict[int, Tuple[SamplingType, int]]
has_random: bool
has_greedy: bool
apply_top_p_top_k: bool
apply_penalty: bool
apply_bias: bool
# mapping for <batch index, tuple(top-k, top-k index)>
logprob_index_map: dict[int, Tuple[int, int]]
has_logprob: bool
sampling_tensors: SamplingTensors
sampling_params: List[SamplingParams]

Expand Down Expand Up @@ -303,6 +270,13 @@ def from_sampling_params(
idx_random = -1
idx_greedy = -1
batch_size = len(sampling_params)
# index 0 is for non-logprob requests
has_logprob = False
logprob_index_map = {}
list_mask_top_logprob = np.full(
((LOGPROB_TOP_K_MAX) + 1, batch_size), False, dtype=bool
)
idxs_logprob = [-1] * ((LOGPROB_TOP_K_MAX) + 1)
for batch_idx, param in enumerate(sampling_params):
# Prepare temperature
# NOTE: Zero temperature means deterministic sampling
Expand All @@ -325,6 +299,15 @@ def from_sampling_params(
idx_greedy += 1
index_map[batch_idx] = (SamplingType.GREEDY, idx_greedy)

# param.top_logprobs is zero if logprob is not used
list_mask_top_logprob[param.top_logprobs] = True
has_logprob |= param.logprobs
idxs_logprob[param.top_logprobs] += 1
logprob_index_map[batch_idx] = (
param.top_logprobs,
idxs_logprob[param.top_logprobs],
)

list_past_output_tokens.append(param.output_tokens)

apply_penalty |= (
Expand Down Expand Up @@ -357,6 +340,7 @@ def from_sampling_params(
dtype,
dev,
list_mask_random,
list_mask_top_logprob,
list_temperatures,
list_top_ps,
list_top_ks,
Expand All @@ -375,6 +359,8 @@ def from_sampling_params(
apply_top_p_top_k,
apply_penalty,
apply_bias,
logprob_index_map,
has_logprob,
sampling_tensors,
sampling_params,
)
Expand Down Expand Up @@ -467,13 +453,6 @@ def _is_safe_to_sample(prob_like):
sampling_tensors = sampling_metadata.sampling_tensors

batch_size = logits.shape[0]
logprob_infos = [None] * batch_size
# lgp_inds_greedy, lgp_inds_random = get_logprob_indices(
# sampling_metadata.sampling_params,
# batch_size,
# )
# (batch_index, rand_ind, sampling_param.top_logprobs)

mask_greedy_t, mask_random_t = (
sampling_tensors.mask_greedy,
sampling_tensors.mask_random,
Expand All @@ -482,38 +461,13 @@ def _is_safe_to_sample(prob_like):
logits_greedy = logits[mask_greedy_t]
res_greedy = torch.argmax(logits_greedy, -1)

# logprob_infos = get_raw_logprob_infos(
# logprob_infos, # logprobinfo
# lgp_inds_greedy, # indices
# logits_greedy, # logits
# res_greedy, # sampled token_ids
# )

# Case when there's only greedy sampling
# if logits_greedy.shape[0] == batch_size:
# torch.cuda.nvtx.range_pop()
# return res_greedy, check_logprob_infos(logprob_infos)

if sampling_metadata.has_random:
logits_random = logits[mask_random_t]
probs = torch.softmax(logits_random, dim=-1)
if check_safety and not _is_safe_to_sample(probs):
return None

res_random = _multinomial(probs, 1)[:, 0]

# logprob_infos = get_raw_logprob_infos(
# logprob_infos,
# lgp_inds_random,
# logits_random,
# res_random,
# )

# Case when there's only random sampling
# if logits_random.shape[0] == batch_size:
# torch.cuda.nvtx.range_pop()
# return res_random, check_logprob_infos(logprob_infos)

# Prepare output
# Send results to CPU and convert them into numpy
res_greedy = list(res_greedy.cpu().numpy()) if res_greedy is not None else list()
Expand All @@ -530,6 +484,31 @@ def _is_safe_to_sample(prob_like):
else:
assert sampling_idx < len(res_greedy)
next_tokens.append(res_greedy[sampling_idx])

logprob_infos = [None] * batch_size
if sampling_metadata.has_logprob:
all_top_logprobs, all_top_tokens = [[]], [[]]
logprobs = torch.log_softmax(logits, dim=-1)
for lp_idx in range(LOGPROB_TOP_K_MAX):
logprob_topk = lp_idx + 1
mask_t = sampling_metadata.sampling_tensors.mask_top_logprob[logprob_topk]
top_logprobs, top_tokens = torch.topk(
logprobs[mask_t], k=logprob_topk, dim=-1, largest=True, sorted=True
)
all_top_logprobs.append(top_logprobs)
all_top_tokens.append(top_tokens)

# recover original batch order
for batch_idx in range(batch_size):
logprob_topk, idx = sampling_metadata.logprob_index_map[batch_idx]
next_token = next_tokens[batch_idx]
logprob_infos[batch_idx] = RawLogprobsInfo(
current_token_id=next_token,
current_logprob=logprobs[next_token],
top_token_ids=all_top_logprobs[logprob_topk][idx],
top_logprobs=all_top_tokens[logprob_topk][idx],
)

# TODO: Recover original order
# mixed: check_logprob_infos(logprob_infos)
return SamplingOutput(next_tokens, logprob_infos)

0 comments on commit b4d8129

Please sign in to comment.