-
Notifications
You must be signed in to change notification settings - Fork 551
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add generator that samples the next tokens
- Loading branch information
Showing
3 changed files
with
295 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import math | ||
from dataclasses import dataclass | ||
from typing import Generator, List, Optional | ||
|
||
import torch | ||
|
||
from outlines.generate.samplers import Sampler | ||
from outlines.index.index import Index | ||
|
||
|
||
@dataclass | ||
class GenerationState: | ||
token_ids: torch.Tensor | ||
attention_masks: torch.Tensor | ||
kv_cache: Optional[torch.Tensor] = None | ||
|
||
|
||
def process(generator: Generator, index: "Index", state: GenerationState): | ||
"""This generator drives the text generation process by | ||
walking through the FSM.""" | ||
next(generator) | ||
|
||
fsm_states = [0 for _ in range(state.token_ids.shape[0])] | ||
while True: | ||
logits_mask = get_next_instructions(index, fsm_states) | ||
|
||
next_token_ids, kv_cache = generator.send((state, logits_mask)) | ||
|
||
token_ids = update_token_ids(state.token_ids, next_token_ids) | ||
attention_masks = update_attention_masks(state.attention_masks) | ||
state = GenerationState(token_ids, attention_masks, kv_cache) | ||
|
||
fsm_states = get_next_fsm_states(index, fsm_states, next_token_ids) | ||
is_finished = is_generation_finished(index, fsm_states) | ||
if is_finished: | ||
yield token_ids, next_token_ids | ||
return | ||
|
||
yield state | ||
|
||
|
||
def get_next_fsm_states( | ||
index, fsm_states: List[int], next_token_ids: torch.Tensor | ||
) -> List[int]: | ||
return [ | ||
index.next_state(fsm_state, token_id) | ||
for fsm_state, token_id in zip(fsm_states, next_token_ids) | ||
] | ||
|
||
|
||
def get_next_instructions(index, fsm_states: List[int]) -> torch.Tensor: | ||
return [index.next_instruction(state) for state in fsm_states] | ||
|
||
|
||
def is_generation_finished(index, fsm_states: List[int]) -> bool: | ||
return all([index.is_finished(state) for state in fsm_states]) | ||
|
||
|
||
def update_token_ids( | ||
token_ids: torch.Tensor, next_token_ids: torch.Tensor | ||
) -> torch.Tensor: | ||
return torch.concatenate([token_ids, next_token_ids], dim=1 - 1) | ||
|
||
|
||
def update_attention_masks(attention_masks: torch.Tensor) -> torch.Tensor: | ||
return torch.concatenate( | ||
[ | ||
attention_masks, | ||
torch.ones( | ||
attention_masks.shape[:-1] + (1,), device=attention_masks.device | ||
), | ||
], | ||
axis=-1, | ||
) | ||
|
||
|
||
def token_generator(model, sampler: "Sampler", samples: int, rng: torch.Generator): | ||
"""Generator that yields a token every time it is called. | ||
This process is designed to be steered by another supervising | ||
process that supplies the current sequence and the indices | ||
of the tokens to mask before sampling. | ||
Parameters | ||
---------- | ||
model | ||
A model that takes a sequence of tokens as an input and | ||
returns a probability distribution over the next tokens. | ||
sampler | ||
A function that samples tokens from a probability | ||
distribution over the next tokens. | ||
Yields | ||
------ | ||
A tensor with the sampled tokens. | ||
""" | ||
while True: | ||
(token_ids, attention_masks, kv_cache), logits_mask = yield | ||
|
||
try: | ||
logits, new_kv_cache = model(token_ids, attention_masks, kv_cache) | ||
except IndexError: # Exceeding the context length | ||
return | ||
|
||
biased_logits = bias_logits(logits, logits_mask) | ||
next_token_ids = sampler(biased_logits, samples, rng) | ||
|
||
yield next_token_ids, new_kv_cache | ||
|
||
|
||
def bias_logits( | ||
logits: torch.Tensor, | ||
ids_to_mask: List, | ||
) -> torch.Tensor: | ||
"""Mask the logits. | ||
The function iterates over a nested list where each list corresponds to the | ||
indices that need to be masked for each row in the array. | ||
Parameters | ||
---------- | ||
logits | ||
Two dimensional tensor that contains the next-token probability | ||
distribution. | ||
ids_to_mask | ||
The ids to mask in each dimension. | ||
Returns | ||
------- | ||
A view of the original logits tensor where some values are masked. | ||
""" | ||
for i, ids in enumerate(ids_to_mask): | ||
logits[i, ids] = -math.inf | ||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import math | ||
|
||
import pytest | ||
import torch | ||
|
||
from outlines.generate.generator import bias_logits, token_generator | ||
|
||
|
||
def test_generator_error(): | ||
def model(*_): | ||
raise IndexError | ||
|
||
def sampler(): | ||
return None | ||
|
||
generator = token_generator(model, sampler, 1, None) | ||
next(generator) | ||
with pytest.raises(StopIteration): | ||
generator.send(((None, None, None), None)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"logits,indices_to_mask,expected", | ||
[ | ||
( | ||
torch.tensor([[1, 2, 3, 4]], dtype=torch.float), | ||
[[]], | ||
torch.tensor([[1, 2, 3, 4]], dtype=torch.float), | ||
), | ||
( | ||
torch.tensor([[1, 2, 3, 4]], dtype=torch.float), | ||
[[1]], | ||
torch.tensor([[1, -math.inf, 3, 4]], dtype=torch.float), | ||
), | ||
( | ||
torch.tensor([[1, 2, 3, 4]], dtype=torch.float), | ||
[[1, 3]], | ||
torch.tensor([[1, -math.inf, 3, -math.inf]], dtype=torch.float), | ||
), | ||
( | ||
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float), | ||
[[0], [2]], | ||
torch.tensor([[-math.inf, 2, 3], [4, 5, -math.inf]], dtype=torch.float), | ||
), | ||
( | ||
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float), | ||
[[1], [0, 2]], | ||
torch.tensor( | ||
[[1, -math.inf, 3], [-math.inf, 5, -math.inf]], dtype=torch.float | ||
), | ||
), | ||
], | ||
) | ||
def test_bias_logits(logits, indices_to_mask, expected): | ||
masked_logits = bias_logits(logits, indices_to_mask) | ||
assert torch.equal(masked_logits, expected) | ||
|
||
|
||
def test_generator_1d(): | ||
def model(*_): | ||
return torch.tensor([[0, 1, 2, 3]], dtype=torch.float), None | ||
|
||
def sampler(biased_logits, *_): | ||
return torch.argmax(biased_logits) | ||
|
||
# 1D, no bias | ||
generator = token_generator(model, sampler, 1, None) | ||
next(generator) | ||
result, _ = generator.send(((None, None, None), [[]])) | ||
assert result == 3 | ||
|
||
# 1D, bias one | ||
generator = token_generator(model, sampler, 1, None) | ||
next(generator) | ||
result, _ = generator.send(((None, None, None), [[3]])) | ||
assert result == 2 | ||
|
||
# 1D, bias two | ||
generator = token_generator(model, sampler, 1, None) | ||
next(generator) | ||
result, _ = generator.send(((None, None, None), [[2, 3]])) | ||
assert result == 1 | ||
|
||
|
||
def test_generator_2d(): | ||
def model(*_): | ||
return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.float), None | ||
|
||
def sampler(biased_logits, *_): | ||
return torch.argmax(biased_logits, dim=1) | ||
|
||
# 2D, no bias | ||
generator = token_generator(model, sampler, 1, None) | ||
next(generator) | ||
result, _ = generator.send(((None, None, None), [[]])) | ||
assert torch.equal(result, torch.tensor([3, 3])) | ||
|
||
# 2D, bias one each | ||
generator = token_generator(model, sampler, 1, None) | ||
next(generator) | ||
result, _ = generator.send(((None, None, None), [[3], [3]])) | ||
assert torch.equal(result, torch.tensor([2, 2])) | ||
|
||
# 2D, bias one | ||
generator = token_generator(model, sampler, 1, None) | ||
next(generator) | ||
result, _ = generator.send(((None, None, None), [[3], []])) | ||
assert torch.equal(result, torch.tensor([2, 3])) | ||
|
||
# 2D, bias different number | ||
generator = token_generator(model, sampler, 1, None) | ||
next(generator) | ||
result, _ = generator.send(((None, None, None), [[3], [2, 3]])) | ||
assert torch.equal(result, torch.tensor([2, 1])) | ||
|
||
|
||
@pytest.mark.xfail | ||
def get_next_fsm_states(): | ||
raise NotImplementedError | ||
|
||
|
||
@pytest.mark.xfail | ||
def get_next_instructions(): | ||
raise NotImplementedError | ||
|
||
|
||
@pytest.mark.xfail | ||
def is_generation_finished(): | ||
raise NotImplementedError | ||
|
||
|
||
@pytest.mark.xfail | ||
def update_token_ids(): | ||
raise NotImplementedError | ||
|
||
|
||
@pytest.mark.xfail | ||
def update_attention_masks(): | ||
raise NotImplementedError |