Skip to content

Commit

Permalink
Align prompt with tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jul 19, 2023
1 parent b464334 commit 9dfdb7b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 7 deletions.
107 changes: 101 additions & 6 deletions outlines/text/generate/sequence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import List, Optional, Tuple, Union
import itertools
import math
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

import interegular
import torch

from outlines.text.parsing import find_partial_matches


class Sequence:
"""Represents a sequence generation method."""
Expand Down Expand Up @@ -171,6 +177,98 @@ def update_token_ids(

return new_token_ids

def find_boundary_tokens(self, prompt: str) -> Dict[int, List[int]]:
"""Find a list of tokens that cross the prompt boundary."""

vocabulary = {
token_id: self.model.tokenizer.decode([token_id])[0]
for token_id in range(len(self.model.tokenizer.vocabulary))
}
prompt_fsm = interegular.parse_pattern(prompt).to_fsm()

prompt_token_ids, _ = self.model.tokenizer.encode(prompt)
prompt_tokens = self.model.tokenizer.decode(prompt_token_ids[0])

token_idx_in_prompt = [0] + list(
itertools.accumulate([len(t) for t in prompt_tokens])
)[:-1]

boundary_tokens = defaultdict(list)
for token_id, token in vocabulary.items():
pmatches = find_partial_matches(prompt_fsm, token)
for pmatch in pmatches:
end_idx, states = pmatch
if end_idx is not None and states[-1] == len(prompt):
if states[0] in token_idx_in_prompt:
boundary_tokens[token_idx_in_prompt.index(states[0])].append(
token_id
)

return boundary_tokens

def align_prompt_tokens(
self, prompt: Union[str, List[str]], rng: torch.Generator
) -> Tuple[torch.LongTensor, torch.LongTensor]:
"""Align the prompt with the vocabulary."""

prompts = prompt
if isinstance(prompts, str):
prompts = [prompts]

masks = []
truncated_attention_masks = []
truncated_token_idss = []
attention_masks = []
for prompt in prompts:
boundary_tokens = self.find_boundary_tokens(prompt)

token_ids, attention_mask = self.model.tokenizer.encode(prompt)
token_ids = token_ids.to(self.device)
attention_mask = attention_mask.to(self.device)

last_token = min(boundary_tokens.keys())
truncated_token_ids = token_ids[:, :last_token]
truncated_attention_mask = attention_mask[:, :last_token]

allowed_tokens = boundary_tokens[last_token]
mask = torch.full(
(len(self.model.tokenizer.vocabulary),), -math.inf, device=self.device
)
mask[allowed_tokens] = 0

masks.append(mask)
truncated_attention_masks.append(truncated_attention_mask.squeeze())
attention_masks.append(attention_mask.squeeze())
truncated_token_idss.append(truncated_token_ids.squeeze())

# Pad left and stack
from torch.nn.utils.rnn import pad_sequence

mask = torch.vstack(masks)
truncated_attention_mask = pad_sequence(
[t.flip(dims=[0]) for t in truncated_attention_masks],
batch_first=True,
padding_value=0,
).flip(dims=[1])
attention_mask = pad_sequence(
[a.flip(dims=[0]) for a in attention_masks],
batch_first=True,
padding_value=0,
).flip(dims=[1])
truncated_token_ids = pad_sequence(
[t.flip(dims=[0]) for t in truncated_token_idss],
batch_first=True,
padding_value=self.model.tokenizer.pad_token_id,
).flip(dims=[1])

probs = self.model(truncated_token_ids, truncated_attention_mask)
probs = probs + mask
probs = torch.nn.functional.softmax(probs, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1)
token_ids = torch.concatenate([truncated_token_ids, next_token_ids], axis=-1)

return token_ids, attention_mask

@torch.inference_mode()
def __call__(
self,
Expand All @@ -192,14 +290,11 @@ def __call__(
The full sequence that contains the prompts and the generated string.
"""
token_ids, attention_mask = self.model.tokenizer.encode(prompt)

token_ids = token_ids.to(self.device)
attention_mask = attention_mask.to(self.device)

if rng is None:
rng = torch.Generator(device=self.device)

token_ids, attention_mask = self.align_prompt_tokens(prompt, rng)

num_prompt_tokens = token_ids.shape[-1]

if samples > 1:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ module = [
"scipy.*",
"tenacity.*",
"tiktoken.*",
"torch",
"torch.*",
"transformers.*",
"lark.*",
"regex.*",
Expand Down

0 comments on commit 9dfdb7b

Please sign in to comment.