Skip to content

Commit

Permalink
Fix after #214 (#215)
Browse files Browse the repository at this point in the history
* transfer prompt mask from sampling params to request state. use torch tensor instead of list

* fix prompt mask for EvalMultiQueryRequest

* clean code

* update sampler tests

* fix after rebase

* add device to comment
  • Loading branch information
vvchernov authored Feb 16, 2024
1 parent 7495bd0 commit 5588d17
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 43 deletions.
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ class RequestState:

request_id: RequestId
prompt_token_ids: list[int]
prompt_mask: Optional[list[bool]]
prompt_mask: Optional[torch.Tensor]
sampling_params: SamplingParams
generation_sequences: list[GenerationSequence]
stopping_criteria: StoppingCriteria
Expand Down
32 changes: 21 additions & 11 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
Common utilites for engine classes.
"""

import torch
import time
from typing import Tuple, Deque, Dict, Optional, Callable, List
from collections import deque
from threading import Condition, Lock
import numpy as np

import structlog

Expand Down Expand Up @@ -58,10 +58,14 @@ def get_new_request_state(
# TODO: Currently, always create this. But we only need this for the requests with repetition penalty
# Follow-up and optimize when it has been stabilized.
# Create prompt mask for repetition penalty
tokens = np.array([prompt_token_ids], dtype=np.int64)
prompt_mask = np.zeros((vocab_size + 1,), dtype=bool)
prompt_mask[tokens] = True
prompt_mask = list(prompt_mask[:vocab_size])
tokens=torch.tensor(prompt_token_ids, dtype=torch.long)
vocab_size = request.sampling_params.vocab_size
bin_counts = torch.zeros((vocab_size + 1,),
dtype=torch.long,
device=tokens.device) # CPU
bin_counts.scatter_add_(0, tokens, torch.ones_like(tokens))
bin_counts = bin_counts[:vocab_size]
prompt_mask = bin_counts > 0

validation_err = None
if request.validate_tokens is not None:
Expand Down Expand Up @@ -93,7 +97,7 @@ def get_new_request_state(

# Based on vllm: https://github.com/vllm-project/vllm/pull/984
def detokenize_incrementally(
prompt_tokens: list[int],
prompt_tokens: List[int],
generation_sequence: GenerationSequence,
tokenizer: TokenizerP,
new_token_id: Optional[int] = None,
Expand Down Expand Up @@ -228,8 +232,8 @@ def prepare_logprob(

def prepare_output(
gen_seq: GenerationSequence,
new_token_ids: list[int],
prompt_token_ids: list[int],
new_token_ids: List[int],
prompt_token_ids: List[int],
logprob_info,
tokenizer: TokenizerP,
stopping_criteria: StoppingCriteria,
Expand All @@ -253,11 +257,11 @@ def prepare_output(


def get_requests_to_process(
current_states: list[RequestState],
current_states: List[RequestState],
cache_manager: KVCacheManager,
tokenizer: TokenizerP,
) -> Tuple[list[RequestType], bool, int]:
requests: list[RequestType] = []
) -> Tuple[List[RequestType], bool, int]:
requests: List[RequestType] = []
# TODO: consider having hybrid batch if the underlying attention kernel supports
# mixing prefill and decode.
is_prompt_batch = any(not state.is_prefilled for state in current_states)
Expand Down Expand Up @@ -289,10 +293,16 @@ def get_requests_to_process(
token_counts += len(state.prompt_token_ids)

for gen_seq in state.generation_sequences:
# TODO(vvchernov): This is for repetion penalty
# Not obvious EvalMultiQueryRequest needs this
# Now empty instead of state.prompt_mask
vocab_size = state.sampling_params.vocab_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
requests.append(
EvalMultiQueryRequest(
sequence_id=gen_seq.seq_id,
num_past_tokens=state.prompt_len,
prompt_mask=prompt_mask,
queries=EvictedTokens(gen_seq.generated_token_ids),
sampling_params=state.sampling_params,
)
Expand Down
6 changes: 4 additions & 2 deletions serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from dataclasses import dataclass
from typing import Optional, Protocol, Union, List, Sequence, Any
import torch

from .base import (
ChatMessage,
Expand All @@ -21,7 +22,7 @@ class PrefillRequest:
request_id: RequestId
# `token_ids` contains prompt token ids
token_ids: List[int]
prompt_mask: Optional[List[bool]]
prompt_mask: Optional[torch.Tensor]
# Number of sequences to generate
num_sequence: int
sampling_params: SamplingParams
Expand All @@ -37,7 +38,7 @@ class PrefillRequest:
class DecodeRequest:
sequence_id: SequenceId
prompt_token_counts: int
prompt_mask: Optional[List[bool]]
prompt_mask: Optional[torch.Tensor]
# Decoded tokens for this sequence
token_ids: List[int]
sampling_params: SamplingParams
Expand Down Expand Up @@ -65,6 +66,7 @@ def num_tokens(self):
class EvalMultiQueryRequest:
sequence_id: SequenceId
num_past_tokens: int
prompt_mask: Optional[torch.Tensor]
queries: Union[DraftTokens, EvictedTokens]
sampling_params: SamplingParams

Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def sample_from_logits(
torch_dtype: torch.dtype,
torch_dev: str,
past_decode_tokens: List[List[int]],
prompt_masks: List[List[bool]],
prompt_masks: List[torch.Tensor],
) -> List[TextGenerationResult]:
batch_size = logits.shape[0]
assert batch_size == len(requests)
Expand Down
8 changes: 2 additions & 6 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ def from_lists(
)
# `mask_top_logprob` will be on cpu
mask_top_logprob = torch.from_numpy(list_mask_top_logprob)
mask_prompt = torch.tensor(
list_mask_prompt,
dtype=torch.bool,
device="cpu",
)
mask_prompt = torch.stack(list_mask_prompt)
temp = torch.tensor(
list_temperatures,
dtype=dtype,
Expand Down Expand Up @@ -256,7 +252,7 @@ def from_sampling_params(
cls,
sampling_params: List[SamplingParams],
list_past_output_tokens: List[List[int]],
list_mask_prompt: List[List[bool]],
list_mask_prompt: List[torch.Tensor],
dtype: torch.dtype,
dev: str,
vocab_size: int,
Expand Down
7 changes: 4 additions & 3 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def generate_multi_query(
last_query_offsets: List[int] = []
sampling_params = []
past_decode_tokens = []
prompt_masks: List[List[bool]] = []
prompt_masks: List[torch.Tensor] = []
for request in requests:
assert not isinstance(request.queries, DraftTokens)
sequence_ids.append(request.sequence_id)
Expand All @@ -274,8 +274,9 @@ def generate_multi_query(
last_query_offsets[-1] + request.queries.num_tokens
)
sampling_params.append(request.sampling_params)
# TODO: Empty mask for now. This is for repetion penalty
prompt_masks.append([False])
# TODO(vvchernov): This is for repetion penalty
# Not obvious EvalMultiQueryRequest needs this
prompt_masks.append(request.prompt_mask)
# Use `vocab_size` as a padding
past_decode_tokens.append([self.vocab_size, *request.queries.token_ids])

Expand Down
42 changes: 23 additions & 19 deletions serve/tests/unittest/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=No
if past_output_tokens is None:
past_output_tokens = [[] for _ in range(batch_size)]
if prompt_masks is None:
prompt_masks = [[] for _ in range(batch_size)]
# Prepare empty prompt mask
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
prompt_masks = [prompt_mask] * batch_size
_copy_stream: torch.cuda.Stream = torch.cuda.Stream()
with torch.cuda.stream(_copy_stream):
sampling_state = SamplingState.from_sampling_params(
Expand All @@ -29,7 +31,7 @@ def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=No
return sampling_state


def _test_temperature(temp=0, batch_size=1):
def test_temperature(temp=0, batch_size=1):
shape = (batch_size, vocab_size)
logits = torch.rand(shape, dtype=dtype, device=dev)
sampling_param = SamplingParams(temperature=temp)
Expand All @@ -41,7 +43,7 @@ def _test_temperature(temp=0, batch_size=1):
assert torch.allclose(expected, new_logits)


def _test_logit_bias_checker():
def test_logit_bias_checker():
# logit bias must be [-100, 100]
with pytest.raises(ValueError):
logit_bias = {1: 2, 3: 105, 2: 2}
Expand Down Expand Up @@ -78,7 +80,7 @@ def _test_logit_bias_checker():
get_sampling_state([sampling_param])


def _test_logit_bias():
def test_logit_bias():
# test single batch
batch_size = 1
shape = (batch_size, vocab_size)
Expand Down Expand Up @@ -112,7 +114,7 @@ def _test_logit_bias():
assert torch.allclose(expected, new_logits)


def _test_penalties_checker():
def test_penalties_checker():
get_sampling_state([SamplingParams(presence_penalty=-1.0)])
get_sampling_state([SamplingParams(frequency_penalty=-1.0)])
get_sampling_state([SamplingParams(repetition_penalty=0.7)])
Expand Down Expand Up @@ -143,15 +145,16 @@ def _test_penalties_checker():
)


def _test_penalties():
def test_penalties():
# TODO(vvchernov): Add test for repetition penalty
batch_size = 1
shape = (batch_size, vocab_size)
logits = torch.rand(shape, dtype=dtype, device=dev)
presence_penalties = [0.8]
frequency_penalties = [0.3]
past_output_tokens = [[2, 2, 2, 3]]
prompt_masks = [[False] * vocab_size] * batch_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
prompt_masks = [prompt_mask] * batch_size

def prepare_metadata(past_output_tokens):
count_map = []
Expand Down Expand Up @@ -202,7 +205,8 @@ def get_expected_result(
presence_penalties = [0.8, 0.7, -0.8]
frequency_penalties = [-0.3, 2.0, 1.2]
past_output_tokens = [[2, 2, 2, 3, 5], [3, 1, 2, 4], [3, 3, 1]]
prompt_masks = [[False] * vocab_size] * batch_size
prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool)
prompt_masks = [prompt_mask] * batch_size

count_map, mask = prepare_metadata(past_output_tokens)
expected = get_expected_result(
Expand All @@ -225,7 +229,7 @@ def get_expected_result(
assert torch.allclose(expected, new_logits)


def _test_top_p_top_k_checker():
def test_top_p_top_k_checker():
get_sampling_state([SamplingParams(top_p=0.8)])
get_sampling_state([SamplingParams(top_k=3)])

Expand All @@ -248,7 +252,7 @@ def _test_top_p_top_k_checker():
get_sampling_state([SamplingParams(top_k=-2)])


def _test_top_p_top_k():
def test_top_p_top_k():
def get_expected_result(logits, top_pks, filter_value=-float("Inf")):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
Expand Down Expand Up @@ -320,7 +324,7 @@ def get_expected_result(logits, top_pks, filter_value=-float("Inf")):
assert torch.allclose(expected, new_logits)


def _test_mixture_of_requests():
def test_mixture_of_requests():
# Mixed greedy & top_p/top_ks
batch_size = 6
shape = (batch_size, vocab_size)
Expand All @@ -341,11 +345,11 @@ def _test_mixture_of_requests():


if __name__ == "__main__":
_test_temperature()
_test_logit_bias_checker()
_test_logit_bias()
_test_penalties_checker()
_test_penalties()
_test_top_p_top_k_checker()
_test_top_p_top_k()
_test_mixture_of_requests()
test_temperature()
test_logit_bias_checker()
test_logit_bias()
test_penalties_checker()
test_penalties()
test_top_p_top_k_checker()
test_top_p_top_k()
test_mixture_of_requests()

0 comments on commit 5588d17

Please sign in to comment.