Skip to content

Commit

Permalink
Use generate.json in Outlines
Browse files Browse the repository at this point in the history
And the `output_type` should now be a Pydantic model or a JSON Schema str
  • Loading branch information
yvan-sraka committed Jan 10, 2025
1 parent ac77adb commit 214a2f3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 78 deletions.
10 changes: 2 additions & 8 deletions outlines/function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import importlib.util
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union

Expand All @@ -11,13 +10,8 @@
from outlines.generate.api import SequenceGenerator
from outlines.prompts import Prompt


# FIXME: This causes all the tests to fail...
warnings.warn(
"The 'function' module is deprecated and will be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)
# Print a deprecation message instead of raising a warning
print("The 'function' module is deprecated and will be removed in a future release.")


@dataclass
Expand Down
41 changes: 25 additions & 16 deletions outlines/outline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import ast
import json
from dataclasses import dataclass

from pydantic import BaseModel

from outlines import generate


@dataclass
class Outline:
Expand All @@ -19,34 +23,39 @@ class Outline:
Examples
--------
from outlines import models
from pydantic import BaseModel
from outlines import models, Outline
class OutputModel(BaseModel):
result: int
model = models.transformers("gpt2")
def template(a: int) -> str:
return f"What is 2 times {a}?"
def template(a: int) -> str:
return f"What is 2 times {a}?"
fn = Outline(model, template, int)
fn = Outline(model, template, OutputModel)
result = fn(3)
print(result) # Expected output: 6
result = fn(3)
print(result) # Expected output: OutputModel(result=6)
"""

def __init__(self, model, template, output_type):
self.model = model
if not (isinstance(output_type, str) or issubclass(output_type, BaseModel)):
raise TypeError(
"output_type must be a Pydantic model or a JSON Schema string"
)
self.template = template
self.output_type = output_type
self.generator = generate.json(model, output_type)

def __call__(self, *args):
prompt = self.template(*args)
response = self.model.generate(prompt)
response = self.generator(prompt)
try:
parsed_response = ast.literal_eval(response.strip())
if isinstance(parsed_response, self.output_type):
return parsed_response
else:
raise ValueError(
f"Response type {type(parsed_response)} does not match expected type {self.output_type}"
)
if isinstance(self.output_type, str):
return json.loads(response)
return self.output_type.model_validate_json(response)
except (ValueError, SyntaxError):
# If `outlines.generate.json` works as intended, this error should never be raised.
raise ValueError(f"Unable to parse response: {response.strip()}")
81 changes: 27 additions & 54 deletions tests/test_outline.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,41 @@
from unittest.mock import MagicMock
from unittest.mock import Mock, patch

import pytest
from pydantic import BaseModel

from outlines.outline import Outline
from outlines import Outline


def test_outline_int_output():
model = MagicMock()
model.generate.return_value = "6"
class OutputModel(BaseModel):
result: int

def template(a: int) -> str:
return f"What is 2 times {a}?"

fn = Outline(model, template, int)
result = fn(3)
assert result == 6
def template(a: int) -> str:
return f"What is 2 times {a}?"


def test_outline_str_output():
model = MagicMock()
model.generate.return_value = "'Hello, world!'"
def test_outline():
mock_model = Mock()
mock_generator = Mock()
mock_generator.return_value = '{"result": 6}'

def template(a: int) -> str:
return f"Say 'Hello, world!' {a} times"
with patch("outlines.generate.json", return_value=mock_generator):
outline_instance = Outline(mock_model, template, OutputModel)
result = outline_instance(3)

fn = Outline(model, template, str)
result = fn(1)
assert result == "Hello, world!"
assert result.result == 6


def test_outline_str_input():
model = MagicMock()
model.generate.return_value = "'Hi, Mark!'"
def test_outline_with_json_schema():
mock_model = Mock()
mock_generator = Mock()
mock_generator.return_value = '{"result": 6}'

def template(a: str) -> str:
return f"Say hi to {a}"
with patch("outlines.generate.json", return_value=mock_generator):
outline_instance = Outline(
mock_model,
template,
'{"type": "object", "properties": {"result": {"type": "integer"}}}',
)
result = outline_instance(3)

fn = Outline(model, template, str)
result = fn(1)
assert result == "Hi, Mark!"


def test_outline_invalid_output():
model = MagicMock()
model.generate.return_value = "not a number"

def template(a: int) -> str:
return f"What is 2 times {a}?"

fn = Outline(model, template, int)
with pytest.raises(ValueError):
fn(3)


def test_outline_mismatched_output_type():
model = MagicMock()
model.generate.return_value = "'Hello, world!'"

def template(a: int) -> str:
return f"What is 2 times {a}?"

fn = Outline(model, template, int)
with pytest.raises(
ValueError,
match="Unable to parse response: 'Hello, world!'",
):
fn(3)
assert result["result"] == 6

0 comments on commit 214a2f3

Please sign in to comment.