From a05f5580931e9794702c6aaebef860a87aac4ac4 Mon Sep 17 00:00:00 2001 From: Frederik Fix Date: Fri, 10 Jan 2025 13:38:30 +0100 Subject: [PATCH 1/4] add filters to prompt function --- outlines/prompts.py | 40 +++++++++++++++++++++++++++++++--------- tests/test_prompts.py | 16 ++++++++++++++++ 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/outlines/prompts.py b/outlines/prompts.py index 86519adaf..f40d27dd8 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -6,7 +6,7 @@ import textwrap from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, Optional, Type, cast +from typing import Any, Callable, List, Dict, Optional, Type, cast, Union import jinja2 import pydantic @@ -40,7 +40,7 @@ def __call__(self, *args, **kwargs) -> str: return self.template.render(**kwargs) @classmethod - def from_str(cls, content: str): + def from_str(cls, content: str, **filters: Dict[str, Callable]): """ Create an instance of the class from a string. @@ -53,10 +53,10 @@ def from_str(cls, content: str): ------- An instance of the class with the provided content as a template. """ - return cls(cls._template_from_str(content), None) + return cls(cls._template_from_str(content, **filters), None) @classmethod - def from_file(cls, path: Path): + def from_file(cls, path: Path, **filters: Dict[str, Callable]): """ Create a Prompt instance from a file containing a Jinja template. @@ -75,10 +75,10 @@ def from_file(cls, path: Path): """ # We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is # split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance) - return cls(cls._template_from_file(path), None) + return cls(cls._template_from_file(path, **filters), None) @classmethod - def _template_from_str(_, content: str) -> jinja2.Template: + def _template_from_str(_, content: str, **filters: Dict[str, Callable]) -> jinja2.Template: # Dedent, and remove extra linebreak cleaned_template = inspect.cleandoc(content) @@ -106,10 +106,13 @@ def _template_from_str(_, content: str) -> jinja2.Template: env.filters["schema"] = get_schema env.filters["args"] = get_fn_args + for name, filter_fn in filters.items(): + env.filters[name] = filter_fn + return env.from_string(cleaned_template) @classmethod - def _template_from_file(_, path: Path) -> jinja2.Template: + def _template_from_file(_, path: Path, **filters: Dict[str, Callable]) -> jinja2.Template: file_directory = os.path.dirname(os.path.abspath(path)) env = jinja2.Environment( loader=jinja2.FileSystemLoader(file_directory), @@ -118,10 +121,14 @@ def _template_from_file(_, path: Path) -> jinja2.Template: keep_trailing_newline=True, undefined=jinja2.StrictUndefined, ) + + for name, filter_fn in filters.items(): + env.filters[name] = filter_fn + return env.get_template(os.path.basename(path)) -def prompt(fn: Callable) -> Prompt: +def prompt(fn: Optional[Callable] = None, **filters: Dict[str, Callable]) -> Callable: """Decorate a function that contains a prompt template. This allows to define prompts in the docstring of a function and simplify their @@ -152,11 +159,26 @@ def prompt(fn: Callable) -> Prompt: ... >>> hal = ft.partial(solve_task, "HAL", "Travel to Jupiter") + Additional Jinja2 filters can be provided as keyword arguments to the decorator. + + >>> def reverse_string(s: str) -> str: + ... return s[::-1] + ... + >>> @outlines.prompt(reverse=reverse_string) + ... def reverse_prompt(text): + ... '''{{ text | reverse }}''' + ... + >>> prompt = reverse_prompt("Hello") + >>> print(prompt) + ... "olleH" + Returns ------- A `Prompt` callable class which will render the template when called. """ + if fn is None: + return lambda fn: prompt(fn, **filters) signature = inspect.signature(fn) @@ -166,7 +188,7 @@ def prompt(fn: Callable) -> Prompt: if docstring is None: raise TypeError("Could not find a template in the function's docstring.") - template = Prompt._template_from_str(cast(str, docstring)) + template = Prompt._template_from_str(cast(str, docstring), **filters) return Prompt(template, signature) diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 4cc4d8ff1..64a6e9711 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -321,6 +321,22 @@ def args_prompt(fn): ) +def test_prompt_with_additional_filters(): + def reverse_string(s: str) -> str: + return s[::-1] + + @outlines.prompt(reverse=reverse_string) + def test_tpl(variable): + """{{ variable | reverse }} test""" + + assert list(test_tpl.signature.parameters) == ["variable"] + + p = test_tpl("test") + assert p == "tset test" + + p = test_tpl(variable="example") + assert p == "elpmaxe test" + @pytest.fixture def temp_prompt_file(): test_dir = tempfile.mkdtemp() From e7f229e3b20ba95ed692f5a839c9b42a69f84fb4 Mon Sep 17 00:00:00 2001 From: Frederik Fix Date: Fri, 10 Jan 2025 15:40:48 +0100 Subject: [PATCH 2/4] code style fixes --- outlines/prompts.py | 10 +++++++--- tests/test_prompts.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/outlines/prompts.py b/outlines/prompts.py index f40d27dd8..9f62bfc0d 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -6,7 +6,7 @@ import textwrap from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, List, Dict, Optional, Type, cast, Union +from typing import Any, Callable, Dict, Optional, Type, cast import jinja2 import pydantic @@ -78,7 +78,9 @@ def from_file(cls, path: Path, **filters: Dict[str, Callable]): return cls(cls._template_from_file(path, **filters), None) @classmethod - def _template_from_str(_, content: str, **filters: Dict[str, Callable]) -> jinja2.Template: + def _template_from_str( + _, content: str, **filters: Dict[str, Callable] + ) -> jinja2.Template: # Dedent, and remove extra linebreak cleaned_template = inspect.cleandoc(content) @@ -112,7 +114,9 @@ def _template_from_str(_, content: str, **filters: Dict[str, Callable]) -> jinja return env.from_string(cleaned_template) @classmethod - def _template_from_file(_, path: Path, **filters: Dict[str, Callable]) -> jinja2.Template: + def _template_from_file( + _, path: Path, **filters: Dict[str, Callable] + ) -> jinja2.Template: file_directory = os.path.dirname(os.path.abspath(path)) env = jinja2.Environment( loader=jinja2.FileSystemLoader(file_directory), diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 64a6e9711..44a9e4583 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -337,6 +337,7 @@ def test_tpl(variable): p = test_tpl(variable="example") assert p == "elpmaxe test" + @pytest.fixture def temp_prompt_file(): test_dir = tempfile.mkdtemp() From d0420a8d673033a9ad54356fd821c206747ee1af Mon Sep 17 00:00:00 2001 From: Frederik Fix Date: Sat, 11 Jan 2025 04:31:35 +0100 Subject: [PATCH 3/4] api style change --- outlines/prompts.py | 58 +++++++++++++++++++++++++++++-------------- tests/test_prompts.py | 23 ++++++++++++++--- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/outlines/prompts.py b/outlines/prompts.py index 9f62bfc0d..d50173a4e 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -6,7 +6,7 @@ import textwrap from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, Optional, Type, cast +from typing import Any, Callable, Dict, List, Optional, Type, Union, cast import jinja2 import pydantic @@ -40,7 +40,9 @@ def __call__(self, *args, **kwargs) -> str: return self.template.render(**kwargs) @classmethod - def from_str(cls, content: str, **filters: Dict[str, Callable]): + def from_str( + cls, content: str, filters: Union[List[Callable], Dict[str, Callable]] = [] + ): """ Create an instance of the class from a string. @@ -53,10 +55,12 @@ def from_str(cls, content: str, **filters: Dict[str, Callable]): ------- An instance of the class with the provided content as a template. """ - return cls(cls._template_from_str(content, **filters), None) + return cls(cls._template_from_str(content, filters), None) @classmethod - def from_file(cls, path: Path, **filters: Dict[str, Callable]): + def from_file( + cls, path: Path, filters: Union[List[Callable], Dict[str, Callable]] = [] + ): """ Create a Prompt instance from a file containing a Jinja template. @@ -75,11 +79,11 @@ def from_file(cls, path: Path, **filters: Dict[str, Callable]): """ # We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is # split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance) - return cls(cls._template_from_file(path, **filters), None) + return cls(cls._template_from_file(path, filters), None) @classmethod def _template_from_str( - _, content: str, **filters: Dict[str, Callable] + _, content: str, filters: Union[List[Callable], Dict[str, Callable]] = [] ) -> jinja2.Template: # Dedent, and remove extra linebreak cleaned_template = inspect.cleandoc(content) @@ -108,14 +112,13 @@ def _template_from_str( env.filters["schema"] = get_schema env.filters["args"] = get_fn_args - for name, filter_fn in filters.items(): - env.filters[name] = filter_fn + _add_filters(env, filters) return env.from_string(cleaned_template) @classmethod def _template_from_file( - _, path: Path, **filters: Dict[str, Callable] + _, path: Path, filters: Union[List[Callable], Dict[str, Callable]] = [] ) -> jinja2.Template: file_directory = os.path.dirname(os.path.abspath(path)) env = jinja2.Environment( @@ -126,13 +129,17 @@ def _template_from_file( undefined=jinja2.StrictUndefined, ) - for name, filter_fn in filters.items(): - env.filters[name] = filter_fn + _add_filters(env, filters) return env.get_template(os.path.basename(path)) -def prompt(fn: Optional[Callable] = None, **filters: Dict[str, Callable]) -> Callable: +def prompt( + fn_or_filters: Optional[ + Union[Callable, List[Callable], Dict[str, Callable]] + ] = None, + filters: Union[List[Callable], Dict[str, Callable]] = [], +) -> Callable: """Decorate a function that contains a prompt template. This allows to define prompts in the docstring of a function and simplify their @@ -165,10 +172,10 @@ def prompt(fn: Optional[Callable] = None, **filters: Dict[str, Callable]) -> Cal Additional Jinja2 filters can be provided as keyword arguments to the decorator. - >>> def reverse_string(s: str) -> str: + >>> def reverse(s: str) -> str: ... return s[::-1] ... - >>> @outlines.prompt(reverse=reverse_string) + >>> @outlines.prompt([reverse]) ... def reverse_prompt(text): ... '''{{ text | reverse }}''' ... @@ -181,22 +188,35 @@ def prompt(fn: Optional[Callable] = None, **filters: Dict[str, Callable]) -> Cal A `Prompt` callable class which will render the template when called. """ - if fn is None: - return lambda fn: prompt(fn, **filters) + if not callable(fn_or_filters): + return lambda fn: prompt( + fn, cast(Union[List[Callable], Dict[str, Callable]], fn_or_filters) + ) - signature = inspect.signature(fn) + signature = inspect.signature(fn_or_filters) # The docstring contains the template that will be rendered to be used # as a prompt to the language model. - docstring = fn.__doc__ + docstring = fn_or_filters.__doc__ if docstring is None: raise TypeError("Could not find a template in the function's docstring.") - template = Prompt._template_from_str(cast(str, docstring), **filters) + template = Prompt._template_from_str(cast(str, docstring), filters) return Prompt(template, signature) +def _add_filters( + env: jinja2.Environment, filters: Union[List[Callable], Dict[str, Callable]] +): + if isinstance(filters, list): + for filter_fn in filters: + env.filters[filter_fn.__name__] = filter_fn + elif isinstance(filters, dict): + for name, filter_fn in filters.items(): + env.filters[name] = filter_fn + + def get_fn_name(fn: Callable): """Returns the name of a callable.""" if not callable(fn): diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 44a9e4583..7fe1d770f 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -321,11 +321,28 @@ def args_prompt(fn): ) -def test_prompt_with_additional_filters(): - def reverse_string(s: str) -> str: +def test_prompt_with_additional_filters_as_dict(): + def reverse(s: str) -> str: return s[::-1] - @outlines.prompt(reverse=reverse_string) + @outlines.prompt(dict(reverse=reverse)) + def test_tpl(variable): + """{{ variable | reverse }} test""" + + assert list(test_tpl.signature.parameters) == ["variable"] + + p = test_tpl("test") + assert p == "tset test" + + p = test_tpl(variable="example") + assert p == "elpmaxe test" + + +def test_prompt_with_additional_filters_as_list(): + def reverse(s: str) -> str: + return s[::-1] + + @outlines.prompt([reverse]) def test_tpl(variable): """{{ variable | reverse }} test""" From 04754dfcb52eecdfc160baf76525a04569e00a07 Mon Sep 17 00:00:00 2001 From: Frederik Fix Date: Sat, 11 Jan 2025 09:02:00 +0100 Subject: [PATCH 4/4] pass the filters with an argument --- outlines/prompts.py | 14 ++++++-------- tests/test_prompts.py | 4 ++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/outlines/prompts.py b/outlines/prompts.py index d50173a4e..9085d04ee 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -135,9 +135,7 @@ def _template_from_file( def prompt( - fn_or_filters: Optional[ - Union[Callable, List[Callable], Dict[str, Callable]] - ] = None, + fn: Optional[Callable] = None, filters: Union[List[Callable], Dict[str, Callable]] = [], ) -> Callable: """Decorate a function that contains a prompt template. @@ -175,7 +173,7 @@ def prompt( >>> def reverse(s: str) -> str: ... return s[::-1] ... - >>> @outlines.prompt([reverse]) + >>> @outlines.prompt(filters=[reverse]) ... def reverse_prompt(text): ... '''{{ text | reverse }}''' ... @@ -188,16 +186,16 @@ def prompt( A `Prompt` callable class which will render the template when called. """ - if not callable(fn_or_filters): + if fn is None: return lambda fn: prompt( - fn, cast(Union[List[Callable], Dict[str, Callable]], fn_or_filters) + fn, cast(Union[List[Callable], Dict[str, Callable]], filters) ) - signature = inspect.signature(fn_or_filters) + signature = inspect.signature(fn) # The docstring contains the template that will be rendered to be used # as a prompt to the language model. - docstring = fn_or_filters.__doc__ + docstring = fn.__doc__ if docstring is None: raise TypeError("Could not find a template in the function's docstring.") diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 7fe1d770f..9b67d6ecd 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -325,7 +325,7 @@ def test_prompt_with_additional_filters_as_dict(): def reverse(s: str) -> str: return s[::-1] - @outlines.prompt(dict(reverse=reverse)) + @outlines.prompt(filters=dict(reverse=reverse)) def test_tpl(variable): """{{ variable | reverse }} test""" @@ -342,7 +342,7 @@ def test_prompt_with_additional_filters_as_list(): def reverse(s: str) -> str: return s[::-1] - @outlines.prompt([reverse]) + @outlines.prompt(filters=[reverse]) def test_tpl(variable): """{{ variable | reverse }} test"""