Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Feb 2, 2024
1 parent 469e817 commit a60f6f8
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 192 deletions.
2 changes: 2 additions & 0 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class RawLogprobsInfo:
top_token_ids: Optional[np.ndarray]
top_logprobs: Optional[np.ndarray]


# TODO(sunggg): can we delete this?
RawLogprobsInfos = List[Optional[RawLogprobsInfo]]


Expand Down
160 changes: 9 additions & 151 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
RequestType,
TextGenerationResult,
)
from .sampler import sample, adjust_logits, SamplingMetadata
from .sampler import sample, adjust_logits, SamplingMetadata, SamplingOutput


LOG = structlog.stdlib.get_logger(__name__)
Expand Down Expand Up @@ -50,8 +50,6 @@ def get_num_cache_blocks(
)


"""
def get_logprob_infos(
i: int,
logprob_infos: Optional[RawLogprobsInfos],
Expand Down Expand Up @@ -139,146 +137,6 @@ def check_logprob_infos(
if check:
return logprob_infos
return None
def sample(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sampling_params: List[SamplingParams],
vocab_size: int,
check_safety=False,
) -> Optional[Tuple[np.ndarray, Optional[RawLogprobsInfos]]]:
def _is_safe_to_sample(prob_like):
return (
torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0))
== 0
)
torch.cuda.nvtx.range_push(f"sample {logits.shape}")
logits = torch.from_dlpack(logits)
num_seq = len(sampling_params)
mask_random_cpu = torch.tensor(
[p.sampling_type == SamplingType.RANDOM for p in sampling_params],
dtype=torch.bool,
)
mask_greedy_cpu = torch.logical_not(mask_random_cpu)
if logits.device == torch.device("cpu"):
mask_random_dvc = mask_random_cpu
mask_greedy_dvc = mask_greedy_cpu
else: # gpu
mask_random_dvc = mask_random_cpu.to(logits.device)
mask_greedy_dvc = mask_greedy_cpu.to(logits.device)
logits_greedy = logits[mask_greedy_dvc]
logprob_infos: RawLogprobsInfos = [None] * num_seq
lgp_inds_greedy, lgp_inds_random = get_logprob_indices(
sampling_params,
num_seq,
)
if logits_greedy.shape[0] > 0:
res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy()
logprob_infos = get_raw_logprob_infos(
logprob_infos,
lgp_inds_greedy,
logits_greedy,
res_greedy,
)
# Case when there's only greedy sampling
if logits_greedy.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_greedy, check_logprob_infos(logprob_infos)
temperatures = []
top_ps = []
top_ks = []
divide_by_temperature = False
do_top_p = False
do_top_k = False
for i in range(num_seq):
param = sampling_params[i]
freq = param.appeared_tokens_freq
if param.sampling_type == SamplingType.RANDOM:
temperatures.append(param.temperature)
top_ps.append(param.top_p)
top_ks.append(param.top_k if param.top_k != -1 else vocab_size)
divide_by_temperature |= temperatures[-1] != 1.0
do_top_p |= top_ps[-1] < 1.0
do_top_k |= top_ks[-1] != vocab_size
# TODO(vvchernov): need to strictly define order of using penalties and logit bias or
# prohibit simultaneous using of them. At the latter case it can be LogitProcessor
if (
not param.presence_penalty == 0.0 or not param.frequency_penalty == 0
) and bool(freq):
index = torch.from_numpy(np.array(list(freq.keys()))).to(
device=logits.device
)
src = (
torch.from_numpy(np.array(list(freq.values())))
.type_as(logits)
.to(device=logits.device)
)
logits[i][index] -= (
src * param.frequency_penalty + param.presence_penalty
)
if not param.repetition_penalty == 1.0 and bool(freq):
index = torch.from_numpy(np.array(list(freq.keys()))).to(
device=logits.device
)
logits[i][index] /= param.repetition_penalty
if param.logit_bias:
logits[i][param.logit_bias_index] += (
torch.Tensor(param.logit_bias_value)
.type_as(logits)
.to(device=logits.device)
)
logits_random = logits[mask_random_dvc]
if divide_by_temperature:
t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device)
logits_random.div_(t.unsqueeze(dim=1))
if do_top_p or do_top_k:
logits_random = _apply_top_p_top_k(logits_random, top_ps, top_ks)
probs = torch.softmax(logits_random, dim=-1)
if check_safety and not _is_safe_to_sample(probs):
torch.cuda.nvtx.range_pop()
return None
res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy()
logprob_infos = get_raw_logprob_infos(
logprob_infos,
lgp_inds_random,
logits_random,
res_random,
)
# Case when there's only random sampling
if logits_random.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_random, check_logprob_infos(logprob_infos)
res = np.empty((num_seq,), dtype=np.int32)
res[mask_random_cpu] = res_random
if logits_greedy.shape[0] > 0:
res[mask_greedy_cpu] = res_greedy
torch.cuda.nvtx.range_pop()
return res, check_logprob_infos(logprob_infos)
"""


def prepare_textgen_result(
Expand Down Expand Up @@ -309,14 +167,15 @@ def prepare_textgen_result(
def sample_from_logits(
logits: Union[tvm.nd.NDArray, torch.Tensor],
sequence_ids: List[SequenceId],
request_maps: dict[SequenceId, RequestType],
requests: List[RequestType],
sampling_metadata: SamplingMetadata,
vocab_size: int,
copy_stream: torch.cuda.Stream,
torch_dtype: torch.dtype,
torch_dev: str,
) -> List[TextGenerationResult]:
assert logits.shape[0] == len(request_maps)
batch_size = logits.shape[0]
assert batch_size == len(requests)
# Convert to torch tensors if logits are in tvm ndarray
if isinstance(logits, tvm.nd.NDArray):
logits = torch.from_dlpack(logits)
Expand All @@ -327,14 +186,15 @@ def sample_from_logits(
logits = adjust_logits(logits, sampling_metadata, vocab_size)

try:
next_tokens_map = sample(
sequence_ids,
sampling_output: SamplingOutput = sample(
logits,
sampling_metadata,
)

outputs: List[TextGenerationResult] = []
for sequence_id, new_token in next_tokens_map:
request = request_maps[sequence_id]
for i, new_token in enumerate(sampling_output.next_tokens):
sequence_id = sequence_ids[i]
request = requests[i]
request.sampling_params.output_tokens.append(new_token)
outputs.append(
prepare_textgen_result(
Expand Down Expand Up @@ -369,10 +229,8 @@ def sample_from_logits(

# TODO:logprob
maybe_next_tokens_map = sample(
[sequence_id],
torch.unsqueeze(logits_per_token, 0),
sampling_metadata,
vocab_size,
check_safety=True,
)
# Valid sample
Expand Down
Loading

0 comments on commit a60f6f8

Please sign in to comment.