From 969bc84d4e20c487ad1112a361018306dd2b9572 Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Thu, 14 Mar 2024 18:57:14 -0700 Subject: [PATCH 1/2] Reverted assistant API --- sweepai/agents/assistant_wrapper.py | 351 +++++++++++++++++++++++++++- 1 file changed, 347 insertions(+), 4 deletions(-) diff --git a/sweepai/agents/assistant_wrapper.py b/sweepai/agents/assistant_wrapper.py index 0d8a9cfecd..2261441850 100644 --- a/sweepai/agents/assistant_wrapper.py +++ b/sweepai/agents/assistant_wrapper.py @@ -1,9 +1,13 @@ import json +import os +import re +import time import traceback from time import sleep from typing import Callable, Optional import openai +from anyio import Path from loguru import logger from openai import AzureOpenAI, OpenAI from openai.pagination import SyncCursorPage @@ -14,6 +18,7 @@ ) from pydantic import BaseModel +from sweepai.agents.assistant_functions import raise_error_schema from sweepai.config.server import ( AZURE_API_KEY, DEFAULT_GPT4_32K_MODEL, @@ -241,6 +246,344 @@ def get_json_messages( def run_until_complete( + thread_id: str, + run_id: str, + assistant_id: str, + model: str = DEFAULT_GPT4_32K_MODEL, + chat_logger: ChatLogger | None = None, + sleep_time: int = 3, + max_iterations: int = 2000, + save_ticket_progress: save_ticket_progress_type | None = None, +): + message_strings = [] + json_messages = [] + try: + num_tool_calls_made = 0 + for i in range(max_iterations): + run = openai_retry_with_timeout( + client.beta.threads.runs.retrieve, + thread_id=thread_id, + run_id=run_id, + ) + if run.status == "completed": + logger.info( + f"Run completed with {run.status} (i={num_tool_calls_made})" + ) + done_response = yield "done", { + "status": "completed", + "message": "Run completed successfully", + } + if not done_response: + break + else: + run = client.beta.threads.runs.create( + thread_id=thread_id, + assistant_id=assistant_id, + instructions=done_response, + model=model, + ) + elif run.status in ("cancelled", "cancelling", "failed", "expired"): + logger.info( + f"Run completed with {run.status} (i={num_tool_calls_made}) and reason {run.last_error}." + ) + done_response = yield "done", { + "status": run.status, + "message": "Run failed", + } + if not done_response: + raise Exception( + f"Run failed assistant_id={assistant_id}, run_id={run_id}, thread_id={thread_id} with status {run.status} (i={num_tool_calls_made})" + ) + else: + run = client.beta.threads.runs.create( + thread_id=thread_id, + assistant_id=assistant_id, + instructions=done_response, + model=model, + ) + elif run.status == "requires_action": + num_tool_calls_made += 1 + if num_tool_calls_made > 15 and model.startswith("gpt-3.5"): + raise AssistantRaisedException( + "Too many tool calls made on GPT 3.5." + ) + raw_tool_calls = [ + tool_call + for tool_call in run.required_action.submit_tool_outputs.tool_calls + ] + tool_outputs = [] + tool_calls = [] + if any( + [ + tool_call.function.name == raise_error_schema["name"] + for tool_call in raw_tool_calls + ] + ): + arguments_parsed = json.loads(tool_calls[0].function.arguments) + raise AssistantRaisedException(arguments_parsed["message"]) + # tool_calls = raw_tool_calls + for tool_call in raw_tool_calls: + try: + tool_call_arguments = re.sub( + r"\\+'", "", tool_call.function.arguments + ) + function_input: dict = json.loads(tool_call_arguments) + except Exception: + logger.warning( + f"Could not parse function arguments (i={num_tool_calls_made}): {tool_call_arguments}" + ) + tool_outputs.append( + { + "tool_call_id": tool_call.id, + "output": "FAILURE: Could not parse function arguments.", + } + ) + continue + tool_function_name = tool_call.function.name + tool_function_input = function_input + # OpenAI has a bug where it calls the imaginary function "multi_tool_use.parallel" + # Based on https://github.com/phdowling/openai_multi_tool_use_parallel_patch/blob/main/openai_multi_tool_use_parallel_patch.py + if tool_function_name in ("multi_tool_use.parallel", "parallel"): + for fake_i, fake_tool_use in function_input["tool_uses"]: + function_input = fake_tool_use["parameters"] + function_name: str = fake_tool_use["recipient_name"] + function_name = function_name.removeprefix("functions.") + tool_calls.append( + ( + f"{tool_call.id}_{fake_i}", + function_name, + function_input, + ) + ) + else: + tool_calls.append( + (tool_call.id, tool_function_name, tool_function_input) + ) + + for tool_call_id, tool_function_name, tool_function_input in tool_calls: + tool_output = yield tool_function_name, tool_function_input + tool_output_formatted = { + "tool_call_id": tool_call_id, + "output": tool_output, + } + tool_outputs.append(tool_output_formatted) + run = openai_retry_with_timeout( + client.beta.threads.runs.submit_tool_outputs, + thread_id=thread_id, + run_id=run.id, + tool_outputs=tool_outputs, + ) + if save_ticket_progress is not None: + save_ticket_progress( + assistant_id=assistant_id, + thread_id=thread_id, + run_id=run_id, + ) + messages = openai_retry_with_timeout( + client.beta.threads.messages.list, + thread_id=thread_id, + ) + current_message_strings = [ + message.content[0].text.value if message.content else "" + for message in messages.data + ] + if message_strings != current_message_strings and current_message_strings: + logger.info(run.status) + logger.info(current_message_strings[0]) + message_strings = current_message_strings + json_messages = get_json_messages( + thread_id=thread_id, + run_id=run_id, + assistant_id=assistant_id, + ) + if chat_logger is not None: + chat_logger.add_chat( + { + "model": model, + "messages": json_messages, + "output": message_strings[0], + "thread_id": thread_id, + "run_id": run_id, + "max_tokens": 1000, + "temperature": 0, + } + ) + else: + if i % 5 == 0: + logger.info(run.status) + if i == max_iterations - 1: + logger.warning( + f"run_until_complete hit max iterations, run.status is {run.status}" + ) + time.sleep(sleep_time) + except (KeyboardInterrupt, SystemExit): + client.beta.threads.runs.cancel(thread_id=thread_id, run_id=run_id) + logger.warning(f"Run cancelled: {run_id} (n={num_tool_calls_made})") + raise SystemExit + if save_ticket_progress is not None: + save_ticket_progress( + assistant_id=assistant_id, + thread_id=thread_id, + run_id=run_id, + ) + for json_message in json_messages: + logger.info(f'(n={num_tool_calls_made}) {json_message["content"]}') + return client.beta.threads.messages.list( + thread_id=thread_id, + ) + + +def openai_assistant_call_helper( + request: str, + instructions: str | None = None, + additional_messages: list[Message] = [], + file_paths: list[str] = [], # use either file_paths or file_ids + uploaded_file_ids: list[str] = [], + tools: list[dict[str, str]] = [{"type": "code_interpreter"}], + model: str = DEFAULT_GPT4_32K_MODEL, + sleep_time: int = 3, + chat_logger: ChatLogger | None = None, + assistant_id: str | None = None, + assistant_name: str | None = None, + save_ticket_progress: save_ticket_progress_type | None = None, +): + file_ids = [] if not uploaded_file_ids else uploaded_file_ids + file_object = None + if not file_ids: + for file_path in file_paths: + if not any(file_path.endswith(extension) for extension in allowed_exts): + os.rename(file_path, file_path + ".txt") + file_path += ".txt" + file_object = client.files.create( + file=Path(file_path), purpose="assistants" + ) + file_ids.append(file_object.id) + + logger.debug(instructions) + # always create new one + assistant = openai_retry_with_timeout( + client.beta.assistants.create, + name=assistant_name, + instructions=instructions, + tools=tools, + model=model, + ) + thread = client.beta.threads.create() + if file_ids: + logger.info("Uploading files...") + if request: + client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content=request, + file_ids=file_ids, + ) + if file_ids: + logger.info("Files uploaded") + for message in additional_messages: + client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content=message.content, + ) + run = client.beta.threads.runs.create( + thread_id=thread.id, + assistant_id=assistant.id, + instructions=instructions, + model=model, + ) + if len(tools) > 1: + return run_until_complete( + thread_id=thread.id, + run_id=run.id, + model=model, + chat_logger=chat_logger, + assistant_id=assistant.id, + sleep_time=sleep_time, + save_ticket_progress=save_ticket_progress, + ) + for file_id in file_ids: + client.files.delete(file_id=file_id) + return ( + assistant.id, + run.id, + thread.id, + ) + + +# Split in two so it can be cached +def openai_assistant_call( + request: str, + instructions: str | None = None, + additional_messages: list[Message] = [], + file_paths: list[str] = [], + uploaded_file_ids: list[str] = [], + tools: list[dict[str, str]] = [{"type": "code_interpreter"}], + model: str = DEFAULT_GPT4_32K_MODEL, + sleep_time: int = 3, + chat_logger: ChatLogger | None = None, + assistant_id: str | None = None, + assistant_name: str | None = None, + save_ticket_progress: save_ticket_progress_type | None = None, +): + model = ( + "gpt-3.5-turbo-1106" + if (chat_logger is None or chat_logger.use_faster_model()) + and not IS_SELF_HOSTED + else DEFAULT_GPT4_32K_MODEL + ) + posthog.capture( + chat_logger.data.get("username") if chat_logger is not None else "anonymous", + "call_assistant_api", + { + "query": request, + "model": model, + "username": ( + chat_logger.data.get("username", "anonymous") + if chat_logger is not None + else "anonymous" + ), + "is_self_hosted": IS_SELF_HOSTED, + "trace": "".join(traceback.format_list(traceback.extract_stack())), + }, + ) + retries = range(3) + for _ in retries: + try: + response = openai_assistant_call_helper( + request=request, + instructions=instructions, + additional_messages=additional_messages, + file_paths=file_paths, + uploaded_file_ids=uploaded_file_ids, + tools=tools, + model=model, + sleep_time=sleep_time, + chat_logger=chat_logger, + assistant_id=assistant_id, + assistant_name=assistant_name, + save_ticket_progress=save_ticket_progress, + ) + if len(tools) > 1: + return response + (assistant_id, run_id, thread_id) = response + messages = client.beta.threads.messages.list( + thread_id=thread_id, + ) + return AssistantResponse( + messages=messages, + assistant_id=assistant_id, + run_id=run_id, + thread_id=thread_id, + ) + except AssistantRaisedException as e: + logger.warning(e.message) + except Exception as e: + logger.error(e) + raise e + + +def run_until_complete_unstable( tools: list[dict[str, str]], model: str = DEFAULT_GPT4_32K_MODEL, chat_logger: ChatLogger | None = None, @@ -367,7 +710,7 @@ def run_until_complete( # ) -def openai_assistant_call_helper( +def openai_assistant_call_helper_unstable( request: str, instructions: str | None = None, additional_messages: list[Message] = [], @@ -388,7 +731,7 @@ def openai_assistant_call_helper( # tools must always be > 1 if len(tools) > 1: - return run_until_complete( + return run_until_complete_unstable( tools=tools, messages=messages, model=model, @@ -401,7 +744,7 @@ def openai_assistant_call_helper( # Split in two so it can be cached -def openai_assistant_call( +def openai_assistant_call_unstable( request: str, instructions: str | None = None, additional_messages: list[Message] = [], @@ -436,7 +779,7 @@ def openai_assistant_call( retries = range(3) for _ in retries: try: - response = openai_assistant_call_helper( + response = openai_assistant_call_helper_unstable( request=request, instructions=instructions, additional_messages=additional_messages, From 1d126d895a196e4bea801ac0cfe481f43270da53 Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Thu, 14 Mar 2024 19:11:23 -0700 Subject: [PATCH 2/2] Max 1 run --- sweepai/handlers/on_ticket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sweepai/handlers/on_ticket.py b/sweepai/handlers/on_ticket.py index b13519f7af..4ab0b8d49a 100644 --- a/sweepai/handlers/on_ticket.py +++ b/sweepai/handlers/on_ticket.py @@ -412,7 +412,7 @@ def on_ticket( # we want to pass in the failing github action messages to the next run in order to fix them failing_gha_messages: list[Message] = [] # we rerun this logic 3 times at most if the github actions associated with the created pr fails - max_pr_attempts = 3 + max_pr_attempts = 1 for run_attempt in range(max_pr_attempts): if tracking_id is None: tracking_id = get_hash()