From 614d364613c184d56b150581873b2caa1ad75c55 Mon Sep 17 00:00:00 2001 From: Flaviu Vadan Date: Thu, 23 Nov 2023 14:02:05 -0800 Subject: [PATCH] Implement get_parameter API on IO mixin and add tests (#876) **Pull Request Checklist** - [x] Fixes #816 - [x] Tests added - [x] Documentation/examples added - [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR title Implements the `get_parameter` API on the `IOMixin`. This allows clients that need parameters in `arguments` fields to get the parameter object from DAGs, Steps, etc. This helps avoid the need to write `{{inputs.parameters.whatever}}` explicitly. `get_parameter` _assumes_ that clients want to use a parameter as input, so the value field is set accordingly Signed-off-by: Flaviu Vadan --- .../workflows/callable_dag_with_param_get.md | 101 ++++++++++++++++++ .../loops_arbitrary_sequential_steps.md | 2 +- .../workflows/upstream/parallelism_nested.md | 8 +- .../callable-dag-with-param-get.yaml | 54 ++++++++++ .../workflows/callable_dag_with_param_get.py | 30 ++++++ .../loops_arbitrary_sequential_steps.py | 2 +- .../workflows/upstream/parallelism_nested.py | 8 +- src/hera/workflows/_mixins.py | 26 +++++ src/hera/workflows/dag.py | 1 + src/hera/workflows/parameter.py | 15 ++- tests/test_unit/test_mixins.py | 47 +++++++- 11 files changed, 277 insertions(+), 17 deletions(-) create mode 100644 docs/examples/workflows/callable_dag_with_param_get.md create mode 100644 examples/workflows/callable-dag-with-param-get.yaml create mode 100644 examples/workflows/callable_dag_with_param_get.py diff --git a/docs/examples/workflows/callable_dag_with_param_get.md b/docs/examples/workflows/callable_dag_with_param_get.md new file mode 100644 index 000000000..a54d35bc0 --- /dev/null +++ b/docs/examples/workflows/callable_dag_with_param_get.md @@ -0,0 +1,101 @@ +# Callable Dag With Param Get + + + + + + +=== "Hera" + + ```python linenums="1" + from typing_extensions import Annotated + + from hera.workflows import DAG, Parameter, Workflow, script + + + @script(constructor="runner") + def hello_with_output(name: str) -> Annotated[str, Parameter(name="output-message")]: + return "Hello, {name}!".format(name=name) + + + with Workflow( + generate_name="callable-dag-", + entrypoint="calling-dag", + ) as w: + with DAG( + name="my-dag-with-outputs", + inputs=Parameter(name="my-dag-input"), + outputs=Parameter( + name="my-dag-output", + value_from={"parameter": "{{hello.outputs.parameters.output-message}}"}, + ), + ) as my_dag: + # Here, get_parameter searches through the *inputs* of my_dag + hello_with_output(name="hello", arguments={"name": f"hello {my_dag.get_parameter('my-dag-input')}"}) + + with DAG(name="calling-dag") as d: + t1 = my_dag(name="call-1", arguments={"my-dag-input": "call-1"}) + # Here, t1 is a Task from the called dag, so get_parameter is called on the Task to get the output parameter! 🚀 + t2 = my_dag(name="call-2", arguments=t1.get_parameter("my-dag-output").with_name("my-dag-input")) + t1 >> t2 + ``` + +=== "YAML" + + ```yaml linenums="1" + apiVersion: argoproj.io/v1alpha1 + kind: Workflow + metadata: + generateName: callable-dag- + spec: + entrypoint: calling-dag + templates: + - dag: + tasks: + - arguments: + parameters: + - name: name + value: hello {{inputs.parameters.my-dag-input}} + name: hello + template: hello-with-output + inputs: + parameters: + - name: my-dag-input + name: my-dag-with-outputs + outputs: + parameters: + - name: my-dag-output + valueFrom: + parameter: '{{hello.outputs.parameters.output-message}}' + - inputs: + parameters: + - name: name + name: hello-with-output + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.callable_dag_with_param_get:hello_with_output + command: + - python + image: python:3.8 + source: '{{inputs.parameters}}' + - dag: + tasks: + - arguments: + parameters: + - name: my-dag-input + value: call-1 + name: call-1 + template: my-dag-with-outputs + - arguments: + parameters: + - name: my-dag-input + value: '{{tasks.call-1.outputs.parameters.my-dag-output}}' + depends: call-1 + name: call-2 + template: my-dag-with-outputs + name: calling-dag + ``` + diff --git a/docs/examples/workflows/upstream/loops_arbitrary_sequential_steps.md b/docs/examples/workflows/upstream/loops_arbitrary_sequential_steps.md index 8a8aafbc3..62e454523 100644 --- a/docs/examples/workflows/upstream/loops_arbitrary_sequential_steps.md +++ b/docs/examples/workflows/upstream/loops_arbitrary_sequential_steps.md @@ -52,7 +52,7 @@ The upstream example can be [found here](https://github.com/argoproj/argo-workfl "exit_code": f"{g.item.exit_code:$}", "message": f"{g.item.message:$}", }, - with_param="{{inputs.parameters.step_params}}", + with_param=s.get_parameter("step_params"), ) ``` diff --git a/docs/examples/workflows/upstream/parallelism_nested.md b/docs/examples/workflows/upstream/parallelism_nested.md index 01df5d2c1..8207920f3 100644 --- a/docs/examples/workflows/upstream/parallelism_nested.md +++ b/docs/examples/workflows/upstream/parallelism_nested.md @@ -36,10 +36,10 @@ The upstream example can be [found here](https://github.com/argoproj/argo-workfl one_job( name="seq-step", arguments=[ - Parameter(name="parallel-id", value="{{inputs.parameters.parallel-id}}"), + seq_worker.get_parameter("parallel-id"), Parameter(name="seq-id", value="{{item}}"), ], - with_param="{{inputs.parameters.seq-list}}", + with_param=seq_worker.get_parameter("seq-list"), ) with Steps( @@ -49,10 +49,10 @@ The upstream example can be [found here](https://github.com/argoproj/argo-workfl name="parallel-worker", template=seq_worker, arguments=[ - Parameter(name="seq-list", value="{{inputs.parameters.seq-list}}"), + seq_worker.get_parameter("seq-list"), Parameter(name="parallel-id", value="{{item}}"), ], - with_param="{{inputs.parameters.parallel-list}}", + with_param=parallel_worker.get_parameter("parallel-list"), ) ``` diff --git a/examples/workflows/callable-dag-with-param-get.yaml b/examples/workflows/callable-dag-with-param-get.yaml new file mode 100644 index 000000000..e0fde2fed --- /dev/null +++ b/examples/workflows/callable-dag-with-param-get.yaml @@ -0,0 +1,54 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: callable-dag- +spec: + entrypoint: calling-dag + templates: + - dag: + tasks: + - arguments: + parameters: + - name: name + value: hello {{inputs.parameters.my-dag-input}} + name: hello + template: hello-with-output + inputs: + parameters: + - name: my-dag-input + name: my-dag-with-outputs + outputs: + parameters: + - name: my-dag-output + valueFrom: + parameter: '{{hello.outputs.parameters.output-message}}' + - inputs: + parameters: + - name: name + name: hello-with-output + script: + args: + - -m + - hera.workflows.runner + - -e + - examples.workflows.callable_dag_with_param_get:hello_with_output + command: + - python + image: python:3.8 + source: '{{inputs.parameters}}' + - dag: + tasks: + - arguments: + parameters: + - name: my-dag-input + value: call-1 + name: call-1 + template: my-dag-with-outputs + - arguments: + parameters: + - name: my-dag-input + value: '{{tasks.call-1.outputs.parameters.my-dag-output}}' + depends: call-1 + name: call-2 + template: my-dag-with-outputs + name: calling-dag diff --git a/examples/workflows/callable_dag_with_param_get.py b/examples/workflows/callable_dag_with_param_get.py new file mode 100644 index 000000000..4e6c3acce --- /dev/null +++ b/examples/workflows/callable_dag_with_param_get.py @@ -0,0 +1,30 @@ +from typing_extensions import Annotated + +from hera.workflows import DAG, Parameter, Workflow, script + + +@script(constructor="runner") +def hello_with_output(name: str) -> Annotated[str, Parameter(name="output-message")]: + return "Hello, {name}!".format(name=name) + + +with Workflow( + generate_name="callable-dag-", + entrypoint="calling-dag", +) as w: + with DAG( + name="my-dag-with-outputs", + inputs=Parameter(name="my-dag-input"), + outputs=Parameter( + name="my-dag-output", + value_from={"parameter": "{{hello.outputs.parameters.output-message}}"}, + ), + ) as my_dag: + # Here, get_parameter searches through the *inputs* of my_dag + hello_with_output(name="hello", arguments={"name": f"hello {my_dag.get_parameter('my-dag-input')}"}) + + with DAG(name="calling-dag") as d: + t1 = my_dag(name="call-1", arguments={"my-dag-input": "call-1"}) + # Here, t1 is a Task from the called dag, so get_parameter is called on the Task to get the output parameter! 🚀 + t2 = my_dag(name="call-2", arguments=t1.get_parameter("my-dag-output").with_name("my-dag-input")) + t1 >> t2 diff --git a/examples/workflows/upstream/loops_arbitrary_sequential_steps.py b/examples/workflows/upstream/loops_arbitrary_sequential_steps.py index f6f859c8a..8849aae13 100644 --- a/examples/workflows/upstream/loops_arbitrary_sequential_steps.py +++ b/examples/workflows/upstream/loops_arbitrary_sequential_steps.py @@ -39,5 +39,5 @@ "exit_code": f"{g.item.exit_code:$}", "message": f"{g.item.message:$}", }, - with_param="{{inputs.parameters.step_params}}", + with_param=s.get_parameter("step_params"), ) diff --git a/examples/workflows/upstream/parallelism_nested.py b/examples/workflows/upstream/parallelism_nested.py index 90dd045c4..282a71a28 100644 --- a/examples/workflows/upstream/parallelism_nested.py +++ b/examples/workflows/upstream/parallelism_nested.py @@ -23,10 +23,10 @@ one_job( name="seq-step", arguments=[ - Parameter(name="parallel-id", value="{{inputs.parameters.parallel-id}}"), + seq_worker.get_parameter("parallel-id"), Parameter(name="seq-id", value="{{item}}"), ], - with_param="{{inputs.parameters.seq-list}}", + with_param=seq_worker.get_parameter("seq-list"), ) with Steps( @@ -36,8 +36,8 @@ name="parallel-worker", template=seq_worker, arguments=[ - Parameter(name="seq-list", value="{{inputs.parameters.seq-list}}"), + seq_worker.get_parameter("seq-list"), Parameter(name="parallel-id", value="{{item}}"), ], - with_param="{{inputs.parameters.parallel-list}}", + with_param=parallel_worker.get_parameter("parallel-list"), ) diff --git a/src/hera/workflows/_mixins.py b/src/hera/workflows/_mixins.py index db6d5529f..5f7475307 100644 --- a/src/hera/workflows/_mixins.py +++ b/src/hera/workflows/_mixins.py @@ -274,6 +274,32 @@ class IOMixin(BaseMixin): inputs: InputsT = None outputs: OutputsT = None + def get_parameter(self, name: str) -> Parameter: + """Finds and returns the parameter with the supplied name. + + Note that this method will raise an error if the parameter is not found. + + Args: + name: name of the input parameter to find and return. + + Returns: + Parameter: the parameter with the supplied name. + + Raises: + KeyError: if the parameter is not found. + """ + inputs = self._build_inputs() + if inputs is None: + raise KeyError(f"No inputs set. Parameter {name} not found.") + if inputs.parameters is None: + raise KeyError(f"No parameters set. Parameter {name} not found.") + for p in inputs.parameters: + if p.name == name: + param = Parameter.from_model(p) + param.value = f"{{{{inputs.parameters.{param.name}}}}}" + return param + raise KeyError(f"Parameter {name} not found.") + def _build_inputs(self) -> Optional[ModelInputs]: """Processes the `inputs` field and returns a generated `ModelInputs`.""" if self.inputs is None: diff --git a/src/hera/workflows/dag.py b/src/hera/workflows/dag.py index 4bbbfd25e..3a5cd2702 100644 --- a/src/hera/workflows/dag.py +++ b/src/hera/workflows/dag.py @@ -32,6 +32,7 @@ class DAG( >>> @script() >>> def foo() -> None: >>> print(42) + >>> >>> with DAG(...) as dag: >>> foo() """ diff --git a/src/hera/workflows/parameter.py b/src/hera/workflows/parameter.py index 7e5f7a726..5ad92abb8 100644 --- a/src/hera/workflows/parameter.py +++ b/src/hera/workflows/parameter.py @@ -51,6 +51,11 @@ def _check_values(cls, values): return values + @classmethod + def _get_input_attributes(cls): + """Return the attributes used for input parameter annotations.""" + return ["enum", "description", "default", "name", "value", "value_from"] + def __str__(self): """Represent the parameter as a string by pointing to its value. @@ -61,6 +66,11 @@ def __str__(self): raise ValueError("Cannot represent `Parameter` as string as `value` is not set") return self.value + @classmethod + def from_model(cls, model: _ModelParameter) -> Parameter: + """Creates a `Parameter` from a `Parameter` model.""" + return cls(**model.dict()) + def with_name(self, name: str) -> Parameter: """Returns a copy of the parameter with the name set to the value.""" p = self.copy(deep=True) @@ -108,10 +118,5 @@ def as_output(self) -> _ModelParameter: value_from=self.value_from, ) - @classmethod - def _get_input_attributes(cls): - """Return the attributes used for input parameter annotations.""" - return ["enum", "description", "default", "name", "value", "value_from"] - __all__ = ["Parameter"] diff --git a/tests/test_unit/test_mixins.py b/tests/test_unit/test_mixins.py index 62471e7bb..71e65b32f 100644 --- a/tests/test_unit/test_mixins.py +++ b/tests/test_unit/test_mixins.py @@ -1,5 +1,11 @@ -from hera.workflows._mixins import ContainerMixin -from hera.workflows.models import ImagePullPolicy +import pytest + +from hera.workflows import Parameter +from hera.workflows._mixins import ContainerMixin, IOMixin +from hera.workflows.models import ( + ImagePullPolicy, + Inputs as ModelInputs, +) class TestContainerMixin: @@ -10,3 +16,40 @@ def test_build_image_pull_policy(self) -> None: == ImagePullPolicy.always ) assert ContainerMixin()._build_image_pull_policy() is None + + +class TestIOMixin: + @pytest.fixture(autouse=True) + def setup(self): + self.io_mixin = IOMixin() + + def test_get_parameter_success(self): + self.io_mixin.inputs = ModelInputs(parameters=[Parameter(name="test", value="value")]) + param = self.io_mixin.get_parameter("test") + assert param.name == "test" + assert param.value == "{{inputs.parameters.test}}" + + def test_get_parameter_no_inputs(self): + with pytest.raises(KeyError): + self.io_mixin.get_parameter("test") + + def test_get_parameter_no_parameters(self): + self.io_mixin.inputs = ModelInputs() + with pytest.raises(KeyError): + self.io_mixin.get_parameter("test") + + def test_get_parameter_not_found(self): + self.io_mixin.inputs = ModelInputs(parameters=[Parameter(name="test", value="value")]) + with pytest.raises(KeyError): + self.io_mixin.get_parameter("not_exist") + + def test_build_inputs_none(self): + assert self.io_mixin._build_inputs() is None + + def test_build_inputs_from_model_inputs(self): + model_inputs = ModelInputs(parameters=[Parameter(name="test", value="value")]) + self.io_mixin.inputs = model_inputs + assert self.io_mixin._build_inputs() == model_inputs + + def test_build_outputs_none(self): + assert self.io_mixin._build_outputs() is None