Skip to content

Commit

Permalink
Add generator that samples the next tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 16, 2023
1 parent 37bb364 commit 22f5533
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 0 deletions.
20 changes: 20 additions & 0 deletions outlines/generate/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,23 @@ def __call__(
return result[0]

return result


def test_get_next_fsm_state():
raise NotImplementedError


def test_get_next_instructions():
raise NotImplementedError


def test_is_generation_finished():
raise NotImplementedError


def test_update_token_ids():
raise NotImplementedError


def update_attention_masks():
raise NotImplementedError
136 changes: 136 additions & 0 deletions outlines/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import math
from dataclasses import dataclass
from typing import Generator, List, Optional

import torch

from outlines.generate.samplers import Sampler
from outlines.index.index import Index


@dataclass
class GenerationState:
token_ids: torch.Tensor
attention_masks: torch.Tensor
kv_cache: Optional[torch.Tensor] = None


def process(generator: Generator, index: "Index", state: GenerationState):
"""This generator drives the text generation process by
walking through the FSM."""
next(generator)

fsm_states = [0 for _ in range(state.token_ids.shape[0])]
while True:
logits_mask = get_next_instructions(index, fsm_states)

next_token_ids, kv_cache = generator.send((state, logits_mask))

token_ids = update_token_ids(state.token_ids, next_token_ids)
attention_masks = update_attention_masks(state.attention_masks)
state = GenerationState(token_ids, attention_masks, kv_cache)

fsm_states = get_next_fsm_states(index, fsm_states, next_token_ids)
is_finished = is_generation_finished(index, fsm_states)
if is_finished:
yield token_ids, next_token_ids
return

yield state


def get_next_fsm_states(
index, fsm_states: List[int], next_token_ids: torch.Tensor
) -> List[int]:
return [
index.next_state(fsm_state, token_id)
for fsm_state, token_id in zip(fsm_states, next_token_ids)
]


def get_next_instructions(index, fsm_states: List[int]) -> torch.Tensor:
return [index.next_instruction(state) for state in fsm_states]


def is_generation_finished(index, fsm_states: List[int]) -> bool:
return all([index.is_finished(state) for state in fsm_states])


def update_token_ids(
token_ids: torch.Tensor, next_token_ids: torch.Tensor
) -> torch.Tensor:
return torch.concatenate([token_ids, next_token_ids], dim=1 - 1)


def update_attention_masks(attention_masks: torch.Tensor) -> torch.Tensor:
return torch.concatenate(
[
attention_masks,
torch.ones(
attention_masks.shape[:-1] + (1,), device=attention_masks.device
),
],
axis=-1,
)


def token_generator(model, sampler: "Sampler", samples: int, rng: torch.Generator):
"""Generator that yields a token every time it is called.
This process is designed to be steered by another supervising
process that supplies the current sequence and the indices
of the tokens to mask before sampling.
Parameters
----------
model
A model that takes a sequence of tokens as an input and
returns a probability distribution over the next tokens.
sampler
A function that samples tokens from a probability
distribution over the next tokens.
Yields
------
A tensor with the sampled tokens.
"""
while True:
(token_ids, attention_masks, kv_cache), logits_mask = yield

try:
logits, new_kv_cache = model(token_ids, attention_masks, kv_cache)
except IndexError: # Exceeding the context length
return

biased_logits = bias_logits(logits, logits_mask)
next_token_ids = sampler(biased_logits, samples, rng)

yield next_token_ids, new_kv_cache


def bias_logits(
logits: torch.Tensor,
ids_to_mask: List,
) -> torch.Tensor:
"""Mask the logits.
The function iterates over a nested list where each list corresponds to the
indices that need to be masked for each row in the array.
Parameters
----------
logits
Two dimensional tensor that contains the next-token probability
distribution.
ids_to_mask
The ids to mask in each dimension.
Returns
-------
A view of the original logits tensor where some values are masked.
"""
for i, ids in enumerate(ids_to_mask):
logits[i, ids] = -math.inf
return logits
139 changes: 139 additions & 0 deletions tests/text/test_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import math

import pytest
import torch

from outlines.generate.generator import bias_logits, token_generator


def test_generator_error():
def model(*_):
raise IndexError

def sampler():
return None

generator = token_generator(model, sampler, 1, None)
next(generator)
with pytest.raises(StopIteration):
generator.send(((None, None, None), None))


@pytest.mark.parametrize(
"logits,indices_to_mask,expected",
[
(
torch.tensor([[1, 2, 3, 4]], dtype=torch.float),
[[]],
torch.tensor([[1, 2, 3, 4]], dtype=torch.float),
),
(
torch.tensor([[1, 2, 3, 4]], dtype=torch.float),
[[1]],
torch.tensor([[1, -math.inf, 3, 4]], dtype=torch.float),
),
(
torch.tensor([[1, 2, 3, 4]], dtype=torch.float),
[[1, 3]],
torch.tensor([[1, -math.inf, 3, -math.inf]], dtype=torch.float),
),
(
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float),
[[0], [2]],
torch.tensor([[-math.inf, 2, 3], [4, 5, -math.inf]], dtype=torch.float),
),
(
torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float),
[[1], [0, 2]],
torch.tensor(
[[1, -math.inf, 3], [-math.inf, 5, -math.inf]], dtype=torch.float
),
),
],
)
def test_bias_logits(logits, indices_to_mask, expected):
masked_logits = bias_logits(logits, indices_to_mask)
assert torch.equal(masked_logits, expected)


def test_generator_1d():
def model(*_):
return torch.tensor([[0, 1, 2, 3]], dtype=torch.float), None

def sampler(biased_logits, *_):
return torch.argmax(biased_logits)

# 1D, no bias
generator = token_generator(model, sampler, 1, None)
next(generator)
result, _ = generator.send(((None, None, None), [[]]))
assert result == 3

# 1D, bias one
generator = token_generator(model, sampler, 1, None)
next(generator)
result, _ = generator.send(((None, None, None), [[3]]))
assert result == 2

# 1D, bias two
generator = token_generator(model, sampler, 1, None)
next(generator)
result, _ = generator.send(((None, None, None), [[2, 3]]))
assert result == 1


def test_generator_2d():
def model(*_):
return torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.float), None

def sampler(biased_logits, *_):
return torch.argmax(biased_logits, dim=1)

# 2D, no bias
generator = token_generator(model, sampler, 1, None)
next(generator)
result, _ = generator.send(((None, None, None), [[]]))
assert torch.equal(result, torch.tensor([3, 3]))

# 2D, bias one each
generator = token_generator(model, sampler, 1, None)
next(generator)
result, _ = generator.send(((None, None, None), [[3], [3]]))
assert torch.equal(result, torch.tensor([2, 2]))

# 2D, bias one
generator = token_generator(model, sampler, 1, None)
next(generator)
result, _ = generator.send(((None, None, None), [[3], []]))
assert torch.equal(result, torch.tensor([2, 3]))

# 2D, bias different number
generator = token_generator(model, sampler, 1, None)
next(generator)
result, _ = generator.send(((None, None, None), [[3], [2, 3]]))
assert torch.equal(result, torch.tensor([2, 1]))


@pytest.mark.xfail
def get_next_fsm_states():
raise NotImplementedError


@pytest.mark.xfail
def get_next_instructions():
raise NotImplementedError


@pytest.mark.xfail
def is_generation_finished():
raise NotImplementedError


@pytest.mark.xfail
def update_token_ids():
raise NotImplementedError


@pytest.mark.xfail
def update_attention_masks():
raise NotImplementedError

0 comments on commit 22f5533

Please sign in to comment.