From 0a11393905a5d51b6ae5b9ce8a984b5656584b67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 23 Nov 2023 09:06:31 +0100 Subject: [PATCH] WIP - Add user interface for text generation --- outlines/generate/api.py | 65 ++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 outlines/generate/api.py diff --git a/outlines/generate/api.py b/outlines/generate/api.py new file mode 100644 index 000000000..085858b2b --- /dev/null +++ b/outlines/generate/api.py @@ -0,0 +1,65 @@ +import json as pyjson +from typing import Callable, List, Optional, Union + +import pydantic._internal._model_construction.ModelMetaClass as ModelMetaClass + +from outlines.generate.generator import SequenceGenerator +from outlines.generate.samplers import Sampler, multinomial +from outlines.index.index import RegexFSM, StopAtTokenFSM +from outlines.index.json_schema import ( + build_regex_from_object, + get_schema_from_signature, +) +from outlines.index.types import python_types_to_regex + + +def text(model, max_tokens: Optional[int] = None, *, sampler: Sampler = multinomial): + eos_token = model.tokenizer.eos_token_id + fsm = StopAtTokenFSM(eos_token, max_tokens) + + device = model.device + generator = SequenceGenerator(fsm, model, sampler, device) + + return generator + + +def regex(model, regex: str, max_tokens: Optional[int], sampler: Sampler = multinomial): + fsm = RegexFSM(regex, model.tokenizer, max_tokens) + + device = model.device + generator = SequenceGenerator(fsm, model, sampler, device) + + return generator + + +def format( + model, python_type, max_tokens: Optional[int], sampler: Sampler = multinomial +): + regex_str = python_types_to_regex(python_type) + return regex(model, regex_str, max_tokens, sampler) + + +def choice( + model, + choices: List[str], + max_tokens: Optional[int] = None, + sampler: Sampler = multinomial, +): + regex_str = r"(" + r"|".join(choices) + r")" + return regex(model, regex_str, max_tokens, sampler) + + +def json( + model, + schema_object: Union[str, ModelMetaClass, Callable], + max_tokens: Optional[int] = None, + sampler: Sampler = multinomial, +): + if isinstance(schema_object, ModelMetaClass): + schema = pyjson.dumps(schema_object.model_json_schema()) + elif callable(schema_object): + schema = pyjson.dumps(get_schema_from_signature(schema_object)) + + regex_str = build_regex_from_object(schema) + + return regex(model, regex_str, max_tokens, sampler) diff --git a/pyproject.toml b/pyproject.toml index bcab189cb..b3fbd66da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ module = [ "perscache.*", "PIL", "PIL.Image", - "pydantic", + "pydantic.*", "pytest", "referencing.*", "scipy.*",