diff --git a/outlines/prompts.py b/outlines/prompts.py index 86519adaf..1cc264226 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -40,7 +40,7 @@ def __call__(self, *args, **kwargs) -> str: return self.template.render(**kwargs) @classmethod - def from_str(cls, content: str): + def from_str(cls, content: str, filters: Dict[str, Callable] = {}): """ Create an instance of the class from a string. @@ -53,10 +53,10 @@ def from_str(cls, content: str): ------- An instance of the class with the provided content as a template. """ - return cls(cls._template_from_str(content), None) + return cls(cls._template_from_str(content, filters), None) @classmethod - def from_file(cls, path: Path): + def from_file(cls, path: Path, filters: Dict[str, Callable] = {}): """ Create a Prompt instance from a file containing a Jinja template. @@ -75,10 +75,12 @@ def from_file(cls, path: Path): """ # We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is # split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance) - return cls(cls._template_from_file(path), None) + return cls(cls._template_from_file(path, filters), None) @classmethod - def _template_from_str(_, content: str) -> jinja2.Template: + def _template_from_str( + _, content: str, filters: Dict[str, Callable] = {} + ) -> jinja2.Template: # Dedent, and remove extra linebreak cleaned_template = inspect.cleandoc(content) @@ -93,12 +95,7 @@ def _template_from_str(_, content: str) -> jinja2.Template: # used to continue to the next line without linebreak. cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) - env = jinja2.Environment( - trim_blocks=True, - lstrip_blocks=True, - keep_trailing_newline=True, - undefined=jinja2.StrictUndefined, - ) + env = create_jinja_env(None, filters) env.filters["name"] = get_fn_name env.filters["description"] = get_fn_description env.filters["source"] = get_fn_source @@ -109,19 +106,19 @@ def _template_from_str(_, content: str) -> jinja2.Template: return env.from_string(cleaned_template) @classmethod - def _template_from_file(_, path: Path) -> jinja2.Template: + def _template_from_file( + _, path: Path, filters: Dict[str, Callable] = {} + ) -> jinja2.Template: file_directory = os.path.dirname(os.path.abspath(path)) - env = jinja2.Environment( - loader=jinja2.FileSystemLoader(file_directory), - trim_blocks=True, - lstrip_blocks=True, - keep_trailing_newline=True, - undefined=jinja2.StrictUndefined, - ) + env = create_jinja_env(jinja2.FileSystemLoader(file_directory), filters) + return env.get_template(os.path.basename(path)) -def prompt(fn: Callable) -> Prompt: +def prompt( + fn: Optional[Callable] = None, + filters: Dict[str, Callable] = {}, +) -> Callable: """Decorate a function that contains a prompt template. This allows to define prompts in the docstring of a function and simplify their @@ -152,11 +149,26 @@ def prompt(fn: Callable) -> Prompt: ... >>> hal = ft.partial(solve_task, "HAL", "Travel to Jupiter") + Additional Jinja2 filters can be provided as keyword arguments to the decorator. + + >>> def reverse(s: str) -> str: + ... return s[::-1] + ... + >>> @outlines.prompt(filters={ 'reverse': reverse }) + ... def reverse_prompt(text): + ... '''{{ text | reverse }}''' + ... + >>> prompt = reverse_prompt("Hello") + >>> print(prompt) + ... "olleH" + Returns ------- A `Prompt` callable class which will render the template when called. """ + if fn is None: + return lambda fn: prompt(fn, cast(Dict[str, Callable], filters)) signature = inspect.signature(fn) @@ -166,11 +178,28 @@ def prompt(fn: Callable) -> Prompt: if docstring is None: raise TypeError("Could not find a template in the function's docstring.") - template = Prompt._template_from_str(cast(str, docstring)) + template = Prompt._template_from_str(cast(str, docstring), filters) return Prompt(template, signature) +def create_jinja_env( + loader: Optional[jinja2.BaseLoader], filters: Dict[str, Callable] +) -> jinja2.Environment: + env = jinja2.Environment( + loader=loader, + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=jinja2.StrictUndefined, + ) + + for name, filter_fn in filters.items(): + env.filters[name] = filter_fn + + return env + + def get_fn_name(fn: Callable): """Returns the name of a callable.""" if not callable(fn): diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 4cc4d8ff1..f59c04ac0 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -321,6 +321,23 @@ def args_prompt(fn): ) +def test_prompt_with_additional_filters(): + def reverse(s: str) -> str: + return s[::-1] + + @outlines.prompt(filters=dict(reverse=reverse)) + def test_tpl(variable): + """{{ variable | reverse }} test""" + + assert list(test_tpl.signature.parameters) == ["variable"] + + p = test_tpl("test") + assert p == "tset test" + + p = test_tpl(variable="example") + assert p == "elpmaxe test" + + @pytest.fixture def temp_prompt_file(): test_dir = tempfile.mkdtemp()