Skip to content

Commit

Permalink
resolved merge conflict, move metadata preparation earlier
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Feb 2, 2024
1 parent c419d84 commit 469e817
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 185 deletions.
1 change: 0 additions & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class EvalMultiQueryRequest:


RequestType = Union[PrefillRequest, DecodeRequest, EvalMultiQueryRequest]
RequestsType = Sequence[RequestType]


@dataclass
Expand Down
147 changes: 80 additions & 67 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import structlog
import numpy as np
import torch
import tvm

from .paged_cache_manager import CacheManager
from ..engine import (
Expand All @@ -19,9 +20,9 @@
PrefillRequest,
EvalMultiQueryRequest,
RequestType,
RequestsType,
TextGenerationResult,
)
from .sampler import sample, adjust_logits, SamplingMetadata


LOG = structlog.stdlib.get_logger(__name__)
Expand Down Expand Up @@ -280,72 +281,69 @@ def _is_safe_to_sample(prob_like):
"""


def update_tokens_frequency(
request: RequestType,
new_token: int
):
if not new_token in request.sampling_params.appeared_tokens_freq:
request.sampling_params.appeared_tokens_freq[new_token] = 0
request.sampling_params.appeared_tokens_freq[new_token] += 1


def append_text_gen_res(
outputs: List[TextGenerationResult],
def prepare_textgen_result(
request: RequestType,
new_token: List[int],
sequence_id: SequenceId,
logprob_info: Optional[RawLogprobsInfos],
err_msg: Optional[str]=None,
err_msg: Optional[str] = None,
) -> List[TextGenerationResult]:
if sequence_id.sequence_index == PROMPT_SEQEUNCE_INDEX:
assert isinstance(request, PrefillRequest)
for seq_id in range(request.num_sequence): # type: ignore
outputs.append(
TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=new_token,
error=err_msg,
logprob_info=logprob_info,
)
)
else:
outputs.append(
TextGenerationResult(
sequence_id=sequence_id,
return TextGenerationResult(
sequence_id=SequenceId(sequence_id.request_id, seq_id),
generated_tokens=new_token,
error=err_msg,
logprob_info=logprob_info,
)
else:
return TextGenerationResult(
sequence_id=sequence_id,
generated_tokens=new_token,
error=err_msg,
logprob_info=logprob_info,
)
return outputs


def sample_from_logits(
logits, #: Union[tvm.nd.NDArray, torch.Tensor],
logits: Union[tvm.nd.NDArray, torch.Tensor],
sequence_ids: List[SequenceId],
requests: RequestsType,
vocab_size,
request_maps: dict[SequenceId, RequestType],
sampling_metadata: SamplingMetadata,
vocab_size: int,
copy_stream: torch.cuda.Stream,
torch_dtype: torch.dtype,
torch_dev: str,
) -> List[TextGenerationResult]:
pass
"""
assert logits.shape[0] == len(requests)
assert logits.shape[0] == len(request_maps)
# Convert to torch tensors if logits are in tvm ndarray
if isinstance(logits, tvm.nd.NDArray):
logits = torch.from_dlpack(logits)

sampling_params = [req.sampling_params for req in requests]
outputs: List[TextGenerationResult] = []
# synchronization point for sampling tensors
# wait until all the tensors are loaded on GPU
torch.cuda.current_stream().wait_stream(copy_stream)
logits = adjust_logits(logits, sampling_metadata, vocab_size)

try:
next_tokens, logprob_infos = sample(logits, sampling_params, vocab_size)
assert next_tokens is not None
for i, (sequence_id, new_token) in enumerate(zip(sequence_ids, next_tokens)):
update_tokens_frequency(requests[i], new_token)
outputs = append_text_gen_res(
outputs,
requests[i],
[new_token],
sequence_id,
get_logprob_infos(i, logprob_infos),
next_tokens_map = sample(
sequence_ids,
logits,
sampling_metadata,
)
outputs: List[TextGenerationResult] = []
for sequence_id, new_token in next_tokens_map:
request = request_maps[sequence_id]
request.sampling_params.output_tokens.append(new_token)
outputs.append(
prepare_textgen_result(
request,
[new_token],
sequence_id,
None, # get_logprob_infos(i, logprob_infos),
)
)
return outputs
except RuntimeError:
# Fallback to per-token sampling in case some logits values are corrupted.
Expand All @@ -354,38 +352,53 @@ def sample_from_logits(
" or element < 0"
)

for i, (sequence_id, logits_per_token, sampling_param) in enumerate(
zip(sequence_ids, torch.from_dlpack(logits), sampling_params)
for sequence_id, logits_per_token, sampling_param in zip(
sequence_ids, torch.from_dlpack(logits), sampling_metadata.sampling_params
):
maybe_new_token, logprob_infos = sample(
# NOTE: Rerun the preparation for simplicity.
# Assume this code path is taken rarely and the recomputation overhead is
# marginal.
with torch.cuda.stream(copy_stream):
sampling_metadata = SamplingMetadata.from_sampling_params(
[sampling_param],
torch_dtype,
torch_dev,
vocab_size,
)
torch.cuda.current_stream().wait_stream(copy_stream)

# TODO:logprob
maybe_next_tokens_map = sample(
[sequence_id],
torch.unsqueeze(logits_per_token, 0),
[sampling_param],
sampling_metadata,
vocab_size,
check_safety=True,
)
if maybe_new_token is not None:
new_token = maybe_new_token[0]
update_tokens_frequency(requests[i], new_token)
outputs = append_text_gen_res(
outputs,
requests[i],
[new_token],
sequence_id,
get_logprob_infos(0, logprob_infos),
# Valid sample
request = request_maps[sequence_id]
if maybe_next_tokens_map is not None:
request.sampling_params.output_tokens.append(new_token)
outputs.append(
prepare_textgen_result(
request,
[new_token], # new_token
sequence_id,
None, # get_logprob_infos(0, logprob_infos),
)
)
else:
outputs = append_text_gen_res(
outputs,
requests[i],
[], # new_token
sequence_id,
get_logprob_infos(0, logprob_infos),
err_msg,
outputs.append(
prepare_textgen_result(
request,
[], # new_token
sequence_id,
None, # get_logprob_infos(0, logprob_infos),
err_msg,
)
)

return outputs
"""


def prepare_inputs(
Expand Down
5 changes: 3 additions & 2 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import numpy as np
import structlog
from dataclasses import dataclass
from typing import List, Union, Optional, Tuple
import tvm
from typing import List, Optional, Tuple
from ..engine import (
SamplingParams,
SamplingType,
Expand Down Expand Up @@ -272,6 +271,7 @@ class SamplingMetadata:
apply_penalty: bool
apply_bias: bool
sampling_tensors: SamplingTensors
sampling_params: List[SamplingParams]

@classmethod
def from_sampling_params(
Expand Down Expand Up @@ -359,6 +359,7 @@ def from_sampling_params(
apply_penalty,
apply_bias,
sampling_tensors,
sampling_params,
)


Expand Down
Loading

0 comments on commit 469e817

Please sign in to comment.