Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CI/CD tests #3289

Merged
merged 25 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ jobs:

- name: Run e2e Tests
env:
MONGODB_URI: ${{ secrets.MONGODB_URI }}
GITHUB_PAT: ${{ secrets.GH_PAT }}
GITHUB_APP_ID: ${{ secrets.GH_APP_ID }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
Expand Down
8 changes: 5 additions & 3 deletions sweepai/agents/assistant_function_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def save_ticket_progress(assistant_id: str, thread_id: str, run_id: str):
try:
done_counter = 0
tool_name, tool_call = assistant_generator.send(None)
for i in range(10000):
for i in range(100): # TODO: tune this parameter
print(tool_name, json.dumps(tool_call, indent=2))
if tool_name == "done":
diff = generate_diff(file_contents, current_contents)
Expand Down Expand Up @@ -335,8 +335,10 @@ def save_ticket_progress(assistant_id: str, thread_id: str, run_id: str):

# Check if the changes are valid
if not error_message:
is_valid, message = check_code(
file_path, current_new_contents
is_valid, message = (
(True, "")
if not initial_code_valid
else check_code(file_path, current_new_contents)
)
current_diff = generate_diff(
new_contents, current_new_contents
Expand Down
69 changes: 59 additions & 10 deletions sweepai/agents/assistant_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import json
import traceback
from time import sleep
from typing import Callable
from typing import Callable, Optional

import openai
from loguru import logger
from openai import AzureOpenAI, OpenAI
from openai.pagination import SyncCursorPage
from openai.types.beta.threads.thread_message import ThreadMessage
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from pydantic import BaseModel

from sweepai.config.server import (
Expand Down Expand Up @@ -68,6 +72,43 @@ def openai_retry_with_timeout(call, *args, num_retries=3, timeout=5, **kwargs):
) from e


def fix_tool_calls(tool_calls: Optional[list[ChatCompletionMessageToolCall]]):
if tool_calls is None:
return

fixed_tool_calls = []

for tool_call in tool_calls:
current_function = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
logger.error(
f"Error: could not decode function arguments: {tool_call.function.args}"
)
fixed_tool_calls.append(tool_call)
continue
if current_function in ("parallel", "multi_tool_use.parallel"):
for _fake_i, _fake_tool_use in enumerate(function_args["tool_uses"]):
_function_args = _fake_tool_use["parameters"]
_current_function = _fake_tool_use["recipient_name"]
if _current_function.startswith("functions."):
_current_function = _current_function[len("functions.") :]

fixed_tc = ChatCompletionMessageToolCall(
id=f"{tool_call.id}_{_fake_i}",
type="function",
function=Function(
name=_current_function, arguments=json.dumps(_function_args)
),
)
fixed_tool_calls.append(fixed_tc)
else:
fixed_tool_calls.append(tool_call)

return fixed_tool_calls


save_ticket_progress_type = Callable[[str, str, str], None]


Expand Down Expand Up @@ -244,7 +285,8 @@ def run_until_complete(
continue

response_message = response.choices[0].message
tool_calls = response_message.tool_calls
tool_calls = fix_tool_calls(response_message.tool_calls)
response_message.tool_calls = tool_calls
# extend conversation
response_message_dict = response_message.dict()
# in some cases the fields are None and we must replace these with empty strings
Expand All @@ -262,8 +304,18 @@ def run_until_complete(
if tool_calls:
for tool_call in tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
tool_output = yield function_name, function_args
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logger.debug(
f"Error: could not decode function arguments: {tool_call.function.args}"
)
tool_output = f"ERROR\nCould not decode function arguments:\n{e}"
else:
logger.debug(
f"tool_call: {function_name} with args: {function_args}"
)
tool_output = yield function_name, function_args
messages.append(
{
"tool_call_id": tool_call.id,
Expand Down Expand Up @@ -363,12 +415,9 @@ def openai_assistant_call(
assistant_name: str | None = None,
save_ticket_progress: save_ticket_progress_type | None = None,
):
model = (
"gpt-3.5-turbo-1106"
if (chat_logger is None or chat_logger.use_faster_model())
and not IS_SELF_HOSTED
else DEFAULT_GPT4_32K_MODEL
)
if chat_logger.use_faster_model():
raise Exception("GPT-3.5 is not supported on assistant calls.")
model = DEFAULT_GPT4_32K_MODEL
posthog.capture(
chat_logger.data.get("username") if chat_logger is not None else "anonymous",
"call_assistant_api",
Expand Down
120 changes: 120 additions & 0 deletions sweepai/agents/assistant_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import json
import unittest

from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)

from sweepai.agents.assistant_wrapper import fix_tool_calls

if __name__ == "__main__":
tool_call = ChatCompletionMessageToolCall(
id="tool_call_id",
function=Function(
arguments="arguments",
name="function_name",
),
type="function",
)


class TestFixToolCalls(unittest.TestCase):
def test_multiple_tool_calls(self):
# Setup input with multiple tool calls, including more than one 'parallel' tool calls
input_tool_calls = [
ChatCompletionMessageToolCall(
id="1",
type="function",
function={
"name": "parallel",
"arguments": json.dumps(
{
"tool_uses": [
{
"recipient_name": "functions.example_function",
"parameters": {"arg1": "value1"},
},
{
"recipient_name": "functions.example_function",
"parameters": {"arg1": "value2"},
},
]
}
),
},
),
ChatCompletionMessageToolCall(
id="2",
type="function",
function={
"name": "example_tool",
"arguments": json.dumps({"arg2": "value2"}),
},
),
ChatCompletionMessageToolCall(
id="3",
type="function",
function={
"name": "parallel",
"arguments": json.dumps(
{
"tool_uses": [
{
"recipient_name": "functions.another_function",
"parameters": {"arg3": "value3"},
}
]
}
),
},
),
]

# Expected tool calls after fix
expected_tool_calls = [
ChatCompletionMessageToolCall(
id="1_0",
type="function",
function={
"name": "example_function",
"arguments": json.dumps({"arg1": "value1"}),
},
),
ChatCompletionMessageToolCall(
id="1_1",
type="function",
function={
"name": "example_function",
"arguments": json.dumps({"arg1": "value2"}),
},
),
ChatCompletionMessageToolCall(
id="2",
type="function",
function={
"name": "example_tool",
"arguments": json.dumps({"arg2": "value2"}),
},
),
ChatCompletionMessageToolCall(
id="3_0",
type="function",
function={
"name": "another_function",
"arguments": json.dumps({"arg3": "value3"}),
},
),
]

# Run the fix_tool_calls function
output_tool_calls = fix_tool_calls(input_tool_calls)
self.assertEqual(len(output_tool_calls), len(expected_tool_calls))
for actual, expected in zip(output_tool_calls, expected_tool_calls):
self.assertEqual(actual.id, expected.id)
self.assertEqual(actual.type, expected.type)
self.assertEqual(actual.function, expected.function)


if __name__ == "__main__":
unittest.main()
2 changes: 0 additions & 2 deletions sweepai/agents/pr_description_bot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re

from sweepai.config.server import DEFAULT_GPT35_MODEL
from sweepai.core.chat import ChatGPT

prompt = """\
Expand Down Expand Up @@ -30,7 +29,6 @@ def describe_diffs(
pr_title,
):
self.messages = []
self.model = DEFAULT_GPT35_MODEL
# attempt to generate description 3 times
pr_desc_pattern = r"<pr_description>\n(.*?)\n</pr_description>"
for attempt in [0, 1, 2]:
Expand Down
5 changes: 0 additions & 5 deletions sweepai/config/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,6 @@
)


OPENAI_USE_3_5_MODEL_ONLY = (
os.environ.get("OPENAI_USE_3_5_MODEL_ONLY", "false").lower() == "true"
)


MONGODB_URI = os.environ.get("MONGODB_URI", None)
IS_SELF_HOSTED = bool(os.environ.get("IS_SELF_HOSTED", MONGODB_URI is None))

Expand Down
12 changes: 8 additions & 4 deletions sweepai/core/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from math import inf
import traceback
from typing import Any, Literal

Expand All @@ -8,7 +9,6 @@
from sweepai.config.client import get_description
from sweepai.config.server import (
DEFAULT_GPT4_32K_MODEL,
DEFAULT_GPT35_MODEL,
)
from sweepai.core.entities import Message
from sweepai.core.prompts import repo_description_prefix_prompt, system_message_prompt
Expand Down Expand Up @@ -192,9 +192,11 @@ def call_openai(
and not self.chat_logger.is_paying_user()
and not self.chat_logger.is_consumer_tier()
):
model = DEFAULT_GPT35_MODEL
raise ValueError(
"You have no more tickets! Please upgrade to a paid plan."
)
else:
tickets_allocated = 120 if self.chat_logger.is_paying_user() else 5
tickets_allocated = inf if self.chat_logger.is_paying_user() else 5
tickets_count = self.chat_logger.get_ticket_count()
purchased_tickets = self.chat_logger.get_ticket_count(purchased=True)
if tickets_count < tickets_allocated:
Expand All @@ -208,7 +210,9 @@ def call_openai(
f"{purchased_tickets} purchased tickets found in MongoDB, using {model}"
)
else:
model = DEFAULT_GPT35_MODEL
raise ValueError(
f"Tickets allocated: {tickets_allocated}, tickets found: {tickets_count}. You have no more tickets!"
)

count_tokens = Tiktoken().count
messages_length = sum(
Expand Down
15 changes: 5 additions & 10 deletions sweepai/core/context_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
AZURE_API_KEY,
AZURE_OPENAI_DEPLOYMENT,
DEFAULT_GPT4_32K_MODEL,
IS_SELF_HOSTED,
OPENAI_API_BASE,
OPENAI_API_KEY,
OPENAI_API_TYPE,
OPENAI_API_VERSION,
)
from sweepai.core.entities import AssistantRaisedException, Snippet
from sweepai.core.entities import Snippet
from sweepai.logn.cache import file_cache
from sweepai.utils.chat_logger import ChatLogger, discord_log_error
from sweepai.utils.code_tree import CodeTree
from sweepai.utils.event_logger import posthog
from sweepai.utils.github_utils import ClonedRepo
from sweepai.utils.progress import AssistantConversation, TicketProgress
from sweepai.utils.str_utils import FASTER_MODEL_MESSAGE
from sweepai.utils.tree_utils import DirectoryTree

if OPENAI_API_TYPE == "openai":
Expand Down Expand Up @@ -426,12 +426,9 @@ def get_relevant_context(
ticket_progress: TicketProgress | None = None,
chat_logger: ChatLogger = None,
):
model = (
"gpt-3.5-turbo-1106"
if (chat_logger is None or chat_logger.use_faster_model())
and not IS_SELF_HOSTED
else DEFAULT_GPT4_32K_MODEL
)
if chat_logger.use_faster_model():
raise Exception(FASTER_MODEL_MESSAGE)
model = DEFAULT_GPT4_32K_MODEL
posthog.capture(
chat_logger.data.get("username") if chat_logger is not None else "anonymous",
"call_assistant_api",
Expand Down Expand Up @@ -550,8 +547,6 @@ def modify_context(
time.sleep(3)
continue
num_tool_calls_made += 1
if num_tool_calls_made > 15 and model.startswith("gpt-3.5"):
raise AssistantRaisedException("Too many tool calls made on gpt-3.5.")
tool_calls = run.required_action.submit_tool_outputs.tool_calls
tool_outputs = []
for tool_call in tool_calls:
Expand Down
2 changes: 0 additions & 2 deletions sweepai/core/external_searcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re

from sweepai.config.server import DEFAULT_GPT35_MODEL
from sweepai.core.chat import ChatGPT
from sweepai.core.entities import Message
from sweepai.core.prompts import external_search_prompt, external_search_system_prompt
Expand All @@ -18,7 +17,6 @@ def extract_summary_from_link(self, url: str, problem: str) -> str:
page_metadata = extract_info(url)

self.messages = [Message(role="system", content=external_search_system_prompt)]
self.model = DEFAULT_GPT35_MODEL # can be optimized
response = self.chat(
external_search_prompt.format(
page_metadata=page_metadata,
Expand Down
Loading
Loading