From 4e2bbd390d64850a47475b5436fb611000d080bf Mon Sep 17 00:00:00 2001 From: Viktor Ferenczi Date: Sun, 1 Oct 2023 14:02:14 +0200 Subject: [PATCH 1/2] Allowed tokens --- tests/samplers/test_sampler.py | 32 +++++++- vllm/model_executor/layers/sampler.py | 107 +++++++++++++++----------- vllm/sampling_params.py | 41 ++++++---- 3 files changed, 119 insertions(+), 61 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index a5f55d50fbb76..41b74cd7a9947 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -26,7 +26,7 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int + batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), @@ -182,3 +182,33 @@ def test_sampler_mixed(seed: int): continue for nth_output in sequence_output: assert nth_output.output_token in expected_tokens + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_constrained(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + + allowed_token_ids = torch.randint(0, 32000, (batch_size, 100)) + + seq_group_metadata_list = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, allowed_token_ids=allowed_token_ids[i].tolist()), + block_tables={0: [1]}, + )) + + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampler_output = sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata) + expected = torch.argmax(fake_logits, dim=-1) + for i, sequence_output in enumerate(sampler_output): + for nth_output in sequence_output: + assert nth_output.output_token in allowed_token_ids[i].tolist() + assert nth_output.output_token == expected[i].item() diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index a17dfcfbaf16f..c74b6c3709b6e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -33,11 +33,11 @@ def __init__(self, vocab_size: int) -> None: self.vocab_size = vocab_size def forward( - self, - embedding: torch.Tensor, - hidden_states: torch.Tensor, - input_metadata: InputMetadata, - embedding_bias: Optional[torch.Tensor] = None, + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + embedding_bias: Optional[torch.Tensor] = None, ) -> SamplerOutput: # Get the hidden states that we use for sampling. hidden_states = _prune_hidden_states(hidden_states, input_metadata) @@ -46,9 +46,14 @@ def forward( logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) - # Apply presence and frequency penalties. + # Apply sampling constraints + _apply_allowed_token_ids(logits, input_metadata) + + # Get the output tokens output_tokens = _get_output_tokens(input_metadata) assert len(output_tokens) == logits.shape[0] + + # Apply presence and frequency penalties. presence_penalties, frequency_penalties = _get_penalties( input_metadata) assert len(presence_penalties) == logits.shape[0] @@ -99,8 +104,8 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, def _prune_hidden_states( - hidden_states: torch.Tensor, - input_metadata: InputMetadata, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, ) -> torch.Tensor: last_token_indices = {t: [] for t in SamplingType} start_idx = 0 @@ -153,10 +158,10 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: def _apply_penalties( - logits: torch.Tensor, - output_tokens: List[List[int]], - presence_penalties: List[float], - frequency_penalties: List[float], + logits: torch.Tensor, + output_tokens: List[List[int]], + presence_penalties: List[float], + frequency_penalties: List[float], ) -> torch.Tensor: num_seqs, vocab_size = logits.shape for i in range(num_seqs): @@ -219,8 +224,8 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]: def _get_top_p_top_k( - input_metadata: InputMetadata, - vocab_size: int, + input_metadata: InputMetadata, + vocab_size: int, ) -> Tuple[List[float], List[int]]: top_ps: List[float] = [] top_ks: List[int] = [] @@ -237,9 +242,9 @@ def _get_top_p_top_k( def _apply_top_p_top_k( - logits: torch.Tensor, - top_ps: List[float], - top_ks: List[int], + logits: torch.Tensor, + top_ps: List[float], + top_ks: List[int], ) -> torch.Tensor: p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) @@ -266,8 +271,8 @@ def _apply_top_p_top_k( def _get_topk_logprobs( - logprobs: torch.Tensor, - num_logprobs: Optional[int], + logprobs: torch.Tensor, + num_logprobs: Optional[int], ) -> List[Dict[int, float]]: num_seqs = logprobs.size(0) if num_logprobs is None or num_logprobs == 0: @@ -288,12 +293,12 @@ def _get_topk_logprobs( def _build_sequence_outputs( - parent_ids: List[int], - next_token_ids: List[int], - selected_token_logprobs: torch.Tensor, - parent_seq_ids: List[int], - parent_logprobs: torch.Tensor, - num_output_logprobs: Optional[int], + parent_ids: List[int], + next_token_ids: List[int], + selected_token_logprobs: torch.Tensor, + parent_seq_ids: List[int], + parent_logprobs: torch.Tensor, + num_output_logprobs: Optional[int], ) -> List[SequenceOutputs]: # Get top-k log probabilities for the next tokens. next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs) @@ -309,8 +314,8 @@ def _build_sequence_outputs( def _greedy_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - logprobs: torch.Tensor, + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + logprobs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: samples = torch.argmax(logprobs, dim=-1).cpu() sample_idx = 0 @@ -329,9 +334,9 @@ def _greedy_sample( def _random_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], - probs: torch.Tensor, + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + is_prompts: List[bool], + probs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: # Find the maximum best_of value of the prompt phase requests. max_best_of = 1 @@ -353,12 +358,12 @@ def _random_sample( "Prompt input should have only one seq.") parent_ids = [0] * sampling_params.best_of next_token_ids = random_samples[ - sample_idx, :sampling_params.best_of].tolist() + sample_idx, :sampling_params.best_of].tolist() else: # Generation phase. parent_ids = list(range(num_parent_seqs)) next_token_ids = random_samples[sample_idx:sample_idx + - num_parent_seqs, 0].tolist() + num_parent_seqs, 0].tolist() results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs assert sample_idx == probs.size(0) @@ -366,10 +371,10 @@ def _random_sample( def _beam_search_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], - seq_data: Dict[int, SequenceData], - logprobs: torch.Tensor, + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + is_prompts: List[bool], + seq_data: Dict[int, SequenceData], + logprobs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: # We sample 2 * beam_width candidates to make sure that with high # probability we can get `beam_width` candidates in addition to @@ -419,9 +424,9 @@ def _beam_search_sample( def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - input_metadata: InputMetadata, + probs: torch.Tensor, + logprobs: torch.Tensor, + input_metadata: InputMetadata, ) -> SamplerOutput: categorized_seq_group_ids = {t: [] for t in SamplingType} category_num_tokens = {t: 0 for t in SamplingType} @@ -442,9 +447,9 @@ def _sample( if num_tokens == 0: continue category_logprobs = logprobs[category_start_idx:category_start_idx + - num_tokens] + num_tokens] category_probs = probs[category_start_idx:category_start_idx + - num_tokens] + num_tokens] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample(seq_groups, category_logprobs) elif sampling_type == SamplingType.RANDOM: @@ -486,9 +491,9 @@ def _sample( num_results = len(next_token_ids) num_parent_seqs = len(seq_ids) parent_logprobs = category_logprobs[sample_idx:sample_idx + - num_parent_seqs] + num_parent_seqs] selected_token_logprobs = batched_logprobs_query_result[ - result_idx:result_idx + num_results] + result_idx:result_idx + num_results] seq_output = _build_sequence_outputs(parent_ids, next_token_ids, selected_token_logprobs, seq_ids, parent_logprobs, @@ -500,3 +505,19 @@ def _sample( category_start_idx += num_tokens return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))] + + +def _apply_allowed_token_ids( + logits: torch.Tensor, + input_metadata: InputMetadata, +) -> None: + vocab_size = logits.shape[1] + mask = torch.zeros(vocab_size, dtype=torch.bool) + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + assert isinstance(sampling_params, SamplingParams) + allowed_token_ids = sampling_params.allowed_token_ids + if allowed_token_ids: + mask[:] = True + mask[allowed_token_ids] = False + logits[i, mask] = -float("inf") diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 5206eb0b8c4d4..2fda24931b02c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -56,6 +56,8 @@ class SamplingParams: stop_token_ids: List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are sepcial tokens. + allowed_token_ids: List of tokens allowed to be generated by the model. + The returned output will only contain the allowed tokens. ignore_eos: Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. @@ -65,23 +67,24 @@ class SamplingParams: """ def __init__( - self, - n: int = 1, - best_of: Optional[int] = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - use_beam_search: bool = False, - length_penalty: float = 1.0, - early_stopping: Union[bool, str] = False, - stop: Union[None, str, List[str]] = None, - stop_token_ids: List[int] = None, - ignore_eos: bool = False, - max_tokens: int = 16, - logprobs: Optional[int] = None, - skip_special_tokens: bool = True, + self, + n: int = 1, + best_of: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + use_beam_search: bool = False, + length_penalty: float = 1.0, + early_stopping: Union[bool, str] = False, + stop: Union[None, str, List[str]] = None, + stop_token_ids: List[int] = None, + allowed_token_ids: List[int] = None, + ignore_eos: bool = False, + max_tokens: int = 16, + logprobs: Optional[int] = None, + skip_special_tokens: bool = True, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -103,6 +106,10 @@ def __init__( self.stop_token_ids = [] else: self.stop_token_ids = list(stop_token_ids) + if allowed_token_ids is None: + self.allowed_token_ids = [] + else: + self.allowed_token_ids = list(allowed_token_ids) self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.logprobs = logprobs From 2bec500e6cf3d2ef77fd365178356c1cc023390b Mon Sep 17 00:00:00 2001 From: Viktor Ferenczi Date: Sun, 1 Oct 2023 21:04:17 +0200 Subject: [PATCH 2/2] Formatting and coding style fixes --- tests/samplers/test_sampler.py | 6 +- vllm/model_executor/layers/sampler.py | 90 +++++++++++++-------------- vllm/sampling_params.py | 36 +++++------ 3 files changed, 67 insertions(+), 65 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 41b74cd7a9947..64778a3910554 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -26,7 +26,7 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int + batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), @@ -199,7 +199,9 @@ def test_sampler_constrained(seed: int): request_id=f"test_{i}", is_prompt=True, seq_data={0: SequenceData([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0, allowed_token_ids=allowed_token_ids[i].tolist()), + sampling_params=SamplingParams( + temperature=0, + allowed_token_ids=allowed_token_ids[i].tolist()), block_tables={0: [1]}, )) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c74b6c3709b6e..8c3e28fe8ce57 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -33,11 +33,11 @@ def __init__(self, vocab_size: int) -> None: self.vocab_size = vocab_size def forward( - self, - embedding: torch.Tensor, - hidden_states: torch.Tensor, - input_metadata: InputMetadata, - embedding_bias: Optional[torch.Tensor] = None, + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + embedding_bias: Optional[torch.Tensor] = None, ) -> SamplerOutput: # Get the hidden states that we use for sampling. hidden_states = _prune_hidden_states(hidden_states, input_metadata) @@ -104,8 +104,8 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, def _prune_hidden_states( - hidden_states: torch.Tensor, - input_metadata: InputMetadata, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, ) -> torch.Tensor: last_token_indices = {t: [] for t in SamplingType} start_idx = 0 @@ -158,10 +158,10 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: def _apply_penalties( - logits: torch.Tensor, - output_tokens: List[List[int]], - presence_penalties: List[float], - frequency_penalties: List[float], + logits: torch.Tensor, + output_tokens: List[List[int]], + presence_penalties: List[float], + frequency_penalties: List[float], ) -> torch.Tensor: num_seqs, vocab_size = logits.shape for i in range(num_seqs): @@ -224,8 +224,8 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]: def _get_top_p_top_k( - input_metadata: InputMetadata, - vocab_size: int, + input_metadata: InputMetadata, + vocab_size: int, ) -> Tuple[List[float], List[int]]: top_ps: List[float] = [] top_ks: List[int] = [] @@ -242,9 +242,9 @@ def _get_top_p_top_k( def _apply_top_p_top_k( - logits: torch.Tensor, - top_ps: List[float], - top_ks: List[int], + logits: torch.Tensor, + top_ps: List[float], + top_ks: List[int], ) -> torch.Tensor: p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) @@ -271,8 +271,8 @@ def _apply_top_p_top_k( def _get_topk_logprobs( - logprobs: torch.Tensor, - num_logprobs: Optional[int], + logprobs: torch.Tensor, + num_logprobs: Optional[int], ) -> List[Dict[int, float]]: num_seqs = logprobs.size(0) if num_logprobs is None or num_logprobs == 0: @@ -293,12 +293,12 @@ def _get_topk_logprobs( def _build_sequence_outputs( - parent_ids: List[int], - next_token_ids: List[int], - selected_token_logprobs: torch.Tensor, - parent_seq_ids: List[int], - parent_logprobs: torch.Tensor, - num_output_logprobs: Optional[int], + parent_ids: List[int], + next_token_ids: List[int], + selected_token_logprobs: torch.Tensor, + parent_seq_ids: List[int], + parent_logprobs: torch.Tensor, + num_output_logprobs: Optional[int], ) -> List[SequenceOutputs]: # Get top-k log probabilities for the next tokens. next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs) @@ -314,8 +314,8 @@ def _build_sequence_outputs( def _greedy_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - logprobs: torch.Tensor, + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + logprobs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: samples = torch.argmax(logprobs, dim=-1).cpu() sample_idx = 0 @@ -334,9 +334,9 @@ def _greedy_sample( def _random_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], - probs: torch.Tensor, + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + is_prompts: List[bool], + probs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: # Find the maximum best_of value of the prompt phase requests. max_best_of = 1 @@ -358,12 +358,12 @@ def _random_sample( "Prompt input should have only one seq.") parent_ids = [0] * sampling_params.best_of next_token_ids = random_samples[ - sample_idx, :sampling_params.best_of].tolist() + sample_idx, :sampling_params.best_of].tolist() else: # Generation phase. parent_ids = list(range(num_parent_seqs)) next_token_ids = random_samples[sample_idx:sample_idx + - num_parent_seqs, 0].tolist() + num_parent_seqs, 0].tolist() results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs assert sample_idx == probs.size(0) @@ -371,10 +371,10 @@ def _random_sample( def _beam_search_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], - seq_data: Dict[int, SequenceData], - logprobs: torch.Tensor, + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + is_prompts: List[bool], + seq_data: Dict[int, SequenceData], + logprobs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: # We sample 2 * beam_width candidates to make sure that with high # probability we can get `beam_width` candidates in addition to @@ -424,9 +424,9 @@ def _beam_search_sample( def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - input_metadata: InputMetadata, + probs: torch.Tensor, + logprobs: torch.Tensor, + input_metadata: InputMetadata, ) -> SamplerOutput: categorized_seq_group_ids = {t: [] for t in SamplingType} category_num_tokens = {t: 0 for t in SamplingType} @@ -447,9 +447,9 @@ def _sample( if num_tokens == 0: continue category_logprobs = logprobs[category_start_idx:category_start_idx + - num_tokens] + num_tokens] category_probs = probs[category_start_idx:category_start_idx + - num_tokens] + num_tokens] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample(seq_groups, category_logprobs) elif sampling_type == SamplingType.RANDOM: @@ -491,9 +491,9 @@ def _sample( num_results = len(next_token_ids) num_parent_seqs = len(seq_ids) parent_logprobs = category_logprobs[sample_idx:sample_idx + - num_parent_seqs] + num_parent_seqs] selected_token_logprobs = batched_logprobs_query_result[ - result_idx:result_idx + num_results] + result_idx:result_idx + num_results] seq_output = _build_sequence_outputs(parent_ids, next_token_ids, selected_token_logprobs, seq_ids, parent_logprobs, @@ -508,13 +508,13 @@ def _sample( def _apply_allowed_token_ids( - logits: torch.Tensor, - input_metadata: InputMetadata, + logits: torch.Tensor, + input_metadata: InputMetadata, ) -> None: vocab_size = logits.shape[1] mask = torch.zeros(vocab_size, dtype=torch.bool) for i, seq_group in enumerate(input_metadata.seq_groups): - seq_ids, sampling_params = seq_group + _, sampling_params = seq_group assert isinstance(sampling_params, SamplingParams) allowed_token_ids = sampling_params.allowed_token_ids if allowed_token_ids: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 2fda24931b02c..d215141eb71ad 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -67,24 +67,24 @@ class SamplingParams: """ def __init__( - self, - n: int = 1, - best_of: Optional[int] = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - use_beam_search: bool = False, - length_penalty: float = 1.0, - early_stopping: Union[bool, str] = False, - stop: Union[None, str, List[str]] = None, - stop_token_ids: List[int] = None, - allowed_token_ids: List[int] = None, - ignore_eos: bool = False, - max_tokens: int = 16, - logprobs: Optional[int] = None, - skip_special_tokens: bool = True, + self, + n: int = 1, + best_of: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + use_beam_search: bool = False, + length_penalty: float = 1.0, + early_stopping: Union[bool, str] = False, + stop: Union[None, str, List[str]] = None, + stop_token_ids: List[int] = None, + allowed_token_ids: List[int] = None, + ignore_eos: bool = False, + max_tokens: int = 16, + logprobs: Optional[int] = None, + skip_special_tokens: bool = True, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n