Skip to content

Commit

Permalink
Make stream output tokens with whitespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 7, 2023
1 parent f526844 commit e09397f
Showing 1 changed file with 100 additions and 12 deletions.
112 changes: 100 additions & 12 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,106 @@ def format_sequence(self, sequence):

def __call__(
self,
prompt,
kv_cache: Optional[torch.tensor] = None,
prompts: Union[str, List[str]],
rng: Optional[torch.Generator] = None,
kv_cache: Optional[torch.tensor] = None,
) -> Union[str, List[str]]:
sequence_generator = self.stream(prompt, kv_cache, rng)
tokens = [token for token in sequence_generator]
sequences = [
self.format_sequence("".join(sequence)) for sequence in list(zip(*tokens))
"""Generate the full text sequence.
Since `SequenceGenerator.stream` calls the tokenizer at every step this
method loops over the generator returned by `sequence_generator` itself
so the tokenizer is called only once after all token ids have been
generated.
Parameters
----------
prompts
A string or list of strings that are passed to the model before
generating the first token.
kv_cache
A tensor containing the past key-value cache. It can be for instance
used when we are interleaving prompting and model calls. Defaults to
`None`.
rng
The random number generator. Defaults to a non-seeded `torch.Generator`
instance.
Returns
-------
A string or list of strings that contain the generated text.
"""

if isinstance(prompts, str):
prompts = [prompts]

prompt_lengths = [len(prompt) for prompt in prompts]

if rng is None:
rng = torch.Generator(device=self.device)
rng.seed()

init_state = init_generator_state(
self.tokenizer, self.device, prompts, kv_cache
)
num_sequences = len(prompts)
init_fsm_states = [FSMState(0) for _ in range(num_sequences)]

states = sequence_generator(
self.generate_token, self.fsm, init_state, init_fsm_states, rng
)

while True:
try:
last_state = next(states)
except StopIteration:
break

sequences = self.tokenizer.decode(last_state.token_ids)
generated = [
sequence[length:] for sequence, length in zip(sequences, prompt_lengths)
]
return sequences if len(sequences) > 1 else sequences[0]
formatted = [self.format_sequence(sequence) for sequence in generated]

return formatted if len(formatted) > 1 else formatted[0]

def stream(
self,
prompt: str,
kv_cache: Optional[torch.tensor] = None,
prompts: Union[str, List[str]],
rng: Optional[torch.Generator] = None,
kv_cache: Optional[torch.tensor] = None,
) -> Iterator[Union[List[str], str]]:
"""Generate the text sequence one token at a time.
Since `Tokenizer.decode` strips the whitespaces from the tokens we have no
choice but to decode the generated token ids at each step and compare the
current decoded strings to the previously decoded strings.
Parameters
----------
prompts
A string or list of strings that are passed to the model before
generating the first token.
kv_cache
A tensor containing the past key-value cache. It can be for instance
used when we are interleaving prompting and model calls. Defaults to
`None`.
rng
The random number generator. Defaults to a non-seeded `torch.Generator`
instance.
Returns
-------
A string or list of strings that contain the generated text.
"""
if rng is None:
rng = torch.Generator(device=self.device)
rng.seed()

init_state = init_generator_state(self.tokenizer, self.device, prompt, kv_cache)
init_state = init_generator_state(
self.tokenizer, self.device, prompts, kv_cache
)

token_ids = init_state[1]
num_sequences = token_ids.shape[0]
Expand All @@ -60,14 +138,24 @@ def stream(
)

def token_generator() -> Iterator[Union[List[str], str]]:
previously_generated_sequences = ["" for _ in range(num_sequences)]
num_generated = 0
while True:
try:
sequence = next(states)
num_generated += 1
except StopIteration:
return

next_token_ids = sequence.token_ids[:, -1]
next_tokens = self.tokenizer.decode(next_token_ids)
generated_token_ids = sequence.token_ids[:, -num_generated:]
generated_sequences = self.tokenizer.decode(generated_token_ids)
next_tokens = [
token[len(sequence) :]
for token, sequence in zip(
generated_sequences, previously_generated_sequences
)
]
previously_generated_sequences = generated_sequences

yield next_tokens

Expand Down

0 comments on commit e09397f

Please sign in to comment.