From 345e687315920f59cc7a4600a975e3cd026dedaa Mon Sep 17 00:00:00 2001 From: Elliot Gunton Date: Wed, 8 May 2024 17:44:52 +0100 Subject: [PATCH] Refactor RunnerIO classes (#1060) * Use a new pair of mixin classes to hold the functions that are compatible with Pydantic V1 and V2, meaning we no longer need to maintain two sets of classes for each Pydantic version * "RunnerIO" classes to be renamed to `Input` and `Output`, keeping the old name as an alias for one release so we can announce the change This PR splits out parts of commits from https://github.com/argoproj-labs/hera/pull/1059 that only refactored RunnerIO --------- Signed-off-by: Elliot Gunton --- .../new_decorators_basic_script.md | 12 +- .../experimental/script_runner_io.md | 6 +- docs/user-guides/script-runner-io.md | 27 +-- docs/walk-through/advanced-hera-features.md | 4 +- .../new_decorators_basic_script.py | 12 +- .../experimental/script_runner_io.py | 6 +- src/hera/{workflows => shared}/_inspect.py | 0 src/hera/shared/_pydantic.py | 33 ++- src/hera/workflows/__init__.py | 4 +- src/hera/workflows/_meta_mixins.py | 4 +- .../_runner/script_annotations_util.py | 33 ++- src/hera/workflows/_runner/util.py | 8 +- src/hera/workflows/io/__init__.py | 14 +- src/hera/workflows/io/_io_mixins.py | 192 ++++++++++++++++++ src/hera/workflows/io/v1.py | 98 +-------- src/hera/workflows/io/v2.py | 99 +-------- src/hera/workflows/script.py | 30 ++- ...pydantic_duplicate_input_artifact_names.py | 4 +- .../pydantic_io_invalid_multiple_inputs.py | 4 +- .../pydantic_io_invalid_outputs.py | 4 +- tests/script_annotations/pydantic_io_v1.py | 14 +- .../script_annotations/pydantic_io_v1_strs.py | 12 +- tests/script_annotations/pydantic_io_v2.py | 20 +- .../script_annotations/pydantic_io_v2_strs.py | 16 +- tests/script_runner/pydantic_io_v1.py | 10 +- tests/script_runner/pydantic_io_v2.py | 12 +- tests/test_runner.py | 6 +- tests/test_script_annotations.py | 4 +- .../test_unit/test_script_annotations_util.py | 6 +- .../multiple_entrypoints.py | 12 +- tests/workflow_decorators/set_entrypoint.py | 12 +- 31 files changed, 393 insertions(+), 325 deletions(-) rename src/hera/{workflows => shared}/_inspect.py (100%) create mode 100644 src/hera/workflows/io/_io_mixins.py diff --git a/docs/examples/workflows/experimental/new_decorators_basic_script.md b/docs/examples/workflows/experimental/new_decorators_basic_script.md index 6c8067dda..36f40158f 100644 --- a/docs/examples/workflows/experimental/new_decorators_basic_script.md +++ b/docs/examples/workflows/experimental/new_decorators_basic_script.md @@ -9,7 +9,7 @@ ```python linenums="1" from hera.shared import global_config - from hera.workflows import RunnerInput, RunnerOutput, WorkflowTemplate + from hera.workflows import Input, Output, WorkflowTemplate global_config.experimental_features["script_annotations"] = True global_config.experimental_features["script_pydantic_io"] = True @@ -17,13 +17,13 @@ w = WorkflowTemplate(name="my-template") - class MyInput(RunnerInput): + class MyInput(Input): user: str @w.script() - def hello_world(my_input: MyInput) -> RunnerOutput: - output = RunnerOutput() + def hello_world(my_input: MyInput) -> Output: + output = Output() output.result = f"Hello Hera User: {my_input.user}!" return output @@ -31,8 +31,8 @@ # Pass script kwargs (including an alternative public template name) in the decorator @w.set_entrypoint @w.script(name="goodbye-world", labels={"my-label": "my-value"}) - def goodbye(my_input: MyInput) -> RunnerOutput: - output = RunnerOutput() + def goodbye(my_input: MyInput) -> Output: + output = Output() output.result = f"Goodbye Hera User: {my_input.user}!" return output ``` diff --git a/docs/examples/workflows/experimental/script_runner_io.md b/docs/examples/workflows/experimental/script_runner_io.md index d631cf038..76d06cf26 100644 --- a/docs/examples/workflows/experimental/script_runner_io.md +++ b/docs/examples/workflows/experimental/script_runner_io.md @@ -16,7 +16,7 @@ from hera.shared import global_config from hera.workflows import Artifact, ArtifactLoader, Parameter, Steps, Workflow, script from hera.workflows.archive import NoneArchiveStrategy - from hera.workflows.io import RunnerInput, RunnerOutput + from hera.workflows.io import Input, Output try: from typing import Annotated # type: ignore @@ -32,7 +32,7 @@ a_str: str = "a default string" - class MyInput(RunnerInput): + class MyInput(Input): param_int: Annotated[int, Parameter(name="param-input")] = 42 an_object: Annotated[MyObject, Parameter(name="obj-input")] = MyObject( a_dict={"my-key": "a-value"}, a_str="hello world!" @@ -40,7 +40,7 @@ artifact_int: Annotated[int, Artifact(name="artifact-input", loader=ArtifactLoader.json)] - class MyOutput(RunnerOutput): + class MyOutput(Output): param_int: Annotated[int, Parameter(name="param-output")] artifact_int: Annotated[int, Artifact(name="artifact-output")] diff --git a/docs/user-guides/script-runner-io.md b/docs/user-guides/script-runner-io.md index eb7d8f6df..076faced0 100644 --- a/docs/user-guides/script-runner-io.md +++ b/docs/user-guides/script-runner-io.md @@ -1,6 +1,9 @@ # Script Runner IO -Hera provides the `RunnerInput` and `RunnerOutput` Pydantic classes which can be used to more succinctly write your +> ⚠️ The `RunnerInput` and `RunnerOutput` classes are deprecated since `v5.16.0`, please use `Input` and `Output` for +> equivalent functionality. They will be removed in `v5.17.0`. + +Hera provides the `Input` and `Output` Pydantic classes which can be used to more succinctly write your script function inputs and outputs, and requires use of the Hera Runner. Use of these classes also requires the `"script_pydantic_io"` experimental feature flag to be enabled: @@ -10,17 +13,17 @@ global_config.experimental_features["script_pydantic_io"] = True ## Pydantic V1 or V2? -You can import `RunnerInput` and `RunnerOutput` from the `hera.workflows.io` submodule to import the version of Pydantic +You can import `Input` and `Output` from the `hera.workflows.io` submodule to import the version of Pydantic that matches your V1 or V2 installation. If you need to use V1 models when you have V2 installed, you should import -`RunnerInput` and `RunnerOutput` from the `hera.workflows.io.v1` or `hera.workflows.io.v2` module explicitly. The V2 +`Input` and `Output` from the `hera.workflows.io.v1` or `hera.workflows.io.v2` module explicitly. The V2 models will not be available if you have installed `pydantic<2`, but the V1 models are usable for either version, allowing you to migrate at your own pace. -## Script inputs using `RunnerInput` +## Script inputs using `Input` -For your script inputs, you can create a derived class of `RunnerInput`, and declare all your input parameters (and +For your script inputs, you can create a derived class of `Input`, and declare all your input parameters (and artifacts) as fields of the class. If you want to use `Annotated` to declare `Artifacts` add metadata to your `Parameters`, you will also need to enable the `"script_annotations"` experimental feature flag. @@ -29,7 +32,7 @@ from typing import Annotated from pydantic import BaseModel from hera.workflows import Artifact, ArtifactLoader, Parameter, script -from hera.workflows.io import RunnerInput +from hera.workflows.io import Input class MyObject(BaseModel): @@ -37,7 +40,7 @@ class MyObject(BaseModel): a_str: str = "a default string" -class MyInput(RunnerInput): +class MyInput(Input): param_int: Annotated[int, Parameter(name="param-input")] = 42 an_object: Annotated[MyObject, Parameter(name="obj-input")] = MyObject( a_dict={"my-key": "a-value"}, a_str="hello world!" @@ -72,9 +75,9 @@ template will also have the `"artifact-input"` artifact. The yaml generated from ... ``` -## Script outputs using `RunnerOutput` +## Script outputs using `Output` -The `RunnerOutput` class comes with two special variables, `exit_code` and `result`. The `exit_code` is used to exit the +The `Output` class comes with two special variables, `exit_code` and `result`. The `exit_code` is used to exit the container when running on Argo with the specific exit code - it is set to `0` by default. The `result` is used to print any serializable object to stdout, which means you can now use `.result` on tasks or steps that use a "runner constructor" script - you should be mindful of printing/logging anything else to stdout, which will stop the `result` @@ -82,16 +85,16 @@ functionality working as intended. If you want an output parameters/artifacts wi can declare another field with an annotation of that name, e.g. `my_exit_code: Annotated[int, Parameter(name="exit_code")]`. -Aside from the `exit_code` and `result`, the `RunnerOutput` behaves exactly like the `RunnerInput`: +Aside from the `exit_code` and `result`, the `Output` behaves exactly like the `Input`: ```py from typing import Annotated from hera.workflows import Artifact, Parameter, script -from hera.workflows.io import RunnerOutput +from hera.workflows.io import Output -class MyOutput(RunnerOutput): +class MyOutput(Output): param_int: Annotated[int, Parameter(name="param-output")] artifact_int: Annotated[int, Artifact(name="artifact-output")] diff --git a/docs/walk-through/advanced-hera-features.md b/docs/walk-through/advanced-hera-features.md index dbdf0791b..71f9689ed 100644 --- a/docs/walk-through/advanced-hera-features.md +++ b/docs/walk-through/advanced-hera-features.md @@ -122,8 +122,8 @@ Read the full guide on script annotations in [the script user guide](../user-gui ### Script IO Models Hera provides Pydantic models for you to create subclasses from, which allow you to more easily declare script template -inputs. Any fields that you declare in your subclass of `RunnerInput` will become input parameters or artifacts, while -`RunnerOutput` fields will become output parameters artifacts. The fields that you declare can be `Annotated` as a +inputs. Any fields that you declare in your subclass of `Input` will become input parameters or artifacts, while +`Output` fields will become output parameters artifacts. The fields that you declare can be `Annotated` as a `Parameter` or `Artifact`, as any fields with a basic type will become `Parameters` - you will also need the `script_annotations` experimental feature enabled. diff --git a/examples/workflows/experimental/new_decorators_basic_script.py b/examples/workflows/experimental/new_decorators_basic_script.py index a3bc699cc..1cdf77320 100644 --- a/examples/workflows/experimental/new_decorators_basic_script.py +++ b/examples/workflows/experimental/new_decorators_basic_script.py @@ -1,5 +1,5 @@ from hera.shared import global_config -from hera.workflows import RunnerInput, RunnerOutput, WorkflowTemplate +from hera.workflows import Input, Output, WorkflowTemplate global_config.experimental_features["script_annotations"] = True global_config.experimental_features["script_pydantic_io"] = True @@ -7,13 +7,13 @@ w = WorkflowTemplate(name="my-template") -class MyInput(RunnerInput): +class MyInput(Input): user: str @w.script() -def hello_world(my_input: MyInput) -> RunnerOutput: - output = RunnerOutput() +def hello_world(my_input: MyInput) -> Output: + output = Output() output.result = f"Hello Hera User: {my_input.user}!" return output @@ -21,7 +21,7 @@ def hello_world(my_input: MyInput) -> RunnerOutput: # Pass script kwargs (including an alternative public template name) in the decorator @w.set_entrypoint @w.script(name="goodbye-world", labels={"my-label": "my-value"}) -def goodbye(my_input: MyInput) -> RunnerOutput: - output = RunnerOutput() +def goodbye(my_input: MyInput) -> Output: + output = Output() output.result = f"Goodbye Hera User: {my_input.user}!" return output diff --git a/examples/workflows/experimental/script_runner_io.py b/examples/workflows/experimental/script_runner_io.py index 1afc0486f..3523ec215 100644 --- a/examples/workflows/experimental/script_runner_io.py +++ b/examples/workflows/experimental/script_runner_io.py @@ -6,7 +6,7 @@ from hera.shared import global_config from hera.workflows import Artifact, ArtifactLoader, Parameter, Steps, Workflow, script from hera.workflows.archive import NoneArchiveStrategy -from hera.workflows.io import RunnerInput, RunnerOutput +from hera.workflows.io import Input, Output try: from typing import Annotated # type: ignore @@ -22,7 +22,7 @@ class MyObject(BaseModel): a_str: str = "a default string" -class MyInput(RunnerInput): +class MyInput(Input): param_int: Annotated[int, Parameter(name="param-input")] = 42 an_object: Annotated[MyObject, Parameter(name="obj-input")] = MyObject( a_dict={"my-key": "a-value"}, a_str="hello world!" @@ -30,7 +30,7 @@ class MyInput(RunnerInput): artifact_int: Annotated[int, Artifact(name="artifact-input", loader=ArtifactLoader.json)] -class MyOutput(RunnerOutput): +class MyOutput(Output): param_int: Annotated[int, Parameter(name="param-output")] artifact_int: Annotated[int, Artifact(name="artifact-output")] diff --git a/src/hera/workflows/_inspect.py b/src/hera/shared/_inspect.py similarity index 100% rename from src/hera/workflows/_inspect.py rename to src/hera/shared/_inspect.py diff --git a/src/hera/shared/_pydantic.py b/src/hera/shared/_pydantic.py index 84e857443..312e1c177 100644 --- a/src/hera/shared/_pydantic.py +++ b/src/hera/shared/_pydantic.py @@ -1,9 +1,16 @@ """Module that holds the underlying base Pydantic models for Hera objects.""" +from collections import ChainMap from typing import TYPE_CHECKING, Any, Dict, Type from pydantic import VERSION +try: + from inspect import get_annotations # type: ignore +except ImportError: + from hera.shared._inspect import get_annotations # type: ignore + + _PYDANTIC_VERSION: int = int(VERSION.split(".")[0]) # The pydantic v1 interface is used for both pydantic v1 and v2 in order to support # users across both versions. @@ -28,14 +35,17 @@ # native pydantic hinting for `__init__` arguments. if TYPE_CHECKING: from pydantic import BaseModel as PydanticBaseModel + from pydantic.fields import FieldInfo else: if _PYDANTIC_VERSION == 2: from pydantic.v1 import BaseModel as PydanticBaseModel # type: ignore + from pydantic.v1.fields import FieldInfo else: from pydantic import BaseModel as PydanticBaseModel # type: ignore[assignment,no-redef] + from pydantic.fields import FieldInfo -def get_fields(cls: Type[PydanticBaseModel]) -> Dict[str, Any]: +def get_fields(cls: Type[PydanticBaseModel]) -> Dict[str, FieldInfo]: """Centralize access to __fields__.""" try: return cls.model_fields # type: ignore @@ -43,14 +53,8 @@ def get_fields(cls: Type[PydanticBaseModel]) -> Dict[str, Any]: return cls.__fields__ # type: ignore -__all__ = [ - "BaseModel", - "Field", - "PydanticBaseModel", # Export for serialization.py to cover user-defined models - "ValidationError", - "root_validator", - "validator", -] +def get_field_annotations(cls: Type[PydanticBaseModel]) -> Dict[str, Any]: + return {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} class BaseModel(PydanticBaseModel): @@ -74,3 +78,14 @@ class Config: smart_union = True """uses smart union for matching a field's specified value to the underlying type that's part of a union""" + + +__all__ = [ + "BaseModel", + "Field", + "FieldInfo", + "PydanticBaseModel", # Export for serialization.py to cover user-defined models + "ValidationError", + "root_validator", + "validator", +] diff --git a/src/hera/workflows/__init__.py b/src/hera/workflows/__init__.py index 6c65792b3..a59062bad 100644 --- a/src/hera/workflows/__init__.py +++ b/src/hera/workflows/__init__.py @@ -30,7 +30,7 @@ from hera.workflows.env_from import ConfigMapEnvFrom, SecretEnvFrom from hera.workflows.exceptions import InvalidDispatchType, InvalidTemplateCall, InvalidType from hera.workflows.http_template import HTTP -from hera.workflows.io import RunnerInput, RunnerOutput +from hera.workflows.io import Input, Output, RunnerInput, RunnerOutput from hera.workflows.metrics import Counter, Gauge, Histogram, Label, Metric, Metrics from hera.workflows.operator import Operator from hera.workflows.parameter import Parameter @@ -126,6 +126,7 @@ "HostPathVolume", "ISCSIVolume", "InlineScriptConstructor", + "Input", "InvalidDispatchType", "InvalidTemplateCall", "InvalidType", @@ -136,6 +137,7 @@ "NoneArchiveStrategy", "OSSArtifact", "Operator", + "Output", "Parallel", "Parameter", "PhotonPersistentDiskVolume", diff --git a/src/hera/workflows/_meta_mixins.py b/src/hera/workflows/_meta_mixins.py index 926405f5c..d80271377 100644 --- a/src/hera/workflows/_meta_mixins.py +++ b/src/hera/workflows/_meta_mixins.py @@ -35,7 +35,7 @@ try: from inspect import get_annotations # type: ignore except ImportError: - from hera.workflows._inspect import get_annotations # type: ignore + from hera.shared._inspect import get_annotations # type: ignore _yaml: Optional[ModuleType] = None try: @@ -134,7 +134,7 @@ def __init__(self, model_path: str, hera_builder: Optional[Callable] = None): fields = get_fields(curr_class) if key not in fields: raise ValueError(f"Model key '{key}' does not exist in class {curr_class}") - curr_class = fields[key].outer_type_ + curr_class = fields[key].outer_type_ # type: ignore @classmethod def _get_model_class(cls) -> Type[BaseModel]: diff --git a/src/hera/workflows/_runner/script_annotations_util.py b/src/hera/workflows/_runner/script_annotations_util.py index 9640a65e6..14c6eaa55 100644 --- a/src/hera/workflows/_runner/script_annotations_util.py +++ b/src/hera/workflows/_runner/script_annotations_util.py @@ -11,19 +11,19 @@ from hera.workflows import Artifact, Parameter from hera.workflows.artifact import ArtifactLoader from hera.workflows.io.v1 import ( - RunnerInput as RunnerInputV1, - RunnerOutput as RunnerOutputV1, + Input as InputV1, + Output as OutputV1, ) try: from hera.workflows.io.v2 import ( # type: ignore - RunnerInput as RunnerInputV2, - RunnerOutput as RunnerOutputV2, + Input as InputV2, + Output as OutputV2, ) except ImportError: from hera.workflows.io.v1 import ( # type: ignore - RunnerInput as RunnerInputV2, - RunnerOutput as RunnerOutputV2, + Input as InputV2, + Output as OutputV2, ) try: @@ -127,7 +127,7 @@ def map_runner_input( runner_input_class: T, kwargs: Dict[str, str], ) -> T: - """Map argo input kwargs to the fields of the given RunnerInput, return an instance of the class. + """Map argo input kwargs to the fields of the given Input, return an instance of the class. If the field is annotated, we look for the kwarg with the name from the annotation (Parameter or Artifact). Otherwise, we look for the kwarg with the name of the field. @@ -198,9 +198,7 @@ def _map_argo_inputs_to_function(function: Callable, kwargs: Dict[str, str]) -> mapped_kwargs[func_param_name] = get_annotated_artifact_value(func_param_annotation) else: mapped_kwargs[func_param_name] = kwargs[func_param_name] - elif get_origin(func_param.annotation) is None and issubclass( - func_param.annotation, (RunnerInputV1, RunnerInputV2) - ): + elif get_origin(func_param.annotation) is None and issubclass(func_param.annotation, (InputV1, InputV2)): mapped_kwargs[func_param_name] = map_runner_input(func_param.annotation, kwargs) else: mapped_kwargs[func_param_name] = kwargs[func_param_name] @@ -209,10 +207,8 @@ def _map_argo_inputs_to_function(function: Callable, kwargs: Dict[str, str]) -> def _save_annotated_return_outputs( function_outputs: Union[Tuple[Any], Any], - output_annotations: List[ - Union[Tuple[type, Union[Parameter, Artifact]], Union[Type[RunnerOutputV1], Type[RunnerOutputV2]]] - ], -) -> Optional[Union[RunnerOutputV1, RunnerOutputV2]]: + output_annotations: List[Union[Tuple[type, Union[Parameter, Artifact]], Union[Type[OutputV1], Type[OutputV2]]]], +) -> Optional[Union[OutputV1, OutputV2]]: """Save the outputs of the function to the specified output destinations. The output values are matched with the output annotations and saved using the schema: @@ -230,7 +226,7 @@ def _save_annotated_return_outputs( return_obj = None for output_value, dest in zip(function_outputs, output_annotations): - if isinstance(output_value, (RunnerOutputV1, RunnerOutputV2)): + if isinstance(output_value, (OutputV1, OutputV2)): if os.environ.get("hera__script_pydantic_io", None) is None: raise ValueError("hera__script_pydantic_io environment variable is not set") @@ -273,7 +269,10 @@ def _save_annotated_return_outputs( def _save_dummy_outputs( output_annotations: List[ - Union[Tuple[type, Union[Parameter, Artifact]], Union[Type[RunnerOutputV1], Type[RunnerOutputV2]]] + Union[ + Tuple[type, Union[Parameter, Artifact]], + Union[Type[OutputV1], Type[OutputV2]], + ] ], ) -> None: """Save dummy values into the outputs specified. @@ -293,7 +292,7 @@ def _save_dummy_outputs( can be provided by the user or is set to /tmp/hera-outputs by default """ for dest in output_annotations: - if isinstance(dest, (RunnerOutputV1, RunnerOutputV2)): + if isinstance(dest, (OutputV1, OutputV2)): if os.environ.get("hera__script_pydantic_io", None) is None: raise ValueError("hera__script_pydantic_io environment variable is not set") diff --git a/src/hera/workflows/_runner/util.py b/src/hera/workflows/_runner/util.py index 609548159..fa1121e50 100644 --- a/src/hera/workflows/_runner/util.py +++ b/src/hera/workflows/_runner/util.py @@ -19,7 +19,7 @@ ) from hera.workflows.artifact import ArtifactLoader from hera.workflows.io.v1 import ( - RunnerOutput as RunnerOutputV1, + Output as OutputV1, ) from hera.workflows.script import _extract_return_annotation_output @@ -33,13 +33,13 @@ from pydantic.v1 import parse_obj_as # type: ignore from hera.workflows.io.v2 import ( # type: ignore - RunnerOutput as RunnerOutputV2, + Output as OutputV2, ) else: from pydantic import parse_obj_as from hera.workflows.io.v1 import ( # type: ignore - RunnerOutput as RunnerOutputV2, + Output as OutputV2, ) @@ -255,7 +255,7 @@ def _run() -> None: if not result: return - if isinstance(result, (RunnerOutputV1, RunnerOutputV2)): + if isinstance(result, (OutputV1, OutputV2)): print(serialize(result.result)) exit(result.exit_code) diff --git a/src/hera/workflows/io/__init__.py b/src/hera/workflows/io/__init__.py index baf8c7325..7dbaf3d24 100644 --- a/src/hera/workflows/io/__init__.py +++ b/src/hera/workflows/io/__init__.py @@ -1,13 +1,23 @@ """Hera IO models.""" +from typing import Type + from pydantic import VERSION if VERSION.split(".")[0] == "2": - from hera.workflows.io.v2 import RunnerInput, RunnerOutput + from hera.workflows.io.v2 import Input, Output + + RunnerInput: Type = Input + RunnerOutput: Type = Output else: - from hera.workflows.io.v1 import RunnerInput, RunnerOutput # type: ignore + from hera.workflows.io.v1 import Input, Output # type: ignore + + RunnerInput: Type = Input # type: ignore + RunnerOutput: Type = Output # type: ignore __all__ = [ + "Input", + "Output", "RunnerInput", "RunnerOutput", ] diff --git a/src/hera/workflows/io/_io_mixins.py b/src/hera/workflows/io/_io_mixins.py new file mode 100644 index 000000000..e7d16670e --- /dev/null +++ b/src/hera/workflows/io/_io_mixins.py @@ -0,0 +1,192 @@ +from typing import TYPE_CHECKING, List, Optional, Union + +from hera.shared._pydantic import _PYDANTIC_VERSION, get_field_annotations, get_fields +from hera.shared.serialization import MISSING, serialize +from hera.workflows.artifact import Artifact +from hera.workflows.models import ( + Arguments as ModelArguments, + Artifact as ModelArtifact, + Parameter as ModelParameter, + ValueFrom, +) +from hera.workflows.parameter import Parameter + +if _PYDANTIC_VERSION == 2: + from pydantic import BaseModel as V2BaseModel + from pydantic.v1 import BaseModel as V1BaseModel + from pydantic_core import PydanticUndefined +else: + from pydantic import BaseModel as V1BaseModel # type: ignore[assignment] + + V2BaseModel = V1BaseModel # type: ignore + PydanticUndefined = None # type: ignore[assignment] + + +try: + from typing import Annotated, Self, get_args, get_origin # type: ignore +except ImportError: + from typing_extensions import Annotated, Self, get_args, get_origin # type: ignore + +if TYPE_CHECKING: + # We add BaseModel as a parent class of the mixins only when type checking which allows it + # to be used with either a V1 BaseModel or a V2 BaseModel + from pydantic import BaseModel +else: + # Subclassing `object` when using the real code (i.e. not type-checking) is basically a no-op + BaseModel = object # type: ignore + + +class InputMixin(BaseModel): + @classmethod + def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Parameter]: + parameters = [] + annotations = get_field_annotations(cls) + + for field, field_info in get_fields(cls).items(): + if get_origin(annotations[field]) is Annotated: + param = get_args(annotations[field])[1] + if isinstance(param, Parameter): + if object_override: + param.default = serialize(getattr(object_override, field)) + elif field_info.default is not None and field_info.default != PydanticUndefined: # type: ignore + # Serialize the value (usually done in Parameter's validator) + param.default = serialize(field_info.default) # type: ignore + parameters.append(param) + else: + # Create a Parameter from basic type annotations + default = getattr(object_override, field) if object_override else field_info.default + + # For users on Pydantic 2 but using V1 BaseModel, we still need to check if `default` is None + if default is None or default == PydanticUndefined: + default = MISSING + + parameters.append(Parameter(name=field, default=default)) + + return parameters + + @classmethod + def _get_artifacts(cls) -> List[Artifact]: + artifacts = [] + annotations = get_field_annotations(cls) + + for field in get_fields(cls): + if get_origin(annotations[field]) is Annotated: + artifact = get_args(annotations[field])[1] + if isinstance(artifact, Artifact): + if artifact.path is None: + artifact.path = artifact._get_default_inputs_path() + artifacts.append(artifact) + return artifacts + + @classmethod + def _get_inputs(cls) -> List[Union[Artifact, Parameter]]: + return cls._get_artifacts() + cls._get_parameters() + + @classmethod + def _get_as_templated_arguments(cls) -> Self: + """Returns the Input with templated values to propagate through a DAG/Steps function.""" + object_dict = {} + cls_fields = get_fields(cls) + annotations = get_field_annotations(cls) + + for field in cls_fields: + if get_origin(annotations[field]) is Annotated: + annotation = get_args(annotations[field])[1] + if isinstance(annotation, (Artifact, Parameter)): + name = annotation.name + if isinstance(annotation, Parameter): + object_dict[field] = "{{inputs.parameters." + f"{name}" + "}}" + elif isinstance(annotation, Artifact): + object_dict[field] = "{{inputs.artifacts." + f"{name}" + "}}" + else: + object_dict[field] = "{{inputs.parameters." + f"{field}" + "}}" + + return cls.construct(None, **object_dict) + + def _get_as_arguments(self) -> ModelArguments: + params = [] + artifacts = [] + annotations = get_field_annotations(type(self)) + + if isinstance(self, V1BaseModel): + self_dict = self.dict() + elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel): + self_dict = self.model_dump() + + for field in get_fields(type(self)): + templated_value = self_dict[field] + + if get_origin(annotations[field]) is Annotated: + annotation = get_args(annotations[field])[1] + if isinstance(annotation, Parameter) and annotation.name: + params.append(ModelParameter(name=annotation.name, value=templated_value)) + elif isinstance(annotation, Artifact) and annotation.name: + artifacts.append(ModelArtifact(name=annotation.name, from_=templated_value)) + else: + params.append(ModelParameter(name=field, value=templated_value)) + + return ModelArguments(parameters=params or None, artifacts=artifacts or None) + + +class OutputMixin(BaseModel): + @classmethod + def _get_outputs(cls) -> List[Union[Artifact, Parameter]]: + outputs = [] + annotations = get_field_annotations(cls) + + model_fields = get_fields(cls) + + for field in model_fields: + if field in {"exit_code", "result"}: + continue + if get_origin(annotations[field]) is Annotated: + if isinstance(get_args(annotations[field])[1], (Parameter, Artifact)): + outputs.append(get_args(annotations[field])[1]) + else: + # Create a Parameter from basic type annotations + default = model_fields[field].default + if default is None or default == PydanticUndefined: + default = MISSING + outputs.append(Parameter(name=field, default=default)) + return outputs + + @classmethod + def _get_output(cls, field_name: str) -> Union[Artifact, Parameter]: + annotations = get_field_annotations(cls) + annotation = annotations[field_name] + if get_origin(annotation) is Annotated: + if isinstance(get_args(annotation)[1], (Parameter, Artifact)): + return get_args(annotation)[1] + + # Create a Parameter from basic type annotations + default = get_fields(cls)[field_name].default + if default is None or default == PydanticUndefined: + default = MISSING + return Parameter(name=field_name, default=default) # type: ignore + + def _get_as_output(self) -> List[Union[Artifact, Parameter]]: + """Get the Output with values of ...""" + outputs: List[Union[Artifact, Parameter]] = [] + annotations = get_field_annotations(type(self)) + + if isinstance(self, V1BaseModel): + self_dict = self.dict() + elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel): + self_dict = self.model_dump() + + for field in get_fields(type(self)): + if field in {"exit_code", "result"}: + continue + + templated_value = self_dict[field] # a string such as `"{{tasks.task_a.outputs.parameter.my_param}}"` + + if get_origin(annotations[field]) is Annotated: + annotation = get_args(annotations[field])[1] + if isinstance(annotation, Parameter) and annotation.name: + outputs.append(Parameter(name=annotation.name, value_from=ValueFrom(parameter=templated_value))) + elif isinstance(annotation, Artifact) and annotation.name: + outputs.append(Artifact(name=annotation.name, from_=templated_value)) + else: + outputs.append(Parameter(name=field, value_from=ValueFrom(parameter=templated_value))) + + return outputs diff --git a/src/hera/workflows/io/v1.py b/src/hera/workflows/io/v1.py index e88549e1f..154c0848e 100644 --- a/src/hera/workflows/io/v1.py +++ b/src/hera/workflows/io/v1.py @@ -1,109 +1,29 @@ """Pydantic V1 input/output models for the Hera runner.""" -from collections import ChainMap -from typing import Any, List, Optional, Union +from typing import Any -from hera.shared._pydantic import BaseModel, get_fields -from hera.shared.serialization import MISSING, serialize -from hera.workflows.artifact import Artifact -from hera.workflows.parameter import Parameter +from hera.shared._pydantic import BaseModel +from hera.workflows.io._io_mixins import InputMixin, OutputMixin -try: - from inspect import get_annotations # type: ignore -except ImportError: - from hera.workflows._inspect import get_annotations # type: ignore -try: - from typing import Annotated, get_args, get_origin # type: ignore -except ImportError: - from typing_extensions import Annotated, get_args, get_origin # type: ignore - - -class RunnerInput(BaseModel): +class Input(BaseModel, InputMixin): """Input model usable by the Hera Runner. - RunnerInput is a Pydantic model which users can create a subclass of. When a subclass - of RunnerInput is used as a function parameter type, the Hera Runner will take the fields + Input is a Pydantic model which users can create a subclass of. When a subclass + of Input is used as a function parameter type, the Hera Runner will take the fields of the user's subclass to create template input parameters and artifacts. See the example for the script_pydantic_io experimental feature. """ - @classmethod - def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> List[Parameter]: - parameters = [] - annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} - - fields = get_fields(cls) - for field, field_info in fields.items(): - if get_origin(annotations[field]) is Annotated: - if isinstance(get_args(annotations[field])[1], Parameter): - param = get_args(annotations[field])[1] - if object_override: - param.default = serialize(getattr(object_override, field)) - elif field_info.default: - # Serialize the value (usually done in Parameter's validator) - param.default = serialize(field_info.default) - parameters.append(param) - else: - # Create a Parameter from basic type annotations - default = getattr(object_override, field) if object_override else field_info.default - if default is None: - default = MISSING - parameters.append(Parameter(name=field, default=default)) - - return parameters - - @classmethod - def _get_artifacts(cls) -> List[Artifact]: - artifacts = [] - annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} - - for field in get_fields(cls): - if get_origin(annotations[field]) is Annotated: - if isinstance(get_args(annotations[field])[1], Artifact): - artifact = get_args(annotations[field])[1] - if artifact.path is None: - artifact.path = artifact._get_default_inputs_path() - artifacts.append(artifact) - return artifacts - -class RunnerOutput(BaseModel): +class Output(BaseModel, OutputMixin): """Output model usable by the Hera Runner. - RunnerOutput is a Pydantic model which users can create a subclass of. When a subclass - of RunnerOutput is used as a function return type, the Hera Runner will take the fields + Output is a Pydantic model which users can create a subclass of. When a subclass + of Output is used as a function return type, the Hera Runner will take the fields of the user's subclass to create template output parameters and artifacts. See the example for the script_pydantic_io experimental feature. """ exit_code: int = 0 result: Any = None - - @classmethod - def _get_outputs(cls) -> List[Union[Artifact, Parameter]]: - outputs = [] - annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} - - fields = get_fields(cls) - for field in fields: - if field in {"exit_code", "result"}: - continue - if get_origin(annotations[field]) is Annotated: - if isinstance(get_args(annotations[field])[1], (Parameter, Artifact)): - outputs.append(get_args(annotations[field])[1]) - else: - # Create a Parameter from basic type annotations - outputs.append(Parameter(name=field, default=fields[field].default)) - return outputs - - @classmethod - def _get_output(cls, field_name: str) -> Union[Artifact, Parameter]: - annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} - annotation = annotations[field_name] - if get_origin(annotation) is Annotated: - if isinstance(get_args(annotation)[1], (Parameter, Artifact)): - return get_args(annotation)[1] - - # Create a Parameter from basic type annotations - return Parameter(name=field_name, default=get_fields(cls)[field_name].default) diff --git a/src/hera/workflows/io/v2.py b/src/hera/workflows/io/v2.py index efa9e525c..ab0d9208a 100644 --- a/src/hera/workflows/io/v2.py +++ b/src/hera/workflows/io/v2.py @@ -1,114 +1,33 @@ """Pydantic V2 input/output models for the Hera runner. -RunnerInput/Output are only defined in this file if Pydantic v2 is installed. +Input/Output are only defined in this file if Pydantic v2 is installed. """ -from collections import ChainMap -from typing import Any, List, Optional, Union +from typing import Any from hera.shared._pydantic import _PYDANTIC_VERSION -from hera.shared.serialization import MISSING, serialize -from hera.workflows.artifact import Artifact -from hera.workflows.parameter import Parameter - -try: - from inspect import get_annotations # type: ignore -except ImportError: - from hera.workflows._inspect import get_annotations # type: ignore - -try: - from typing import Annotated, get_args, get_origin # type: ignore -except ImportError: - from typing_extensions import Annotated, get_args, get_origin # type: ignore - +from hera.workflows.io._io_mixins import InputMixin, OutputMixin if _PYDANTIC_VERSION == 2: from pydantic import BaseModel - from pydantic_core import PydanticUndefined - class RunnerInput(BaseModel): + class Input(InputMixin, BaseModel): """Input model usable by the Hera Runner. - RunnerInput is a Pydantic model which users can create a subclass of. When a subclass - of RunnerInput is used as a function parameter type, the Hera Runner will take the fields + Input is a Pydantic model which users can create a subclass of. When a subclass + of Input is used as a function parameter type, the Hera Runner will take the fields of the user's subclass to create template input parameters and artifacts. See the example for the script_pydantic_io experimental feature. """ - @classmethod - def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> List[Parameter]: - parameters = [] - annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} - - for field, field_info in cls.model_fields.items(): # type: ignore - if get_origin(annotations[field]) is Annotated: - if isinstance(get_args(annotations[field])[1], Parameter): - param = get_args(annotations[field])[1] - if object_override: - param.default = serialize(getattr(object_override, field)) - elif field_info.default: # type: ignore - # Serialize the value (usually done in Parameter's validator) - param.default = serialize(field_info.default) # type: ignore - parameters.append(param) - else: - # Create a Parameter from basic type annotations - default = getattr(object_override, field) if object_override else field_info.default - if default == PydanticUndefined: - default = MISSING - - parameters.append(Parameter(name=field, default=default)) - - return parameters - - @classmethod - def _get_artifacts(cls) -> List[Artifact]: - artifacts = [] - annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} - - for field in cls.model_fields: # type: ignore - if get_origin(annotations[field]) is Annotated: - if isinstance(get_args(annotations[field])[1], Artifact): - artifact = get_args(annotations[field])[1] - if artifact.path is None: - artifact.path = artifact._get_default_inputs_path() - artifacts.append(artifact) - return artifacts - - class RunnerOutput(BaseModel): + class Output(OutputMixin, BaseModel): """Output model usable by the Hera Runner. - RunnerOutput is a Pydantic model which users can create a subclass of. When a subclass - of RunnerOutput is used as a function return type, the Hera Runner will take the fields + Output is a Pydantic model which users can create a subclass of. When a subclass + of Output is used as a function return type, the Hera Runner will take the fields of the user's subclass to create template output parameters and artifacts. See the example for the script_pydantic_io experimental feature. """ exit_code: int = 0 result: Any = None - - @classmethod - def _get_outputs(cls) -> List[Union[Artifact, Parameter]]: - outputs = [] - annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} - - for field in cls.model_fields: # type: ignore - if field in {"exit_code", "result"}: - continue - if get_origin(annotations[field]) is Annotated: - if isinstance(get_args(annotations[field])[1], (Parameter, Artifact)): - outputs.append(get_args(annotations[field])[1]) - else: - # Create a Parameter from basic type annotations - outputs.append(Parameter(name=field, default=cls.model_fields[field].default)) # type: ignore - return outputs - - @classmethod - def _get_output(cls, field_name: str) -> Union[Artifact, Parameter]: - annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()} - annotation = annotations[field_name] - if get_origin(annotation) is Annotated: - if isinstance(get_args(annotation)[1], (Parameter, Artifact)): - return get_args(annotation)[1] - - # Create a Parameter from basic type annotations - return Parameter(name=field_name, default=cls.model_fields[field_name].default) # type: ignore diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index bc9584064..ac3c7296c 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -42,19 +42,19 @@ Artifact, ) from hera.workflows.io.v1 import ( - RunnerInput as RunnerInputV1, - RunnerOutput as RunnerOutputV1, + Input as InputV1, + Output as OutputV1, ) try: from hera.workflows.io.v2 import ( # type: ignore - RunnerInput as RunnerInputV2, - RunnerOutput as RunnerOutputV2, + Input as InputV2, + Output as OutputV2, ) except ImportError: from hera.workflows.io.v1 import ( # type: ignore - RunnerInput as RunnerInputV2, - RunnerOutput as RunnerOutputV2, + Input as InputV2, + Output as OutputV2, ) from hera.workflows.models import ( EnvVar, @@ -380,11 +380,11 @@ def append_annotation(annotation: Union[Artifact, Parameter]): append_annotation(get_args(return_annotation)[1]) elif get_origin(return_annotation) is tuple: for annotation in get_args(return_annotation): - if isinstance(annotation, type) and issubclass(annotation, (RunnerOutputV1, RunnerOutputV2)): - raise ValueError("RunnerOutput cannot be part of a tuple output") + if isinstance(annotation, type) and issubclass(annotation, (OutputV1, OutputV2)): + raise ValueError("Output cannot be part of a tuple output") append_annotation(get_args(annotation)[1]) - elif return_annotation and issubclass(return_annotation, (RunnerOutputV1, RunnerOutputV2)): + elif return_annotation and issubclass(return_annotation, (OutputV1, OutputV2)): if not global_config.experimental_features["script_pydantic_io"]: raise ValueError( ( @@ -444,7 +444,7 @@ def _get_inputs_from_callable(source: Callable) -> Tuple[List[Parameter], List[A """Return all inputs from the function. This includes all basic Python function parameters, and all parameters with a Parameter or Artifact annotation. - For the Pydantic IO experimental feature, any input parameter which is a subclass of RunnerInput, the fields of the + For the Pydantic IO experimental feature, any input parameter which is a subclass of Input, the fields of the class will be used as inputs, rather than the class itself. Note, the given Parameter/Artifact names in annotations of different inputs could clash, which will raise a ValueError. @@ -453,9 +453,7 @@ class will be used as inputs, rather than the class itself. artifacts = [] for func_param in inspect.signature(source).parameters.values(): - if get_origin(func_param.annotation) is None and issubclass( - func_param.annotation, (RunnerInputV1, RunnerInputV2) - ): + if get_origin(func_param.annotation) is None and issubclass(func_param.annotation, (InputV1, InputV2)): if not global_config.experimental_features["script_pydantic_io"]: raise ValueError( ( @@ -467,7 +465,7 @@ class will be used as inputs, rather than the class itself. ) if len(inspect.signature(source).parameters) != 1: - raise SyntaxError("Only one function parameter can be specified when using a RunnerInput.") + raise SyntaxError("Only one function parameter can be specified when using an Input.") input_class = func_param.annotation if ( @@ -523,7 +521,7 @@ class will be used as inputs, rather than the class itself. def _extract_return_annotation_output(source: Callable) -> List: """Extract the output annotations from the return annotation of the function signature.""" - output: List[Union[Tuple[type, Union[Parameter, Artifact]], Type[Union[RunnerOutputV1, RunnerOutputV2]]]] = [] + output: List[Union[Tuple[type, Union[Parameter, Artifact]], Type[Union[OutputV1, OutputV2]]]] = [] return_annotation = inspect.signature(source).return_annotation origin_type = get_origin(return_annotation) @@ -536,7 +534,7 @@ def _extract_return_annotation_output(source: Callable) -> List: elif ( origin_type is None and isinstance(return_annotation, type) - and issubclass(return_annotation, (RunnerOutputV1, RunnerOutputV2)) + and issubclass(return_annotation, (OutputV1, OutputV2)) ): output.append(return_annotation) diff --git a/tests/script_annotations/pydantic_duplicate_input_artifact_names.py b/tests/script_annotations/pydantic_duplicate_input_artifact_names.py index 11732add4..97d035f7b 100644 --- a/tests/script_annotations/pydantic_duplicate_input_artifact_names.py +++ b/tests/script_annotations/pydantic_duplicate_input_artifact_names.py @@ -1,6 +1,6 @@ from hera.shared import global_config from hera.workflows import Artifact, ArtifactLoader, Workflow, script -from hera.workflows.io.v1 import RunnerInput +from hera.workflows.io.v1 import Input try: from typing import Annotated # type: ignore @@ -11,7 +11,7 @@ global_config.experimental_features["script_pydantic_io"] = True -class ArtifactOnlyInput(RunnerInput): +class ArtifactOnlyInput(Input): str_path_artifact: Annotated[str, Artifact(name="str-path-artifact", loader=None)] file_artifact: Annotated[str, Artifact(name="file-artifact", loader=ArtifactLoader.file)] diff --git a/tests/script_annotations/pydantic_io_invalid_multiple_inputs.py b/tests/script_annotations/pydantic_io_invalid_multiple_inputs.py index d3933e45a..b44f4e2b9 100644 --- a/tests/script_annotations/pydantic_io_invalid_multiple_inputs.py +++ b/tests/script_annotations/pydantic_io_invalid_multiple_inputs.py @@ -1,6 +1,6 @@ from hera.shared import global_config from hera.workflows import Parameter, Workflow, script -from hera.workflows.io.v1 import RunnerInput +from hera.workflows.io.v1 import Input try: from typing import Annotated # type: ignore @@ -11,7 +11,7 @@ global_config.experimental_features["script_pydantic_io"] = True -class ParamOnlyInput(RunnerInput): +class ParamOnlyInput(Input): my_int: int = 1 my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42 diff --git a/tests/script_annotations/pydantic_io_invalid_outputs.py b/tests/script_annotations/pydantic_io_invalid_outputs.py index e57fae92b..7835d4695 100644 --- a/tests/script_annotations/pydantic_io_invalid_outputs.py +++ b/tests/script_annotations/pydantic_io_invalid_outputs.py @@ -3,7 +3,7 @@ from hera.shared import global_config from hera.workflows import Parameter, Workflow, script -from hera.workflows.io.v1 import RunnerOutput +from hera.workflows.io.v1 import Output try: from typing import Annotated # type: ignore @@ -14,7 +14,7 @@ global_config.experimental_features["script_pydantic_io"] = True -class ParamOnlyOutput(RunnerOutput): +class ParamOnlyOutput(Output): my_output_str: str = "my-default-str" another_output: Annotated[Path, Parameter(name="second-output")] diff --git a/tests/script_annotations/pydantic_io_v1.py b/tests/script_annotations/pydantic_io_v1.py index 69167454c..85d8a97f6 100644 --- a/tests/script_annotations/pydantic_io_v1.py +++ b/tests/script_annotations/pydantic_io_v1.py @@ -2,7 +2,7 @@ from typing import List from hera.workflows import Artifact, ArtifactLoader, Parameter, Workflow, script -from hera.workflows.io.v1 import RunnerInput, RunnerOutput +from hera.workflows.io.v1 import Input, Output try: from typing import Annotated # type: ignore @@ -10,13 +10,13 @@ from typing_extensions import Annotated # type: ignore -class ParamOnlyInput(RunnerInput): +class ParamOnlyInput(Input): my_int: int = 1 my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42 no_default_param: int -class ParamOnlyOutput(RunnerOutput): +class ParamOnlyOutput(Output): my_output_str: str = "my-default-str" another_output: Annotated[Path, Parameter(name="second-output")] @@ -28,14 +28,14 @@ def pydantic_io_params( pass -class ArtifactOnlyInput(RunnerInput): +class ArtifactOnlyInput(Input): my_file_artifact: Annotated[Path, Artifact(name="file-artifact")] my_int_artifact: Annotated[ int, Artifact(name="an-int-artifact", description="my desc", loader=ArtifactLoader.json) ] -class ArtifactOnlyOutput(RunnerOutput): +class ArtifactOnlyOutput(Output): an_artifact: Annotated[str, Artifact(name="artifact-output")] @@ -46,12 +46,12 @@ def pydantic_io_artifacts( pass -class BothInput(RunnerInput): +class BothInput(Input): param_int: Annotated[int, Parameter(name="param-int")] = 42 artifact_int: Annotated[int, Artifact(name="artifact-int", loader=ArtifactLoader.json)] -class BothOutput(RunnerOutput): +class BothOutput(Output): param_int: Annotated[int, Parameter(name="param-int")] artifact_int: Annotated[int, Artifact(name="artifact-int")] diff --git a/tests/script_annotations/pydantic_io_v1_strs.py b/tests/script_annotations/pydantic_io_v1_strs.py index 9068aa3e2..03f07b39f 100644 --- a/tests/script_annotations/pydantic_io_v1_strs.py +++ b/tests/script_annotations/pydantic_io_v1_strs.py @@ -4,13 +4,13 @@ try: from hera.workflows.io.v2 import ( # type: ignore - RunnerInput, - RunnerOutput, + Input, + Output, ) except ImportError: from hera.workflows.io.v1 import ( # type: ignore - RunnerInput, - RunnerOutput, + Input, + Output, ) try: @@ -19,13 +19,13 @@ from typing_extensions import Annotated # type: ignore -class ParamOnlyInput(RunnerInput): +class ParamOnlyInput(Input): my_str: str my_empty_default_str: str = "" my_annotated_str: Annotated[str, Parameter(name="alt-name")] = "hello world!" -class ParamOnlyOutput(RunnerOutput): +class ParamOnlyOutput(Output): my_output_str: str = "my-default-str" another_output: Annotated[Path, Parameter(name="second-output")] diff --git a/tests/script_annotations/pydantic_io_v2.py b/tests/script_annotations/pydantic_io_v2.py index e3753b062..012377d9a 100644 --- a/tests/script_annotations/pydantic_io_v2.py +++ b/tests/script_annotations/pydantic_io_v2.py @@ -5,13 +5,13 @@ try: from hera.workflows.io.v2 import ( # type: ignore - RunnerInput, - RunnerOutput, + Input, + Output, ) except ImportError: from hera.workflows.io.v1 import ( # type: ignore - RunnerInput, - RunnerOutput, + Input, + Output, ) try: @@ -20,13 +20,13 @@ from typing_extensions import Annotated # type: ignore -class ParamOnlyInput(RunnerInput): +class ParamOnlyInput(Input): my_int: int = 1 my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42 no_default_param: int -class ParamOnlyOutput(RunnerOutput): +class ParamOnlyOutput(Output): my_output_str: str = "my-default-str" another_output: Annotated[Path, Parameter(name="second-output")] @@ -38,14 +38,14 @@ def pydantic_io_params( pass -class ArtifactOnlyInput(RunnerInput): +class ArtifactOnlyInput(Input): my_file_artifact: Annotated[Path, Artifact(name="file-artifact")] my_int_artifact: Annotated[ int, Artifact(name="an-int-artifact", description="my desc", loader=ArtifactLoader.json) ] -class ArtifactOnlyOutput(RunnerOutput): +class ArtifactOnlyOutput(Output): an_artifact: Annotated[str, Artifact(name="artifact-output")] @@ -56,12 +56,12 @@ def pydantic_io_artifacts( pass -class BothInput(RunnerInput): +class BothInput(Input): param_int: Annotated[int, Parameter(name="param-int")] = 42 artifact_int: Annotated[int, Artifact(name="artifact-int", loader=ArtifactLoader.json)] -class BothOutput(RunnerOutput): +class BothOutput(Output): param_int: Annotated[int, Parameter(name="param-int")] artifact_int: Annotated[int, Artifact(name="artifact-int")] diff --git a/tests/script_annotations/pydantic_io_v2_strs.py b/tests/script_annotations/pydantic_io_v2_strs.py index d8cb434d9..03f07b39f 100644 --- a/tests/script_annotations/pydantic_io_v2_strs.py +++ b/tests/script_annotations/pydantic_io_v2_strs.py @@ -1,7 +1,17 @@ from pathlib import Path from hera.workflows import Parameter, Workflow, script -from hera.workflows.io.v1 import RunnerInput, RunnerOutput + +try: + from hera.workflows.io.v2 import ( # type: ignore + Input, + Output, + ) +except ImportError: + from hera.workflows.io.v1 import ( # type: ignore + Input, + Output, + ) try: from typing import Annotated # type: ignore @@ -9,13 +19,13 @@ from typing_extensions import Annotated # type: ignore -class ParamOnlyInput(RunnerInput): +class ParamOnlyInput(Input): my_str: str my_empty_default_str: str = "" my_annotated_str: Annotated[str, Parameter(name="alt-name")] = "hello world!" -class ParamOnlyOutput(RunnerOutput): +class ParamOnlyOutput(Output): my_output_str: str = "my-default-str" another_output: Annotated[Path, Parameter(name="second-output")] diff --git a/tests/script_runner/pydantic_io_v1.py b/tests/script_runner/pydantic_io_v1.py index 585bc2c95..c4f4f4256 100644 --- a/tests/script_runner/pydantic_io_v1.py +++ b/tests/script_runner/pydantic_io_v1.py @@ -5,7 +5,7 @@ from hera.shared import global_config from hera.workflows import Artifact, ArtifactLoader, Parameter, script -from hera.workflows.io.v1 import RunnerInput, RunnerOutput +from hera.workflows.io.v1 import Input, Output try: from pydantic.v1 import BaseModel @@ -21,14 +21,14 @@ global_config.experimental_features["script_pydantic_io"] = True -class ParamOnlyInput(RunnerInput): +class ParamOnlyInput(Input): my_required_int: int my_int: int = 1 my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42 my_ints: Annotated[List[int], Parameter(name="multiple-ints")] = [] -class ParamOnlyOutput(RunnerOutput): +class ParamOnlyOutput(Output): my_output_str: str = "my-default-str" annotated_str: Annotated[str, Parameter(name="second-output")] @@ -79,7 +79,7 @@ class MyArtifact(BaseModel): b: str = "b" -class ArtifactOnlyInput(RunnerInput): +class ArtifactOnlyInput(Input): json_artifact: Annotated[ MyArtifact, Artifact(name="json-artifact", path=ARTIFACT_PATH + "/json", loader=ArtifactLoader.json) ] @@ -92,7 +92,7 @@ class ArtifactOnlyInput(RunnerInput): ] -class ArtifactOnlyOutput(RunnerOutput): +class ArtifactOnlyOutput(Output): an_artifact: Annotated[str, Artifact(name="artifact-str-output")] diff --git a/tests/script_runner/pydantic_io_v2.py b/tests/script_runner/pydantic_io_v2.py index 98a03cd7e..bd2a2dbf3 100644 --- a/tests/script_runner/pydantic_io_v2.py +++ b/tests/script_runner/pydantic_io_v2.py @@ -8,9 +8,9 @@ from hera.workflows import Artifact, ArtifactLoader, Parameter, script try: - from hera.workflows.io.v2 import RunnerInput, RunnerOutput + from hera.workflows.io.v2 import Input, Output except ImportError: - from hera.workflows.io.v1 import RunnerInput, RunnerOutput + from hera.workflows.io.v1 import Input, Output try: from typing import Annotated # type: ignore @@ -21,14 +21,14 @@ global_config.experimental_features["script_pydantic_io"] = True -class ParamOnlyInput(RunnerInput): +class ParamOnlyInput(Input): my_required_int: int my_int: int = 1 my_annotated_int: Annotated[int, Parameter(name="another-int", description="my desc")] = 42 my_ints: Annotated[List[int], Parameter(name="multiple-ints")] = [] -class ParamOnlyOutput(RunnerOutput): +class ParamOnlyOutput(Output): my_output_str: str = "my-default-str" annotated_str: Annotated[str, Parameter(name="second-output")] @@ -79,7 +79,7 @@ class MyArtifact(BaseModel): b: str = "b" -class ArtifactOnlyInput(RunnerInput): +class ArtifactOnlyInput(Input): json_artifact: Annotated[ MyArtifact, Artifact(name="json-artifact", path=ARTIFACT_PATH + "/json", loader=ArtifactLoader.json) ] @@ -92,7 +92,7 @@ class ArtifactOnlyInput(RunnerInput): ] -class ArtifactOnlyOutput(RunnerOutput): +class ArtifactOnlyOutput(Output): an_artifact: Annotated[str, Artifact(name="artifact-str-output")] diff --git a/tests/test_runner.py b/tests/test_runner.py index 14c574d45..25b709cb4 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -22,7 +22,7 @@ from hera.shared._pydantic import _PYDANTIC_VERSION from hera.shared.serialization import serialize from hera.workflows._runner.util import _run, _runner -from hera.workflows.io.v1 import RunnerOutput +from hera.workflows.io.v1 import Output @pytest.mark.parametrize( @@ -736,7 +736,7 @@ def test_runner_pydantic_output_params( output = _runner(entrypoint, []) # THEN - assert isinstance(output, RunnerOutput) + assert isinstance(output, Output) for file in expected_files: assert Path(tmp_path / file["subpath"]).is_file() assert Path(tmp_path / file["subpath"]).read_text() == file["value"] @@ -838,7 +838,7 @@ def test_runner_pydantic_output_artifacts( output = _runner(entrypoint, []) # THEN - assert isinstance(output, RunnerOutput) + assert isinstance(output, Output) for file in expected_files: assert Path(tmp_path / file["subpath"]).is_file() assert Path(tmp_path / file["subpath"]).read_text() == file["value"] diff --git a/tests/test_script_annotations.py b/tests/test_script_annotations.py index a756dd246..40cca57b6 100644 --- a/tests/test_script_annotations.py +++ b/tests/test_script_annotations.py @@ -373,7 +373,7 @@ def test_script_pydantic_invalid_outputs(global_config_fixture): with pytest.raises(ValueError) as e: workflow.to_dict() - assert "RunnerOutput cannot be part of a tuple output" in str(e.value) + assert "Output cannot be part of a tuple output" in str(e.value) def test_script_pydantic_multiple_inputs(global_config_fixture): @@ -392,7 +392,7 @@ def test_script_pydantic_multiple_inputs(global_config_fixture): with pytest.raises(SyntaxError) as e: workflow.to_dict() - assert "Only one function parameter can be specified when using a RunnerInput" in str(e.value) + assert "Only one function parameter can be specified when using an Input" in str(e.value) def test_script_pydantic_without_experimental_flag(global_config_fixture): diff --git a/tests/test_unit/test_script_annotations_util.py b/tests/test_unit/test_script_annotations_util.py index 991f5145d..1191d4626 100644 --- a/tests/test_unit/test_script_annotations_util.py +++ b/tests/test_unit/test_script_annotations_util.py @@ -12,7 +12,7 @@ map_runner_input, ) from hera.workflows.artifact import Artifact, ArtifactLoader -from hera.workflows.io import RunnerInput +from hera.workflows.io import Input from hera.workflows.models import ValueFrom from hera.workflows.parameter import Parameter @@ -142,7 +142,7 @@ def test_get_annotated_artifact_value_path_outputs( def test_map_runner_input(): - class MyInput(RunnerInput): + class MyInput(Input): a_str: str an_int: int a_dict: dict @@ -165,7 +165,7 @@ class MyInput(RunnerInput): def test_map_runner_input_strings(): """Test the parsing logic when str type fields are passed json-serialized strings.""" - class MyInput(RunnerInput): + class MyInput(Input): a_dict_str: str a_list_str: str diff --git a/tests/workflow_decorators/multiple_entrypoints.py b/tests/workflow_decorators/multiple_entrypoints.py index 424f29f1a..59300569a 100644 --- a/tests/workflow_decorators/multiple_entrypoints.py +++ b/tests/workflow_decorators/multiple_entrypoints.py @@ -1,5 +1,5 @@ from hera.shared import global_config -from hera.workflows import RunnerInput, RunnerOutput, WorkflowTemplate +from hera.workflows import Input, Output, WorkflowTemplate global_config.experimental_features["script_annotations"] = True global_config.experimental_features["script_pydantic_io"] = True @@ -7,21 +7,21 @@ w = WorkflowTemplate(name="my-template") -class MyInput(RunnerInput): +class MyInput(Input): user: str @w.set_entrypoint @w.script() -def hello_world(my_input: MyInput) -> RunnerOutput: - output = RunnerOutput() +def hello_world(my_input: MyInput) -> Output: + output = Output() output.result = f"Hello Hera User: {my_input.user}!" return output @w.set_entrypoint @w.script() -def hello_world_2(my_input: MyInput) -> RunnerOutput: - output = RunnerOutput() +def hello_world_2(my_input: MyInput) -> Output: + output = Output() output.result = f"Hello Hera User: {my_input.user}!" return output diff --git a/tests/workflow_decorators/set_entrypoint.py b/tests/workflow_decorators/set_entrypoint.py index a3bc699cc..1cdf77320 100644 --- a/tests/workflow_decorators/set_entrypoint.py +++ b/tests/workflow_decorators/set_entrypoint.py @@ -1,5 +1,5 @@ from hera.shared import global_config -from hera.workflows import RunnerInput, RunnerOutput, WorkflowTemplate +from hera.workflows import Input, Output, WorkflowTemplate global_config.experimental_features["script_annotations"] = True global_config.experimental_features["script_pydantic_io"] = True @@ -7,13 +7,13 @@ w = WorkflowTemplate(name="my-template") -class MyInput(RunnerInput): +class MyInput(Input): user: str @w.script() -def hello_world(my_input: MyInput) -> RunnerOutput: - output = RunnerOutput() +def hello_world(my_input: MyInput) -> Output: + output = Output() output.result = f"Hello Hera User: {my_input.user}!" return output @@ -21,7 +21,7 @@ def hello_world(my_input: MyInput) -> RunnerOutput: # Pass script kwargs (including an alternative public template name) in the decorator @w.set_entrypoint @w.script(name="goodbye-world", labels={"my-label": "my-value"}) -def goodbye(my_input: MyInput) -> RunnerOutput: - output = RunnerOutput() +def goodbye(my_input: MyInput) -> Output: + output = Output() output.result = f"Goodbye Hera User: {my_input.user}!" return output