Skip to content

Commit

Permalink
Fix reading special character issue for tool resolver (#778)
Browse files Browse the repository at this point in the history
# Description

For LLM/Prompt/CUSTOM_LLM tool, we need to read file to resolve the
tool.
If there is special character, will introduce unexpected error.
So, we set encoding as "uts-8".

This is continuing PR for
#759

# All Promptflow Contribution checklist:
- [X] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [X] Title of the pull request is clear and informative.
- [X] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [X] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Robben Wang <[email protected]>
  • Loading branch information
huaiyan and Robben Wang authored Oct 16, 2023
1 parent 7170d0d commit 48a4c19
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 7 deletions.
7 changes: 4 additions & 3 deletions src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,16 @@ def resolve_tool_by_node(self, node: Node, convert_input_types=True) -> Resolved

def _load_source_content(self, node: Node) -> str:
source = node.source
if source is None or source.path is None or not Path(self._working_dir / source.path).exists():
# If is_file returns True, the path points to a existing file, so we don't need to check if exists.
if source is None or source.path is None or not (self._working_dir / source.path).is_file():
raise InvalidSource(
target=ErrorTarget.EXECUTOR,
message_format="Node source path '{source_path}' is invalid on node '{node_name}'.",
source_path=source.path if source is not None else None,
node_name=node.name,
)
with open(self._working_dir / source.path) as fin:
return fin.read()
file = self._working_dir / source.path
return file.read_text(encoding="utf-8")

def _validate_duplicated_inputs(self, prompt_tpl_inputs: list, tool_params: list, msg: str):
duplicated_inputs = set(prompt_tpl_inputs) & set(tool_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TestToolMetaUtils:
("prompt_tools", "summarize_text_content_prompt.jinja2", "llm"),
("script_with_import", "dummy_utils/main.py", "python"),
("script_with___file__", "script_with___file__.py", "python"),
("script_with_special_character", "script_with_special_character.py", "python"),
],
)
def test_generate_tool_meta_dict_by_file(self, flow_dir, tool_path, tool_type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from promptflow._core.tools_manager import ToolLoader
from promptflow._sdk.entities import CustomConnection, CustomStrongTypeConnection
from promptflow.connections import AzureOpenAIConnection
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSource, ToolSourceType
from promptflow.contracts.tool import InputDefinition, Secret, Tool, ToolType, ValueType
from promptflow.contracts.types import PromptTemplate
from promptflow.exceptions import UserErrorException
Expand All @@ -20,6 +20,8 @@
)
from promptflow.executor._tool_resolver import ResolvedTool, ToolResolver

from ...utils import FLOW_ROOT

TEST_ROOT = Path(__file__).parent.parent.parent
REQUESTS_PATH = TEST_ROOT / "test_configs/executor_api_requests"
WRONG_REQUESTS_PATH = TEST_ROOT / "test_configs/executor_wrong_requests"
Expand Down Expand Up @@ -416,3 +418,30 @@ def test_convert_to_custom_strong_type_connection_value(self, conn_types: list[s
actual = tool_resolver._convert_to_custom_strong_type_connection_value("conn_name", v, node, conn_types, m)
assert isinstance(actual, expected_type)
assert actual.api_base == "mock"

def test_load_source(self):
# Create a mock Node object with a valid source path
node = Node(name="mock", tool=None, inputs={}, source=ToolSource())
node.source.path = "./script_with_special_character/script_with_special_character.py"

resolver = ToolResolver(FLOW_ROOT)

result = resolver._load_source_content(node)
assert "https://www.bing.com/\ue000\ue001/" in result

@pytest.mark.parametrize(
"source",
[
None,
ToolSource(path=None), # Then will try to read one directory.
ToolSource(path=""), # Then will try to read one directory.
ToolSource(path="NotExistPath.py"),
],
)
def test_load_source_error(self, source):
# Create a mock Node object with a valid source path
node = Node(name="mock", tool=None, inputs={}, source=source)
resolver = ToolResolver(FLOW_ROOT)

with pytest.raises(InvalidSource) as _:
resolver._load_source_content(node)
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

from promptflow import tool

# Add special character to test if file read is working.
print("https://www.bing.com/")

print(f"The script is {__file__}")
assert Path(__file__).is_absolute(), f"__file__ should be absolute path, got {__file__}"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"name": "script_with_special_character",
"type": "python",
"inputs": {
"input1": {
"type": [
"string"
]
}
},
"source": "script_with_special_character.py",
"function": "print_special_character"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from promptflow import tool

@tool
def print_special_character(input1: str) -> str:
# Add special character to test if file read is working.
return "https://www.bing.com//"

0 comments on commit 48a4c19

Please sign in to comment.