Skip to content

Commit

Permalink
Clean SequenceGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 9, 2024
1 parent 49b069a commit 94e2a38
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 33 deletions.
45 changes: 18 additions & 27 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,34 +48,23 @@ def __init__(
def get_generated_token_ids(
self,
prompt_token_ids: torch.Tensor,
prompts: List[str],
token_ids: torch.Tensor,
num_samples: int,
) -> List[torch.Tensor]:
"""Get the tokens generated so far.
Parameters
----------
init_state
The initial state of the generation.
prompts
The prompts passed to the generator.
prompt_token_ids
Tensor that contains the token ids of the sequences' prompts.
token_ids
The generated token ids.
num_samples
The number of samples taken for each sequence
Returns
-------
A tensor that contains the token ids that have been generated so far.
"""
prompt_lengths = [
len(prompt_token_ids[i])
for _ in range(num_samples)
for i in range(len(prompts))
]

prompt_lengths = [len(prompt) for prompt in prompt_token_ids]
token_ids = [
cur_token_ids[length:]
for cur_token_ids, length in zip(token_ids, prompt_lengths)
Expand Down Expand Up @@ -240,7 +229,7 @@ def __call__(
if max_tokens or stop_sequences:
token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(
prompt_token_ids, prompts, token_ids, num_samples
prompt_token_ids, token_ids
)
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
break
Expand All @@ -252,9 +241,7 @@ def __call__(
break

token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(
prompt_token_ids, prompts, token_ids, num_samples
)
generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)

generated = self.tokenizer.decode(generated_token_ids)
stripped = [
Expand All @@ -265,14 +252,16 @@ def __call__(

# We reshape the output to (batch_size, sample_size)
output = []
for i in range(len(prompts)):
for i in range(batch_size):
output.append(formatted[i : i + num_samples])

# We remove leading dimensions for the output
if len(prompts) == 1 and num_samples == 1:
if batch_size == 1 and num_samples == 1:
return output[0][0]
elif batch_size == 1:
return output[0]
elif num_samples == 1:
return [seq[0] for seq in output]
return [samples[0] for samples in output]
else:
return output

Expand Down Expand Up @@ -394,17 +383,19 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
generated_sequences, is_stop_at_reached
)
]
# We reshape the output to (sample_size, batch_size)

# We reshape the output to (batch_size, sample_size)
output = []
step = len(prompts)
for i in range(0, len(next_tokens), step):
output.append(next_tokens[i : i + step])
for i in range(batch_size):
output.append(next_tokens[i : i + num_samples])

# We remove leading dimensions for the output
if len(prompts) == 1 and num_samples == 1:
if batch_size == 1 and num_samples == 1:
yield output[0][0]
elif num_samples == 1:
elif batch_size == 1:
yield output[0]
elif num_samples == 1:
yield [samples[0] for samples in output]
else:
yield output

Expand Down
18 changes: 12 additions & 6 deletions outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from outlines.fsm.fsm import FSM


class ContextLengthExceededError(Exception):
pass


@dataclasses.dataclass(frozen=True)
class GenerationState:
token_ids: torch.Tensor
Expand All @@ -24,7 +28,7 @@ def sequence_generator(
sampler: Callable,
fsms: List["FSM"],
token_ids: torch.Tensor,
weights: torch.Tensor,
sequence_weights: torch.Tensor,
attention_masks: torch.Tensor,
fsm_states: List[FSMState],
rng: torch.Generator = torch.Generator(),
Expand All @@ -43,7 +47,7 @@ def sequence_generator(
token_ids
A tensor of token ids on which the sequence distribution is conditioned, of
shape ``(n_seqs, n_prompt_tokens)``
weights
sequence_weights
A tensor that contains the initial weights of the sequences, of shape
``(n_seqs,)``
attention_masks
Expand All @@ -66,13 +70,15 @@ def sequence_generator(
try:
logits, kv_cache = model(token_ids, attention_masks, kv_cache)
except IndexError: # Exceeding the context length
raise IndexError(
raise ContextLengthExceededError(
"The input length exceeds the context length of the model."
)

allowed_tokens = get_allowed_tokens(fsms, fsm_states)
biased_logits = bias_logits(logits, allowed_tokens)
next_token_ids, ancestors, weights = sampler(biased_logits, weights, rng)
next_token_ids, ancestors, sequence_weights = sampler(
biased_logits, sequence_weights, rng
)

token_ids = update_token_ids(token_ids, next_token_ids, ancestors)
attention_masks = update_attention_masks(attention_masks, ancestors)
Expand All @@ -88,7 +94,7 @@ def sequence_generator(
token_ids,
kv_cache,
logits,
weights,
sequence_weights,
fsm_states,
)
return
Expand All @@ -97,7 +103,7 @@ def sequence_generator(
token_ids,
kv_cache,
logits,
weights,
sequence_weights,
fsm_states,
)

Expand Down

0 comments on commit 94e2a38

Please sign in to comment.