diff --git a/outlines/generate/sequence.py b/outlines/generate/sequence.py index f293d912f..5b3546bad 100644 --- a/outlines/generate/sequence.py +++ b/outlines/generate/sequence.py @@ -242,3 +242,23 @@ def __call__( return result[0] return result + + +def test_get_next_fsm_state(): + raise NotImplementedError + + +def test_get_next_instructions(): + raise NotImplementedError + + +def test_is_generation_finished(): + raise NotImplementedError + + +def test_update_token_ids(): + raise NotImplementedError + + +def update_attention_masks(): + raise NotImplementedError diff --git a/outlines/generator.py b/outlines/generator.py new file mode 100644 index 000000000..6beb7fa89 --- /dev/null +++ b/outlines/generator.py @@ -0,0 +1,136 @@ +import math +from dataclasses import dataclass +from typing import Generator, List, Optional + +import torch + +from outlines.generate.samplers import Sampler +from outlines.index.index import Index + + +@dataclass +class GenerationState: + token_ids: torch.Tensor + attention_masks: torch.Tensor + kv_cache: Optional[torch.Tensor] = None + + +def process(generator: Generator, index: "Index", state: GenerationState): + """This generator drives the text generation process by + walking through the FSM.""" + next(generator) + + fsm_states = [0 for _ in range(state.token_ids.shape[0])] + while True: + logits_mask = get_next_instructions(index, fsm_states) + + next_token_ids, kv_cache = generator.send((state, logits_mask)) + + token_ids = update_token_ids(state.token_ids, next_token_ids) + attention_masks = update_attention_masks(state.attention_masks) + state = GenerationState(token_ids, attention_masks, kv_cache) + + fsm_states = get_next_fsm_states(index, fsm_states, next_token_ids) + is_finished = is_generation_finished(index, fsm_states) + if is_finished: + yield token_ids, next_token_ids + return + + yield state + + +def get_next_fsm_states( + index, fsm_states: List[int], next_token_ids: torch.Tensor +) -> List[int]: + return [ + index.next_state(fsm_state, token_id) + for fsm_state, token_id in zip(fsm_states, next_token_ids) + ] + + +def get_next_instructions(index, fsm_states: List[int]) -> torch.Tensor: + return [index.next_instruction(state) for state in fsm_states] + + +def is_generation_finished(index, fsm_states: List[int]) -> bool: + return all([index.is_finished(state) for state in fsm_states]) + + +def update_token_ids( + token_ids: torch.Tensor, next_token_ids: torch.Tensor +) -> torch.Tensor: + return torch.concatenate([token_ids, next_token_ids], dim=1 - 1) + + +def update_attention_masks(attention_masks: torch.Tensor) -> torch.Tensor: + return torch.concatenate( + [ + attention_masks, + torch.ones( + attention_masks.shape[:-1] + (1,), device=attention_masks.device + ), + ], + axis=-1, + ) + + +def token_generator(model, sampler: "Sampler", samples: int, rng: torch.Generator): + """Generator that yields a token every time it is called. + + This process is designed to be steered by another supervising + process that supplies the current sequence and the indices + of the tokens to mask before sampling. + + Parameters + ---------- + model + A model that takes a sequence of tokens as an input and + returns a probability distribution over the next tokens. + sampler + A function that samples tokens from a probability + distribution over the next tokens. + + Yields + ------ + A tensor with the sampled tokens. + + """ + while True: + (token_ids, attention_masks, kv_cache), logits_mask = yield + + try: + logits, new_kv_cache = model(token_ids, attention_masks, kv_cache) + except IndexError: # Exceeding the context length + return + + biased_logits = bias_logits(logits, logits_mask) + next_token_ids = sampler(biased_logits, samples, rng) + + yield next_token_ids, new_kv_cache + + +def bias_logits( + logits: torch.Tensor, + ids_to_mask: List, +) -> torch.Tensor: + """Mask the logits. + + The function iterates over a nested list where each list corresponds to the + indices that need to be masked for each row in the array. + + Parameters + ---------- + logits + Two dimensional tensor that contains the next-token probability + distribution. + ids_to_mask + The ids to mask in each dimension. + + Returns + ------- + A view of the original logits tensor where some values are masked. + + """ + for i, ids in enumerate(ids_to_mask): + logits[i, ids] = -math.inf + return logits diff --git a/tests/text/test_generator.py b/tests/text/test_generator.py new file mode 100644 index 000000000..b9e79982b --- /dev/null +++ b/tests/text/test_generator.py @@ -0,0 +1,139 @@ +import math + +import pytest +import torch + +from outlines.generate.generator import bias_logits, token_generator + + +def test_generator_error(): + def model(*_): + raise IndexError + + def sampler(): + return None + + generator = token_generator(model, sampler, 1, None) + next(generator) + with pytest.raises(StopIteration): + generator.send(((None, None, None), None)) + + +@pytest.mark.parametrize( + "logits,indices_to_mask,expected", + [ + ( + torch.tensor([[1, 2, 3, 4]], dtype=torch.float), + [[]], + torch.tensor([[1, 2, 3, 4]], dtype=torch.float), + ), + ( + torch.tensor([[1, 2, 3, 4]], dtype=torch.float), + [[1]], + torch.tensor([[1, -math.inf, 3, 4]], dtype=torch.float), + ), + ( + torch.tensor([[1, 2, 3, 4]], dtype=torch.float), + [[1, 3]], + torch.tensor([[1, -math.inf, 3, -math.inf]], dtype=torch.float), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float), + [[0], [2]], + torch.tensor([[-math.inf, 2, 3], [4, 5, -math.inf]], dtype=torch.float), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float), + [[1], [0, 2]], + torch.tensor( + [[1, -math.inf, 3], [-math.inf, 5, -math.inf]], dtype=torch.float + ), + ), + ], +) +def test_bias_logits(logits, indices_to_mask, expected): + masked_logits = bias_logits(logits, indices_to_mask) + assert torch.equal(masked_logits, expected) + + +def test_generator_1d(): + def model(*_): + return torch.tensor([[0, 1, 2, 3]], dtype=torch.float), None + + def sampler(biased_logits, *_): + return torch.argmax(biased_logits) + + # 1D, no bias + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[]])) + assert result == 3 + + # 1D, bias one + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[3]])) + assert result == 2 + + # 1D, bias two + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[2, 3]])) + assert result == 1 + + +def test_generator_2d(): + def model(*_): + return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.float), None + + def sampler(biased_logits, *_): + return torch.argmax(biased_logits, dim=1) + + # 2D, no bias + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[]])) + assert torch.equal(result, torch.tensor([3, 3])) + + # 2D, bias one each + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[3], [3]])) + assert torch.equal(result, torch.tensor([2, 2])) + + # 2D, bias one + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[3], []])) + assert torch.equal(result, torch.tensor([2, 3])) + + # 2D, bias different number + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[3], [2, 3]])) + assert torch.equal(result, torch.tensor([2, 1])) + + +@pytest.mark.xfail +def get_next_fsm_states(): + raise NotImplementedError + + +@pytest.mark.xfail +def get_next_instructions(): + raise NotImplementedError + + +@pytest.mark.xfail +def is_generation_finished(): + raise NotImplementedError + + +@pytest.mark.xfail +def update_token_ids(): + raise NotImplementedError + + +@pytest.mark.xfail +def update_attention_masks(): + raise NotImplementedError