diff --git a/docs/api/sample.md b/docs/api/sample.md deleted file mode 100644 index 1f962ea16..000000000 --- a/docs/api/sample.md +++ /dev/null @@ -1 +0,0 @@ -::: outlines.text.generate.sample diff --git a/docs/api/samplers.md b/docs/api/samplers.md new file mode 100644 index 000000000..b5ccd47cc --- /dev/null +++ b/docs/api/samplers.md @@ -0,0 +1 @@ +::: outlines.text.generate.samplers diff --git a/mkdocs.yml b/mkdocs.yml index 33835f8b7..e74724583 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -112,5 +112,5 @@ nav: - api/fsm.md - api/parsing.md - api/regex.md - - api/sample.md + - api/samplers.md - api/continuation.md diff --git a/outlines/text/generate/continuation.py b/outlines/text/generate/continuation.py index 8f70ccf9c..d55acc49f 100644 --- a/outlines/text/generate/continuation.py +++ b/outlines/text/generate/continuation.py @@ -5,7 +5,7 @@ from outlines.text.generate.sequence import Sequence if TYPE_CHECKING: - from outlines.text.generate.sample import Sampler + from outlines.text.generate.samplers import Sampler class Continuation(Sequence): diff --git a/outlines/text/generate/regex.py b/outlines/text/generate/regex.py index 71271467a..68f9ee694 100644 --- a/outlines/text/generate/regex.py +++ b/outlines/text/generate/regex.py @@ -12,7 +12,7 @@ from outlines.text.types import python_types_to_regex if TYPE_CHECKING: - from outlines.text.generate.sample import Sampler + from outlines.text.generate.samplers import Sampler class Regex(Continuation): diff --git a/outlines/text/generate/sample.py b/outlines/text/generate/samplers.py similarity index 100% rename from outlines/text/generate/sample.py rename to outlines/text/generate/samplers.py diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index 8550c2e8a..948266dcd 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from outlines.models.transformers import KVCacheType, Transformers - from outlines.text.generate.sample import Sampler + from outlines.text.generate.samplers import Sampler class Sequence: @@ -45,7 +45,7 @@ def __init__( model.tokenizer.pad_token_id, device=model.device ) if sampler is None: - from outlines.text.generate.sample import multinomial + from outlines.text.generate.samplers import multinomial self.sampler = multinomial else: diff --git a/outlines/text/generator.py b/outlines/text/generator.py index b972bfc3e..12cf9ba7f 100644 --- a/outlines/text/generator.py +++ b/outlines/text/generator.py @@ -1,9 +1,10 @@ import math -from typing import Generator, List +from typing import TYPE_CHECKING, Generator, List import torch -from outlines.text.generate.sample import Sampler +if TYPE_CHECKING: + from outlines.text.generate.samplers import Sampler def process(generator: Generator, index, token_ids: torch.Tensor): @@ -28,7 +29,7 @@ def process(generator: Generator, index, token_ids: torch.Tensor): yield next_token_id, token_ids -def token_generator(model, sampler: Sampler, samples: int, rng: torch.Generator): +def token_generator(model, sampler: "Sampler", samples: int, rng: torch.Generator): """Generator that yields a token every time it is called. This process is designed to be steered by another supervising diff --git a/tests/text/generate/test_sample.py b/tests/text/generate/test_samplers.py similarity index 95% rename from tests/text/generate/test_sample.py rename to tests/text/generate/test_samplers.py index 884bb4a30..5e1543fc7 100644 --- a/tests/text/generate/test_sample.py +++ b/tests/text/generate/test_samplers.py @@ -2,7 +2,11 @@ import torch -from outlines.text.generate.sample import greedy, multinomial, vectorized_random_choice +from outlines.text.generate.samplers import ( + greedy, + multinomial, + vectorized_random_choice, +) def test_greedy():