Skip to content

Commit

Permalink
Add Outlines
Browse files Browse the repository at this point in the history
  • Loading branch information
yvan-sraka committed Jan 8, 2025
1 parent 3cc399d commit 26b14b9
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 0 deletions.
3 changes: 3 additions & 0 deletions outlines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Outlines is a Generative Model Programming Framework."""

import outlines.generate
import outlines.grammars
import outlines.models
Expand All @@ -7,6 +8,7 @@
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
from outlines.function import Function
from outlines.outline import Outline
from outlines.prompts import prompt

__all__ = [
Expand All @@ -17,4 +19,5 @@
"prompt",
"vectorize",
"grammars",
"Outline",
]
52 changes: 52 additions & 0 deletions outlines/outline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import ast
from dataclasses import dataclass


@dataclass
class Outline:
"""
Outline is a class that creates a callable object to generate responses
based on a given model and a prompt template.
Args:
model: The model to be used for generating responses.
template (function): A function that takes arguments and returns a prompt string.
output_type: The expected output type of the generated response.
Example:
from outlines import models
model = models.transformers("gpt2")
def template(a: int) -> str:
return f"What is 2 times {a}?"
fn = Outline(model, template, int)
result = fn(3)
print(result) # Expected output: 6
"""

def __init__(self, model, template, output_type):
self.model = model
self.template = template
self.output_type = output_type

def __call__(self, *args):
# Generate the prompt using the template function
prompt = self.template(*args)

# Generate the response using the model
response = self.model.generate(prompt)

# Process the response to match the expected output type
try:
parsed_response = ast.literal_eval(response.strip())
if isinstance(parsed_response, self.output_type):
return parsed_response
else:
raise ValueError(
f"Response type {type(parsed_response)} does not match expected type {self.output_type}"
)
except (ValueError, SyntaxError):
raise ValueError(f"Unable to parse response: {response.strip()}")
76 changes: 76 additions & 0 deletions tests/test_outline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from unittest.mock import MagicMock

import pytest

from outlines.outline import Outline


def test_outline_int_output():
# Mock the model
model = MagicMock()
model.generate.return_value = "6"

# Define the template function
def template(a: int) -> str:
return f"What is 2 times {a}?"

# Create an instance of Outline
fn = Outline(model, template, int)

# Test the callable object
result = fn(3)
assert result == 6


def test_outline_str_output():
# Mock the model
model = MagicMock()
model.generate.return_value = "'Hello, world!'"

# Define the template function
def template(a: int) -> str:
return f"Say hello {a} times"

# Create an instance of Outline
fn = Outline(model, template, str)

# Test the callable object
result = fn(1)
assert result == "Hello, world!"


def test_outline_invalid_output():
# Mock the model
model = MagicMock()
model.generate.return_value = "not a number"

# Define the template function
def template(a: int) -> str:
return f"What is 2 times {a}?"

# Create an instance of Outline
fn = Outline(model, template, int)

# Test the callable object with invalid output
with pytest.raises(ValueError):
fn(3)


def test_outline_mismatched_output_type():
# Mock the model
model = MagicMock()
model.generate.return_value = "'Hello, world!'"

# Define the template function
def template(a: int) -> str:
return f"What is 2 times {a}?"

# Create an instance of Outline
fn = Outline(model, template, int)

# Test the callable object with mismatched output type
with pytest.raises(
ValueError,
match="Response type <class 'str'> does not match expected type <class 'int'>",
):
fn(3)

0 comments on commit 26b14b9

Please sign in to comment.