From 331d5df2d18d54839c06e70f1493584b98692c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 15 Nov 2023 12:15:37 +0100 Subject: [PATCH] Add generator that samples the next tokens --- outlines/generate/generator.py | 186 +++++++++++++++++++++++++-------- outlines/generate/text.py | 10 ++ outlines/generator.py | 136 ++++++++++++++++++++++++ tests/text/test_generator.py | 139 ++++++++++++++++++++++++ 4 files changed, 428 insertions(+), 43 deletions(-) create mode 100644 outlines/generate/text.py create mode 100644 outlines/generator.py create mode 100644 tests/text/test_generator.py diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index f76e449a1..fba15c5af 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Generator, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Union import torch @@ -8,24 +8,117 @@ from outlines.generate.samplers import Sampler from outlines.index.index import Index +""" +-------------------- +Functions vs classes +-------------------- -@dataclass +def text_generator(model, tokenizer, sampler, fsm): + generate_token = token_generator(model, sampler) + generator = generator(generate_token, fsm, tokenizer) + return generator + +generator = text_generator(model, tokenizer, sampler, fsm) +generator("prompt") + +def call_text_generator(prompt): + _*, sequence = generator(prompt) + return sequence + +--> We need to chain the generator with a generator that yields a `Sequence` state. +--> Then use this generator at the end when calling the calling function directly + +def json(model, tokenizer, sampler, fsm): + generate_token = token_generator(model, sampler) + generator = sequence_generator(generate_token, fsm, tokenizer) + return generator + +""" + + +@dataclass(frozen=True) class GenerationState: token_ids: torch.Tensor attention_masks: torch.Tensor kv_cache: Optional[torch.Tensor] = None + rng: Optional[torch.Generator] = None + + +class SequenceGenerator: + def __init__(self, fsm, model, sampler): + self.generate_token = token_generator(model, sampler) + self.fsm = fsm + self.model = model + self.device = self.model.device + + def init_generation_state( + self, prompt: Union[str, List[str]], rng: Optional[torch.Generator] = None + ): + """Initialize the generation state. + + This method is responsible for encoding the prompt, moving token ids + to the device and initializing the random number generator. + + Parameters + ---------- + prompt + The prompt on which the generation is conditioned. + rng + The state of the random number generator. + + Returns + ------- + A `GenerationState` object. + + """ + token_ids, attention_masks = self.model.tokenizer.encode(prompt) + token_ids = token_ids.squeeze(0).to(self.device) + attention_masks = token_ids.squeeze(0).to(self.device) + + if rng is None: + rng = torch.Generator(device=self.device) + rng.seed() + + return GenerationState(token_ids, attention_masks, None, rng) + + def __call__(self, prompt, rng: Optional[torch.Generator]): + self.state + generator = self.stream(prompt, rng) + *_, last = generator + return last + + def stream(self, prompt, rng): + self.state = self.init_generation_state(prompt, rng) + self.fsm_states = [0 for _ in range(self.num_sequences)] + return self + + def __iter__(self): + """Generates new tokens based on the model and the FSM. + + Parameter + --------- + token_generator + A generator that yields new tokens from a `GenerationState` and a list + of the ids of the tokens we need to mask. + index + The index that drives the generation. + """ -def process(generator: Generator, index: "Index", state: GenerationState): - """This generator drives the text generation process by - walking through the FSM.""" - next(generator) +def process(generator: Callable, index: "Index", state: GenerationState): fsm_states = [0 for _ in range(state.token_ids.shape[0])] while True: - logits_mask = get_next_instructions(index, fsm_states) + logits_masks = get_next_instructions(index, fsm_states) - next_token_ids, kv_cache = generator.send((state, logits_mask)) + next_token_ids, kv_cache = generator( + state.token_ids, + state.attention_masks, + state.kv_cache, + logits_masks, + 1, + state.rng, + ) token_ids = update_token_ids(state.token_ids, next_token_ids) attention_masks = update_attention_masks(state.attention_masks) @@ -40,6 +133,48 @@ def process(generator: Generator, index: "Index", state: GenerationState): yield state +def token_generator(model, sampler: "Sampler") -> Callable: + """Generate one token at a time. + + This process is designed to be steered by another supervising + process that supplies the current sequence and the indices + of the tokens to mask before sampling. + + Parameters + ---------- + model + A model that takes a sequence of tokens as an input and + returns a probability distribution over the next tokens. + sampler + A function that samples tokens from a probability + distribution over the next tokens. + + Returns + ------- + A tensor with the sampled tokens. + + """ + + def generate( + token_ids, + attention_masks, + kv_cache, + logits_masks, + rng: torch.Generator, + ): + try: + logits, new_kv_cache = model(token_ids, attention_masks, kv_cache) + except IndexError: # Exceeding the context length + return + + biased_logits = bias_logits(logits, logits_masks) + next_token_ids = sampler(biased_logits, 1, rng) + + yield next_token_ids, new_kv_cache + + return generate + + def get_next_fsm_states( index, fsm_states: List[int], next_token_ids: torch.Tensor ) -> List[int]: @@ -75,41 +210,6 @@ def update_attention_masks(attention_masks: torch.Tensor) -> torch.Tensor: ) -def token_generator(model, sampler: "Sampler", samples: int, rng: torch.Generator): - """Generator that yields a token every time it is called. - - This process is designed to be steered by another supervising - process that supplies the current sequence and the indices - of the tokens to mask before sampling. - - Parameters - ---------- - model - A model that takes a sequence of tokens as an input and - returns a probability distribution over the next tokens. - sampler - A function that samples tokens from a probability - distribution over the next tokens. - - Yields - ------ - A tensor with the sampled tokens. - - """ - while True: - (token_ids, attention_masks, kv_cache), logits_mask = yield - - try: - logits, new_kv_cache = model(token_ids, attention_masks, kv_cache) - except IndexError: # Exceeding the context length - return - - biased_logits = bias_logits(logits, logits_mask) - next_token_ids = sampler(biased_logits, samples, rng) - - yield next_token_ids, new_kv_cache - - def bias_logits( logits: torch.Tensor, ids_to_mask: List, diff --git a/outlines/generate/text.py b/outlines/generate/text.py new file mode 100644 index 000000000..38bc8d532 --- /dev/null +++ b/outlines/generate/text.py @@ -0,0 +1,10 @@ +class text: + def __init__(self): + pass + + def __call__(self, prompt): + pass + + def __iter__(self): + # This is something + pass diff --git a/outlines/generator.py b/outlines/generator.py new file mode 100644 index 000000000..6beb7fa89 --- /dev/null +++ b/outlines/generator.py @@ -0,0 +1,136 @@ +import math +from dataclasses import dataclass +from typing import Generator, List, Optional + +import torch + +from outlines.generate.samplers import Sampler +from outlines.index.index import Index + + +@dataclass +class GenerationState: + token_ids: torch.Tensor + attention_masks: torch.Tensor + kv_cache: Optional[torch.Tensor] = None + + +def process(generator: Generator, index: "Index", state: GenerationState): + """This generator drives the text generation process by + walking through the FSM.""" + next(generator) + + fsm_states = [0 for _ in range(state.token_ids.shape[0])] + while True: + logits_mask = get_next_instructions(index, fsm_states) + + next_token_ids, kv_cache = generator.send((state, logits_mask)) + + token_ids = update_token_ids(state.token_ids, next_token_ids) + attention_masks = update_attention_masks(state.attention_masks) + state = GenerationState(token_ids, attention_masks, kv_cache) + + fsm_states = get_next_fsm_states(index, fsm_states, next_token_ids) + is_finished = is_generation_finished(index, fsm_states) + if is_finished: + yield token_ids, next_token_ids + return + + yield state + + +def get_next_fsm_states( + index, fsm_states: List[int], next_token_ids: torch.Tensor +) -> List[int]: + return [ + index.next_state(fsm_state, token_id) + for fsm_state, token_id in zip(fsm_states, next_token_ids) + ] + + +def get_next_instructions(index, fsm_states: List[int]) -> torch.Tensor: + return [index.next_instruction(state) for state in fsm_states] + + +def is_generation_finished(index, fsm_states: List[int]) -> bool: + return all([index.is_finished(state) for state in fsm_states]) + + +def update_token_ids( + token_ids: torch.Tensor, next_token_ids: torch.Tensor +) -> torch.Tensor: + return torch.concatenate([token_ids, next_token_ids], dim=1 - 1) + + +def update_attention_masks(attention_masks: torch.Tensor) -> torch.Tensor: + return torch.concatenate( + [ + attention_masks, + torch.ones( + attention_masks.shape[:-1] + (1,), device=attention_masks.device + ), + ], + axis=-1, + ) + + +def token_generator(model, sampler: "Sampler", samples: int, rng: torch.Generator): + """Generator that yields a token every time it is called. + + This process is designed to be steered by another supervising + process that supplies the current sequence and the indices + of the tokens to mask before sampling. + + Parameters + ---------- + model + A model that takes a sequence of tokens as an input and + returns a probability distribution over the next tokens. + sampler + A function that samples tokens from a probability + distribution over the next tokens. + + Yields + ------ + A tensor with the sampled tokens. + + """ + while True: + (token_ids, attention_masks, kv_cache), logits_mask = yield + + try: + logits, new_kv_cache = model(token_ids, attention_masks, kv_cache) + except IndexError: # Exceeding the context length + return + + biased_logits = bias_logits(logits, logits_mask) + next_token_ids = sampler(biased_logits, samples, rng) + + yield next_token_ids, new_kv_cache + + +def bias_logits( + logits: torch.Tensor, + ids_to_mask: List, +) -> torch.Tensor: + """Mask the logits. + + The function iterates over a nested list where each list corresponds to the + indices that need to be masked for each row in the array. + + Parameters + ---------- + logits + Two dimensional tensor that contains the next-token probability + distribution. + ids_to_mask + The ids to mask in each dimension. + + Returns + ------- + A view of the original logits tensor where some values are masked. + + """ + for i, ids in enumerate(ids_to_mask): + logits[i, ids] = -math.inf + return logits diff --git a/tests/text/test_generator.py b/tests/text/test_generator.py new file mode 100644 index 000000000..b9e79982b --- /dev/null +++ b/tests/text/test_generator.py @@ -0,0 +1,139 @@ +import math + +import pytest +import torch + +from outlines.generate.generator import bias_logits, token_generator + + +def test_generator_error(): + def model(*_): + raise IndexError + + def sampler(): + return None + + generator = token_generator(model, sampler, 1, None) + next(generator) + with pytest.raises(StopIteration): + generator.send(((None, None, None), None)) + + +@pytest.mark.parametrize( + "logits,indices_to_mask,expected", + [ + ( + torch.tensor([[1, 2, 3, 4]], dtype=torch.float), + [[]], + torch.tensor([[1, 2, 3, 4]], dtype=torch.float), + ), + ( + torch.tensor([[1, 2, 3, 4]], dtype=torch.float), + [[1]], + torch.tensor([[1, -math.inf, 3, 4]], dtype=torch.float), + ), + ( + torch.tensor([[1, 2, 3, 4]], dtype=torch.float), + [[1, 3]], + torch.tensor([[1, -math.inf, 3, -math.inf]], dtype=torch.float), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float), + [[0], [2]], + torch.tensor([[-math.inf, 2, 3], [4, 5, -math.inf]], dtype=torch.float), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float), + [[1], [0, 2]], + torch.tensor( + [[1, -math.inf, 3], [-math.inf, 5, -math.inf]], dtype=torch.float + ), + ), + ], +) +def test_bias_logits(logits, indices_to_mask, expected): + masked_logits = bias_logits(logits, indices_to_mask) + assert torch.equal(masked_logits, expected) + + +def test_generator_1d(): + def model(*_): + return torch.tensor([[0, 1, 2, 3]], dtype=torch.float), None + + def sampler(biased_logits, *_): + return torch.argmax(biased_logits) + + # 1D, no bias + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[]])) + assert result == 3 + + # 1D, bias one + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[3]])) + assert result == 2 + + # 1D, bias two + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[2, 3]])) + assert result == 1 + + +def test_generator_2d(): + def model(*_): + return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.float), None + + def sampler(biased_logits, *_): + return torch.argmax(biased_logits, dim=1) + + # 2D, no bias + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[]])) + assert torch.equal(result, torch.tensor([3, 3])) + + # 2D, bias one each + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[3], [3]])) + assert torch.equal(result, torch.tensor([2, 2])) + + # 2D, bias one + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[3], []])) + assert torch.equal(result, torch.tensor([2, 3])) + + # 2D, bias different number + generator = token_generator(model, sampler, 1, None) + next(generator) + result, _ = generator.send(((None, None, None), [[3], [2, 3]])) + assert torch.equal(result, torch.tensor([2, 1])) + + +@pytest.mark.xfail +def get_next_fsm_states(): + raise NotImplementedError + + +@pytest.mark.xfail +def get_next_instructions(): + raise NotImplementedError + + +@pytest.mark.xfail +def is_generation_finished(): + raise NotImplementedError + + +@pytest.mark.xfail +def update_token_ids(): + raise NotImplementedError + + +@pytest.mark.xfail +def update_attention_masks(): + raise NotImplementedError