From c7fd32462a329f6c60b9a1a7b114e7ce953b9efd Mon Sep 17 00:00:00 2001 From: Sambhav Kothari Date: Thu, 7 Dec 2023 23:17:18 +0000 Subject: [PATCH] Add support for pydantic v2 in hera runner Signed-off-by: Sambhav Kothari --- examples/workflows/scripts/callable_script.py | 9 +++--- src/hera/shared/_pydantic.py | 16 ++++++++-- src/hera/shared/serialization.py | 23 ++++++++++---- src/hera/workflows/runner.py | 30 +++++++++++-------- tests/script_runner/artifact_loaders.py | 6 ++-- tests/script_runner/parameter_inputs.py | 17 ++--------- tests/test_examples.py | 2 +- tests/test_runner.py | 21 ++----------- 8 files changed, 60 insertions(+), 64 deletions(-) diff --git a/examples/workflows/scripts/callable_script.py b/examples/workflows/scripts/callable_script.py index 24aa0218b..b292de79b 100644 --- a/examples/workflows/scripts/callable_script.py +++ b/examples/workflows/scripts/callable_script.py @@ -1,14 +1,13 @@ from typing import List, Union +from hera.shared.serialization import serialize + try: from typing import Annotated # type: ignore except ImportError: from typing_extensions import Annotated # type: ignore -try: - from pydantic.v1 import BaseModel -except ImportError: - from pydantic import BaseModel +from pydantic import BaseModel from hera.shared import global_config from hera.workflows import Parameter, Script, Steps, Workflow, script @@ -92,7 +91,7 @@ def function_kebab_object(annotated_input_value: Annotated[Input, Parameter(name with Workflow(name="my-workflow") as w: with Steps(name="my-steps") as s: my_function(arguments={"input": Input(a=2, b="bar", c=42)}) - str_function(arguments={"input": Input(a=2, b="bar", c=42).json()}) + str_function(arguments={"input": serialize(Input(a=2, b="bar", c=42))}) another_function(arguments={"inputs": [Input(a=2, b="bar", c=42), Input(a=2, b="bar", c=42.0)]}) function_kebab(arguments={"a-but-kebab": 3, "b-but-kebab": "bar"}) function_kebab_object(arguments={"input-value": Input(a=3, b="bar", c="42")}) diff --git a/src/hera/shared/_pydantic.py b/src/hera/shared/_pydantic.py index a1e9b637d..ac4c64018 100644 --- a/src/hera/shared/_pydantic.py +++ b/src/hera/shared/_pydantic.py @@ -1,26 +1,36 @@ """Module that holds the underlying base Pydantic models for Hera objects.""" +from functools import partial +_PYDANTIC_VERSION = 1 # The pydantic v1 interface is used for both pydantic v1 and v2 in order to support # users across both versions. + try: + from pydantic import ( + validate_call as validate_arguments, + ) from pydantic.v1 import ( BaseModel as PydanticBaseModel, Field, ValidationError, root_validator, - validate_arguments, validator, ) -except ImportError: + + _PYDANTIC_VERSION = 2 +except (ImportError, ModuleNotFoundError): from pydantic import ( # type: ignore[assignment,no-redef] BaseModel as PydanticBaseModel, Field, ValidationError, root_validator, - validate_arguments, + validate_arguments as validate_call, validator, ) + validate_arguments = partial(validate_call, config=dict(smart_union=True)) # type: ignore + _PYDANTIC_VERSION = 1 + __all__ = [ "BaseModel", diff --git a/src/hera/shared/serialization.py b/src/hera/shared/serialization.py index 4314f6f7a..c2b55d062 100644 --- a/src/hera/shared/serialization.py +++ b/src/hera/shared/serialization.py @@ -3,11 +3,16 @@ from json import JSONEncoder from typing import Any, Optional +from pydantic import BaseModel + # NOTE: Use the original BaseModel in order to support serializing user-defined models, -# which won't use our hera.shared._pydantic import. This does still require that the -# user-defined models are using v1 pydantic models for now (either from a pydantic v1 -# installation or `pydantic.v1` import from a pydantic v2 installation). -from hera.shared._pydantic import PydanticBaseModel +# for hera internal models, we still need to support v1 base models. +from hera.shared._pydantic import _PYDANTIC_VERSION + +try: + from pydantic.v1 import BaseModel as V1BaseModel +except (ImportError, ModuleNotFoundError): + V1BaseModel = None # type: ignore MISSING = object() """`MISSING` is a placeholder that indicates field value nullity. @@ -22,8 +27,14 @@ class PydanticEncoder(JSONEncoder): def default(self, o: Any): """Return the default representation of the given object.""" - if isinstance(o, PydanticBaseModel): - return o.dict(by_alias=True) + if _PYDANTIC_VERSION == 1: + if isinstance(o, BaseModel): + return o.dict(by_alias=True) + else: + if isinstance(o, BaseModel): + return o.model_dump(by_alias=True, mode="json") + if isinstance(o, V1BaseModel): + return o.dict(by_alias=True) return super().default(o) diff --git a/src/hera/workflows/runner.py b/src/hera/workflows/runner.py index 9587d2e7a..591667bd8 100644 --- a/src/hera/workflows/runner.py +++ b/src/hera/workflows/runner.py @@ -6,7 +6,7 @@ import json import os from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from hera.shared._pydantic import validate_arguments from hera.shared.serialization import serialize @@ -63,7 +63,9 @@ def _parse(value, key, f): The parsed value. """ - if _is_str_kwarg_of(key, f) or _is_artifact_loaded(key, f) or _is_output_kwarg(key, f): + if _is_str_kwarg_of(key, f): + return cast(type, _get_type(key, f))(value) + if _is_artifact_loaded(key, f) or _is_output_kwarg(key, f): return value try: return json.loads(value) @@ -71,22 +73,24 @@ def _parse(value, key, f): return value -def _is_str_kwarg_of(key: str, f: Callable): - """Check if param `key` of function `f` has a type annotation of a subclass of str.""" +def _get_type(key: str, f: Callable) -> Optional[type]: type_ = inspect.signature(f).parameters[key].annotation - if type_ is inspect.Parameter.empty: - # Untyped args are interpreted according to json spec - # ie. we will try to load it via json.loads in _parse - return False + return None if get_origin(type_) is None: - return issubclass(type_, str) - + return type_ origin_type = cast(type, get_origin(type_)) if origin_type is Annotated: - return issubclass(get_args(type_)[0], str) + return get_args(type_)[0] + return origin_type - return issubclass(origin_type, str) + +def _is_str_kwarg_of(key: str, f: Callable): + """Check if param `key` of function `f` has a type annotation of a subclass of str.""" + type_ = _get_type(key, f) + if type_ is None: + return False + return issubclass(type_, str) def _is_artifact_loaded(key, f): @@ -272,7 +276,7 @@ def _runner(entrypoint: str, kwargs_list: List) -> Any: # using smart union by default just in case clients do not rely on it. This means that if a function uses a union # type for any of its inputs, then this will at least try to map those types correctly if the input object is # not a pydantic model with smart_union enabled - function = validate_arguments(function, config=dict(smart_union=True)) + function = validate_arguments(function) function = _ignore_unmatched_kwargs(function) if os.environ.get("hera__script_annotations", None) is not None: diff --git a/tests/script_runner/artifact_loaders.py b/tests/script_runner/artifact_loaders.py index 6a3f5c12f..d96c1bab2 100644 --- a/tests/script_runner/artifact_loaders.py +++ b/tests/script_runner/artifact_loaders.py @@ -5,10 +5,10 @@ from pathlib import Path +from pydantic import BaseModel from tests.helper import ARTIFACT_PATH from hera.shared import global_config -from hera.shared._pydantic import BaseModel from hera.workflows import script from hera.workflows.artifact import Artifact, ArtifactLoader @@ -16,8 +16,8 @@ class MyArtifact(BaseModel): - a = "a" - b = "b" + a: str = "a" + b: str = "b" @script(constructor="runner") diff --git a/tests/script_runner/parameter_inputs.py b/tests/script_runner/parameter_inputs.py index 59040e256..f4259e549 100644 --- a/tests/script_runner/parameter_inputs.py +++ b/tests/script_runner/parameter_inputs.py @@ -6,8 +6,9 @@ except ImportError: from typing_extensions import Annotated +from pydantic import BaseModel + from hera.shared import global_config -from hera.shared._pydantic import BaseModel from hera.workflows import Parameter, script global_config.experimental_features["script_annotations"] = True @@ -61,17 +62,3 @@ def str_parameter_expects_jsonstr_list(my_json_str: str) -> list: @script() def annotated_str_parameter_expects_jsonstr_dict(my_json_str: Annotated[str, "some metadata"]) -> list: return json.loads(my_json_str) - - -class MyStr(str): - pass - - -@script() -def str_subclass_parameter_expects_jsonstr_dict(my_json_str: MyStr) -> list: - return json.loads(my_json_str) - - -@script() -def str_subclass_annotated_parameter_expects_jsonstr_dict(my_json_str: Annotated[MyStr, "some metadata"]) -> list: - return json.loads(my_json_str) diff --git a/tests/test_examples.py b/tests/test_examples.py index eff949316..085d528df 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -27,7 +27,7 @@ HERA_REGENERATE = os.environ.get("HERA_REGENERATE") CI_MODE = os.environ.get("CI") -LOWEST_SUPPORTED_PY_VERSION = (3, 8) +LOWEST_SUPPORTED_PY_VERSION = (3, 9) def _generate_yaml(path: Path) -> bool: diff --git a/tests/test_runner.py b/tests/test_runner.py index 566ce80b1..bcd06d628 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -12,10 +12,10 @@ from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError import tests.helper as test_module from hera.shared import GlobalConfig -from hera.shared._pydantic import ValidationError from hera.shared.serialization import serialize from hera.workflows.runner import _run, _runner from hera.workflows.script import RunnerScriptConstructor @@ -78,18 +78,6 @@ {"my": "dict"}, id="str-json-annotated-param-as-dict", ), - pytest.param( - "tests.script_runner.parameter_inputs:str_subclass_parameter_expects_jsonstr_dict", - [{"name": "my_json_str", "value": json.dumps({"my": "dict"})}], - {"my": "dict"}, - id="str-subclass-json-param-as-dict", - ), - pytest.param( - "tests.script_runner.parameter_inputs:str_subclass_annotated_parameter_expects_jsonstr_dict", - [{"name": "my_json_str", "value": json.dumps({"my": "dict"})}], - {"my": "dict"}, - id="str-subclass-json-annotated-param-as-dict", - ), ), ) def test_parameter_loading( @@ -475,12 +463,9 @@ def test_script_annotations_artifact_input_loader_error( importlib.reload(module) - # WHEN - with pytest.raises(ValidationError) as e: - _ = _runner(f"{module.__name__}:{function_name}", kwargs_list) - # THEN - assert "value is not a valid integer" in str(e.value) + with pytest.raises(ValidationError): + _ = _runner(f"{module.__name__}:{function_name}", kwargs_list) @pytest.mark.parametrize(