Skip to content

Commit

Permalink
Merge branch 'main' into yigao/move_devkit_azure_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
crazygao committed Apr 11, 2024
2 parents b43d162 + 0531f06 commit 794a39b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
FLOW_TOOLS_JSON_GEN_TIMEOUT,
LOCAL_MGMT_DB_PATH,
SERVE_SAMPLE_JSON_PATH,
SignatureValueType,
)
from promptflow._sdk._load_functions import load_flow
from promptflow._sdk._orchestrator import TestSubmitter
Expand All @@ -55,7 +54,6 @@
parse_variant,
)
from promptflow._utils.yaml_utils import dump_yaml, load_yaml
from promptflow.contracts.tool import ValueType
from promptflow.exceptions import ErrorTarget, UserErrorException


Expand Down Expand Up @@ -1071,11 +1069,12 @@ def _infer_signature(
raise UserErrorException("Entry must be a function or a class.")

# signature is language irrelevant, so we apply json type system
# TODO: enable this mapping after service supports more types
value_type_map = {
ValueType.INT.value: SignatureValueType.INT.value,
ValueType.DOUBLE.value: SignatureValueType.NUMBER.value,
ValueType.LIST.value: SignatureValueType.ARRAY.value,
ValueType.BOOL.value: SignatureValueType.BOOL.value,
# ValueType.INT.value: SignatureValueType.INT.value,
# ValueType.DOUBLE.value: SignatureValueType.NUMBER.value,
# ValueType.LIST.value: SignatureValueType.ARRAY.value,
# ValueType.BOOL.value: SignatureValueType.BOOL.value,
}
for port_type in ["inputs", "outputs", "init"]:
if port_type not in flow_meta:
Expand Down
19 changes: 15 additions & 4 deletions src/promptflow-devkit/promptflow/_sdk/schemas/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from promptflow._constants import LANGUAGE_KEY, ConnectionType, FlowLanguage
from promptflow._proxy import ProxyFactory
from promptflow._sdk._constants import FlowType, SignatureValueType
from promptflow._sdk._constants import FlowType
from promptflow._sdk.schemas._base import PatchedSchemaMeta, YamlFileSchema
from promptflow._sdk.schemas._fields import NestedField
from promptflow.contracts.tool import ValueType


class FlowInputSchema(metaclass=PatchedSchemaMeta):
Expand Down Expand Up @@ -60,19 +61,29 @@ class FlowSchema(BaseFlowSchema):
node_variants = fields.Dict(keys=fields.Str(), values=fields.Dict())


ALLOWED_TYPES = [
ValueType.STRING.value,
ValueType.INT.value,
ValueType.DOUBLE.value,
ValueType.BOOL.value,
ValueType.LIST.value,
ValueType.OBJECT.value,
]


class FlexFlowInputSchema(FlowInputSchema):
type = fields.Str(
required=True,
# TODO 3062609: Flex flow GPT-V support
validate=validate.OneOf(list(map(lambda x: x.value, SignatureValueType))),
validate=validate.OneOf(ALLOWED_TYPES),
)


class FlexFlowInitSchema(FlowInputSchema):
type = fields.Str(
required=True,
validate=validate.OneOf(
list(map(lambda x: x.value, SignatureValueType))
ALLOWED_TYPES
+ list(
map(lambda x: f"{x.value}Connection", filter(lambda x: x != ConnectionType._NOT_SET, ConnectionType))
)
Expand All @@ -83,7 +94,7 @@ class FlexFlowInitSchema(FlowInputSchema):
class FlexFlowOutputSchema(FlowOutputSchema):
type = fields.Str(
required=True,
validate=validate.OneOf(list(map(lambda x: x.value, SignatureValueType))),
validate=validate.OneOf(ALLOWED_TYPES),
)


Expand Down
43 changes: 29 additions & 14 deletions src/promptflow-devkit/tests/sdk_cli_test/e2etests/test_flow_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def global_hello_int_return(text: str) -> int:


def global_hello_strong_return(text: str) -> GlobalHello:
return len(text)
return GlobalHello(AzureOpenAIConnection("test"))


def global_hello_kwargs(text: str, **kwargs) -> str:
return f"Hello {text}!"


@pytest.mark.usefixtures(
Expand Down Expand Up @@ -162,7 +166,7 @@ class TestFlowSave:
"type": "string",
},
"length": {
"type": "integer",
"type": "int",
},
},
},
Expand Down Expand Up @@ -200,16 +204,16 @@ class TestFlowSave:
"type": "string",
},
"i": {
"type": "integer",
"type": "int",
},
"f": {
"type": "number",
"type": "double",
},
"b": {
"type": "boolean",
"type": "bool",
},
"li": {
"type": "array",
"type": "list",
},
"d": {
"type": "object",
Expand All @@ -220,16 +224,16 @@ class TestFlowSave:
"type": "string",
},
"i": {
"type": "integer",
"type": "int",
},
"f": {
"type": "number",
"type": "double",
},
"b": {
"type": "boolean",
"type": "bool",
},
"li": {
"type": "array",
"type": "list",
},
"d": {
"type": "object",
Expand All @@ -240,16 +244,16 @@ class TestFlowSave:
"type": "string",
},
"i": {
"type": "integer",
"type": "int",
},
"f": {
"type": "number",
"type": "double",
},
"b": {
"type": "boolean",
"type": "bool",
},
"l": {
"type": "array",
"type": "list",
},
"d": {
"type": "object",
Expand Down Expand Up @@ -490,6 +494,17 @@ def test_pf_save_callable_function(self):
},
id="inherited_typed_dict_output",
),
pytest.param(
global_hello_kwargs,
{
"inputs": {
"text": {
"type": "string",
}
},
},
id="kwargs",
),
],
)
def test_infer_signature(
Expand Down

0 comments on commit 794a39b

Please sign in to comment.