Skip to content

Commit

Permalink
WIP - Add user interface for text generation
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Nov 23, 2023
1 parent a1aa65e commit 615e550
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
65 changes: 65 additions & 0 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json as pyjson
from typing import Callable, List, Optional, Union

import pydantic._internal._model_construction.ModelMetaClass as ModelMetaClass

from outlines.generate.generator import SequenceGenerator
from outlines.generate.samplers import Sampler, multinomial
from outlines.index.index import RegexFSM, StopAtTokenFSM
from outlines.index.json_schema import (
build_regex_from_object,
get_schema_from_signature,
)
from outlines.index.types import python_types_to_regex


def text(model, max_tokens: Optional[int] = None, *, sampler: Sampler = multinomial):
eos_token = model.tokenizer.eos_token_id
fsm = StopAtTokenFSM(eos_token, max_tokens)

device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator


def regex(model, regex: str, max_tokens: Optional[int], sampler: Sampler = multinomial):
fsm = RegexFSM(regex, model.tokenizer, max_tokens)

device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)

return generator


def format(
model, python_type, max_tokens: Optional[int], sampler: Sampler = multinomial
):
regex_str = python_types_to_regex(python_type)
return regex(model, regex_str, max_tokens, sampler)


def choice(
model,
choices: List[str],
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
):
regex_str = r"(" + r"|".join(choices) + r")"
return regex(model, regex_str, max_tokens, sampler)


def json(
model,
schema_object: Union[str, ModelMetaClass, Callable],
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
):
if isinstance(schema_object, ModelMetaClass):
schema = pyjson.dumps(schema_object.model_json_schema())
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))

regex_str = build_regex_from_object(schema)

return regex(model, regex_str, max_tokens, sampler)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ module = [
"perscache.*",
"PIL",
"PIL.Image",
"pydantic",
"pydantic.*",
"pytest",
"referencing.*",
"scipy.*",
Expand Down

0 comments on commit 615e550

Please sign in to comment.