-
Notifications
You must be signed in to change notification settings - Fork 905
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PromptFlowService] Add flask_restx to pfs to generate swagger (#1123)
# Description 1. Add flask_restx to pfs to generate swagger 2. Update connection api Please add an informative description that covers that changes made by the pull request and link all relevant issues. # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes.
- Loading branch information
1 parent
277ee50
commit bf0f519
Showing
16 changed files
with
570 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,7 +149,8 @@ | |
"vcrpy", | ||
"uionly", | ||
"llmops", | ||
"Abhishek" | ||
"Abhishek", | ||
"restx" | ||
], | ||
"flagWords": [ | ||
"Prompt Flow" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Prompt Flow Service | ||
This document will describe the usage of pfs(prompt flow service) CLI. | ||
|
||
### Start prompt flow service (optional) | ||
If you don't install pfs as a service, you need to start pfs manually. | ||
pfs CLI provides **start** command to start service. You can also use this command to specify the service port. | ||
|
||
```commandline | ||
usage: pfs [-h] [-p PORT] | ||
Start prompt flow service. | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
-p PORT, --port PORT port of the promptflow service | ||
``` | ||
|
||
If you don't specify a port to start service, pfs will first use the port in the configure file in "~/.promptflow/pfs.port". | ||
|
||
If not found port configuration or the port is used, pfs will use a random port to start the service. | ||
|
||
### Swagger of service | ||
After start the service, it will provide Swagger UI documentation, served from "http://localhost:your-port/v1.0/swagger.json". | ||
|
||
For details, please refer to [swagger.json](./swagger.json). | ||
|
||
#### Generate C# client | ||
1. Right click the project, Add -> Rest API Client... -> Generate with OpenAPI Generator | ||
|
||
2. It will open a dialog, fill in the file name and swagger url, it will generate the client under the project. | ||
|
||
For details, please refer to [REST API Client Code Generator](https://marketplace.visualstudio.com/items?itemName=ChristianResmaHelle.ApiClientCodeGenerator2022). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# --------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# --------------------------------------------------------- | ||
|
||
__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore |
108 changes: 108 additions & 0 deletions
108
src/promptflow/promptflow/_sdk/_service/apis/connection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# --------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# --------------------------------------------------------- | ||
|
||
import json | ||
|
||
from flask import jsonify, request | ||
from flask_restx import Namespace, Resource, fields | ||
|
||
from promptflow._sdk._errors import ConnectionNotFoundError | ||
from promptflow._sdk._service.utils.utils import local_user_only | ||
from promptflow._sdk.entities._connection import _Connection | ||
from promptflow._sdk.operations._connection_operations import ConnectionOperations | ||
|
||
api = Namespace("Connections", description="Connections Management") | ||
|
||
# Define base connection request parsing | ||
remote_parser = api.parser() | ||
remote_parser.add_argument("X-Remote-User", location="headers", required=True) | ||
|
||
# Define create or update connection request parsing | ||
create_or_update_parser = remote_parser.copy() | ||
create_or_update_parser.add_argument("connection_dict", type=str, location="args", required=True) | ||
|
||
# Response model of list connections | ||
list_connection_field = api.model( | ||
"Connection", | ||
{ | ||
"name": fields.String, | ||
"type": fields.String, | ||
"module": fields.String, | ||
"expiry_time": fields.DateTime(), | ||
"created_date": fields.DateTime(), | ||
"last_modified_date": fields.DateTime(), | ||
}, | ||
) | ||
# Response model of connection operation | ||
dict_field = api.schema_model("ConnectionDict", {"additionalProperties": True, "type": "object"}) | ||
|
||
|
||
@api.errorhandler(ConnectionNotFoundError) | ||
def handle_connection_not_found_exception(error): | ||
api.logger.warning(f"Raise ConnectionNotFoundError, {error.message}") | ||
return {"error_message": error.message}, 404 | ||
|
||
|
||
@api.route("/") | ||
class ConnectionList(Resource): | ||
@api.doc(parser=remote_parser, description="List all connection") | ||
@api.marshal_with(list_connection_field, skip_none=True, as_list=True) | ||
@local_user_only | ||
def get(self): | ||
connection_op = ConnectionOperations() | ||
# parse query parameters | ||
max_results = request.args.get("max_results", default=50, type=int) | ||
all_results = request.args.get("all_results", default=False, type=bool) | ||
|
||
connections = connection_op.list(max_results=max_results, all_results=all_results) | ||
connections_dict = [connection._to_dict() for connection in connections] | ||
return connections_dict | ||
|
||
|
||
@api.route("/<string:name>") | ||
@api.param("name", "The connection name.") | ||
class Connection(Resource): | ||
@api.doc(parser=remote_parser, description="Get connection") | ||
@api.response(code=200, description="Connection details", model=dict_field) | ||
@local_user_only | ||
def get(self, name: str): | ||
connection_op = ConnectionOperations() | ||
# parse query parameters | ||
with_secrets = request.args.get("with_secrets", default=False, type=bool) | ||
raise_error = request.args.get("raise_error", default=True, type=bool) | ||
|
||
connection = connection_op.get(name=name, with_secrets=with_secrets, raise_error=raise_error) | ||
connection_dict = connection._to_dict() | ||
return jsonify(connection_dict) | ||
|
||
@api.doc(parser=create_or_update_parser, description="Create connection") | ||
@api.response(code=200, description="Connection details", model=dict_field) | ||
@local_user_only | ||
def post(self, name: str): | ||
connection_op = ConnectionOperations() | ||
args = create_or_update_parser.parse_args() | ||
connection_data = json.loads(args["connection_dict"]) | ||
connection_data["name"] = name | ||
connection = _Connection._load(data=connection_data) | ||
connection = connection_op.create_or_update(connection) | ||
return jsonify(connection._to_dict()) | ||
|
||
@api.doc(parser=create_or_update_parser, description="Update connection") | ||
@api.response(code=200, description="Connection details", model=dict_field) | ||
@local_user_only | ||
def put(self, name: str): | ||
connection_op = ConnectionOperations() | ||
args = create_or_update_parser.parse_args() | ||
params_override = [{k: v} for k, v in json.loads(args["connection_dict"]).items()] | ||
existing_connection = connection_op.get(name) | ||
connection = _Connection._load(data=existing_connection._to_dict(), params_override=params_override) | ||
connection._secrets = existing_connection._secrets | ||
connection = connection_op.create_or_update(connection) | ||
return jsonify(connection._to_dict()) | ||
|
||
@api.doc(parser=remote_parser, description="Delete connection") | ||
@local_user_only | ||
def delete(self, name: str): | ||
connection_op = ConnectionOperations() | ||
connection_op.delete(name=name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# --------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# --------------------------------------------------------- | ||
|
||
from dataclasses import asdict | ||
|
||
from flask import jsonify, request | ||
from flask_restx import Namespace, Resource | ||
|
||
from promptflow._sdk._constants import FlowRunProperties, get_list_view_type | ||
from promptflow._sdk._errors import RunNotFoundError | ||
from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations | ||
from promptflow._sdk.operations._run_operations import RunOperations | ||
from promptflow.contracts._run_management import RunMetadata | ||
|
||
api = Namespace("Runs", description="Runs Management") | ||
|
||
|
||
@api.errorhandler(RunNotFoundError) | ||
def handle_run_not_found_exception(error): | ||
api.logger.warning(f"Raise RunNotFoundError, {error.message}") | ||
return {"error_message": error.message}, 404 | ||
|
||
|
||
@api.route("/") | ||
class RunList(Resource): | ||
@api.doc(description="List all runs") | ||
def get(self): | ||
# parse query parameters | ||
max_results = request.args.get("max_results", default=50, type=int) | ||
all_results = request.args.get("all_results", default=False, type=bool) | ||
archived_only = request.args.get("archived_only", default=False, type=bool) | ||
include_archived = request.args.get("include_archived", default=False, type=bool) | ||
# align with CLI behavior | ||
if all_results: | ||
max_results = None | ||
list_view_type = get_list_view_type(archived_only=archived_only, include_archived=include_archived) | ||
|
||
op = RunOperations() | ||
runs = op.list(max_results=max_results, list_view_type=list_view_type) | ||
runs_dict = [run._to_dict() for run in runs] | ||
return jsonify(runs_dict) | ||
|
||
|
||
@api.route("/<string:name>") | ||
class Run(Resource): | ||
def get(self, name: str): | ||
op = RunOperations() | ||
run = op.get(name=name) | ||
run_dict = run._to_dict() | ||
return jsonify(run_dict) | ||
|
||
|
||
@api.route("/<string:name>/metadata") | ||
class MetaData(Resource): | ||
def get(self, name: str): | ||
run_op = RunOperations() | ||
run = run_op.get(name=name) | ||
local_storage_op = LocalStorageOperations(run=run) | ||
metadata = RunMetadata( | ||
name=run.name, | ||
display_name=run.display_name, | ||
create_time=run.created_on, | ||
flow_path=run.properties[FlowRunProperties.FLOW_PATH], | ||
output_path=run.properties[FlowRunProperties.OUTPUT_PATH], | ||
tags=run.tags, | ||
lineage=run.run, | ||
metrics=local_storage_op.load_metrics(), | ||
dag=local_storage_op.load_dag_as_string(), | ||
flow_tools_json=local_storage_op.load_flow_tools_json(), | ||
) | ||
return jsonify(asdict(metadata)) | ||
|
||
|
||
@api.route("/<string:name>/detail") | ||
class Detail(Resource): | ||
def get(self, name: str): | ||
run_op = RunOperations() | ||
run = run_op.get(name=name) | ||
local_storage_op = LocalStorageOperations(run=run) | ||
detail_dict = local_storage_op.load_detail() | ||
return jsonify(detail_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.