From 35611ab0f125069080a19e17d53373e69720a413 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 16 Nov 2023 17:03:02 +0100 Subject: [PATCH] Add `Index` type --- outlines/generate/generator.py | 3 ++- outlines/index/index.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 outlines/index/index.py diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index 1a6dfde48..f76e449a1 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from outlines.generate.samplers import Sampler + from outlines.index.index import Index @dataclass @@ -15,7 +16,7 @@ class GenerationState: kv_cache: Optional[torch.Tensor] = None -def process(generator: Generator, index, state: GenerationState): +def process(generator: Generator, index: "Index", state: GenerationState): """This generator drives the text generation process by walking through the FSM.""" next(generator) diff --git a/outlines/index/index.py b/outlines/index/index.py new file mode 100644 index 000000000..89e49393f --- /dev/null +++ b/outlines/index/index.py @@ -0,0 +1,11 @@ +from typing import Callable, NamedTuple, NewType + +import torch + +State = NewType("State", int) + + +class Index(NamedTuple): + next_instruction: Callable[[State], torch.Tensor] + next_state: Callable[[State, torch.Tensor], State] + is_final: Callable[[State], bool]