diff --git a/src/hera/shared/_pydantic.py b/src/hera/shared/_pydantic.py index cc6fe18f0..a4ad57a10 100644 --- a/src/hera/shared/_pydantic.py +++ b/src/hera/shared/_pydantic.py @@ -64,7 +64,7 @@ class Config: """ allow_population_by_field_name = True - """support populating Hera object fields via keyed dictionaries""" + """support populating Hera object fields by their Field alias""" allow_mutation = True """supports mutating Hera objects post instantiation""" diff --git a/src/hera/workflows/__init__.py b/src/hera/workflows/__init__.py index 71dbdea54..6c65792b3 100644 --- a/src/hera/workflows/__init__.py +++ b/src/hera/workflows/__init__.py @@ -30,6 +30,7 @@ from hera.workflows.env_from import ConfigMapEnvFrom, SecretEnvFrom from hera.workflows.exceptions import InvalidDispatchType, InvalidTemplateCall, InvalidType from hera.workflows.http_template import HTTP +from hera.workflows.io import RunnerInput, RunnerOutput from hera.workflows.metrics import Counter, Gauge, Histogram, Label, Metric, Metrics from hera.workflows.operator import Operator from hera.workflows.parameter import Parameter @@ -148,6 +149,8 @@ "Resources", "RetryPolicy", "RetryStrategy", + "RunnerInput", + "RunnerOutput", "RunnerScriptConstructor", "S3Artifact", "ScaleIOVolume", diff --git a/src/hera/workflows/io/__init__.py b/src/hera/workflows/io/__init__.py new file mode 100644 index 000000000..5031e5333 --- /dev/null +++ b/src/hera/workflows/io/__init__.py @@ -0,0 +1,12 @@ +"""Hera IO models.""" +from importlib.util import find_spec + +if find_spec("pydantic.v1"): + from hera.workflows.io.v2 import RunnerInput, RunnerOutput +else: + from hera.workflows.io.v1 import RunnerInput, RunnerOutput # type: ignore + +__all__ = [ + "RunnerInput", + "RunnerOutput", +] diff --git a/src/hera/workflows/io.py b/src/hera/workflows/io/v1.py similarity index 98% rename from src/hera/workflows/io.py rename to src/hera/workflows/io/v1.py index b387728b6..9699dbed8 100644 --- a/src/hera/workflows/io.py +++ b/src/hera/workflows/io/v1.py @@ -1,4 +1,4 @@ -"""Input/output models for the Hera runner.""" +"""Pydantic V1 input/output models for the Hera runner.""" from collections import ChainMap from typing import Any, List, Optional, Union @@ -81,7 +81,7 @@ class RunnerOutput(BaseModel): """ exit_code: int = 0 - result: Any + result: Any = None @classmethod def _get_outputs(cls) -> List[Union[Artifact, Parameter]]: diff --git a/src/hera/workflows/io/v2.py b/src/hera/workflows/io/v2.py new file mode 100644 index 000000000..6512809cc --- /dev/null +++ b/src/hera/workflows/io/v2.py @@ -0,0 +1,110 @@ +"""Pydantic V2 input/output models for the Hera runner. + +RunnerInput/Output are only defined in this file if Pydantic v2 is installed. +""" +from collections import ChainMap +from typing import Any, List, Optional, Union + +from hera.shared.serialization import serialize +from hera.workflows.artifact import Artifact +from hera.workflows.parameter import Parameter + +try: + from inspect import get_annotations # type: ignore +except ImportError: + from hera.workflows._inspect import get_annotations # type: ignore + +try: + from typing import Annotated, get_args, get_origin # type: ignore +except ImportError: + from typing_extensions import Annotated, get_args, get_origin # type: ignore + +from importlib.util import find_spec + +if find_spec("pydantic.v1"): + from pydantic import BaseModel + + class RunnerInput(BaseModel): + """Input model usable by the Hera Runner. + + RunnerInput is a Pydantic model which users can create a subclass of. When a subclass + of RunnerInput is used as a function parameter type, the Hera Runner will take the fields + of the user's subclass to create template input parameters and artifacts. See the example + for the script_pydantic_io experimental feature. + """ + + @classmethod + def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> List[Parameter]: + parameters = [] + annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} + + for field in cls.model_fields: # type: ignore + if get_origin(annotations[field]) is Annotated: + if isinstance(get_args(annotations[field])[1], Parameter): + param = get_args(annotations[field])[1] + if object_override: + param.default = serialize(getattr(object_override, field)) + elif cls.model_fields[field].default: # type: ignore + # Serialize the value (usually done in Parameter's validator) + param.default = serialize(cls.model_fields[field].default) # type: ignore + parameters.append(param) + else: + # Create a Parameter from basic type annotations + if object_override: + parameters.append(Parameter(name=field, default=serialize(getattr(object_override, field)))) + else: + parameters.append(Parameter(name=field, default=cls.model_fields[field].default)) # type: ignore + return parameters + + @classmethod + def _get_artifacts(cls) -> List[Artifact]: + artifacts = [] + annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} + + for field in cls.model_fields: # type: ignore + if get_origin(annotations[field]) is Annotated: + if isinstance(get_args(annotations[field])[1], Artifact): + artifact = get_args(annotations[field])[1] + if artifact.path is None: + artifact.path = artifact._get_default_inputs_path() + artifacts.append(artifact) + return artifacts + + class RunnerOutput(BaseModel): + """Output model usable by the Hera Runner. + + RunnerOutput is a Pydantic model which users can create a subclass of. When a subclass + of RunnerOutput is used as a function return type, the Hera Runner will take the fields + of the user's subclass to create template output parameters and artifacts. See the example + for the script_pydantic_io experimental feature. + """ + + exit_code: int = 0 + result: Any = None + + @classmethod + def _get_outputs(cls) -> List[Union[Artifact, Parameter]]: + outputs = [] + annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} + + for field in cls.model_fields: # type: ignore + if field in {"exit_code", "result"}: + continue + if get_origin(annotations[field]) is Annotated: + if isinstance(get_args(annotations[field])[1], (Parameter, Artifact)): + outputs.append(get_args(annotations[field])[1]) + else: + # Create a Parameter from basic type annotations + outputs.append(Parameter(name=field, default=cls.model_fields[field].default)) # type: ignore + return outputs + + @classmethod + def _get_output(cls, field_name: str) -> Union[Artifact, Parameter]: + annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} + annotation = annotations[field_name] + if get_origin(annotation) is Annotated: + if isinstance(get_args(annotation)[1], (Parameter, Artifact)): + return get_args(annotation)[1] + + # Create a Parameter from basic type annotations + return Parameter(name=field_name, default=cls.model_fields[field_name].default) # type: ignore diff --git a/src/hera/workflows/runner.py b/src/hera/workflows/runner.py index 4af38a7c1..9ab2249e8 100644 --- a/src/hera/workflows/runner.py +++ b/src/hera/workflows/runner.py @@ -12,7 +12,21 @@ from hera.shared.serialization import serialize from hera.workflows import Artifact, Parameter from hera.workflows.artifact import ArtifactLoader -from hera.workflows.io import RunnerInput, RunnerOutput +from hera.workflows.io.v1 import ( + RunnerInput as RunnerInputV1, + RunnerOutput as RunnerOutputV1, +) + +try: + from hera.workflows.io.v2 import ( # type: ignore + RunnerInput as RunnerInputV2, + RunnerOutput as RunnerOutputV2, + ) +except ImportError: + from hera.workflows.io.v1 import ( # type: ignore + RunnerInput as RunnerInputV2, + RunnerOutput as RunnerOutputV2, + ) from hera.workflows.script import _extract_return_annotation_output try: @@ -20,6 +34,12 @@ except ImportError: from typing_extensions import Annotated, get_args, get_origin # type: ignore +try: + from pydantic.type_adapter import TypeAdapter # type: ignore + from pydantic.v1 import parse_obj_as # type: ignore +except ImportError: + from pydantic import parse_obj_as + def _ignore_unmatched_kwargs(f): """Make function ignore unmatched kwargs. @@ -33,18 +53,7 @@ def _ignore_unmatched_kwargs(f): def inner(**kwargs): # filter out kwargs that are not part of the function signature # and transform them to the correct type - if os.environ.get("hera__script_pydantic_io", None) is None: - filtered_kwargs = {key: _parse(value, key, f) for key, value in kwargs.items() if _is_kwarg_of(key, f)} - return f(**filtered_kwargs) - - # filter out kwargs that are not part of the function signature - # and transform them to the correct type. If any kwarg values are - # of RunnerType, pass them through without parsing. - filtered_kwargs = {} - for key, value in kwargs.items(): - if _is_kwarg_of(key, f): - type_ = _get_type(key, f) - filtered_kwargs[key] = value if type_ and issubclass(type_, RunnerInput) else _parse(value, key, f) + filtered_kwargs = {key: _parse(value, key, f) for key, value in kwargs.items() if _is_kwarg_of(key, f)} return f(**filtered_kwargs) return inner @@ -78,8 +87,21 @@ def _parse(value, key, f): if _is_str_kwarg_of(key, f) or _is_artifact_loaded(key, f) or _is_output_kwarg(key, f): return value try: - return json.loads(value) - except json.JSONDecodeError: + if os.environ.get("hera__script_annotations", None) is None: + return json.loads(value) + + type_ = _get_unannotated_type(key, f) + loaded_json_value = json.loads(value) + + if not type_: + return loaded_json_value + + _pydantic_mode = int(os.environ.get("hera__pydantic_mode", _PYDANTIC_VERSION)) + if _pydantic_mode == 1: + return parse_obj_as(type_, loaded_json_value) + else: + return TypeAdapter(type_).validate_python(loaded_json_value) + except (json.JSONDecodeError, TypeError): return value @@ -95,6 +117,22 @@ def _get_type(key: str, f: Callable) -> Optional[type]: return origin_type +def _get_unannotated_type(key: str, f: Callable) -> Optional[type]: + """Get the type of function param without the 'Annotated' outer type.""" + type_ = inspect.signature(f).parameters[key].annotation + if type_ is inspect.Parameter.empty: + return None + if get_origin(type_) is None: + return type_ + + origin_type = cast(type, get_origin(type_)) + if origin_type is Annotated: + return get_args(type_)[0] + + # Type could be a dict/list with subscript type + return type_ + + def _is_str_kwarg_of(key: str, f: Callable) -> bool: """Check if param `key` of function `f` has a type annotation of a subclass of str.""" type_ = _get_type(key, f) @@ -174,7 +212,7 @@ def map_annotated_artifact(param_name: str, artifact_annotation: Artifact) -> No elif artifact_annotation.loader is None: mapped_kwargs[param_name] = artifact_annotation.path - T = TypeVar("T", bound=RunnerInput) + T = TypeVar("T", bound=Union[RunnerInputV1, RunnerInputV2]) def map_runner_input(param_name: str, runner_input_class: T): """Map argo input kwargs to the fields of the given RunnerInput. @@ -215,18 +253,21 @@ def map_field(field: str) -> Optional[str]: map_annotated_artifact(param_name, func_param_annotation) else: mapped_kwargs[param_name] = kwargs[param_name] - elif get_origin(func_param.annotation) is None and issubclass(func_param.annotation, RunnerInput): + elif get_origin(func_param.annotation) is None and issubclass( + func_param.annotation, (RunnerInputV1, RunnerInputV2) + ): map_runner_input(param_name, func_param.annotation) else: mapped_kwargs[param_name] = kwargs[param_name] - return mapped_kwargs def _save_annotated_return_outputs( function_outputs: Union[Tuple[Any], Any], - output_annotations: List[Union[Tuple[type, Union[Parameter, Artifact]], Type[RunnerOutput]]], -) -> Optional[RunnerOutput]: + output_annotations: List[ + Union[Tuple[type, Union[Parameter, Artifact]], Union[Type[RunnerOutputV1], Type[RunnerOutputV2]]] + ], +) -> Optional[Union[RunnerOutputV1, RunnerOutputV2]]: """Save the outputs of the function to the specified output destinations. The output values are matched with the output annotations and saved using the schema: @@ -244,7 +285,7 @@ def _save_annotated_return_outputs( return_obj = None for output_value, dest in zip(function_outputs, output_annotations): - if isinstance(output_value, RunnerOutput): + if isinstance(output_value, (RunnerOutputV1, RunnerOutputV2)): if os.environ.get("hera__script_pydantic_io", None) is None: raise ValueError("hera__script_pydantic_io environment variable is not set") @@ -346,13 +387,13 @@ def _runner(entrypoint: str, kwargs_list: List) -> Any: if _pydantic_mode == 2: from pydantic import validate_call # type: ignore - function = validate_call(function) # TODO: v2 function blocks pydantic IO + function = validate_call(function) else: if _PYDANTIC_VERSION == 1: from pydantic import validate_arguments else: from pydantic.v1 import validate_arguments # type: ignore - function = validate_arguments(function, config=dict(smart_union=True)) # type: ignore + function = validate_arguments(function, config=dict(smart_union=True, arbitrary_types_allowed=True)) # type: ignore function = _ignore_unmatched_kwargs(function) @@ -395,7 +436,7 @@ def _run(): if not result: return - if isinstance(result, RunnerOutput): + if isinstance(result, (RunnerOutputV1, RunnerOutputV2)): print(serialize(result.result)) exit(result.exit_code) diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index b0e3525bc..cf202a347 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -41,7 +41,21 @@ from hera.workflows.artifact import ( Artifact, ) -from hera.workflows.io import RunnerInput, RunnerOutput +from hera.workflows.io.v1 import ( + RunnerInput as RunnerInputV1, + RunnerOutput as RunnerOutputV1, +) + +try: + from hera.workflows.io.v2 import ( # type: ignore + RunnerInput as RunnerInputV2, + RunnerOutput as RunnerOutputV2, + ) +except ImportError: + from hera.workflows.io.v1 import ( # type: ignore + RunnerInput as RunnerInputV2, + RunnerOutput as RunnerOutputV2, + ) from hera.workflows.models import ( EnvVar, Inputs as ModelInputs, @@ -364,11 +378,11 @@ def append_annotation(annotation: Union[Artifact, Parameter]): append_annotation(get_args(return_annotation)[1]) elif get_origin(return_annotation) is tuple: for annotation in get_args(return_annotation): - if isinstance(annotation, type) and issubclass(annotation, RunnerOutput): + if isinstance(annotation, type) and issubclass(annotation, (RunnerOutputV1, RunnerOutputV2)): raise ValueError("RunnerOutput cannot be part of a tuple output") append_annotation(get_args(annotation)[1]) - elif return_annotation and issubclass(return_annotation, RunnerOutput): + elif return_annotation and issubclass(return_annotation, (RunnerOutputV1, RunnerOutputV2)): if not global_config.experimental_features["script_pydantic_io"]: raise ValueError( ( @@ -437,7 +451,9 @@ class will be used as inputs, rather than the class itself. artifacts = [] for func_param in inspect.signature(source).parameters.values(): - if get_origin(func_param.annotation) is None and issubclass(func_param.annotation, RunnerInput): + if get_origin(func_param.annotation) is None and issubclass( + func_param.annotation, (RunnerInputV1, RunnerInputV2) + ): if not global_config.experimental_features["script_pydantic_io"]: raise ValueError( ( @@ -505,7 +521,7 @@ class will be used as inputs, rather than the class itself. def _extract_return_annotation_output(source: Callable) -> List: """Extract the output annotations from the return annotation of the function signature.""" - output: List[Union[Tuple[type, Union[Parameter, Artifact]], Type[RunnerOutput]]] = [] + output: List[Union[Tuple[type, Union[Parameter, Artifact]], Type[Union[RunnerOutputV1, RunnerOutputV2]]]] = [] return_annotation = inspect.signature(source).return_annotation origin_type = get_origin(return_annotation) @@ -515,7 +531,11 @@ def _extract_return_annotation_output(source: Callable) -> List: elif origin_type is tuple: for annotated_type in annotation_args: output.append(get_args(annotated_type)) - elif origin_type is None and isinstance(return_annotation, type) and issubclass(return_annotation, RunnerOutput): + elif ( + origin_type is None + and isinstance(return_annotation, type) + and issubclass(return_annotation, (RunnerOutputV1, RunnerOutputV2)) + ): output.append(return_annotation) return output diff --git a/tests/script_annotations/pydantic_duplicate_input_artifact_names.py b/tests/script_annotations/pydantic_duplicate_input_artifact_names.py index 287280c50..11732add4 100644 --- a/tests/script_annotations/pydantic_duplicate_input_artifact_names.py +++ b/tests/script_annotations/pydantic_duplicate_input_artifact_names.py @@ -1,6 +1,6 @@ from hera.shared import global_config from hera.workflows import Artifact, ArtifactLoader, Workflow, script -from hera.workflows.io import RunnerInput +from hera.workflows.io.v1 import RunnerInput try: from typing import Annotated # type: ignore diff --git a/tests/script_annotations/pydantic_io_invalid_multiple_inputs.py b/tests/script_annotations/pydantic_io_invalid_multiple_inputs.py index a192ac420..d3933e45a 100644 --- a/tests/script_annotations/pydantic_io_invalid_multiple_inputs.py +++ b/tests/script_annotations/pydantic_io_invalid_multiple_inputs.py @@ -1,6 +1,6 @@ from hera.shared import global_config from hera.workflows import Parameter, Workflow, script -from hera.workflows.io import RunnerInput +from hera.workflows.io.v1 import RunnerInput try: from typing import Annotated # type: ignore diff --git a/tests/script_annotations/pydantic_io_invalid_outputs.py b/tests/script_annotations/pydantic_io_invalid_outputs.py index cfff209bd..e57fae92b 100644 --- a/tests/script_annotations/pydantic_io_invalid_outputs.py +++ b/tests/script_annotations/pydantic_io_invalid_outputs.py @@ -3,7 +3,7 @@ from hera.shared import global_config from hera.workflows import Parameter, Workflow, script -from hera.workflows.io import RunnerOutput +from hera.workflows.io.v1 import RunnerOutput try: from typing import Annotated # type: ignore diff --git a/tests/script_annotations/pydantic_io.py b/tests/script_annotations/pydantic_io_v1.py similarity index 97% rename from tests/script_annotations/pydantic_io.py rename to tests/script_annotations/pydantic_io_v1.py index 6193e670e..ca2663a7f 100644 --- a/tests/script_annotations/pydantic_io.py +++ b/tests/script_annotations/pydantic_io_v1.py @@ -2,7 +2,7 @@ from typing import List from hera.workflows import Artifact, ArtifactLoader, Parameter, Workflow, script -from hera.workflows.io import RunnerInput, RunnerOutput +from hera.workflows.io.v1 import RunnerInput, RunnerOutput try: from typing import Annotated # type: ignore diff --git a/tests/script_annotations/pydantic_io_v2.py b/tests/script_annotations/pydantic_io_v2.py new file mode 100644 index 000000000..fb2643f74 --- /dev/null +++ b/tests/script_annotations/pydantic_io_v2.py @@ -0,0 +1,94 @@ +from pathlib import Path +from typing import List + +from hera.workflows import Artifact, ArtifactLoader, Parameter, Workflow, script + +try: + from hera.workflows.io.v2 import ( # type: ignore + RunnerInput, + RunnerOutput, + ) +except ImportError: + from hera.workflows.io.v1 import ( # type: ignore + RunnerInput, + RunnerOutput, + ) + +try: + from typing import Annotated # type: ignore +except ImportError: + from typing_extensions import Annotated # type: ignore + + +class ParamOnlyInput(RunnerInput): + my_int: int = 1 + my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42 + + +class ParamOnlyOutput(RunnerOutput): + my_output_str: str = "my-default-str" + another_output: Annotated[Path, Parameter(name="second-output")] + + +@script(constructor="runner") +def pydantic_io_params( + my_input: ParamOnlyInput, +) -> ParamOnlyOutput: + pass + + +class ArtifactOnlyInput(RunnerInput): + my_file_artifact: Annotated[Path, Artifact(name="file-artifact")] + my_int_artifact: Annotated[ + int, Artifact(name="an-int-artifact", description="my desc", loader=ArtifactLoader.json) + ] + + +class ArtifactOnlyOutput(RunnerOutput): + an_artifact: Annotated[str, Artifact(name="artifact-output")] + + +@script(constructor="runner") +def pydantic_io_artifacts( + my_input: ArtifactOnlyInput, +) -> ArtifactOnlyOutput: + pass + + +class BothInput(RunnerInput): + param_int: Annotated[int, Parameter(name="param-int")] = 42 + artifact_int: Annotated[int, Artifact(name="artifact-int", loader=ArtifactLoader.json)] + + +class BothOutput(RunnerOutput): + param_int: Annotated[int, Parameter(name="param-int")] + artifact_int: Annotated[int, Artifact(name="artifact-int")] + + +@script(constructor="runner") +def pydantic_io( + my_input: BothInput, +) -> BothOutput: + pass + + +@script(constructor="runner") +def pydantic_io_with_defaults( + my_input: ParamOnlyInput = ParamOnlyInput(my_int=2, my_annotated_int=24), +) -> ParamOnlyOutput: + pass + + +@script(constructor="runner") +def pydantic_io_within_generic( + my_inputs: List[ParamOnlyInput] = [ParamOnlyInput(), ParamOnlyInput(my_int=2)], +) -> ParamOnlyOutput: + pass + + +with Workflow(generate_name="pydantic-io-") as w: + pydantic_io_params() + pydantic_io_artifacts() + pydantic_io() + pydantic_io_with_defaults() + pydantic_io_within_generic() diff --git a/tests/script_runner/pydantic_io.py b/tests/script_runner/pydantic_io_v1.py similarity index 94% rename from tests/script_runner/pydantic_io.py rename to tests/script_runner/pydantic_io_v1.py index 299d3bc14..585bc2c95 100644 --- a/tests/script_runner/pydantic_io.py +++ b/tests/script_runner/pydantic_io_v1.py @@ -4,9 +4,13 @@ from tests.helper import ARTIFACT_PATH from hera.shared import global_config -from hera.shared._pydantic import BaseModel from hera.workflows import Artifact, ArtifactLoader, Parameter, script -from hera.workflows.io import RunnerInput, RunnerOutput +from hera.workflows.io.v1 import RunnerInput, RunnerOutput + +try: + from pydantic.v1 import BaseModel +except ImportError: + from pydantic import BaseModel try: from typing import Annotated # type: ignore @@ -71,7 +75,7 @@ def pydantic_output_using_result() -> ParamOnlyOutput: class MyArtifact(BaseModel): - a: str = "a" + a: int = 42 b: str = "b" diff --git a/tests/script_runner/pydantic_io_v2.py b/tests/script_runner/pydantic_io_v2.py new file mode 100644 index 000000000..98a03cd7e --- /dev/null +++ b/tests/script_runner/pydantic_io_v2.py @@ -0,0 +1,108 @@ +from pathlib import Path +from typing import List + +from pydantic import BaseModel +from tests.helper import ARTIFACT_PATH + +from hera.shared import global_config +from hera.workflows import Artifact, ArtifactLoader, Parameter, script + +try: + from hera.workflows.io.v2 import RunnerInput, RunnerOutput +except ImportError: + from hera.workflows.io.v1 import RunnerInput, RunnerOutput + +try: + from typing import Annotated # type: ignore +except ImportError: + from typing_extensions import Annotated # type: ignore + +global_config.experimental_features["script_annotations"] = True +global_config.experimental_features["script_pydantic_io"] = True + + +class ParamOnlyInput(RunnerInput): + my_required_int: int + my_int: int = 1 + my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42 + my_ints: Annotated[List[int], Parameter(name="multiple-ints")] = [] + + +class ParamOnlyOutput(RunnerOutput): + my_output_str: str = "my-default-str" + annotated_str: Annotated[str, Parameter(name="second-output")] + + +@script(constructor="runner") +def pydantic_input_parameters( + my_input: ParamOnlyInput, +) -> int: + return 42 + + +@script(constructor="runner") +def pydantic_io_in_generic( + my_inputs: List[ParamOnlyInput], +) -> str: + """my_inputs is a `list` type, we cannot infer its sub-type in the runner + so it should behave like a normal Pydantic input class. + """ + return len(my_inputs) + + +@script(constructor="runner") +def pydantic_output_parameters() -> ParamOnlyOutput: + outputs = ParamOnlyOutput(annotated_str="my-val") + outputs.my_output_str = "a string!" + + return outputs + + +@script(constructor="runner") +def pydantic_output_using_exit_code() -> ParamOnlyOutput: + outputs = ParamOnlyOutput(exit_code=42, annotated_str="my-val") + outputs.my_output_str = "a string!" + + return outputs + + +@script(constructor="runner") +def pydantic_output_using_result() -> ParamOnlyOutput: + outputs = ParamOnlyOutput(result=42, annotated_str="my-val") + outputs.my_output_str = "a string!" + + return outputs + + +class MyArtifact(BaseModel): + a: int = 42 + b: str = "b" + + +class ArtifactOnlyInput(RunnerInput): + json_artifact: Annotated[ + MyArtifact, Artifact(name="json-artifact", path=ARTIFACT_PATH + "/json", loader=ArtifactLoader.json) + ] + path_artifact: Annotated[Path, Artifact(name="path-artifact", path=ARTIFACT_PATH + "/path", loader=None)] + str_path_artifact: Annotated[ + str, Artifact(name="str-path-artifact", path=ARTIFACT_PATH + "/str-path", loader=None) + ] + file_artifact: Annotated[ + str, Artifact(name="file-artifact", path=ARTIFACT_PATH + "/file", loader=ArtifactLoader.file) + ] + + +class ArtifactOnlyOutput(RunnerOutput): + an_artifact: Annotated[str, Artifact(name="artifact-str-output")] + + +@script(constructor="runner") +def pydantic_input_artifact( + my_input: ArtifactOnlyInput, +) -> str: + return my_input.json_artifact + + +@script(constructor="runner") +def pydantic_output_artifact() -> ArtifactOnlyOutput: + return ArtifactOnlyOutput(an_artifact="test") diff --git a/tests/test_runner.py b/tests/test_runner.py index 779e84cf0..e713c3309 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -18,7 +18,7 @@ from hera.shared import GlobalConfig from hera.shared._pydantic import _PYDANTIC_VERSION from hera.shared.serialization import serialize -from hera.workflows.io import RunnerOutput +from hera.workflows.io.v1 import RunnerOutput from hera.workflows.runner import _run, _runner from hera.workflows.script import RunnerScriptConstructor @@ -619,11 +619,12 @@ def test_run_null_string(mock_parse_args, mock_runner, tmp_path: Path): mock_runner.assert_called_once_with("my_entrypoint", []) +@pytest.mark.parametrize("pydantic_mode", [1, _PYDANTIC_VERSION]) @pytest.mark.parametrize( - "entrypoint,kwargs_list,expected_output,pydantic_mode", + "entrypoint,kwargs_list,expected_output", [ pytest.param( - "tests.script_runner.pydantic_io:pydantic_input_parameters", + "tests.script_runner.pydantic_io_vX:pydantic_input_parameters", [ {"name": "my_required_int", "value": "4"}, {"name": "my_int", "value": "3"}, @@ -631,16 +632,14 @@ def test_run_null_string(mock_parse_args, mock_runner, tmp_path: Path): {"name": "multiple-ints", "value": "[1, 2, 3]"}, ], "42", - 1, id="test parameter only input variations", ), pytest.param( - "tests.script_runner.pydantic_io:pydantic_io_in_generic", + "tests.script_runner.pydantic_io_vX:pydantic_io_in_generic", [ {"name": "my_inputs", "value": '[{"my_required_int": 2, "my_annotated_int": 3}]'}, ], "1", - 1, id="test generic usage (reverts to regular pydantic class implementation)", ), ], @@ -656,6 +655,7 @@ def test_runner_pydantic_inputs_params( tmp_path: Path, ): # GIVEN + entrypoint = entrypoint.replace("pydantic_io_vX", f"pydantic_io_v{pydantic_mode}") monkeypatch.setenv("hera__pydantic_mode", str(pydantic_mode)) os.environ["hera__script_annotations"] = "" os.environ["hera__script_pydantic_io"] = "" @@ -674,7 +674,7 @@ def test_runner_pydantic_inputs_params( "entrypoint,kwargs_list,expected_files,pydantic_mode", [ pytest.param( - "tests.script_runner.pydantic_io:pydantic_output_parameters", + "tests.script_runner.pydantic_io_v1:pydantic_output_parameters", [], [ {"subpath": "tmp/hera-outputs/parameters/my_output_str", "value": "a string!"}, @@ -700,7 +700,7 @@ def test_runner_pydantic_output_params( os.environ["hera__script_annotations"] = "" os.environ["hera__script_pydantic_io"] = "" - import tests.script_runner.pydantic_io as module + import tests.script_runner.pydantic_io_v1 as module importlib.reload(module) @@ -721,14 +721,14 @@ def test_runner_pydantic_output_params( "entrypoint,input_files,expected_output,pydantic_mode", [ pytest.param( - "tests.script_runner.pydantic_io:pydantic_input_artifact", + "tests.script_runner.pydantic_io_v1:pydantic_input_artifact", { "json": '{"a": 3, "b": "bar"}', "path": "dummy", "str-path": "dummy", "file": "dummy", }, - '{"a": "3", "b": "bar"}', + '{"a": 3, "b": "bar"}', 1, id="pydantic io artifact input variations", ), @@ -755,7 +755,7 @@ def test_runner_pydantic_input_artifacts( os.environ["hera__script_annotations"] = "" os.environ["hera__script_pydantic_io"] = "" - import tests.script_runner.pydantic_io as module + import tests.script_runner.pydantic_io_v1 as module importlib.reload(module) @@ -773,7 +773,7 @@ def test_runner_pydantic_input_artifacts( "entrypoint,input_files,expected_files,pydantic_mode", [ pytest.param( - "tests.script_runner.pydantic_io:pydantic_output_artifact", + "tests.script_runner.pydantic_io_v1:pydantic_output_artifact", { "json": '{"a": 3, "b": "bar"}', "path": "dummy", @@ -809,7 +809,7 @@ def test_runner_pydantic_output_artifacts( os.environ["hera__script_annotations"] = "" os.environ["hera__script_pydantic_io"] = "" - import tests.script_runner.pydantic_io as module + import tests.script_runner.pydantic_io_v1 as module importlib.reload(module) @@ -830,7 +830,7 @@ def test_runner_pydantic_output_artifacts( "entrypoint,expected_files,pydantic_mode", [ pytest.param( - "tests.script_runner.pydantic_io:pydantic_output_using_exit_code", + "tests.script_runner.pydantic_io_v1:pydantic_output_using_exit_code", [ {"subpath": "tmp/hera-outputs/parameters/my_output_str", "value": "a string!"}, {"subpath": "tmp/hera-outputs/parameters/second-output", "value": "my-val"}, @@ -854,7 +854,7 @@ def test_runner_pydantic_output_with_exit_code( os.environ["hera__script_annotations"] = "" os.environ["hera__script_pydantic_io"] = "" - import tests.script_runner.pydantic_io as module + import tests.script_runner.pydantic_io_v1 as module importlib.reload(module) @@ -875,7 +875,7 @@ def test_runner_pydantic_output_with_exit_code( "entrypoint,expected_files,pydantic_mode", [ pytest.param( - "tests.script_runner.pydantic_io:pydantic_output_using_exit_code", + "tests.script_runner.pydantic_io_v1:pydantic_output_using_exit_code", [ {"subpath": "tmp/hera-outputs/parameters/my_output_str", "value": "a string!"}, {"subpath": "tmp/hera-outputs/parameters/second-output", "value": "my-val"}, @@ -906,7 +906,7 @@ def test_run_pydantic_output_with_exit_code( os.environ["hera__script_annotations"] = "" os.environ["hera__script_pydantic_io"] = "" - import tests.script_runner.pydantic_io as module + import tests.script_runner.pydantic_io_v1 as module importlib.reload(module) @@ -929,7 +929,7 @@ def test_run_pydantic_output_with_exit_code( "entrypoint,expected_files,expected_result,pydantic_mode", [ pytest.param( - "tests.script_runner.pydantic_io:pydantic_output_using_result", + "tests.script_runner.pydantic_io_v1:pydantic_output_using_result", [ {"subpath": "tmp/hera-outputs/parameters/my_output_str", "value": "a string!"}, {"subpath": "tmp/hera-outputs/parameters/second-output", "value": "my-val"}, @@ -955,7 +955,7 @@ def test_runner_pydantic_output_with_result( os.environ["hera__script_annotations"] = "" os.environ["hera__script_pydantic_io"] = "" - import tests.script_runner.pydantic_io as module + import tests.script_runner.pydantic_io_v1 as module importlib.reload(module) diff --git a/tests/test_script_annotations.py b/tests/test_script_annotations.py index 5136254d0..05ab1ec2d 100644 --- a/tests/test_script_annotations.py +++ b/tests/test_script_annotations.py @@ -10,6 +10,7 @@ except ImportError: from typing_extensions import Annotated # type: ignore +from hera.shared._pydantic import _PYDANTIC_VERSION from hera.workflows import Workflow, script from hera.workflows.parameter import Parameter from hera.workflows.steps import Steps @@ -179,6 +180,13 @@ def test_configmap(global_config_fixture): _compare_workflows(workflow_old, output_old, output_new) +@pytest.mark.parametrize( + "pydantic_mode", + [ + 1, + _PYDANTIC_VERSION, + ], +) @pytest.mark.parametrize( "function_name,expected_input,expected_output", [ @@ -269,15 +277,16 @@ def test_configmap(global_config_fixture): ), ], ) -def test_script_pydantic_io(function_name, expected_input, expected_output, global_config_fixture): +def test_script_pydantic_io(pydantic_mode, function_name, expected_input, expected_output, global_config_fixture): """Test that output annotations work correctly by asserting correct inputs and outputs on the built workflow.""" # GIVEN global_config_fixture.experimental_features["script_annotations"] = True global_config_fixture.experimental_features["script_pydantic_io"] = True # Force a reload of the test module, as the runner performs "importlib.import_module", which # may fetch a cached version - import tests.script_annotations.pydantic_io as module + module_name = f"tests.script_annotations.pydantic_io_v{pydantic_mode}" + module = importlib.import_module(module_name) importlib.reload(module) workflow = importlib.import_module(module.__name__).w @@ -337,7 +346,7 @@ def test_script_pydantic_without_experimental_flag(global_config_fixture): global_config_fixture.experimental_features["script_pydantic_io"] = False # Force a reload of the test module, as the runner performs "importlib.import_module", which # may fetch a cached version - import tests.script_annotations.pydantic_io as module + import tests.script_annotations.pydantic_io_v1 as module importlib.reload(module) workflow = importlib.import_module(module.__name__).w @@ -347,6 +356,6 @@ def test_script_pydantic_without_experimental_flag(global_config_fixture): workflow.to_dict() assert ( - "Unable to instantiate since it is an experimental feature." + "Unable to instantiate since it is an experimental feature." in str(e.value) )