Skip to content

Commit

Permalink
Add support for pydantic v2 in hera runner
Browse files Browse the repository at this point in the history
Signed-off-by: Sambhav Kothari <[email protected]>
  • Loading branch information
sambhav committed Dec 7, 2023
1 parent b1ab2b8 commit c7fd324
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 64 deletions.
9 changes: 4 additions & 5 deletions examples/workflows/scripts/callable_script.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import List, Union

from hera.shared.serialization import serialize

try:
from typing import Annotated # type: ignore
except ImportError:
from typing_extensions import Annotated # type: ignore

try:
from pydantic.v1 import BaseModel
except ImportError:
from pydantic import BaseModel
from pydantic import BaseModel

from hera.shared import global_config
from hera.workflows import Parameter, Script, Steps, Workflow, script
Expand Down Expand Up @@ -92,7 +91,7 @@ def function_kebab_object(annotated_input_value: Annotated[Input, Parameter(name
with Workflow(name="my-workflow") as w:
with Steps(name="my-steps") as s:
my_function(arguments={"input": Input(a=2, b="bar", c=42)})
str_function(arguments={"input": Input(a=2, b="bar", c=42).json()})
str_function(arguments={"input": serialize(Input(a=2, b="bar", c=42))})
another_function(arguments={"inputs": [Input(a=2, b="bar", c=42), Input(a=2, b="bar", c=42.0)]})
function_kebab(arguments={"a-but-kebab": 3, "b-but-kebab": "bar"})
function_kebab_object(arguments={"input-value": Input(a=3, b="bar", c="42")})
16 changes: 13 additions & 3 deletions src/hera/shared/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
"""Module that holds the underlying base Pydantic models for Hera objects."""
from functools import partial

_PYDANTIC_VERSION = 1
# The pydantic v1 interface is used for both pydantic v1 and v2 in order to support
# users across both versions.

try:
from pydantic import (
validate_call as validate_arguments,
)
from pydantic.v1 import (
BaseModel as PydanticBaseModel,
Field,
ValidationError,
root_validator,
validate_arguments,
validator,
)
except ImportError:

_PYDANTIC_VERSION = 2
except (ImportError, ModuleNotFoundError):
from pydantic import ( # type: ignore[assignment,no-redef]
BaseModel as PydanticBaseModel,
Field,
ValidationError,
root_validator,
validate_arguments,
validate_arguments as validate_call,
validator,
)

validate_arguments = partial(validate_call, config=dict(smart_union=True)) # type: ignore
_PYDANTIC_VERSION = 1


__all__ = [
"BaseModel",
Expand Down
23 changes: 17 additions & 6 deletions src/hera/shared/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from json import JSONEncoder
from typing import Any, Optional

from pydantic import BaseModel

# NOTE: Use the original BaseModel in order to support serializing user-defined models,
# which won't use our hera.shared._pydantic import. This does still require that the
# user-defined models are using v1 pydantic models for now (either from a pydantic v1
# installation or `pydantic.v1` import from a pydantic v2 installation).
from hera.shared._pydantic import PydanticBaseModel
# for hera internal models, we still need to support v1 base models.
from hera.shared._pydantic import _PYDANTIC_VERSION

try:
from pydantic.v1 import BaseModel as V1BaseModel
except (ImportError, ModuleNotFoundError):
V1BaseModel = None # type: ignore

MISSING = object()
"""`MISSING` is a placeholder that indicates field value nullity.
Expand All @@ -22,8 +27,14 @@ class PydanticEncoder(JSONEncoder):

def default(self, o: Any):
"""Return the default representation of the given object."""
if isinstance(o, PydanticBaseModel):
return o.dict(by_alias=True)
if _PYDANTIC_VERSION == 1:
if isinstance(o, BaseModel):
return o.dict(by_alias=True)
else:
if isinstance(o, BaseModel):
return o.model_dump(by_alias=True, mode="json")
if isinstance(o, V1BaseModel):
return o.dict(by_alias=True)
return super().default(o)


Expand Down
30 changes: 17 additions & 13 deletions src/hera/workflows/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

from hera.shared._pydantic import validate_arguments
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -63,30 +63,34 @@ def _parse(value, key, f):
The parsed value.
"""
if _is_str_kwarg_of(key, f) or _is_artifact_loaded(key, f) or _is_output_kwarg(key, f):
if _is_str_kwarg_of(key, f):
return cast(type, _get_type(key, f))(value)
if _is_artifact_loaded(key, f) or _is_output_kwarg(key, f):
return value
try:
return json.loads(value)
except json.JSONDecodeError:
return value


def _is_str_kwarg_of(key: str, f: Callable):
"""Check if param `key` of function `f` has a type annotation of a subclass of str."""
def _get_type(key: str, f: Callable) -> Optional[type]:
type_ = inspect.signature(f).parameters[key].annotation

if type_ is inspect.Parameter.empty:
# Untyped args are interpreted according to json spec
# ie. we will try to load it via json.loads in _parse
return False
return None
if get_origin(type_) is None:
return issubclass(type_, str)

return type_
origin_type = cast(type, get_origin(type_))
if origin_type is Annotated:
return issubclass(get_args(type_)[0], str)
return get_args(type_)[0]
return origin_type

return issubclass(origin_type, str)

def _is_str_kwarg_of(key: str, f: Callable):
"""Check if param `key` of function `f` has a type annotation of a subclass of str."""
type_ = _get_type(key, f)
if type_ is None:
return False
return issubclass(type_, str)


def _is_artifact_loaded(key, f):
Expand Down Expand Up @@ -272,7 +276,7 @@ def _runner(entrypoint: str, kwargs_list: List) -> Any:
# using smart union by default just in case clients do not rely on it. This means that if a function uses a union
# type for any of its inputs, then this will at least try to map those types correctly if the input object is
# not a pydantic model with smart_union enabled
function = validate_arguments(function, config=dict(smart_union=True))
function = validate_arguments(function)
function = _ignore_unmatched_kwargs(function)

if os.environ.get("hera__script_annotations", None) is not None:
Expand Down
6 changes: 3 additions & 3 deletions tests/script_runner/artifact_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@

from pathlib import Path

from pydantic import BaseModel
from tests.helper import ARTIFACT_PATH

from hera.shared import global_config
from hera.shared._pydantic import BaseModel
from hera.workflows import script
from hera.workflows.artifact import Artifact, ArtifactLoader

global_config.experimental_features["script_annotations"] = True


class MyArtifact(BaseModel):
a = "a"
b = "b"
a: str = "a"
b: str = "b"


@script(constructor="runner")
Expand Down
17 changes: 2 additions & 15 deletions tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
except ImportError:
from typing_extensions import Annotated

from pydantic import BaseModel

from hera.shared import global_config
from hera.shared._pydantic import BaseModel
from hera.workflows import Parameter, script

global_config.experimental_features["script_annotations"] = True
Expand Down Expand Up @@ -61,17 +62,3 @@ def str_parameter_expects_jsonstr_list(my_json_str: str) -> list:
@script()
def annotated_str_parameter_expects_jsonstr_dict(my_json_str: Annotated[str, "some metadata"]) -> list:
return json.loads(my_json_str)


class MyStr(str):
pass


@script()
def str_subclass_parameter_expects_jsonstr_dict(my_json_str: MyStr) -> list:
return json.loads(my_json_str)


@script()
def str_subclass_annotated_parameter_expects_jsonstr_dict(my_json_str: Annotated[MyStr, "some metadata"]) -> list:
return json.loads(my_json_str)
2 changes: 1 addition & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
HERA_REGENERATE = os.environ.get("HERA_REGENERATE")
CI_MODE = os.environ.get("CI")

LOWEST_SUPPORTED_PY_VERSION = (3, 8)
LOWEST_SUPPORTED_PY_VERSION = (3, 9)


def _generate_yaml(path: Path) -> bool:
Expand Down
21 changes: 3 additions & 18 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from unittest.mock import MagicMock, patch

import pytest
from pydantic import ValidationError

import tests.helper as test_module
from hera.shared import GlobalConfig
from hera.shared._pydantic import ValidationError
from hera.shared.serialization import serialize
from hera.workflows.runner import _run, _runner
from hera.workflows.script import RunnerScriptConstructor
Expand Down Expand Up @@ -78,18 +78,6 @@
{"my": "dict"},
id="str-json-annotated-param-as-dict",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_subclass_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
{"my": "dict"},
id="str-subclass-json-param-as-dict",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_subclass_annotated_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
{"my": "dict"},
id="str-subclass-json-annotated-param-as-dict",
),
),
)
def test_parameter_loading(
Expand Down Expand Up @@ -475,12 +463,9 @@ def test_script_annotations_artifact_input_loader_error(

importlib.reload(module)

# WHEN
with pytest.raises(ValidationError) as e:
_ = _runner(f"{module.__name__}:{function_name}", kwargs_list)

# THEN
assert "value is not a valid integer" in str(e.value)
with pytest.raises(ValidationError):
_ = _runner(f"{module.__name__}:{function_name}", kwargs_list)


@pytest.mark.parametrize(
Expand Down

0 comments on commit c7fd324

Please sign in to comment.