Skip to content

Commit

Permalink
Fix _is_str_kwarg_of for annotated parameters (#859)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [x] Fixes #855 
- [x] Tests added
- [x] Documentation/examples added
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Currently, annotations break the `_is_str_kwarg_of` function. Adds logic
to check whether it's a built-in `str`, an `Annotated` str or a simple
subclass of `str`. It also fixes/makes explicit the logic behind doing a
json parse of the value if it is untyped.

---------

Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton authored Nov 23, 2023
1 parent 9940b48 commit 862c45b
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 11 deletions.
19 changes: 12 additions & 7 deletions src/hera/workflows/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,19 @@ def _parse(value, key, f):
def _is_str_kwarg_of(key: str, f: Callable):
"""Check if param `key` of function `f` has a type annotation of a subclass of str."""
type_ = inspect.signature(f).parameters[key].annotation
if not type_:
return True
try:
return issubclass(type_, str)
except TypeError:
# If this happens then it means that the annotation is a more complex type annotation
# and may be interpretable by the Hera runner

if type_ is inspect.Parameter.empty:
# Untyped args are interpreted according to json spec
# ie. we will try to load it via json.loads in _parse
return False
if get_origin(type_) is None:
return issubclass(type_, str)

origin_type = cast(type, get_origin(type_))
if origin_type is Annotated:
return issubclass(get_args(type_)[0], str)

return issubclass(origin_type, str)


def _is_artifact_loaded(key, f):
Expand Down
38 changes: 37 additions & 1 deletion tests/script_runner/parameter_inputs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
import json
from typing import Any, List

try:
from typing import Annotated
Expand Down Expand Up @@ -41,3 +42,38 @@ def annotated_parameter_no_name(
annotated_input_value: Annotated[Input, Parameter(description="a value to input")]
) -> Output:
return Output(output=[annotated_input_value])


@script()
def no_type_parameter(my_anything) -> Any:
"""`my_anything` will be whatever the json loader gives back."""
return my_anything


@script()
def str_parameter_expects_jsonstr_dict(my_json_str: str) -> dict:
return json.loads(my_json_str)


@script()
def str_parameter_expects_jsonstr_list(my_json_str: str) -> list:
return json.loads(my_json_str)


@script()
def annotated_str_parameter_expects_jsonstr_dict(my_json_str: Annotated[str, "some metadata"]) -> list:
return json.loads(my_json_str)


class MyStr(str):
pass


@script()
def str_subclass_parameter_expects_jsonstr_dict(my_json_str: MyStr) -> list:
return json.loads(my_json_str)


@script()
def str_subclass_annotated_parameter_expects_jsonstr_dict(my_json_str: Annotated[MyStr, "some metadata"]) -> list:
return json.loads(my_json_str)
105 changes: 102 additions & 3 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
part of a Workflow when running locally.
"""
import importlib
import json
import os
from pathlib import Path
from typing import Dict, List
Expand All @@ -20,6 +21,95 @@
from hera.workflows.script import RunnerScriptConstructor


@pytest.mark.parametrize(
"entrypoint,kwargs_list,expected_output",
(
pytest.param(
"tests.script_runner.parameter_inputs:no_type_parameter",
[{"name": "my_anything", "value": "test"}],
"test",
id="no-type-string",
),
pytest.param(
"tests.script_runner.parameter_inputs:no_type_parameter",
[{"name": "my_anything", "value": "1"}],
1,
id="no-type-int",
),
pytest.param(
"tests.script_runner.parameter_inputs:no_type_parameter",
[{"name": "my_anything", "value": "null"}],
None,
id="no-type-none",
),
pytest.param(
"tests.script_runner.parameter_inputs:no_type_parameter",
[{"name": "my_anything", "value": "true"}],
True,
id="no-type-bool",
),
pytest.param(
"tests.script_runner.parameter_inputs:no_type_parameter",
[{"name": "my_anything", "value": "[]"}],
[],
id="no-type-list",
),
pytest.param(
"tests.script_runner.parameter_inputs:no_type_parameter",
[{"name": "my_anything", "value": "{}"}],
{},
id="no-type-dict",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
{"my": "dict"},
id="str-json-param-as-dict",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_parameter_expects_jsonstr_list",
[{"name": "my_json_str", "value": json.dumps([{"my": "dict"}])}],
[{"my": "dict"}],
id="str-json-param-as-list",
),
pytest.param(
"tests.script_runner.parameter_inputs:annotated_str_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
{"my": "dict"},
id="str-json-annotated-param-as-dict",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_subclass_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
{"my": "dict"},
id="str-subclass-json-param-as-dict",
),
pytest.param(
"tests.script_runner.parameter_inputs:str_subclass_annotated_parameter_expects_jsonstr_dict",
[{"name": "my_json_str", "value": json.dumps({"my": "dict"})}],
{"my": "dict"},
id="str-subclass-json-annotated-param-as-dict",
),
),
)
def test_parameter_loading(
entrypoint,
kwargs_list: List[Dict[str, str]],
expected_output,
global_config_fixture: GlobalConfig,
environ_annotations_fixture: None,
):
# GIVEN
global_config_fixture.experimental_features["script_annotations"] = True
global_config_fixture.experimental_features["script_runner"] = True

# WHEN
output = _runner(entrypoint, kwargs_list)

# THEN
assert output == expected_output


@pytest.mark.parametrize(
"entrypoint,kwargs_list,expected_output",
[
Expand Down Expand Up @@ -74,20 +164,29 @@ def test_runner_parameter_inputs(
@pytest.mark.parametrize(
"entrypoint,kwargs_list,expected_output",
[
(
pytest.param(
"tests.script_runner.parameter_inputs:annotated_basic_types",
[{"name": "a-but-kebab", "value": "3"}, {"name": "b-but-kebab", "value": "bar"}],
'{"output": [{"a": 3, "b": "bar"}]}',
id="basic-test",
),
(
pytest.param(
"tests.script_runner.parameter_inputs:annotated_basic_types",
[{"name": "a-but-kebab", "value": "3"}, {"name": "b-but-kebab", "value": "1"}],
'{"output": [{"a": 3, "b": "1"}]}',
id="str-param-given-int",
),
pytest.param(
"tests.script_runner.parameter_inputs:annotated_object",
[{"name": "input-value", "value": '{"a": 3, "b": "bar"}'}],
'{"output": [{"a": 3, "b": "bar"}]}',
id="annotated-object",
),
(
pytest.param(
"tests.script_runner.parameter_inputs:annotated_parameter_no_name",
[{"name": "annotated_input_value", "value": '{"a": 3, "b": "bar"}'}],
'{"output": [{"a": 3, "b": "bar"}]}',
id="annotated-param-no-name",
),
],
)
Expand Down

0 comments on commit 862c45b

Please sign in to comment.