From 2709a9205a042e2baabd7d2f97f40365337b8c30 Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Wed, 29 Jan 2025 10:24:50 -0800 Subject: [PATCH] =?UTF-8?q?fixes=20interpolation=20issues=20when=20inputs?= =?UTF-8?q?=20are=20type=20dict,list=20specificall=E2=80=A6=20(#1992)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fixes interpolation issues when inputs are type dict,list specifically when defined on expected_output * improvements with type hints, doc fixes and rm print statements * more tests * test passing --------- Co-authored-by: Brandon Hancock --- src/crewai/task.py | 48 ++++--- tests/task_test.py | 317 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 350 insertions(+), 15 deletions(-) diff --git a/src/crewai/task.py b/src/crewai/task.py index 030bce779c..cbf651f9bb 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -431,7 +431,9 @@ def _execute_core( content = ( json_output if json_output - else pydantic_output.model_dump_json() if pydantic_output else result + else pydantic_output.model_dump_json() + if pydantic_output + else result ) self._save_file(content) @@ -452,7 +454,7 @@ def prompt(self) -> str: return "\n".join(tasks_slices) def interpolate_inputs_and_add_conversation_history( - self, inputs: Dict[str, Union[str, int, float]] + self, inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] ) -> None: """Interpolate inputs into the task description, expected output, and output file path. Add conversation history if present. @@ -524,7 +526,9 @@ def interpolate_inputs_and_add_conversation_history( ) def interpolate_only( - self, input_string: Optional[str], inputs: Dict[str, Union[str, int, float]] + self, + input_string: Optional[str], + inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]], ) -> str: """Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched. @@ -532,17 +536,39 @@ def interpolate_only( input_string: The string containing template variables to interpolate. Can be None or empty, in which case an empty string is returned. inputs: Dictionary mapping template variables to their values. - Supported value types are strings, integers, and floats. - If input_string is empty or has no placeholders, inputs can be empty. + Supported value types are strings, integers, floats, and dicts/lists + containing only these types and other nested dicts/lists. Returns: The interpolated string with all template variables replaced with their values. Empty string if input_string is None or empty. Raises: - ValueError: If a required template variable is missing from inputs. - KeyError: If a template variable is not found in the inputs dictionary. + ValueError: If a value contains unsupported types """ + + # Validation function for recursive type checking + def validate_type(value: Any) -> None: + if value is None: + return + if isinstance(value, (str, int, float, bool)): + return + if isinstance(value, (dict, list)): + for item in value.values() if isinstance(value, dict) else value: + validate_type(item) + return + raise ValueError( + f"Unsupported type {type(value).__name__} in inputs. " + "Only str, int, float, bool, dict, and list are allowed." + ) + + # Validate all input values + for key, value in inputs.items(): + try: + validate_type(value) + except ValueError as e: + raise ValueError(f"Invalid value for key '{key}': {str(e)}") from e + if input_string is None or not input_string: return "" if "{" not in input_string and "}" not in input_string: @@ -551,15 +577,7 @@ def interpolate_only( raise ValueError( "Inputs dictionary cannot be empty when interpolating variables" ) - try: - # Validate input types - for key, value in inputs.items(): - if not isinstance(value, (str, int, float)): - raise ValueError( - f"Value for key '{key}' must be a string, integer, or float, got {type(value).__name__}" - ) - escaped_string = input_string.replace("{", "{{").replace("}", "}}") for key in inputs.keys(): diff --git a/tests/task_test.py b/tests/task_test.py index 59e58dcca9..5ffaf2534d 100644 --- a/tests/task_test.py +++ b/tests/task_test.py @@ -779,6 +779,43 @@ def test_interpolate_only(): assert result == no_placeholders +def test_interpolate_only_with_dict_inside_expected_output(): + """Test the interpolate_only method for various scenarios including JSON structure preservation.""" + task = Task( + description="Unused in this test", + expected_output="Unused in this test: {questions}", + ) + + json_string = '{"questions": {"main_question": "What is the user\'s name?", "secondary_question": "What is the user\'s age?"}}' + result = task.interpolate_only( + input_string=json_string, + inputs={ + "questions": { + "main_question": "What is the user's name?", + "secondary_question": "What is the user's age?", + } + }, + ) + assert '"main_question": "What is the user\'s name?"' in result + assert '"secondary_question": "What is the user\'s age?"' in result + assert result == json_string + + normal_string = "Hello {name}, welcome to {place}!" + result = task.interpolate_only( + input_string=normal_string, inputs={"name": "John", "place": "CrewAI"} + ) + assert result == "Hello John, welcome to CrewAI!" + + result = task.interpolate_only(input_string="", inputs={"unused": "value"}) + assert result == "" + + no_placeholders = "Hello, this is a test" + result = task.interpolate_only( + input_string=no_placeholders, inputs={"unused": "value"} + ) + assert result == no_placeholders + + def test_task_output_str_with_pydantic(): from crewai.tasks.output_format import OutputFormat @@ -966,3 +1003,283 @@ def test_task_execution_times(): assert task.start_time is not None assert task.end_time is not None assert task.execution_duration == (task.end_time - task.start_time).total_seconds() + + +def test_interpolate_with_list_of_strings(): + task = Task( + description="Test list interpolation", + expected_output="List: {items}", + ) + + # Test simple list of strings + input_str = "Available items: {items}" + inputs = {"items": ["apple", "banana", "cherry"]} + result = task.interpolate_only(input_str, inputs) + assert result == f"Available items: {inputs['items']}" + + # Test empty list + empty_list_input = {"items": []} + result = task.interpolate_only(input_str, empty_list_input) + assert result == "Available items: []" + + +def test_interpolate_with_list_of_dicts(): + task = Task( + description="Test list of dicts interpolation", + expected_output="People: {people}", + ) + + input_data = { + "people": [ + {"name": "Alice", "age": 30, "skills": ["Python", "AI"]}, + {"name": "Bob", "age": 25, "skills": ["Java", "Cloud"]}, + ] + } + result = task.interpolate_only("{people}", input_data) + + parsed_result = eval(result) + assert isinstance(parsed_result, list) + assert len(parsed_result) == 2 + assert parsed_result[0]["name"] == "Alice" + assert parsed_result[0]["age"] == 30 + assert parsed_result[0]["skills"] == ["Python", "AI"] + assert parsed_result[1]["name"] == "Bob" + assert parsed_result[1]["age"] == 25 + assert parsed_result[1]["skills"] == ["Java", "Cloud"] + + +def test_interpolate_with_nested_structures(): + task = Task( + description="Test nested structures", + expected_output="Company: {company}", + ) + + input_data = { + "company": { + "name": "TechCorp", + "departments": [ + { + "name": "Engineering", + "employees": 50, + "tools": ["Git", "Docker", "Kubernetes"], + }, + {"name": "Sales", "employees": 20, "regions": {"north": 5, "south": 3}}, + ], + } + } + result = task.interpolate_only("{company}", input_data) + parsed = eval(result) + + assert parsed["name"] == "TechCorp" + assert len(parsed["departments"]) == 2 + assert parsed["departments"][0]["tools"] == ["Git", "Docker", "Kubernetes"] + assert parsed["departments"][1]["regions"]["north"] == 5 + + +def test_interpolate_with_special_characters(): + task = Task( + description="Test special characters in dicts", + expected_output="Data: {special_data}", + ) + + input_data = { + "special_data": { + "quotes": """This has "double" and 'single' quotes""", + "unicode": "文字化けテスト", + "symbols": "!@#$%^&*()", + "empty": "", + } + } + result = task.interpolate_only("{special_data}", input_data) + parsed = eval(result) + + assert parsed["quotes"] == """This has "double" and 'single' quotes""" + assert parsed["unicode"] == "文字化けテスト" + assert parsed["symbols"] == "!@#$%^&*()" + assert parsed["empty"] == "" + + +def test_interpolate_mixed_types(): + task = Task( + description="Test mixed type interpolation", + expected_output="Mixed: {data}", + ) + + input_data = { + "data": { + "name": "Test Dataset", + "samples": 1000, + "features": ["age", "income", "location"], + "metadata": { + "source": "public", + "validated": True, + "tags": ["demo", "test", "temp"], + }, + } + } + result = task.interpolate_only("{data}", input_data) + parsed = eval(result) + + assert parsed["name"] == "Test Dataset" + assert parsed["samples"] == 1000 + assert parsed["metadata"]["tags"] == ["demo", "test", "temp"] + + +def test_interpolate_complex_combination(): + task = Task( + description="Test complex combination", + expected_output="Report: {report}", + ) + + input_data = { + "report": [ + { + "month": "January", + "metrics": {"sales": 15000, "expenses": 8000, "profit": 7000}, + "top_products": ["Product A", "Product B"], + }, + { + "month": "February", + "metrics": {"sales": 18000, "expenses": 8500, "profit": 9500}, + "top_products": ["Product C", "Product D"], + }, + ] + } + result = task.interpolate_only("{report}", input_data) + parsed = eval(result) + + assert len(parsed) == 2 + assert parsed[0]["month"] == "January" + assert parsed[1]["metrics"]["profit"] == 9500 + assert "Product D" in parsed[1]["top_products"] + + +def test_interpolate_invalid_type_validation(): + task = Task( + description="Test invalid type validation", + expected_output="Should never reach here", + ) + + # Test with invalid top-level type + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": set()}) # type: ignore we are purposely testing this failure + + assert "Unsupported type set" in str(excinfo.value) + + # Test with invalid nested type + invalid_nested = { + "profile": { + "name": "John", + "age": 30, + "tags": {"a", "b", "c"}, # Set is invalid + } + } + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": invalid_nested}) + assert "Unsupported type set" in str(excinfo.value) + + +def test_interpolate_custom_object_validation(): + task = Task( + description="Test custom object rejection", + expected_output="Should never reach here", + ) + + class CustomObject: + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + # Test with custom object at top level + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{obj}", {"obj": CustomObject(5)}) # type: ignore we are purposely testing this failure + assert "Unsupported type CustomObject" in str(excinfo.value) + + # Test with nested custom object in dictionary + with pytest.raises(ValueError) as excinfo: + task.interpolate_only( + "{data}", {"data": {"valid": 1, "invalid": CustomObject(5)}} + ) + assert "Unsupported type CustomObject" in str(excinfo.value) + + # Test with nested custom object in list + with pytest.raises(ValueError) as excinfo: + task.interpolate_only("{data}", {"data": [1, "valid", CustomObject(5)]}) + assert "Unsupported type CustomObject" in str(excinfo.value) + + # Test with deeply nested custom object + with pytest.raises(ValueError) as excinfo: + task.interpolate_only( + "{data}", {"data": {"level1": {"level2": [{"level3": CustomObject(5)}]}}} + ) + assert "Unsupported type CustomObject" in str(excinfo.value) + + +def test_interpolate_valid_complex_types(): + task = Task( + description="Test valid complex types", + expected_output="Validation should pass", + ) + + # Valid complex structure + valid_data = { + "name": "Valid Dataset", + "stats": { + "count": 1000, + "distribution": [0.2, 0.3, 0.5], + "features": ["age", "income"], + "nested": {"deep": [1, 2, 3], "deeper": {"a": 1, "b": 2.5}}, + }, + } + + # Should not raise any errors + result = task.interpolate_only("{data}", {"data": valid_data}) + parsed = eval(result) + assert parsed["name"] == "Valid Dataset" + assert parsed["stats"]["nested"]["deeper"]["b"] == 2.5 + + +def test_interpolate_edge_cases(): + task = Task( + description="Test edge cases", + expected_output="Edge case handling", + ) + + # Test empty dict and list + assert task.interpolate_only("{}", {"data": {}}) == "{}" + assert task.interpolate_only("[]", {"data": []}) == "[]" + + # Test numeric types + assert task.interpolate_only("{num}", {"num": 42}) == "42" + assert task.interpolate_only("{num}", {"num": 3.14}) == "3.14" + + # Test boolean values (valid JSON types) + assert task.interpolate_only("{flag}", {"flag": True}) == "True" + assert task.interpolate_only("{flag}", {"flag": False}) == "False" + + +def test_interpolate_valid_types(): + task = Task( + description="Test valid types including null and boolean", + expected_output="Should pass validation", + ) + + # Test with boolean and null values (valid JSON types) + valid_data = { + "name": "Test", + "active": True, + "deleted": False, + "optional": None, + "nested": {"flag": True, "empty": None}, + } + + result = task.interpolate_only("{data}", {"data": valid_data}) + parsed = eval(result) + + assert parsed["active"] is True + assert parsed["deleted"] is False + assert parsed["optional"] is None + assert parsed["nested"]["flag"] is True + assert parsed["nested"]["empty"] is None