Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prompty] Support model config in prompty #2728

Merged
merged 14 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 72 additions & 42 deletions src/promptflow-core/promptflow/core/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
from promptflow._constants import DEFAULT_ENCODING, LANGUAGE_KEY, PROMPTY_EXTENSION, FlowLanguage
from promptflow._utils.flow_utils import is_flex_flow, is_prompty_flow, resolve_flow_path
from promptflow._utils.yaml_utils import load_yaml_string
from promptflow.contracts.tool import ValueType
from promptflow.core._errors import MissingRequiredInputError
from promptflow.core._model_configuration import PromptyModelConfiguration
from promptflow.core._prompty_utils import (
convert_model_configuration_to_connection,
convert_prompt_template,
format_llm_response,
get_connection,
get_open_ai_client_by_connection,
prepare_open_ai_request_params,
send_request_to_llm,
update_dict_recursively,
)
from promptflow.exceptions import UserErrorException
from promptflow.tracing import trace
Expand Down Expand Up @@ -230,6 +231,55 @@ class Prompty(FlowBase):
prompty = Prompty.load(source="path/to/prompty.prompty")
result = prompty(input_a=1, input_b=2)

# Override model config with dict
model_config = {
"api": "chat",
"configuration": {
"type": "azure_openai",
"azure_deployment": "gpt-35-turbo",
"api_key": ${env:AZURE_OPENAI_API_KEY},
"api_version": ${env:AZURE_OPENAI_API_VERSION},
"azure_endpoint": ${env:AZURE_OPENAI_ENDPOINT},
},
"parameters": {
max_token: 512
}
}
prompty = Prompty.load(source="path/to/prompty.prompty", model=model_config)
result = prompty(input_a=1, input_b=2)

# Override model config with configuration
from promptflow.core._model_configuration import AzureOpenAIModelConfiguration
model_config = {
"api": "chat",
"configuration": AzureOpenAIModelConfiguration(
azure_deployment="gpt-35-turbo",
api_key="${env:AZURE_OPENAI_API_KEY}",
api_version=${env:AZURE_OPENAI_API_VERSION}",
azure_endpoint="${env:AZURE_OPENAI_ENDPOINT}",
),
"parameters": {
max_token: 512
}
}
prompty = Prompty.load(source="path/to/prompty.prompty", model=model_config)
result = prompty(input_a=1, input_b=2)

# Override model config with created connection
from promptflow.core._model_configuration import AzureOpenAIModelConfiguration
model_config = {
"api": "chat",
"configuration": AzureOpenAIModelConfiguration(
connection="azure_open_ai_connection",
azure_deployment="gpt-35-turbo",
),
"parameters": {
max_token: 512
}
}
prompty = Prompty.load(source="path/to/prompty.prompty", model=model_config)
result = prompty(input_a=1, input_b=2)

"""

def __init__(
Expand All @@ -240,29 +290,15 @@ def __init__(
):
# prompty file path
path = Path(path)
model = model or {}
configs, self._template = self._parse_prompty(path)
prompty_model = configs.get("model", {})
prompty_model["api"] = model.get("api") or prompty_model.get("api", "chat")
# TODO wait for model spec
prompty_model["connection"] = model.get("connection") or prompty_model.get("connection", None)
if model.get("parameters", None):
if prompty_model.get("parameters", {}):
prompty_model["parameters"].update(model["parameters"])
else:
prompty_model["parameters"] = model["parameters"]
for k in list(kwargs.keys()):
value = kwargs.pop(k)
if k in configs and isinstance(value, dict):
configs[k].update(value)
else:
configs[k] = value
configs["inputs"] = self._resolve_inputs(configs.get("inputs", {}))
self._connection = prompty_model["connection"]
self._parameters = prompty_model.get("parameters", None)
self._api = prompty_model["api"]
configs = update_dict_recursively(configs, kwargs)
configs["model"] = update_dict_recursively(configs.get("model", {}), model or {})

self._model = PromptyModelConfiguration(**configs["model"])
self._inputs = configs.get("inputs", {})
configs["model"] = prompty_model
self._outputs = configs.get("outputs", {})
# TODO support more templating engine
self._template_engine = configs.get("template", "jinja2")
super().__init__(code=path.parent, path=path, data=configs, content_hash=None, **kwargs)

@classmethod
Expand Down Expand Up @@ -314,15 +350,6 @@ def _parse_prompty(path):
configs = load_yaml_string(config_content)
return configs, prompt_template

def _resolve_inputs(self, inputs):
resolved_inputs = {}
for k, v in inputs.items():
if isinstance(v, dict):
resolved_inputs[k] = v
else:
resolved_inputs[k] = {"type": ValueType.from_value(v).value, "default": v}
return resolved_inputs

def _validate_inputs(self, input_values):
resolved_inputs = {}
missing_inputs = []
Expand All @@ -348,25 +375,26 @@ def __call__(self, *args, **kwargs):
"""
if args:
raise UserErrorException("Prompty can only be called with keyword arguments.")

# 1. Get connection
connection = get_connection(self._connection)
connection = convert_model_configuration_to_connection(self._model.configuration)

# 2.deal with prompt
inputs = self._validate_inputs(kwargs)
traced_convert_prompt_template = _traced(func=convert_prompt_template, args_to_ignore=["api"])
template = traced_convert_prompt_template(self._template, inputs, self._api)
template = traced_convert_prompt_template(self._template, inputs, self._model.api)

# 3. prepare params
params = prepare_open_ai_request_params(self._parameters, template, self._api, connection)
params = prepare_open_ai_request_params(self._model, template, connection)

# 4. send request to open ai
api_client = get_open_ai_client_by_connection(connection=connection)

traced_llm_call = _traced(send_request_to_llm)
response = traced_llm_call(api_client, self._api, params)
response = traced_llm_call(api_client, self._model.api, params)
return format_llm_response(
response=response,
api=self._api,
api=self._model.api,
response_format=params.get("response_format", None),
raw=self._data.get("format", None) == "raw",
)
Expand Down Expand Up @@ -399,24 +427,26 @@ async def __call__(self, *args, **kwargs) -> Mapping[str, Any]:
"""
if args:
raise UserErrorException("Prompty can only be called with keyword arguments.")

# 1. Get connection
connection = get_connection(self._connection)
connection = convert_model_configuration_to_connection(self._model.configuration)

# 2.deal with prompt
inputs = self._validate_inputs(kwargs)
template = convert_prompt_template(self._template, inputs, self._api)
traced_convert_prompt_template = _traced(func=convert_prompt_template, args_to_ignore=["api"])
template = traced_convert_prompt_template(self._template, inputs, self._model.api)

# 3. prepare params
params = prepare_open_ai_request_params(self._parameters, template, self._api, connection)
params = prepare_open_ai_request_params(self._model, template, connection)

# 4. send request to open ai
api_client = get_open_ai_client_by_connection(connection=connection, is_async=True)

traced_llm_call = _traced(send_request_to_llm)
response = await traced_llm_call(api_client, self._api, params)
response = await traced_llm_call(api_client, self._model.api, params)
return format_llm_response(
response=response,
api=self._api,
api=self._model.api,
response_format=params.get("response_format", None),
raw=self._data.get("format", None) == "raw",
)
79 changes: 79 additions & 0 deletions src/promptflow-core/promptflow/core/_model_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from dataclasses import dataclass
from typing import Union

from promptflow._constants import ConnectionType
from promptflow.core._errors import InvalidConnectionError


class ModelConfiguration:
pass


@dataclass
class AzureOpenAIModelConfiguration(ModelConfiguration):
azure_deployment: str
azure_endpoint: str = None
api_version: str = None
api_key: str = None
organization: str = None
# connection and model configs are exclusive.
connection: str = None

def __post_init__(self):
self._type = ConnectionType.AZURE_OPEN_AI
if any([self.azure_endpoint, self.api_key, self.api_version, self.organization]) and self.connection:
raise InvalidConnectionError("Cannot configure model config and connection at the same time.")


@dataclass
class OpenAIModelConfiguration(ModelConfiguration):
model: str
base_url: str = None
api_key: str = None
organization: str = None
# connection and model configs are exclusive.
connection: str = None

def __post_init__(self):
self._type = ConnectionType.OPEN_AI
if any([self.base_url, self.api_key, self.api_version, self.organization]) and self.connection:
raise InvalidConnectionError("Cannot configure model config and connection at the same time.")


@dataclass
class PromptyModelConfiguration:
"""
A dataclass that represents a model config of prompty.

:param api: Type of the LLM request, default value is chat.
:type api: str
:param configuration: Prompty model connection configuration
:type configuration: Union[dict, AzureOpenAIModelConfiguration, OpenAIModelConfiguration]
:param parameters: Params of the LLM request.
:type parameters: dict
:param response: Return the complete response or the first choice, default value is first.
:type response: str
"""

configuration: Union[dict, AzureOpenAIModelConfiguration, OpenAIModelConfiguration]
parameters: dict
api: str = "chat"
response: str = "first"

def __post_init__(self):
if isinstance(self.configuration, dict):
# Load connection from model configuration
model_config = {
k: v
for k, v in self.configuration.items()
if k not in ["type", "connection", "model", "azure_deployment"]
}
if self.configuration.get("connection", None) and any([v for v in model_config.values()]):
raise InvalidConnectionError(
"Cannot configure model config and connection in configuration at the same time."
)
self._model = self.configuration.get("azure_deployment", None) or self.configuration.get("model", None)
elif isinstance(self.configuration, OpenAIModelConfiguration):
self._model = self.configuration.model
elif isinstance(self.configuration, AzureOpenAIModelConfiguration):
self._model = self.configuration.azure_deployment
Loading
Loading