Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming chat #12

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
Expand Down
6 changes: 4 additions & 2 deletions ipython_gpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(self, *args, **kwargs):
def chat(self, line, cell):
cmd = ChatCommand(self._context)
result = cmd.execute(line, cell)
self.display.display(result)
for i in result:
self.display.display(i)

@line_magic
def chat_config(self, line):
Expand All @@ -37,7 +38,8 @@ def chat_config(self, line):
def chat_models(self, line):
cmd = ChatModelsBrowserCommand(self._context)
result = cmd.execute(line)
self.display.display(result)
for i in result:
self.display.display(i)


name = "ipython_gpt"
Expand Down
59 changes: 34 additions & 25 deletions ipython_gpt/api_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import http.client
# import http.client
import json
import urllib.parse
import requests

OPEN_AI_API_HOST = "api.openai.com"
OPEN_AI_API_PORT = 443
Expand Down Expand Up @@ -30,48 +31,56 @@ def __str__(self):
class UnauthorizedAPIException(APIResponseException):
pass

class IPythonGPTResponse():
def __init__(self, data, is_streaming=False):
self.is_streaming = is_streaming
self.data = data

class OpenAIClient:
def __init__(self, openai_api_key, api_version=DEFAULT_API_VERSION):
self.openai_api_key = openai_api_key
self.api_version = api_version

def request(self, method, path, headers=None, query_params=None, json_body=None):
def request(self, method, path, headers=None, query_params=None, json_body=None, stream=False):
method = method.upper()
assert path.startswith("/"), "Invalid path"
assert not path.startswith(
"/v"
), "API Version must be specified at moment of client creation"

connection = http.client.HTTPSConnection(
host=OPEN_AI_API_HOST, port=OPEN_AI_API_PORT
)

headers = headers or {}
headers.setdefault("Authorization", f"Bearer {self.openai_api_key}")
headers.setdefault("Content-Type", "application/json")

body = None
if json_body:
json_body = json_body.copy()
if stream:
json_body["stream"] = True
body = json.dumps(json_body)

path = f"/{self.api_version}" + path
url = f"https://{OPEN_AI_API_HOST}:{OPEN_AI_API_PORT}/{self.api_version}{path}"
if query_params is not None:
path += "?" + urllib.parse.urlencode(query_params)

try:
connection.request(method, path, body, headers)
resp = connection.getresponse()
resp_body = resp.read()

# TODO: this might raise an exception for an invalid body
content = json.loads(resp_body.decode("utf-8"))
if 200 <= resp.status < 300:
return content
if resp.status == 401:
raise UnauthorizedAPIException(method, path, resp.status, content)

# Catch all exception for any other not known status
raise APIResponseException(method, path, resp.status, content)
finally:
connection.close()
url += "?" + urllib.parse.urlencode(query_params)
resp = requests.request(method, url, headers=headers, data=body, stream=stream)

if 200 <= resp.status_code < 300:
yield from self._post_process_response(resp, stream=stream)
return
if resp.status_code == 401:
raise UnauthorizedAPIException(method, path, resp.status_code, resp.json())

# Catch all exception for any other not known status
raise APIResponseException(method, path, resp.status_code, resp.json())

def _post_process_response(self, response, stream=False):
"""The pattern is borrowed from studying OpenAI's code for api-requestor"""
if stream and "text/event-stream" in response.headers.get("Content-Type", ""):
# TODO: Better handle errors for streaming responses as well.
for line in response.iter_lines():
decoded_line = line.decode("utf-8")
if '[DONE]' not in decoded_line and decoded_line:
json_line = json.loads(decoded_line[len('data: '):])
yield IPythonGPTResponse(json_line, True)
else:
yield IPythonGPTResponse(response.json(), False)
2 changes: 1 addition & 1 deletion ipython_gpt/displays.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def display(self, results):

class ShellDisplay(BaseDisplay):
def display(self, results):
print(results)
print(results, end="", flush=True)


DISPLAY_METHODS = {
Expand Down
47 changes: 40 additions & 7 deletions ipython_gpt/subcommands.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import argparse
import shlex
import json
from typing import Generator

from .api_client import OpenAIClient
from .api_client import IPythonGPTResponse, OpenAIClient


class BaseIPythonGPTCommand:
Expand Down Expand Up @@ -50,7 +52,10 @@ def execute(self, line, cell=None):
assert bool(openai_api_key), "OPENAI_API_KEY missing"
client = OpenAIClient(openai_api_key)
results = self._execute(client, args, line, cell)
return results
if isinstance(results, Generator):
yield from results
else:
yield from [results]


class ChatCommand(BaseIPythonGPTCommand):
Expand All @@ -72,10 +77,16 @@ def _customize_parser(self, parser):
help="The maximum number of tokens to generate in the chat completion.",
type=int,
)
parser.add_argument(
"--stream",
action="store_true",
help="Stream output to the console.",
)
return parser

def _execute(self, client, args, line, cell=None):
message_history = self.context["message_history"]
stream = args.stream or False
if args.reset_conversation:
message_history = []

Expand All @@ -97,13 +108,31 @@ def _execute(self, client, args, line, cell=None):
if args.max_tokens:
json_body["max_tokens"] = args.max_tokens

resp = client.request("POST", "/chat/completions", json_body=json_body)
chat_response = resp["choices"][0]["message"]["content"]
resp = client.request("POST", "/chat/completions", json_body=json_body, stream=stream)
# TEST ME
yield from self._from_wrapped_response(resp, message_history)


def _from_wrapped_response(self, wrapper_generator, message_history):
message = ""
for wrapper in wrapper_generator:
assert isinstance(wrapper, IPythonGPTResponse)
if wrapper.is_streaming:
json_line = wrapper.data
if 'choices' in json_line:
content = json_line['choices'][0]['delta'].get('content', '') or ''
# print(content, end="")
message += content
yield content
else:
json_content = wrapper.data
message = json_content["choices"][0]["message"]["content"]
yield message

message_history += [
{"role": "assistant", "content": chat_response},
{"role": "assistant", "content": message},
]
self.context["message_history"] = message_history
return chat_response


class ConfigCommand(BaseIPythonGPTCommand):
Expand Down Expand Up @@ -140,9 +169,13 @@ def _customize_parser(self, parser):

def _execute(self, client, args, line, cell):
resp = client.request("GET", "/models")
resp = next(resp)
assert isinstance(resp, IPythonGPTResponse)
assert resp.is_streaming is False
json_data = resp.data
models = [
m["id"]
for m in resp["data"]
for m in json_data["data"]
if args.all_models or m["id"].startswith("gpt")
]
formatted_models = "\n".join([f"\t- {model}" for model in models])
Expand Down
Loading