Skip to content

Commit

Permalink
[Prompty] Support model config in prompty (#2728)
Browse files Browse the repository at this point in the history
# Description
- prompty with azure openai
```
---
name: Basic Prompt
description: A basic prompt that uses the GPT-3 chat API to answer questions
model:
    api: chat
    configuration:
      type: azure_openai
      azure_deployment: gpt-35-turbo
      azure_endpoint: ${env:AZURE_ENDPOINT}
      api_key: ${env:AZURE_API_KEY}
    parameters:
      max_tokens: 128
      temperature: 0.2
inputs:
  firstName:
    type: string
    default: John
  lastName:
    type: string
    default: Doh
  question:
    type: string
---
system:
You are an AI assistant who helps people find information.
Use their name to address them in your responses.

user:
{{question}}

```
- prompty with connection
```
---
name: Basic Prompt
description: A basic prompt that uses the GPT-3 chat API to answer questions
model:
    api: chat
    connection: azure_open_ai_connection
    configuration:
      type: azure_openai
      azure_deployment: gpt-35-turbo
    parameters:
      max_tokens: 128
      temperature: 0.2
inputs:
  firstName:
    type: string
    default: John
  lastName:
    type: string
    default: Doh
  question:
    type: string
---
system:
You are an AI assistant who helps people find information.
Use their name to address them in your responses.

user:
{{question}}

```

- prompty with openai
```
---
name: Basic Prompt
description: A basic prompt that uses the GPT-3 chat API to answer questions
model:
    api: chat
    connection: azure_open_ai_connection
    configuration:
      type: openai
      model: gpt-35-turbo
      api_key: ${env:API_KEY}
      base_url: ${env:BASE_URL}
    parameters:
      max_tokens: 128
      temperature: 0.2
inputs:
  firstName:
    type: string
    default: John
  lastName:
    type: string
    default: Doh
  question:
    type: string
---
system:
You are an AI assistant who helps people find information.
Use their name to address them in your responses.

user:
{{question}}

```

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
lalala123123 authored Apr 12, 2024
1 parent ac82e2e commit 15960bf
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 102 deletions.
114 changes: 72 additions & 42 deletions src/promptflow-core/promptflow/core/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
from promptflow._constants import DEFAULT_ENCODING, LANGUAGE_KEY, PROMPTY_EXTENSION, FlowLanguage
from promptflow._utils.flow_utils import is_flex_flow, is_prompty_flow, resolve_flow_path
from promptflow._utils.yaml_utils import load_yaml_string
from promptflow.contracts.tool import ValueType
from promptflow.core._errors import MissingRequiredInputError
from promptflow.core._model_configuration import PromptyModelConfiguration
from promptflow.core._prompty_utils import (
convert_model_configuration_to_connection,
convert_prompt_template,
format_llm_response,
get_connection,
get_open_ai_client_by_connection,
prepare_open_ai_request_params,
send_request_to_llm,
update_dict_recursively,
)
from promptflow.exceptions import UserErrorException
from promptflow.tracing import trace
Expand Down Expand Up @@ -230,6 +231,55 @@ class Prompty(FlowBase):
prompty = Prompty.load(source="path/to/prompty.prompty")
result = prompty(input_a=1, input_b=2)
# Override model config with dict
model_config = {
"api": "chat",
"configuration": {
"type": "azure_openai",
"azure_deployment": "gpt-35-turbo",
"api_key": ${env:AZURE_OPENAI_API_KEY},
"api_version": ${env:AZURE_OPENAI_API_VERSION},
"azure_endpoint": ${env:AZURE_OPENAI_ENDPOINT},
},
"parameters": {
max_token: 512
}
}
prompty = Prompty.load(source="path/to/prompty.prompty", model=model_config)
result = prompty(input_a=1, input_b=2)
# Override model config with configuration
from promptflow.core._model_configuration import AzureOpenAIModelConfiguration
model_config = {
"api": "chat",
"configuration": AzureOpenAIModelConfiguration(
azure_deployment="gpt-35-turbo",
api_key="${env:AZURE_OPENAI_API_KEY}",
api_version=${env:AZURE_OPENAI_API_VERSION}",
azure_endpoint="${env:AZURE_OPENAI_ENDPOINT}",
),
"parameters": {
max_token: 512
}
}
prompty = Prompty.load(source="path/to/prompty.prompty", model=model_config)
result = prompty(input_a=1, input_b=2)
# Override model config with created connection
from promptflow.core._model_configuration import AzureOpenAIModelConfiguration
model_config = {
"api": "chat",
"configuration": AzureOpenAIModelConfiguration(
connection="azure_open_ai_connection",
azure_deployment="gpt-35-turbo",
),
"parameters": {
max_token: 512
}
}
prompty = Prompty.load(source="path/to/prompty.prompty", model=model_config)
result = prompty(input_a=1, input_b=2)
"""

def __init__(
Expand All @@ -240,29 +290,15 @@ def __init__(
):
# prompty file path
path = Path(path)
model = model or {}
configs, self._template = self._parse_prompty(path)
prompty_model = configs.get("model", {})
prompty_model["api"] = model.get("api") or prompty_model.get("api", "chat")
# TODO wait for model spec
prompty_model["connection"] = model.get("connection") or prompty_model.get("connection", None)
if model.get("parameters", None):
if prompty_model.get("parameters", {}):
prompty_model["parameters"].update(model["parameters"])
else:
prompty_model["parameters"] = model["parameters"]
for k in list(kwargs.keys()):
value = kwargs.pop(k)
if k in configs and isinstance(value, dict):
configs[k].update(value)
else:
configs[k] = value
configs["inputs"] = self._resolve_inputs(configs.get("inputs", {}))
self._connection = prompty_model["connection"]
self._parameters = prompty_model.get("parameters", None)
self._api = prompty_model["api"]
configs = update_dict_recursively(configs, kwargs)
configs["model"] = update_dict_recursively(configs.get("model", {}), model or {})

self._model = PromptyModelConfiguration(**configs["model"])
self._inputs = configs.get("inputs", {})
configs["model"] = prompty_model
self._outputs = configs.get("outputs", {})
# TODO support more templating engine
self._template_engine = configs.get("template", "jinja2")
super().__init__(code=path.parent, path=path, data=configs, content_hash=None, **kwargs)

@classmethod
Expand Down Expand Up @@ -314,15 +350,6 @@ def _parse_prompty(path):
configs = load_yaml_string(config_content)
return configs, prompt_template

def _resolve_inputs(self, inputs):
resolved_inputs = {}
for k, v in inputs.items():
if isinstance(v, dict):
resolved_inputs[k] = v
else:
resolved_inputs[k] = {"type": ValueType.from_value(v).value, "default": v}
return resolved_inputs

def _validate_inputs(self, input_values):
resolved_inputs = {}
missing_inputs = []
Expand All @@ -348,25 +375,26 @@ def __call__(self, *args, **kwargs):
"""
if args:
raise UserErrorException("Prompty can only be called with keyword arguments.")

# 1. Get connection
connection = get_connection(self._connection)
connection = convert_model_configuration_to_connection(self._model.configuration)

# 2.deal with prompt
inputs = self._validate_inputs(kwargs)
traced_convert_prompt_template = _traced(func=convert_prompt_template, args_to_ignore=["api"])
template = traced_convert_prompt_template(self._template, inputs, self._api)
template = traced_convert_prompt_template(self._template, inputs, self._model.api)

# 3. prepare params
params = prepare_open_ai_request_params(self._parameters, template, self._api, connection)
params = prepare_open_ai_request_params(self._model, template, connection)

# 4. send request to open ai
api_client = get_open_ai_client_by_connection(connection=connection)

traced_llm_call = _traced(send_request_to_llm)
response = traced_llm_call(api_client, self._api, params)
response = traced_llm_call(api_client, self._model.api, params)
return format_llm_response(
response=response,
api=self._api,
api=self._model.api,
response_format=params.get("response_format", None),
raw=self._data.get("format", None) == "raw",
)
Expand Down Expand Up @@ -399,24 +427,26 @@ async def __call__(self, *args, **kwargs) -> Mapping[str, Any]:
"""
if args:
raise UserErrorException("Prompty can only be called with keyword arguments.")

# 1. Get connection
connection = get_connection(self._connection)
connection = convert_model_configuration_to_connection(self._model.configuration)

# 2.deal with prompt
inputs = self._validate_inputs(kwargs)
template = convert_prompt_template(self._template, inputs, self._api)
traced_convert_prompt_template = _traced(func=convert_prompt_template, args_to_ignore=["api"])
template = traced_convert_prompt_template(self._template, inputs, self._model.api)

# 3. prepare params
params = prepare_open_ai_request_params(self._parameters, template, self._api, connection)
params = prepare_open_ai_request_params(self._model, template, connection)

# 4. send request to open ai
api_client = get_open_ai_client_by_connection(connection=connection, is_async=True)

traced_llm_call = _traced(send_request_to_llm)
response = await traced_llm_call(api_client, self._api, params)
response = await traced_llm_call(api_client, self._model.api, params)
return format_llm_response(
response=response,
api=self._api,
api=self._model.api,
response_format=params.get("response_format", None),
raw=self._data.get("format", None) == "raw",
)
79 changes: 79 additions & 0 deletions src/promptflow-core/promptflow/core/_model_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from dataclasses import dataclass
from typing import Union

from promptflow._constants import ConnectionType
from promptflow.core._errors import InvalidConnectionError


class ModelConfiguration:
pass


@dataclass
class AzureOpenAIModelConfiguration(ModelConfiguration):
azure_deployment: str
azure_endpoint: str = None
api_version: str = None
api_key: str = None
organization: str = None
# connection and model configs are exclusive.
connection: str = None

def __post_init__(self):
self._type = ConnectionType.AZURE_OPEN_AI
if any([self.azure_endpoint, self.api_key, self.api_version, self.organization]) and self.connection:
raise InvalidConnectionError("Cannot configure model config and connection at the same time.")


@dataclass
class OpenAIModelConfiguration(ModelConfiguration):
model: str
base_url: str = None
api_key: str = None
organization: str = None
# connection and model configs are exclusive.
connection: str = None

def __post_init__(self):
self._type = ConnectionType.OPEN_AI
if any([self.base_url, self.api_key, self.api_version, self.organization]) and self.connection:
raise InvalidConnectionError("Cannot configure model config and connection at the same time.")


@dataclass
class PromptyModelConfiguration:
"""
A dataclass that represents a model config of prompty.
:param api: Type of the LLM request, default value is chat.
:type api: str
:param configuration: Prompty model connection configuration
:type configuration: Union[dict, AzureOpenAIModelConfiguration, OpenAIModelConfiguration]
:param parameters: Params of the LLM request.
:type parameters: dict
:param response: Return the complete response or the first choice, default value is first.
:type response: str
"""

configuration: Union[dict, AzureOpenAIModelConfiguration, OpenAIModelConfiguration]
parameters: dict
api: str = "chat"
response: str = "first"

def __post_init__(self):
if isinstance(self.configuration, dict):
# Load connection from model configuration
model_config = {
k: v
for k, v in self.configuration.items()
if k not in ["type", "connection", "model", "azure_deployment"]
}
if self.configuration.get("connection", None) and any([v for v in model_config.values()]):
raise InvalidConnectionError(
"Cannot configure model config and connection in configuration at the same time."
)
self._model = self.configuration.get("azure_deployment", None) or self.configuration.get("model", None)
elif isinstance(self.configuration, OpenAIModelConfiguration):
self._model = self.configuration.model
elif isinstance(self.configuration, AzureOpenAIModelConfiguration):
self._model = self.configuration.azure_deployment
Loading

0 comments on commit 15960bf

Please sign in to comment.