Skip to content

Commit

Permalink
Add top-p processor to the multinomial sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 12, 2024
1 parent 441e533 commit 7ce7d28
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 12 deletions.
9 changes: 9 additions & 0 deletions docs/reference/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ You can ask Outlines to only consider the top-k logits at each step by specifyin
sampler = samplers.multinomial(3, top_k=10)
```

### Top-p sampling

You can ask Outlines to only consider the highest probability tokens such that their cumulative probability is greater than a threshold `p`. Specify the `top_p` keyword argument when initializing the sampler:


```python
sampler = samplers.multinomial(3, top_p=0.95)
```

## Greedy sampler

You can also use the greedy sampler. For this you need to initialize the generator with the sampler:
Expand Down
46 changes: 41 additions & 5 deletions outlines/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,20 @@ class MultinomialSampler:
"""

def __init__(self, samples: int = 1, *, top_k: Optional[int] = None):
def __init__(
self,
samples: int = 1,
*,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
):
self.samples = samples

self.logits_processor = lambda x: x
if top_k is not None:
self.logits_processor = keep_top_k_logits(top_k)
elif top_p is not None:
self.logits_processor = keep_top_p_logits(top_p)

def __call__(
self,
Expand Down Expand Up @@ -138,7 +146,7 @@ def __call__(
multinomial = MultinomialSampler


def keep_top_k_logits(k) -> Callable[[torch.Tensor], torch.Tensor]:
def keep_top_k_logits(k: int) -> Callable[[torch.Tensor], torch.Tensor]:
"""Build a function that masks logits values smaller than the top `k` ones.
Parameters
Expand All @@ -148,9 +156,7 @@ def keep_top_k_logits(k) -> Callable[[torch.Tensor], torch.Tensor]:
"""
if not isinstance(k, int) or k < 1:
raise ValueError(
f"`top_k` must be a strictly positive integers, got {k} instead."
)
raise ValueError(f"`k` must be a strictly positive integers, got {k} instead.")

def logits_processor(logits: torch.Tensor) -> torch.Tensor:
num_to_keep = min(k, logits.size(-1))
Expand All @@ -160,6 +166,36 @@ def logits_processor(logits: torch.Tensor) -> torch.Tensor:
return logits_processor


def keep_top_p_logits(p: float) -> Callable[[torch.Tensor], torch.Tensor]:
"""Build a function that masks the lowest probability tokens whose
cumulative probability is below a certain threshold.
Parameters
----------
p
The value of the threshold. We keep the highest probability tokens whose
cumulative distribution is greater than or equal to `p` and mask the
others. Its value must be between 0 (excluded) and 1 (included).
"""
if p <= 0.0 or p > 1.0:
raise ValueError(
f"`p` must be a floating point number between 0 (excluded) and 1 (included), got {p} instead."
)

def logits_processor(logits: torch.Tensor) -> torch.Tensor:
sorted_logits, sorted_idx = torch.sort(logits, descending=False)
cumulative_probabilties = torch.nn.functional.softmax(
sorted_logits, dim=-1
).cumsum(dim=-1)

sorted_masked_idx = cumulative_probabilties <= (1 - p)
mask_idx = torch.scatter(sorted_masked_idx, 1, sorted_idx, sorted_masked_idx)
return logits.masked_fill(mask_idx, -math.inf)

return logits_processor


class BeamSearchSampler:
"""Beam Search sampling algorithm.
Expand Down
75 changes: 68 additions & 7 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
beam_search,
greedy,
keep_top_k_logits,
keep_top_p_logits,
multinomial,
)

Expand Down Expand Up @@ -71,7 +72,13 @@ def test_multinomial():
assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]]))


def test_topk():
def test_top_k():
with pytest.raises(ValueError, match="`k` must be a strictly"):
keep_top_k_logits(-1)

with pytest.raises(ValueError, match="`k` must be a strictly"):
keep_top_k_logits(0.1)

logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]])

logits_processor = keep_top_k_logits(1)
Expand All @@ -82,12 +89,6 @@ def test_topk():
result = logits_processor(logits)
assert result.equal(torch.tensor([[1.0, 2.0, 3.0, 4.0]]))

with pytest.raises(ValueError, match="`top_k` must be a strictly"):
keep_top_k_logits(-1)

with pytest.raises(ValueError, match="`top_k` must be a strictly"):
keep_top_k_logits(0.1)

logits = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]])
logits_processor = keep_top_k_logits(2)
result = logits_processor(logits)
Expand All @@ -98,6 +99,66 @@ def test_topk():
)


def test_top_p():
with pytest.raises(ValueError, match="`p` must be a floating point"):
keep_top_p_logits(-0.1)

with pytest.raises(ValueError, match="`p` must be a floating point"):
keep_top_p_logits(0.0)

with pytest.raises(ValueError, match="`p` must be a floating point"):
keep_top_p_logits(1.1)

logits = torch.tensor([[1.0, 1.01, 1.02, 4.0]])

logits_processor = keep_top_p_logits(0.1)
result = logits_processor(logits)
assert result.equal(torch.tensor([[-math.inf, -math.inf, -math.inf, 4.0]]))

logits_processor = keep_top_p_logits(0.95)
result = logits_processor(logits)
assert result.equal(torch.tensor([[-math.inf, 1.01, 1.02, 4.0]]))

logits_processor = keep_top_p_logits(1.0)
result = logits_processor(logits)
assert result.equal(torch.tensor([[1.0, 1.01, 1.02, 4.0]]))

logits = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]])

logits_processor = keep_top_p_logits(0.1)
result = logits_processor(logits)
assert result.equal(
torch.tensor(
[
[-math.inf, -math.inf, -math.inf, 4.0],
[-math.inf, -math.inf, -math.inf, 8.0],
]
)
)

logits_processor = keep_top_p_logits(0.95)
result = logits_processor(logits)
assert result.equal(
torch.tensor(
[
[-math.inf, 2.0, 3.0, 4.0],
[-math.inf, 6.0, 7.0, 8.0],
]
)
)

logits_processor = keep_top_p_logits(1.0)
result = logits_processor(logits)
assert result.equal(
torch.tensor(
[
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
]
)
)


def test_beam_search():
# Two beams, single sequence
sampler = BeamSearchSampler(2)
Expand Down

0 comments on commit 7ce7d28

Please sign in to comment.