diff --git a/src/promptflow-devkit/promptflow/_sdk/operations/_flow_operations.py b/src/promptflow-devkit/promptflow/_sdk/operations/_flow_operations.py index 0a1631f009e..daa55132863 100644 --- a/src/promptflow-devkit/promptflow/_sdk/operations/_flow_operations.py +++ b/src/promptflow-devkit/promptflow/_sdk/operations/_flow_operations.py @@ -29,7 +29,6 @@ FLOW_TOOLS_JSON_GEN_TIMEOUT, LOCAL_MGMT_DB_PATH, SERVE_SAMPLE_JSON_PATH, - SignatureValueType, ) from promptflow._sdk._load_functions import load_flow from promptflow._sdk._orchestrator import TestSubmitter @@ -55,7 +54,6 @@ parse_variant, ) from promptflow._utils.yaml_utils import dump_yaml, load_yaml -from promptflow.contracts.tool import ValueType from promptflow.exceptions import ErrorTarget, UserErrorException @@ -1071,11 +1069,12 @@ def _infer_signature( raise UserErrorException("Entry must be a function or a class.") # signature is language irrelevant, so we apply json type system + # TODO: enable this mapping after service supports more types value_type_map = { - ValueType.INT.value: SignatureValueType.INT.value, - ValueType.DOUBLE.value: SignatureValueType.NUMBER.value, - ValueType.LIST.value: SignatureValueType.ARRAY.value, - ValueType.BOOL.value: SignatureValueType.BOOL.value, + # ValueType.INT.value: SignatureValueType.INT.value, + # ValueType.DOUBLE.value: SignatureValueType.NUMBER.value, + # ValueType.LIST.value: SignatureValueType.ARRAY.value, + # ValueType.BOOL.value: SignatureValueType.BOOL.value, } for port_type in ["inputs", "outputs", "init"]: if port_type not in flow_meta: diff --git a/src/promptflow-devkit/promptflow/_sdk/schemas/_flow.py b/src/promptflow-devkit/promptflow/_sdk/schemas/_flow.py index a50094264fc..c64cf2cc3a3 100644 --- a/src/promptflow-devkit/promptflow/_sdk/schemas/_flow.py +++ b/src/promptflow-devkit/promptflow/_sdk/schemas/_flow.py @@ -6,9 +6,10 @@ from promptflow._constants import LANGUAGE_KEY, ConnectionType, FlowLanguage from promptflow._proxy import ProxyFactory -from promptflow._sdk._constants import FlowType, SignatureValueType +from promptflow._sdk._constants import FlowType from promptflow._sdk.schemas._base import PatchedSchemaMeta, YamlFileSchema from promptflow._sdk.schemas._fields import NestedField +from promptflow.contracts.tool import ValueType class FlowInputSchema(metaclass=PatchedSchemaMeta): @@ -60,11 +61,21 @@ class FlowSchema(BaseFlowSchema): node_variants = fields.Dict(keys=fields.Str(), values=fields.Dict()) +ALLOWED_TYPES = [ + ValueType.STRING.value, + ValueType.INT.value, + ValueType.DOUBLE.value, + ValueType.BOOL.value, + ValueType.LIST.value, + ValueType.OBJECT.value, +] + + class FlexFlowInputSchema(FlowInputSchema): type = fields.Str( required=True, # TODO 3062609: Flex flow GPT-V support - validate=validate.OneOf(list(map(lambda x: x.value, SignatureValueType))), + validate=validate.OneOf(ALLOWED_TYPES), ) @@ -72,7 +83,7 @@ class FlexFlowInitSchema(FlowInputSchema): type = fields.Str( required=True, validate=validate.OneOf( - list(map(lambda x: x.value, SignatureValueType)) + ALLOWED_TYPES + list( map(lambda x: f"{x.value}Connection", filter(lambda x: x != ConnectionType._NOT_SET, ConnectionType)) ) @@ -83,7 +94,7 @@ class FlexFlowInitSchema(FlowInputSchema): class FlexFlowOutputSchema(FlowOutputSchema): type = fields.Str( required=True, - validate=validate.OneOf(list(map(lambda x: x.value, SignatureValueType))), + validate=validate.OneOf(ALLOWED_TYPES), ) diff --git a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_save.py b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_save.py index 6121adaa41b..e5d53877147 100644 --- a/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_save.py +++ b/src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_save.py @@ -70,7 +70,11 @@ def global_hello_int_return(text: str) -> int: def global_hello_strong_return(text: str) -> GlobalHello: - return len(text) + return GlobalHello(AzureOpenAIConnection("test")) + + +def global_hello_kwargs(text: str, **kwargs) -> str: + return f"Hello {text}!" @pytest.mark.usefixtures( @@ -162,7 +166,7 @@ class TestFlowSave: "type": "string", }, "length": { - "type": "integer", + "type": "int", }, }, }, @@ -200,16 +204,16 @@ class TestFlowSave: "type": "string", }, "i": { - "type": "integer", + "type": "int", }, "f": { - "type": "number", + "type": "double", }, "b": { - "type": "boolean", + "type": "bool", }, "li": { - "type": "array", + "type": "list", }, "d": { "type": "object", @@ -220,16 +224,16 @@ class TestFlowSave: "type": "string", }, "i": { - "type": "integer", + "type": "int", }, "f": { - "type": "number", + "type": "double", }, "b": { - "type": "boolean", + "type": "bool", }, "li": { - "type": "array", + "type": "list", }, "d": { "type": "object", @@ -240,16 +244,16 @@ class TestFlowSave: "type": "string", }, "i": { - "type": "integer", + "type": "int", }, "f": { - "type": "number", + "type": "double", }, "b": { - "type": "boolean", + "type": "bool", }, "l": { - "type": "array", + "type": "list", }, "d": { "type": "object", @@ -490,6 +494,17 @@ def test_pf_save_callable_function(self): }, id="inherited_typed_dict_output", ), + pytest.param( + global_hello_kwargs, + { + "inputs": { + "text": { + "type": "string", + } + }, + }, + id="kwargs", + ), ], ) def test_infer_signature(