Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constrained decoding #1243

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,35 @@ def test_sampler_mixed(seed: int):
continue
for nth_output in sequence_output:
assert nth_output.output_token in expected_tokens


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_constrained(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)

allowed_token_ids = torch.randint(0, 32000, (batch_size, 100))

seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
allowed_token_ids=allowed_token_ids[i].tolist()),
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token in allowed_token_ids[i].tolist()
assert nth_output.output_token == expected[i].item()
23 changes: 22 additions & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,14 @@ def forward(
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)

# Apply presence and frequency penalties.
# Apply sampling constraints
_apply_allowed_token_ids(logits, input_metadata)

# Get the output tokens
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]

# Apply presence and frequency penalties.
presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
assert len(presence_penalties) == logits.shape[0]
Expand Down Expand Up @@ -500,3 +505,19 @@ def _sample(
category_start_idx += num_tokens

return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]


def _apply_allowed_token_ids(
logits: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
vocab_size = logits.shape[1]
mask = torch.zeros(vocab_size, dtype=torch.bool)
for i, seq_group in enumerate(input_metadata.seq_groups):
_, sampling_params = seq_group
assert isinstance(sampling_params, SamplingParams)
allowed_token_ids = sampling_params.allowed_token_ids
if allowed_token_ids:
mask[:] = True
mask[allowed_token_ids] = False
logits[i, mask] = -float("inf")
7 changes: 7 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class SamplingParams:
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are sepcial tokens.
allowed_token_ids: List of tokens allowed to be generated by the model.
The returned output will only contain the allowed tokens.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
Expand All @@ -78,6 +80,7 @@ def __init__(
early_stopping: Union[bool, str] = False,
stop: Union[None, str, List[str]] = None,
stop_token_ids: List[int] = None,
allowed_token_ids: List[int] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
Expand All @@ -103,6 +106,10 @@ def __init__(
self.stop_token_ids = []
else:
self.stop_token_ids = list(stop_token_ids)
if allowed_token_ids is None:
self.allowed_token_ids = []
else:
self.allowed_token_ids = list(allowed_token_ids)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs
Expand Down