diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index a5f55d50fbb76..64778a3910554 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -182,3 +182,35 @@ def test_sampler_mixed(seed: int): continue for nth_output in sequence_output: assert nth_output.output_token in expected_tokens + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_constrained(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + + allowed_token_ids = torch.randint(0, 32000, (batch_size, 100)) + + seq_group_metadata_list = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams( + temperature=0, + allowed_token_ids=allowed_token_ids[i].tolist()), + block_tables={0: [1]}, + )) + + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampler_output = sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata) + expected = torch.argmax(fake_logits, dim=-1) + for i, sequence_output in enumerate(sampler_output): + for nth_output in sequence_output: + assert nth_output.output_token in allowed_token_ids[i].tolist() + assert nth_output.output_token == expected[i].item() diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index a17dfcfbaf16f..8c3e28fe8ce57 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -46,9 +46,14 @@ def forward( logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) - # Apply presence and frequency penalties. + # Apply sampling constraints + _apply_allowed_token_ids(logits, input_metadata) + + # Get the output tokens output_tokens = _get_output_tokens(input_metadata) assert len(output_tokens) == logits.shape[0] + + # Apply presence and frequency penalties. presence_penalties, frequency_penalties = _get_penalties( input_metadata) assert len(presence_penalties) == logits.shape[0] @@ -500,3 +505,19 @@ def _sample( category_start_idx += num_tokens return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))] + + +def _apply_allowed_token_ids( + logits: torch.Tensor, + input_metadata: InputMetadata, +) -> None: + vocab_size = logits.shape[1] + mask = torch.zeros(vocab_size, dtype=torch.bool) + for i, seq_group in enumerate(input_metadata.seq_groups): + _, sampling_params = seq_group + assert isinstance(sampling_params, SamplingParams) + allowed_token_ids = sampling_params.allowed_token_ids + if allowed_token_ids: + mask[:] = True + mask[allowed_token_ids] = False + logits[i, mask] = -float("inf") diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 5206eb0b8c4d4..d215141eb71ad 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -56,6 +56,8 @@ class SamplingParams: stop_token_ids: List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are sepcial tokens. + allowed_token_ids: List of tokens allowed to be generated by the model. + The returned output will only contain the allowed tokens. ignore_eos: Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. @@ -78,6 +80,7 @@ def __init__( early_stopping: Union[bool, str] = False, stop: Union[None, str, List[str]] = None, stop_token_ids: List[int] = None, + allowed_token_ids: List[int] = None, ignore_eos: bool = False, max_tokens: int = 16, logprobs: Optional[int] = None, @@ -103,6 +106,10 @@ def __init__( self.stop_token_ids = [] else: self.stop_token_ids = list(stop_token_ids) + if allowed_token_ids is None: + self.allowed_token_ids = [] + else: + self.allowed_token_ids = list(allowed_token_ids) self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.logprobs = logprobs