Skip to content

Commit

Permalink
Fix the type handeling for tasks to support string types (#1470)
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel authored Jan 5, 2025
1 parent 5689aed commit 50f97a6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/unitxt/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def prepare_args(self):
self.prediction_type
)

def verify(self):
def task_deprecations(self):
if hasattr(self, "inputs") and self.inputs is not None:
depr_message = (
"The 'inputs' field is deprecated. Please use 'input_fields' instead."
Expand All @@ -127,6 +127,9 @@ def verify(self):
depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)

def verify(self):
self.task_deprecations()

if self.input_fields is None:
raise UnitxtError(
"Missing attribute in task: 'input_fields' not set.",
Expand All @@ -152,7 +155,11 @@ def verify(self):
f"will raise an exception.",
Documentation.ADDING_TASK,
)
data = {key: Any for key in data}
if isinstance(data, dict):
data = parse_type_dict(to_type_dict(data))
else:
data = {key: Any for key in data}

if io_type == "input_fields":
self.input_fields = data
else:
Expand Down
3 changes: 3 additions & 0 deletions src/unitxt/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,9 @@ def strtype(typing_type) -> str:
- The function checks the `__origin__` attribute to determine the base type and formats
the type arguments accordingly.
"""
if isinstance(typing_type, str):
return typing_type

if not is_type(typing_type):
raise UnsupportedTypeError(typing_type)

Expand Down

0 comments on commit 50f97a6

Please sign in to comment.