Skip to content

Commit

Permalink
Add optional support for pydantic v1 models when v2 is installed (#897)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [ ] Fixes #<!--issue number goes here-->
- [x] Tests added
- [x] Documentation/examples added
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Currently, when pydantic v2 is used, users can only use v2 models.

This PR adds the ability to optionally provide v1 models to scripts.

Signed-off-by: Sambhav Kothari <[email protected]>
  • Loading branch information
sambhav authored Dec 15, 2023
1 parent 38a4254 commit 4cf2340
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 14 deletions.
140 changes: 140 additions & 0 deletions examples/workflows/scripts/callable-script-v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
name: my-workflow
spec:
templates:
- name: my-steps
steps:
- - arguments:
parameters:
- name: input
value: '{"a": 2, "b": "bar", "c": 42}'
name: my-function
template: my-function
- - arguments:
parameters:
- name: input
value: '{"a": 2, "b": "bar", "c": 42}'
name: str-function
template: str-function
- - arguments:
parameters:
- name: inputs
value: '[{"a": 2, "b": "bar", "c": 42}, {"a": 2, "b": "bar", "c": 42.0}]'
name: another-function
template: another-function
- - arguments:
parameters:
- name: a-but-kebab
value: '3'
- name: b-but-kebab
value: bar
name: function-kebab
template: function-kebab
- - arguments:
parameters:
- name: input-value
value: '{"a": 3, "b": "bar", "c": "42"}'
name: function-kebab-object
template: function-kebab-object
- inputs:
parameters:
- name: input
name: my-function
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.scripts.callable_script_v1:my_function
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__pydantic_mode
value: '1'
image: my-image-with-python-source-code-and-dependencies
source: '{{inputs.parameters}}'
- inputs:
parameters:
- name: input
name: str-function
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.scripts.callable_script_v1:str_function
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__pydantic_mode
value: '1'
image: my-image-with-python-source-code-and-dependencies
source: '{{inputs.parameters}}'
- inputs:
parameters:
- name: inputs
name: another-function
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.scripts.callable_script_v1:another_function
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__pydantic_mode
value: '1'
image: my-image-with-python-source-code-and-dependencies
source: '{{inputs.parameters}}'
- inputs:
parameters:
- default: '2'
name: a-but-kebab
- default: foo
name: b-but-kebab
- default: '42.0'
name: c-but-kebab
name: function-kebab
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.scripts.callable_script_v1:function_kebab
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__pydantic_mode
value: '1'
image: my-image-with-python-source-code-and-dependencies
source: '{{inputs.parameters}}'
- inputs:
parameters:
- name: input-value
name: function-kebab-object
script:
args:
- -m
- hera.workflows.runner
- -e
- examples.workflows.scripts.callable_script_v1:function_kebab_object
command:
- python
env:
- name: hera__script_annotations
value: ''
- name: hera__pydantic_mode
value: '1'
image: my-image-with-python-source-code-and-dependencies
source: '{{inputs.parameters}}'
100 changes: 100 additions & 0 deletions examples/workflows/scripts/callable_script_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import List, Union

from hera.shared.serialization import serialize

try:
from typing import Annotated # type: ignore
except ImportError:
from typing_extensions import Annotated # type: ignore
try:
from pydantic.v1 import BaseModel
except (ImportError, ModuleNotFoundError):
from pydantic import BaseModel


from hera.shared import global_config
from hera.workflows import Parameter, RunnerScriptConstructor, Script, Steps, Workflow, script

# Note, setting constructor to runner is only possible if the source code is available
# along with dependencies include hera in the image.
# Callable is a robust mode that allows you to run any python function
# and is compatible with pydantic. It automatically parses the input
# and serializes the output.
global_config.image = "my-image-with-python-source-code-and-dependencies"
global_config.set_class_defaults(Script, constructor=RunnerScriptConstructor(pydantic_mode=1))
# Script annotations is still an experimental feature and we need to explicitly opt in to it
# Note that experimental features are subject to breaking changes in future releases of the same major version
global_config.experimental_features["script_annotations"] = True


# An optional pydantic input type
# hera can automatically de-serialize argo
# arguments into types denoted by your function's signature
# as long as they are de-serializable by pydantic
# This provides auto-magic input parsing with validation
# provided by pydantic.
class Input(BaseModel):
a: int
b: str = "foo"
c: Union[str, int, float]

class Config:
smart_union = True


# An optional pydantic output type
# hera can automatically serialize the output
# of your function into a json string
# as long as they are serializable by pydantic or json serializable
# This provides auto-magic output serialization with validation
# provided by pydantic.
class Output(BaseModel):
output: List[Input]


@script()
def my_function(input: Input) -> Output:
return Output(output=[input])


# Note that the input type is a list of Input
# hera can also automatically de-serialize
# composite types like lists and dicts
@script()
def another_function(inputs: List[Input]) -> Output:
return Output(output=inputs)


# it also works with raw json strings
# but those must be explicitly marked as
# a string type
@script()
def str_function(input: str) -> Output:
# Example function to ensure we are not json parsing
# string types before passing it to the function
return Output(output=[Input.parse_raw(input)])


# Use the script_annotations feature to seamlessly enable aliased kebab-case names
# as your template interface, while using regular snake_case in the Python code
@script()
def function_kebab(
a_but_kebab: Annotated[int, Parameter(name="a-but-kebab")] = 2,
b_but_kebab: Annotated[str, Parameter(name="b-but-kebab")] = "foo",
c_but_kebab: Annotated[float, Parameter(name="c-but-kebab")] = 42.0,
) -> Output:
return Output(output=[Input(a=a_but_kebab, b=b_but_kebab, c=c_but_kebab)])


@script()
def function_kebab_object(annotated_input_value: Annotated[Input, Parameter(name="input-value")]) -> Output:
return Output(output=[annotated_input_value])


with Workflow(name="my-workflow") as w:
with Steps(name="my-steps") as s:
my_function(arguments={"input": Input(a=2, b="bar", c=42)})
str_function(arguments={"input": serialize(Input(a=2, b="bar", c=42))})
another_function(arguments={"inputs": [Input(a=2, b="bar", c=42), Input(a=2, b="bar", c=42.0)]})
function_kebab(arguments={"a-but-kebab": 3, "b-but-kebab": "bar"})
function_kebab_object(arguments={"input-value": Input(a=3, b="bar", c="42")})
11 changes: 3 additions & 8 deletions src/hera/shared/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Module that holds the underlying base Pydantic models for Hera objects."""
from functools import partial

_PYDANTIC_VERSION = 1
from typing import Literal

_PYDANTIC_VERSION: Literal[1, 2] = 1
# The pydantic v1 interface is used for both pydantic v1 and v2 in order to support
# users across both versions.

try:
from pydantic import ( # type: ignore
validate_call as validate_arguments,
)
from pydantic.v1 import ( # type: ignore
BaseModel as PydanticBaseModel,
Field,
Expand All @@ -24,11 +22,9 @@
Field,
ValidationError,
root_validator,
validate_arguments as validate_call,
validator,
)

validate_arguments = partial(validate_call, config=dict(smart_union=True)) # type: ignore
_PYDANTIC_VERSION = 1


Expand All @@ -38,7 +34,6 @@
"PydanticBaseModel", # Export for serialization.py to cover user-defined models
"ValidationError",
"root_validator",
"validate_arguments",
"validator",
]

Expand Down
15 changes: 13 additions & 2 deletions src/hera/workflows/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

from hera.shared._pydantic import validate_arguments
from hera.shared._pydantic import _PYDANTIC_VERSION
from hera.shared.serialization import serialize
from hera.workflows import Artifact, Parameter
from hera.workflows.artifact import ArtifactLoader
Expand Down Expand Up @@ -274,7 +274,18 @@ def _runner(entrypoint: str, kwargs_list: List) -> Any:
# The imported validate_arguments uses smart union by default just in case clients do not rely on it. This means that if a function uses a union
# type for any of its inputs, then this will at least try to map those types correctly if the input object is
# not a pydantic model with smart_union enabled
function = validate_arguments(function)
_pydantic_mode = int(os.environ.get("hera__pydantic_mode", _PYDANTIC_VERSION))
if _pydantic_mode == 2:
from pydantic import validate_call # type: ignore

function = validate_call(function)
else:
if _PYDANTIC_VERSION == 1:
from pydantic import validate_arguments
else:
from pydantic.v1 import validate_arguments # type: ignore
function = validate_arguments(function, config=dict(smart_union=True)) # type: ignore

function = _ignore_unmatched_kwargs(function)

if os.environ.get("hera__script_annotations", None) is not None:
Expand Down
16 changes: 15 additions & 1 deletion src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Any,
Callable,
List,
Literal,
Optional,
Tuple,
Type,
Expand All @@ -26,7 +27,7 @@

from hera.expr import g
from hera.shared import BaseMixin, global_config
from hera.shared._pydantic import root_validator, validator
from hera.shared._pydantic import _PYDANTIC_VERSION, root_validator, validator
from hera.workflows._context import _context
from hera.workflows._mixins import (
CallableTemplateMixin,
Expand Down Expand Up @@ -710,6 +711,17 @@ class RunnerScriptConstructor(ScriptConstructor):
DEFAULT_HERA_OUTPUTS_DIRECTORY: str = "/tmp/hera-outputs"
"""Used as the default value for when the outputs_directory is not set"""

pydantic_mode: Optional[Literal[1, 2]] = None
"""Used for selecting the pydantic version used for BaseModels.
Allows for using pydantic.v1 BaseModels with pydantic v2.
Defaults to the installed version of Pydantic."""

@validator("pydantic_mode", always=True)
def _pydantic_mode(cls, value: Optional[Literal[1, 2]]) -> Optional[Literal[1, 2]]:
if value and value > _PYDANTIC_VERSION:
raise ValueError("v2 pydantic mode only available for pydantic>=2")
return value

def transform_values(self, cls: Type[Script], values: Any) -> Any:
"""A function that can inspect the Script instance and generate the source field."""
if not callable(values.get("source")):
Expand Down Expand Up @@ -740,6 +752,8 @@ def transform_script_template_post_build(
script.env.append(EnvVar(name="hera__script_annotations", value=""))
if self.outputs_directory:
script.env.append(EnvVar(name="hera__outputs_directory", value=self.outputs_directory))
if self.pydantic_mode:
script.env.append(EnvVar(name="hera__pydantic_mode", value=str(self.pydantic_mode)))
return script


Expand Down
21 changes: 20 additions & 1 deletion tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@

from pydantic import BaseModel

try:
from pydantic.v1 import BaseModel as V1BaseModel
except (ImportError, ModuleNotFoundError):
from pydantic import BaseModel as V1BaseModel

from hera.shared import global_config
from hera.workflows import Parameter, script
from hera.workflows import Parameter, RunnerScriptConstructor, script

global_config.experimental_features["script_annotations"] = True

Expand All @@ -23,6 +28,15 @@ class Output(BaseModel):
output: List[Input]


class InputV1(V1BaseModel):
a: int
b: str = "foo"


class OutputV1(V1BaseModel):
output: List[InputV1]


@script()
def annotated_basic_types(
a_but_kebab: Annotated[int, Parameter(name="a-but-kebab")] = 2,
Expand All @@ -36,6 +50,11 @@ def annotated_object(annotated_input_value: Annotated[Input, Parameter(name="inp
return Output(output=[annotated_input_value])


@script(constructor=RunnerScriptConstructor(pydantic_mode=1))
def annotated_object_v1(annotated_input_value: Annotated[InputV1, Parameter(name="input-value")]) -> OutputV1:
return OutputV1(output=[annotated_input_value])


@script()
def annotated_parameter_no_name(
annotated_input_value: Annotated[Input, Parameter(description="a value to input")],
Expand Down
Loading

0 comments on commit 4cf2340

Please sign in to comment.