From e09397f82846bc44b66324c08e27bcf5e653f839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 7 Dec 2023 10:47:32 +0100 Subject: [PATCH] Make `stream` output tokens with whitespaces --- outlines/generate/api.py | 112 ++++++++++++++++++++++++++++++++++----- 1 file changed, 100 insertions(+), 12 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index f6c747767..7130bf709 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -27,28 +27,106 @@ def format_sequence(self, sequence): def __call__( self, - prompt, - kv_cache: Optional[torch.tensor] = None, + prompts: Union[str, List[str]], rng: Optional[torch.Generator] = None, + kv_cache: Optional[torch.tensor] = None, ) -> Union[str, List[str]]: - sequence_generator = self.stream(prompt, kv_cache, rng) - tokens = [token for token in sequence_generator] - sequences = [ - self.format_sequence("".join(sequence)) for sequence in list(zip(*tokens)) + """Generate the full text sequence. + + Since `SequenceGenerator.stream` calls the tokenizer at every step this + method loops over the generator returned by `sequence_generator` itself + so the tokenizer is called only once after all token ids have been + generated. + + Parameters + ---------- + prompts + A string or list of strings that are passed to the model before + generating the first token. + kv_cache + A tensor containing the past key-value cache. It can be for instance + used when we are interleaving prompting and model calls. Defaults to + `None`. + rng + The random number generator. Defaults to a non-seeded `torch.Generator` + instance. + + Returns + ------- + A string or list of strings that contain the generated text. + + """ + + if isinstance(prompts, str): + prompts = [prompts] + + prompt_lengths = [len(prompt) for prompt in prompts] + + if rng is None: + rng = torch.Generator(device=self.device) + rng.seed() + + init_state = init_generator_state( + self.tokenizer, self.device, prompts, kv_cache + ) + num_sequences = len(prompts) + init_fsm_states = [FSMState(0) for _ in range(num_sequences)] + + states = sequence_generator( + self.generate_token, self.fsm, init_state, init_fsm_states, rng + ) + + while True: + try: + last_state = next(states) + except StopIteration: + break + + sequences = self.tokenizer.decode(last_state.token_ids) + generated = [ + sequence[length:] for sequence, length in zip(sequences, prompt_lengths) ] - return sequences if len(sequences) > 1 else sequences[0] + formatted = [self.format_sequence(sequence) for sequence in generated] + + return formatted if len(formatted) > 1 else formatted[0] def stream( self, - prompt: str, - kv_cache: Optional[torch.tensor] = None, + prompts: Union[str, List[str]], rng: Optional[torch.Generator] = None, + kv_cache: Optional[torch.tensor] = None, ) -> Iterator[Union[List[str], str]]: + """Generate the text sequence one token at a time. + + Since `Tokenizer.decode` strips the whitespaces from the tokens we have no + choice but to decode the generated token ids at each step and compare the + current decoded strings to the previously decoded strings. + + Parameters + ---------- + prompts + A string or list of strings that are passed to the model before + generating the first token. + kv_cache + A tensor containing the past key-value cache. It can be for instance + used when we are interleaving prompting and model calls. Defaults to + `None`. + rng + The random number generator. Defaults to a non-seeded `torch.Generator` + instance. + + Returns + ------- + A string or list of strings that contain the generated text. + + """ if rng is None: rng = torch.Generator(device=self.device) rng.seed() - init_state = init_generator_state(self.tokenizer, self.device, prompt, kv_cache) + init_state = init_generator_state( + self.tokenizer, self.device, prompts, kv_cache + ) token_ids = init_state[1] num_sequences = token_ids.shape[0] @@ -60,14 +138,24 @@ def stream( ) def token_generator() -> Iterator[Union[List[str], str]]: + previously_generated_sequences = ["" for _ in range(num_sequences)] + num_generated = 0 while True: try: sequence = next(states) + num_generated += 1 except StopIteration: return - next_token_ids = sequence.token_ids[:, -1] - next_tokens = self.tokenizer.decode(next_token_ids) + generated_token_ids = sequence.token_ids[:, -num_generated:] + generated_sequences = self.tokenizer.decode(generated_token_ids) + next_tokens = [ + token[len(sequence) :] + for token, sequence in zip( + generated_sequences, previously_generated_sequences + ) + ] + previously_generated_sequences = generated_sequences yield next_tokens