Skip to content

Commit

Permalink
[PromptFlowService] Add flask_restx to pfs to generate swagger (#1123)
Browse files Browse the repository at this point in the history
# Description
1. Add flask_restx to pfs to generate swagger
2. Update connection api
Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
lalala123123 authored Nov 17, 2023
1 parent 277ee50 commit bf0f519
Show file tree
Hide file tree
Showing 16 changed files with 570 additions and 150 deletions.
3 changes: 2 additions & 1 deletion .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@
"vcrpy",
"uionly",
"llmops",
"Abhishek"
"Abhishek",
"restx"
],
"flagWords": [
"Prompt Flow"
Expand Down
1 change: 1 addition & 0 deletions src/promptflow/promptflow/_sdk/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
HOME_PROMPT_FLOW_DIR = (Path.home() / PROMPT_FLOW_DIR_NAME).resolve()
SERVICE_CONFIG_FILE = "pf.yaml"
PF_SERVICE_PORT_FILE = "pfs.port"
PF_SERVICE_LOG_FILE = "pfs.log"

if not HOME_PROMPT_FLOW_DIR.is_dir():
HOME_PROMPT_FLOW_DIR.mkdir(exist_ok=True)
Expand Down
32 changes: 32 additions & 0 deletions src/promptflow/promptflow/_sdk/_service/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Prompt Flow Service
This document will describe the usage of pfs(prompt flow service) CLI.

### Start prompt flow service (optional)
If you don't install pfs as a service, you need to start pfs manually.
pfs CLI provides **start** command to start service. You can also use this command to specify the service port.

```commandline
usage: pfs [-h] [-p PORT]
Start prompt flow service.
optional arguments:
-h, --help show this help message and exit
-p PORT, --port PORT port of the promptflow service
```

If you don't specify a port to start service, pfs will first use the port in the configure file in "~/.promptflow/pfs.port".

If not found port configuration or the port is used, pfs will use a random port to start the service.

### Swagger of service
After start the service, it will provide Swagger UI documentation, served from "http://localhost:your-port/v1.0/swagger.json".

For details, please refer to [swagger.json](./swagger.json).

#### Generate C# client
1. Right click the project, Add -> Rest API Client... -> Generate with OpenAPI Generator

2. It will open a dialog, fill in the file name and swagger url, it will generate the client under the project.

For details, please refer to [REST API Client Code Generator](https://marketplace.visualstudio.com/items?itemName=ChristianResmaHelle.ApiClientCodeGenerator2022).
5 changes: 5 additions & 0 deletions src/promptflow/promptflow/_sdk/_service/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
108 changes: 108 additions & 0 deletions src/promptflow/promptflow/_sdk/_service/apis/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import json

from flask import jsonify, request
from flask_restx import Namespace, Resource, fields

from promptflow._sdk._errors import ConnectionNotFoundError
from promptflow._sdk._service.utils.utils import local_user_only
from promptflow._sdk.entities._connection import _Connection
from promptflow._sdk.operations._connection_operations import ConnectionOperations

api = Namespace("Connections", description="Connections Management")

# Define base connection request parsing
remote_parser = api.parser()
remote_parser.add_argument("X-Remote-User", location="headers", required=True)

# Define create or update connection request parsing
create_or_update_parser = remote_parser.copy()
create_or_update_parser.add_argument("connection_dict", type=str, location="args", required=True)

# Response model of list connections
list_connection_field = api.model(
"Connection",
{
"name": fields.String,
"type": fields.String,
"module": fields.String,
"expiry_time": fields.DateTime(),
"created_date": fields.DateTime(),
"last_modified_date": fields.DateTime(),
},
)
# Response model of connection operation
dict_field = api.schema_model("ConnectionDict", {"additionalProperties": True, "type": "object"})


@api.errorhandler(ConnectionNotFoundError)
def handle_connection_not_found_exception(error):
api.logger.warning(f"Raise ConnectionNotFoundError, {error.message}")
return {"error_message": error.message}, 404


@api.route("/")
class ConnectionList(Resource):
@api.doc(parser=remote_parser, description="List all connection")
@api.marshal_with(list_connection_field, skip_none=True, as_list=True)
@local_user_only
def get(self):
connection_op = ConnectionOperations()
# parse query parameters
max_results = request.args.get("max_results", default=50, type=int)
all_results = request.args.get("all_results", default=False, type=bool)

connections = connection_op.list(max_results=max_results, all_results=all_results)
connections_dict = [connection._to_dict() for connection in connections]
return connections_dict


@api.route("/<string:name>")
@api.param("name", "The connection name.")
class Connection(Resource):
@api.doc(parser=remote_parser, description="Get connection")
@api.response(code=200, description="Connection details", model=dict_field)
@local_user_only
def get(self, name: str):
connection_op = ConnectionOperations()
# parse query parameters
with_secrets = request.args.get("with_secrets", default=False, type=bool)
raise_error = request.args.get("raise_error", default=True, type=bool)

connection = connection_op.get(name=name, with_secrets=with_secrets, raise_error=raise_error)
connection_dict = connection._to_dict()
return jsonify(connection_dict)

@api.doc(parser=create_or_update_parser, description="Create connection")
@api.response(code=200, description="Connection details", model=dict_field)
@local_user_only
def post(self, name: str):
connection_op = ConnectionOperations()
args = create_or_update_parser.parse_args()
connection_data = json.loads(args["connection_dict"])
connection_data["name"] = name
connection = _Connection._load(data=connection_data)
connection = connection_op.create_or_update(connection)
return jsonify(connection._to_dict())

@api.doc(parser=create_or_update_parser, description="Update connection")
@api.response(code=200, description="Connection details", model=dict_field)
@local_user_only
def put(self, name: str):
connection_op = ConnectionOperations()
args = create_or_update_parser.parse_args()
params_override = [{k: v} for k, v in json.loads(args["connection_dict"]).items()]
existing_connection = connection_op.get(name)
connection = _Connection._load(data=existing_connection._to_dict(), params_override=params_override)
connection._secrets = existing_connection._secrets
connection = connection_op.create_or_update(connection)
return jsonify(connection._to_dict())

@api.doc(parser=remote_parser, description="Delete connection")
@local_user_only
def delete(self, name: str):
connection_op = ConnectionOperations()
connection_op.delete(name=name)
82 changes: 82 additions & 0 deletions src/promptflow/promptflow/_sdk/_service/apis/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from dataclasses import asdict

from flask import jsonify, request
from flask_restx import Namespace, Resource

from promptflow._sdk._constants import FlowRunProperties, get_list_view_type
from promptflow._sdk._errors import RunNotFoundError
from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations
from promptflow._sdk.operations._run_operations import RunOperations
from promptflow.contracts._run_management import RunMetadata

api = Namespace("Runs", description="Runs Management")


@api.errorhandler(RunNotFoundError)
def handle_run_not_found_exception(error):
api.logger.warning(f"Raise RunNotFoundError, {error.message}")
return {"error_message": error.message}, 404


@api.route("/")
class RunList(Resource):
@api.doc(description="List all runs")
def get(self):
# parse query parameters
max_results = request.args.get("max_results", default=50, type=int)
all_results = request.args.get("all_results", default=False, type=bool)
archived_only = request.args.get("archived_only", default=False, type=bool)
include_archived = request.args.get("include_archived", default=False, type=bool)
# align with CLI behavior
if all_results:
max_results = None
list_view_type = get_list_view_type(archived_only=archived_only, include_archived=include_archived)

op = RunOperations()
runs = op.list(max_results=max_results, list_view_type=list_view_type)
runs_dict = [run._to_dict() for run in runs]
return jsonify(runs_dict)


@api.route("/<string:name>")
class Run(Resource):
def get(self, name: str):
op = RunOperations()
run = op.get(name=name)
run_dict = run._to_dict()
return jsonify(run_dict)


@api.route("/<string:name>/metadata")
class MetaData(Resource):
def get(self, name: str):
run_op = RunOperations()
run = run_op.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
metadata = RunMetadata(
name=run.name,
display_name=run.display_name,
create_time=run.created_on,
flow_path=run.properties[FlowRunProperties.FLOW_PATH],
output_path=run.properties[FlowRunProperties.OUTPUT_PATH],
tags=run.tags,
lineage=run.run,
metrics=local_storage_op.load_metrics(),
dag=local_storage_op.load_dag_as_string(),
flow_tools_json=local_storage_op.load_flow_tools_json(),
)
return jsonify(asdict(metadata))


@api.route("/<string:name>/detail")
class Detail(Resource):
def get(self, name: str):
run_op = RunOperations()
run = run_op.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
detail_dict = local_storage_op.load_detail()
return jsonify(detail_dict)
44 changes: 36 additions & 8 deletions src/promptflow/promptflow/_sdk/_service/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import logging
from flask import Blueprint, Flask, jsonify
from flask_restx import Api
from werkzeug.exceptions import HTTPException

from flask import Flask, jsonify

from promptflow._sdk._service.connection import connection_bp
from promptflow._sdk._service.run import run_bp
from promptflow._sdk._utils import get_promptflow_sdk_version
from promptflow._sdk._constants import HOME_PROMPT_FLOW_DIR, PF_SERVICE_LOG_FILE
from promptflow._sdk._service.apis.connection import api as connection_api
from promptflow._sdk._service.apis.run import api as run_api
from promptflow._sdk._utils import get_promptflow_sdk_version, read_write_by_user


def heartbeat():
Expand All @@ -17,6 +20,31 @@ def heartbeat():
def create_app():
app = Flask(__name__)
app.add_url_rule("/heartbeat", view_func=heartbeat)
app.register_blueprint(run_bp)
app.register_blueprint(connection_bp)
return app
with app.app_context():
api_v1 = Blueprint("Prompt Flow Service", __name__, url_prefix="/v1.0")

# Registers resources from namespace for current instance of api
api = Api(api_v1, title="Prompt Flow Service", version="1.0")
api.add_namespace(connection_api)
api.add_namespace(run_api)
app.register_blueprint(api_v1)

# Disable flask-restx set X-Fields in header. https://flask-restx.readthedocs.io/en/latest/mask.html#usage
app.config["RESTX_MASK_SWAGGER"] = False

# Enable log
app.logger.setLevel(logging.INFO)
log_file = HOME_PROMPT_FLOW_DIR / PF_SERVICE_LOG_FILE
log_file.touch(mode=read_write_by_user(), exist_ok=True)
handler = logging.FileHandler(filename=log_file)
app.logger.addHandler(handler)

# Basic error handler
@app.errorhandler(Exception)
def handle_exception(e):
if isinstance(e, HTTPException):
return e
app.logger.error(e, exc_info=True, stack_info=True)
return jsonify({"error_message": "Internal Server Error"}), 500

return app, api
36 changes: 0 additions & 36 deletions src/promptflow/promptflow/_sdk/_service/connection.py

This file was deleted.

9 changes: 6 additions & 3 deletions src/promptflow/promptflow/_sdk/_service/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
# ---------------------------------------------------------
import argparse
import sys

import waitress
import yaml

from promptflow._sdk._constants import HOME_PROMPT_FLOW_DIR, PF_SERVICE_PORT_FILE
from promptflow._sdk._service.app import create_app
from promptflow._sdk._service.utils import get_random_port, is_port_in_use
from promptflow._sdk._service.utils.utils import get_random_port, is_port_in_use
from promptflow._sdk._utils import read_write_by_user
from promptflow.exceptions import UserErrorException

Expand All @@ -24,10 +23,13 @@ def main():
)

parser.add_argument("-p", "--port", type=int, help="port of the promptflow service")

args = parser.parse_args(command_args)
port = args.port
app, _ = create_app()

if port and is_port_in_use(port):
app.logger.warning(f"Service port {port} is used.")
raise UserErrorException(f"Service port {port} is used.")
if not port:
(HOME_PROMPT_FLOW_DIR / PF_SERVICE_PORT_FILE).touch(mode=read_write_by_user(), exist_ok=True)
Expand All @@ -42,10 +44,11 @@ def main():
service_config["service"]["port"] = port
yaml.dump(service_config, f)

app = create_app()
if is_port_in_use(port):
app.logger.warning(f"Service port {port} is used.")
raise UserErrorException(f"Service port {port} is used.")
# Set host to localhost, only allow request from localhost.
app.logger.info(f"Start Prompt Flow Service on http://localhost:{port}")
waitress.serve(app, host="127.0.0.1", port=port)


Expand Down
Loading

0 comments on commit bf0f519

Please sign in to comment.