diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index 1a6dfde48..f76e449a1 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from outlines.generate.samplers import Sampler + from outlines.index.index import Index @dataclass @@ -15,7 +16,7 @@ class GenerationState: kv_cache: Optional[torch.Tensor] = None -def process(generator: Generator, index, state: GenerationState): +def process(generator: Generator, index: "Index", state: GenerationState): """This generator drives the text generation process by walking through the FSM.""" next(generator) diff --git a/outlines/index/index.py b/outlines/index/index.py new file mode 100644 index 000000000..89e49393f --- /dev/null +++ b/outlines/index/index.py @@ -0,0 +1,11 @@ +from typing import Callable, NamedTuple, NewType + +import torch + +State = NewType("State", int) + + +class Index(NamedTuple): + next_instruction: Callable[[State], torch.Tensor] + next_state: Callable[[State, torch.Tensor], State] + is_final: Callable[[State], bool]