Skip to content

Commit

Permalink
Refactor RunnerIO classes (#1060)
Browse files Browse the repository at this point in the history
* Use a new pair of mixin classes to hold the functions that are
compatible with Pydantic V1 and V2, meaning we no longer need to
maintain two sets of classes for each Pydantic version
* "RunnerIO" classes to be renamed to `Input` and `Output`, keeping the
old name as an alias for one release so we can announce the change

This PR splits out parts of commits from
#1059 that only refactored
RunnerIO

---------

Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton authored May 8, 2024
1 parent b841f6c commit 345e687
Show file tree
Hide file tree
Showing 31 changed files with 393 additions and 325 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,30 @@

```python linenums="1"
from hera.shared import global_config
from hera.workflows import RunnerInput, RunnerOutput, WorkflowTemplate
from hera.workflows import Input, Output, WorkflowTemplate

global_config.experimental_features["script_annotations"] = True
global_config.experimental_features["script_pydantic_io"] = True

w = WorkflowTemplate(name="my-template")


class MyInput(RunnerInput):
class MyInput(Input):
user: str


@w.script()
def hello_world(my_input: MyInput) -> RunnerOutput:
output = RunnerOutput()
def hello_world(my_input: MyInput) -> Output:
output = Output()
output.result = f"Hello Hera User: {my_input.user}!"
return output


# Pass script kwargs (including an alternative public template name) in the decorator
@w.set_entrypoint
@w.script(name="goodbye-world", labels={"my-label": "my-value"})
def goodbye(my_input: MyInput) -> RunnerOutput:
output = RunnerOutput()
def goodbye(my_input: MyInput) -> Output:
output = Output()
output.result = f"Goodbye Hera User: {my_input.user}!"
return output
```
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/workflows/experimental/script_runner_io.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from hera.shared import global_config
from hera.workflows import Artifact, ArtifactLoader, Parameter, Steps, Workflow, script
from hera.workflows.archive import NoneArchiveStrategy
from hera.workflows.io import RunnerInput, RunnerOutput
from hera.workflows.io import Input, Output

try:
from typing import Annotated # type: ignore
Expand All @@ -32,15 +32,15 @@
a_str: str = "a default string"


class MyInput(RunnerInput):
class MyInput(Input):
param_int: Annotated[int, Parameter(name="param-input")] = 42
an_object: Annotated[MyObject, Parameter(name="obj-input")] = MyObject(
a_dict={"my-key": "a-value"}, a_str="hello world!"
)
artifact_int: Annotated[int, Artifact(name="artifact-input", loader=ArtifactLoader.json)]


class MyOutput(RunnerOutput):
class MyOutput(Output):
param_int: Annotated[int, Parameter(name="param-output")]
artifact_int: Annotated[int, Artifact(name="artifact-output")]

Expand Down
27 changes: 15 additions & 12 deletions docs/user-guides/script-runner-io.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Script Runner IO

Hera provides the `RunnerInput` and `RunnerOutput` Pydantic classes which can be used to more succinctly write your
> ⚠️ The `RunnerInput` and `RunnerOutput` classes are deprecated since `v5.16.0`, please use `Input` and `Output` for
> equivalent functionality. They will be removed in `v5.17.0`.
Hera provides the `Input` and `Output` Pydantic classes which can be used to more succinctly write your
script function inputs and outputs, and requires use of the Hera Runner. Use of these classes also requires the
`"script_pydantic_io"` experimental feature flag to be enabled:

Expand All @@ -10,17 +13,17 @@ global_config.experimental_features["script_pydantic_io"] = True

## Pydantic V1 or V2?

You can import `RunnerInput` and `RunnerOutput` from the `hera.workflows.io` submodule to import the version of Pydantic
You can import `Input` and `Output` from the `hera.workflows.io` submodule to import the version of Pydantic
that matches your V1 or V2 installation.

If you need to use V1 models when you have V2 installed, you should import
`RunnerInput` and `RunnerOutput` from the `hera.workflows.io.v1` or `hera.workflows.io.v2` module explicitly. The V2
`Input` and `Output` from the `hera.workflows.io.v1` or `hera.workflows.io.v2` module explicitly. The V2
models will not be available if you have installed `pydantic<2`, but the V1 models are usable for either version,
allowing you to migrate at your own pace.

## Script inputs using `RunnerInput`
## Script inputs using `Input`

For your script inputs, you can create a derived class of `RunnerInput`, and declare all your input parameters (and
For your script inputs, you can create a derived class of `Input`, and declare all your input parameters (and
artifacts) as fields of the class. If you want to use `Annotated` to declare `Artifacts` add metadata to your
`Parameters`, you will also need to enable the `"script_annotations"` experimental feature flag.

Expand All @@ -29,15 +32,15 @@ from typing import Annotated
from pydantic import BaseModel

from hera.workflows import Artifact, ArtifactLoader, Parameter, script
from hera.workflows.io import RunnerInput
from hera.workflows.io import Input


class MyObject(BaseModel):
a_dict: dict
a_str: str = "a default string"


class MyInput(RunnerInput):
class MyInput(Input):
param_int: Annotated[int, Parameter(name="param-input")] = 42
an_object: Annotated[MyObject, Parameter(name="obj-input")] = MyObject(
a_dict={"my-key": "a-value"}, a_str="hello world!"
Expand Down Expand Up @@ -72,26 +75,26 @@ template will also have the `"artifact-input"` artifact. The yaml generated from
...
```
## Script outputs using `RunnerOutput`
## Script outputs using `Output`

The `RunnerOutput` class comes with two special variables, `exit_code` and `result`. The `exit_code` is used to exit the
The `Output` class comes with two special variables, `exit_code` and `result`. The `exit_code` is used to exit the
container when running on Argo with the specific exit code - it is set to `0` by default. The `result` is used to print
any serializable object to stdout, which means you can now use `.result` on tasks or steps that use a "runner
constructor" script - you should be mindful of printing/logging anything else to stdout, which will stop the `result`
functionality working as intended. If you want an output parameters/artifacts with the name `exit_code` or `result`, you
can declare another field with an annotation of that name, e.g.
`my_exit_code: Annotated[int, Parameter(name="exit_code")]`.

Aside from the `exit_code` and `result`, the `RunnerOutput` behaves exactly like the `RunnerInput`:
Aside from the `exit_code` and `result`, the `Output` behaves exactly like the `Input`:

```py
from typing import Annotated
from hera.workflows import Artifact, Parameter, script
from hera.workflows.io import RunnerOutput
from hera.workflows.io import Output
class MyOutput(RunnerOutput):
class MyOutput(Output):
param_int: Annotated[int, Parameter(name="param-output")]
artifact_int: Annotated[int, Artifact(name="artifact-output")]
Expand Down
4 changes: 2 additions & 2 deletions docs/walk-through/advanced-hera-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ Read the full guide on script annotations in [the script user guide](../user-gui
### Script IO Models

Hera provides Pydantic models for you to create subclasses from, which allow you to more easily declare script template
inputs. Any fields that you declare in your subclass of `RunnerInput` will become input parameters or artifacts, while
`RunnerOutput` fields will become output parameters artifacts. The fields that you declare can be `Annotated` as a
inputs. Any fields that you declare in your subclass of `Input` will become input parameters or artifacts, while
`Output` fields will become output parameters artifacts. The fields that you declare can be `Annotated` as a
`Parameter` or `Artifact`, as any fields with a basic type will become `Parameters` - you will also need the
`script_annotations` experimental feature enabled.

Expand Down
12 changes: 6 additions & 6 deletions examples/workflows/experimental/new_decorators_basic_script.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
from hera.shared import global_config
from hera.workflows import RunnerInput, RunnerOutput, WorkflowTemplate
from hera.workflows import Input, Output, WorkflowTemplate

global_config.experimental_features["script_annotations"] = True
global_config.experimental_features["script_pydantic_io"] = True

w = WorkflowTemplate(name="my-template")


class MyInput(RunnerInput):
class MyInput(Input):
user: str


@w.script()
def hello_world(my_input: MyInput) -> RunnerOutput:
output = RunnerOutput()
def hello_world(my_input: MyInput) -> Output:
output = Output()
output.result = f"Hello Hera User: {my_input.user}!"
return output


# Pass script kwargs (including an alternative public template name) in the decorator
@w.set_entrypoint
@w.script(name="goodbye-world", labels={"my-label": "my-value"})
def goodbye(my_input: MyInput) -> RunnerOutput:
output = RunnerOutput()
def goodbye(my_input: MyInput) -> Output:
output = Output()
output.result = f"Goodbye Hera User: {my_input.user}!"
return output
6 changes: 3 additions & 3 deletions examples/workflows/experimental/script_runner_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from hera.shared import global_config
from hera.workflows import Artifact, ArtifactLoader, Parameter, Steps, Workflow, script
from hera.workflows.archive import NoneArchiveStrategy
from hera.workflows.io import RunnerInput, RunnerOutput
from hera.workflows.io import Input, Output

try:
from typing import Annotated # type: ignore
Expand All @@ -22,15 +22,15 @@ class MyObject(BaseModel):
a_str: str = "a default string"


class MyInput(RunnerInput):
class MyInput(Input):
param_int: Annotated[int, Parameter(name="param-input")] = 42
an_object: Annotated[MyObject, Parameter(name="obj-input")] = MyObject(
a_dict={"my-key": "a-value"}, a_str="hello world!"
)
artifact_int: Annotated[int, Artifact(name="artifact-input", loader=ArtifactLoader.json)]


class MyOutput(RunnerOutput):
class MyOutput(Output):
param_int: Annotated[int, Parameter(name="param-output")]
artifact_int: Annotated[int, Artifact(name="artifact-output")]

Expand Down
File renamed without changes.
33 changes: 24 additions & 9 deletions src/hera/shared/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""Module that holds the underlying base Pydantic models for Hera objects."""

from collections import ChainMap
from typing import TYPE_CHECKING, Any, Dict, Type

from pydantic import VERSION

try:
from inspect import get_annotations # type: ignore
except ImportError:
from hera.shared._inspect import get_annotations # type: ignore


_PYDANTIC_VERSION: int = int(VERSION.split(".")[0])
# The pydantic v1 interface is used for both pydantic v1 and v2 in order to support
# users across both versions.
Expand All @@ -28,29 +35,26 @@
# native pydantic hinting for `__init__` arguments.
if TYPE_CHECKING:
from pydantic import BaseModel as PydanticBaseModel
from pydantic.fields import FieldInfo
else:
if _PYDANTIC_VERSION == 2:
from pydantic.v1 import BaseModel as PydanticBaseModel # type: ignore
from pydantic.v1.fields import FieldInfo
else:
from pydantic import BaseModel as PydanticBaseModel # type: ignore[assignment,no-redef]
from pydantic.fields import FieldInfo


def get_fields(cls: Type[PydanticBaseModel]) -> Dict[str, Any]:
def get_fields(cls: Type[PydanticBaseModel]) -> Dict[str, FieldInfo]:
"""Centralize access to __fields__."""
try:
return cls.model_fields # type: ignore
except AttributeError:
return cls.__fields__ # type: ignore


__all__ = [
"BaseModel",
"Field",
"PydanticBaseModel", # Export for serialization.py to cover user-defined models
"ValidationError",
"root_validator",
"validator",
]
def get_field_annotations(cls: Type[PydanticBaseModel]) -> Dict[str, Any]:
return {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}


class BaseModel(PydanticBaseModel):
Expand All @@ -74,3 +78,14 @@ class Config:

smart_union = True
"""uses smart union for matching a field's specified value to the underlying type that's part of a union"""


__all__ = [
"BaseModel",
"Field",
"FieldInfo",
"PydanticBaseModel", # Export for serialization.py to cover user-defined models
"ValidationError",
"root_validator",
"validator",
]
4 changes: 3 additions & 1 deletion src/hera/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from hera.workflows.env_from import ConfigMapEnvFrom, SecretEnvFrom
from hera.workflows.exceptions import InvalidDispatchType, InvalidTemplateCall, InvalidType
from hera.workflows.http_template import HTTP
from hera.workflows.io import RunnerInput, RunnerOutput
from hera.workflows.io import Input, Output, RunnerInput, RunnerOutput
from hera.workflows.metrics import Counter, Gauge, Histogram, Label, Metric, Metrics
from hera.workflows.operator import Operator
from hera.workflows.parameter import Parameter
Expand Down Expand Up @@ -126,6 +126,7 @@
"HostPathVolume",
"ISCSIVolume",
"InlineScriptConstructor",
"Input",
"InvalidDispatchType",
"InvalidTemplateCall",
"InvalidType",
Expand All @@ -136,6 +137,7 @@
"NoneArchiveStrategy",
"OSSArtifact",
"Operator",
"Output",
"Parallel",
"Parameter",
"PhotonPersistentDiskVolume",
Expand Down
4 changes: 2 additions & 2 deletions src/hera/workflows/_meta_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
try:
from inspect import get_annotations # type: ignore
except ImportError:
from hera.workflows._inspect import get_annotations # type: ignore
from hera.shared._inspect import get_annotations # type: ignore

_yaml: Optional[ModuleType] = None
try:
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(self, model_path: str, hera_builder: Optional[Callable] = None):
fields = get_fields(curr_class)
if key not in fields:
raise ValueError(f"Model key '{key}' does not exist in class {curr_class}")
curr_class = fields[key].outer_type_
curr_class = fields[key].outer_type_ # type: ignore

@classmethod
def _get_model_class(cls) -> Type[BaseModel]:
Expand Down
Loading

0 comments on commit 345e687

Please sign in to comment.