Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add filters to prompt function #1371

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions outlines/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, cast
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast

import jinja2
import pydantic
Expand Down Expand Up @@ -40,7 +40,9 @@ def __call__(self, *args, **kwargs) -> str:
return self.template.render(**kwargs)

@classmethod
def from_str(cls, content: str):
def from_str(
cls, content: str, filters: Union[List[Callable], Dict[str, Callable]] = []
):
"""
Create an instance of the class from a string.

Expand All @@ -53,10 +55,12 @@ def from_str(cls, content: str):
-------
An instance of the class with the provided content as a template.
"""
return cls(cls._template_from_str(content), None)
return cls(cls._template_from_str(content, filters), None)

@classmethod
def from_file(cls, path: Path):
def from_file(
cls, path: Path, filters: Union[List[Callable], Dict[str, Callable]] = []
):
"""
Create a Prompt instance from a file containing a Jinja template.

Expand All @@ -75,10 +79,12 @@ def from_file(cls, path: Path):
"""
# We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is
# split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance)
return cls(cls._template_from_file(path), None)
return cls(cls._template_from_file(path, filters), None)

@classmethod
def _template_from_str(_, content: str) -> jinja2.Template:
def _template_from_str(
_, content: str, filters: Union[List[Callable], Dict[str, Callable]] = []
) -> jinja2.Template:
# Dedent, and remove extra linebreak
cleaned_template = inspect.cleandoc(content)

Expand Down Expand Up @@ -106,10 +112,14 @@ def _template_from_str(_, content: str) -> jinja2.Template:
env.filters["schema"] = get_schema
env.filters["args"] = get_fn_args

_add_filters(env, filters)

return env.from_string(cleaned_template)

@classmethod
def _template_from_file(_, path: Path) -> jinja2.Template:
def _template_from_file(
_, path: Path, filters: Union[List[Callable], Dict[str, Callable]] = []
) -> jinja2.Template:
file_directory = os.path.dirname(os.path.abspath(path))
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(file_directory),
Expand All @@ -118,10 +128,16 @@ def _template_from_file(_, path: Path) -> jinja2.Template:
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)

_add_filters(env, filters)

return env.get_template(os.path.basename(path))


def prompt(fn: Callable) -> Prompt:
def prompt(
fn: Optional[Callable] = None,
filters: Union[List[Callable], Dict[str, Callable]] = [],
) -> Callable:
"""Decorate a function that contains a prompt template.

This allows to define prompts in the docstring of a function and simplify their
Expand Down Expand Up @@ -152,11 +168,28 @@ def prompt(fn: Callable) -> Prompt:
...
>>> hal = ft.partial(solve_task, "HAL", "Travel to Jupiter")

Additional Jinja2 filters can be provided as keyword arguments to the decorator.

>>> def reverse(s: str) -> str:
... return s[::-1]
...
>>> @outlines.prompt(filters=[reverse])
... def reverse_prompt(text):
... '''{{ text | reverse }}'''
...
>>> prompt = reverse_prompt("Hello")
>>> print(prompt)
... "olleH"

Returns
-------
A `Prompt` callable class which will render the template when called.

"""
if fn is None:
return lambda fn: prompt(
fn, cast(Union[List[Callable], Dict[str, Callable]], filters)
)

signature = inspect.signature(fn)

Expand All @@ -166,11 +199,22 @@ def prompt(fn: Callable) -> Prompt:
if docstring is None:
raise TypeError("Could not find a template in the function's docstring.")

template = Prompt._template_from_str(cast(str, docstring))
template = Prompt._template_from_str(cast(str, docstring), filters)

return Prompt(template, signature)


def _add_filters(
env: jinja2.Environment, filters: Union[List[Callable], Dict[str, Callable]]
):
if isinstance(filters, list):
for filter_fn in filters:
env.filters[filter_fn.__name__] = filter_fn
elif isinstance(filters, dict):
for name, filter_fn in filters.items():
env.filters[name] = filter_fn


def get_fn_name(fn: Callable):
"""Returns the name of a callable."""
if not callable(fn):
Expand Down
34 changes: 34 additions & 0 deletions tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,40 @@ def args_prompt(fn):
)


def test_prompt_with_additional_filters_as_dict():
def reverse(s: str) -> str:
return s[::-1]

@outlines.prompt(filters=dict(reverse=reverse))
def test_tpl(variable):
"""{{ variable | reverse }} test"""

assert list(test_tpl.signature.parameters) == ["variable"]

p = test_tpl("test")
assert p == "tset test"

p = test_tpl(variable="example")
assert p == "elpmaxe test"


def test_prompt_with_additional_filters_as_list():
def reverse(s: str) -> str:
return s[::-1]

@outlines.prompt(filters=[reverse])
def test_tpl(variable):
"""{{ variable | reverse }} test"""

assert list(test_tpl.signature.parameters) == ["variable"]

p = test_tpl("test")
assert p == "tset test"

p = test_tpl(variable="example")
assert p == "elpmaxe test"


@pytest.fixture
def temp_prompt_file():
test_dir = tempfile.mkdtemp()
Expand Down