Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Beam Search sampler #618

Merged
merged 3 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ First time here? Go to our [setup guide](https://outlines-dev.github.io/outlines
- [x] 🐍 Interleave completions with loops, conditionals, and custom Python functions
- [x] 💾 Caching of generations
- [x] 🗂️ Batch inference
- [x] 🎲 Sample with the greedy, multinomial and beam search algorithms (and more to come!)
- [x] 🚀 [Serve with vLLM](https://outlines-dev.github.io/outlines/reference/vllm)


Expand Down
21 changes: 21 additions & 0 deletions docs/reference/samplers.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Samplers

Outlines offers different sequence sampling algorithms, and we will integrate more in the future. You can read [this blog post](https://huggingface.co/blog/how-to-generate) for an overview of the different sampling algorithm.

## Multinomial sampling

Outlines defaults to the multinomial sampler without top-p or top-k sampling, and temperature equal to 1. Not specifying a sampler is equivalent to:
Expand Down Expand Up @@ -71,3 +73,22 @@ print(answer)
```

You cannot ask for multiple samples with the greedy sampler since it does not clear what the result should be.


## Beam Search

Outlines also comes with the Beam Search sampling algorithm:

```python
from outlines import models, generate, samplers


model = models.transformers("mistralai/Mistral-7B-Instruct-v0.2")
sampler = samplers.beam_search(beams=5)

generator = generate.text(model, sampler=sampler)
answer = generator("What is 2+2?")

print(answer)
rlouf marked this conversation as resolved.
Show resolved Hide resolved
# 4
```
79 changes: 40 additions & 39 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from outlines.fsm.fsm import FSMState
from outlines.generate.generator import sequence_generator, token_generator
from outlines.generate.generator import sequence_generator


class SequenceGenerator:
Expand All @@ -18,12 +18,13 @@ def __init__(
max_tokens=None,
stop_at=None,
):
self.generate_token = token_generator(model, sampler)
self.fsm = fsm
self.model = model
self.sampler = sampler
self.tokenizer = model.tokenizer
self.device = device
self.max_tokens = max_tokens
self.num_particles = sampler.particles
self.num_samples = sampler.samples

if isinstance(stop_at, str):
stop_at = [stop_at]
Expand All @@ -47,34 +48,23 @@ def __init__(
def get_generated_token_ids(
self,
prompt_token_ids: torch.Tensor,
prompts: List[str],
token_ids: torch.Tensor,
num_samples: int,
) -> List[torch.Tensor]:
"""Get the tokens generated so far.

Parameters
----------
init_state
The initial state of the generation.
prompts
The prompts passed to the generator.
prompt_token_ids
Tensor that contains the token ids of the sequences' prompts.
token_ids
The generated token ids.
num_samples
The number of samples taken for each sequence

Returns
-------
A tensor that contains the token ids that have been generated so far.

"""
prompt_lengths = [
len(prompt_token_ids[i])
for _ in range(num_samples)
for i in range(len(prompts))
]

prompt_lengths = [len(prompt) for prompt in prompt_token_ids]
token_ids = [
cur_token_ids[length:]
for cur_token_ids, length in zip(token_ids, prompt_lengths)
Expand Down Expand Up @@ -198,7 +188,7 @@ def __call__(

stop_sequences = stop_at or self.stop_sequences
max_tokens = max_tokens or self.max_tokens
num_samples = self.num_particles
num_samples = self.num_samples

if rng is None:
rng = torch.Generator(device=self.device)
Expand All @@ -211,18 +201,23 @@ def __call__(
# To draw multiple samples we repeat the prompt as many times
# as there are samples. We copy the FSMs and initialize the
# FSM states.
num_samples = self.num_particles
num_samples = self.num_samples
batch_size = len(prompts)

prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
weights = torch.zeros(
(batch_size * num_samples), dtype=torch.float, device=self.device
)

states = sequence_generator(
self.generate_token,
self.model,
self.sampler,
fsms,
prompt_token_ids,
weights,
rlouf marked this conversation as resolved.
Show resolved Hide resolved
attention_masks,
fsm_states,
rng=rng,
Expand All @@ -234,7 +229,7 @@ def __call__(
if max_tokens or stop_sequences:
token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(
prompt_token_ids, prompts, token_ids, num_samples
prompt_token_ids, token_ids
)
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
break
Expand All @@ -246,9 +241,7 @@ def __call__(
break

token_ids = last_state.token_ids
generated_token_ids = self.get_generated_token_ids(
prompt_token_ids, prompts, token_ids, num_samples
)
generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids)

generated = self.tokenizer.decode(generated_token_ids)
stripped = [
Expand All @@ -257,17 +250,18 @@ def __call__(
]
formatted = [self.format_sequence(sequence) for sequence in stripped]

# We reshape the output to (sample_size, batch_size)
# We reshape the output to (batch_size, sample_size)
output = []
step = len(prompts)
for i in range(0, len(formatted), step):
output.append(formatted[i : i + step])
for i in range(batch_size):
output.append(formatted[i : i + num_samples])

# We remove leading dimensions for the output
if len(prompts) == 1 and num_samples == 1:
if batch_size == 1 and num_samples == 1:
return output[0][0]
elif num_samples == 1:
elif batch_size == 1:
return output[0]
elif num_samples == 1:
return [samples[0] for samples in output]
else:
return output

Expand Down Expand Up @@ -317,7 +311,7 @@ def stream(

stop_sequences = stop_at or self.stop_sequences
max_tokens = max_tokens or self.max_tokens
num_samples = self.num_particles
num_samples = self.num_samples

if rng is None:
rng = torch.Generator(device=self.device)
Expand All @@ -330,18 +324,23 @@ def stream(
# To draw multiple samples we repeat the prompt as many times
# as there are samples. We copy the FSMs and initialize the
# FSM states.
num_samples = self.num_particles
num_samples = self.num_samples
batch_size = len(prompts)

prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0)
attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0)
fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
weights = torch.zeros(
(batch_size * num_samples), dtype=torch.float, device=self.device
)

states = sequence_generator(
self.generate_token,
self.model,
self.sampler,
fsms,
prompt_token_ids,
weights,
attention_masks,
fsm_states,
rng=rng,
Expand Down Expand Up @@ -384,17 +383,19 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
generated_sequences, is_stop_at_reached
)
]
# We reshape the output to (sample_size, batch_size)

# We reshape the output to (batch_size, sample_size)
output = []
step = len(prompts)
for i in range(0, len(next_tokens), step):
output.append(next_tokens[i : i + step])
for i in range(batch_size):
output.append(next_tokens[i : i + num_samples])

# We remove leading dimensions for the output
if len(prompts) == 1 and num_samples == 1:
if batch_size == 1 and num_samples == 1:
yield output[0][0]
elif num_samples == 1:
elif batch_size == 1:
yield output[0]
elif num_samples == 1:
yield [samples[0] for samples in output]
else:
yield output

Expand Down
Loading
Loading