From 393eeb6f538aceacfda70471bc644a0c9fa9b290 Mon Sep 17 00:00:00 2001 From: martin ye Date: Thu, 14 Mar 2024 18:27:10 +0000 Subject: [PATCH 1/4] fixes --- sweepai/core/context_pruning.py | 3 ++- sweepai/core/repo_parsing_utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sweepai/core/context_pruning.py b/sweepai/core/context_pruning.py index 84b5245a37..f416b78e2f 100644 --- a/sweepai/core/context_pruning.py +++ b/sweepai/core/context_pruning.py @@ -484,7 +484,8 @@ def get_relevant_context( logger.exception(e) if len(repo_context_manager.current_top_snippets) == 0: repo_context_manager.current_top_snippets = old_top_snippets - discord_log_error(f"Context manager empty ({ticket_progress.tracking_id})") + if ticket_progress: + discord_log_error(f"Context manager empty ({ticket_progress.tracking_id})") return repo_context_manager except Exception as e: logger.exception(e) diff --git a/sweepai/core/repo_parsing_utils.py b/sweepai/core/repo_parsing_utils.py index 99c28be674..7d28bdfc67 100644 --- a/sweepai/core/repo_parsing_utils.py +++ b/sweepai/core/repo_parsing_utils.py @@ -84,7 +84,7 @@ def read_file(file_name: str) -> str: return "" -FILE_THRESHOLD = 120 +FILE_THRESHOLD = 240 def file_path_to_chunks(file_path: str) -> list[str]: From 50420f2302fe92c9d1cf919cc1285842751ebc7b Mon Sep 17 00:00:00 2001 From: martin ye Date: Thu, 14 Mar 2024 19:20:24 +0000 Subject: [PATCH 2/4] added banned files to get_file_list --- sweepai/utils/github_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sweepai/utils/github_utils.py b/sweepai/utils/github_utils.py index 49d7ce5c12..24587142ab 100644 --- a/sweepai/utils/github_utils.py +++ b/sweepai/utils/github_utils.py @@ -300,7 +300,7 @@ def list_directory_contents( def get_file_list(self) -> str: root_directory = self.repo_dir files = [] - + sweep_config: SweepConfig = SweepConfig() def dfs_helper(directory): nonlocal files for item in os.listdir(directory): @@ -308,7 +308,9 @@ def dfs_helper(directory): continue item_path = os.path.join(directory, item) if os.path.isfile(item_path): - files.append(item_path) # Add the file to the list + # make sure the item_path is not in one of the banned directories + if len([banned_dir for banned_dir in sweep_config.exclude_dirs if banned_dir in item_path.split(os.path.sep)]) == 0: + files.append(item_path) # Add the file to the list elif os.path.isdir(item_path): dfs_helper(item_path) # Recursive call to explore subdirectory From 761fe04b8c61859fb1aa963161ea0684007e9d98 Mon Sep 17 00:00:00 2001 From: martin ye Date: Fri, 15 Mar 2024 02:58:35 +0000 Subject: [PATCH 3/4] multiple fixes --- sweepai/agents/assistant_function_modify.py | 16 +------ sweepai/core/context_pruning.py | 5 ++- sweepai/core/repo_parsing_utils.py | 49 ++++++++++++++------- sweepai/handlers/on_ticket.py | 1 - sweepai/utils/ticket_utils.py | 4 +- sweepai/utils/tree_utils.py | 31 ++++++++++++- 6 files changed, 69 insertions(+), 37 deletions(-) diff --git a/sweepai/agents/assistant_function_modify.py b/sweepai/agents/assistant_function_modify.py index a386a33976..f06cd47bc1 100644 --- a/sweepai/agents/assistant_function_modify.py +++ b/sweepai/agents/assistant_function_modify.py @@ -16,6 +16,7 @@ from sweepai.utils.diff import generate_diff from sweepai.utils.progress import AssistantConversation, TicketProgress from sweepai.utils.utils import check_code, chunk_code +from sweepai.core.repo_parsing_utils import read_file_with_fallback_encodings # Pre-amble using ideas from https://github.com/paul-gauthier/aider/blob/main/aider/coders/udiff_prompts.py # Doesn't regress on the benchmark but improves average code generated and avoids empty comments. @@ -84,21 +85,6 @@ def ensure_additional_messages_length(additional_messages: list[Message]): ) return additional_messages - -def read_file_with_fallback_encodings( - file_path, encodings=["utf-8", "windows-1252", "iso-8859-1"] -): - for encoding in encodings: - try: - with open(file_path, "r", encoding=encoding) as file: - return file.read() - except UnicodeDecodeError: - continue - raise UnicodeDecodeError( - f"Could not decode {file_path} with any of the specified encodings: {encodings}" - ) - - def build_keyword_search_match_results( match_indices: list[int], chunks: list[str], keyword: str, success_message ) -> str: diff --git a/sweepai/core/context_pruning.py b/sweepai/core/context_pruning.py index f416b78e2f..7dbc469880 100644 --- a/sweepai/core/context_pruning.py +++ b/sweepai/core/context_pruning.py @@ -33,6 +33,7 @@ from sweepai.utils.github_utils import ClonedRepo from sweepai.utils.progress import AssistantConversation, TicketProgress from sweepai.utils.tree_utils import DirectoryTree +from time import time if OPENAI_API_TYPE == "openai": client = OpenAI(api_key=OPENAI_API_KEY, timeout=90) if OPENAI_API_KEY else None @@ -447,7 +448,9 @@ def get_relevant_context( repo_context_manager = build_import_trees(repo_context_manager, import_graph) # for any code file mentioned in the query add it to the top relevant snippets repo_context_manager = add_relevant_files_to_top_snippets(repo_context_manager) - # check to see if there are any files that are mentioned in the query + # add relevant files to dir_obj inside repo_context_manager, this is in case dir_obj is too large when as a string + repo_context_manager.dir_obj.add_relevant_files(repo_context_manager.relevant_file_paths) + user_prompt = repo_context_manager.format_context( unformatted_user_prompt=unformatted_user_prompt, query=query, diff --git a/sweepai/core/repo_parsing_utils.py b/sweepai/core/repo_parsing_utils.py index 0e2d149534..c358fe093d 100644 --- a/sweepai/core/repo_parsing_utils.py +++ b/sweepai/core/repo_parsing_utils.py @@ -14,6 +14,19 @@ tiktoken_client = Tiktoken() +def read_file_with_fallback_encodings( + file_path, encodings=["utf-8", "windows-1252", "iso-8859-1"] +): + for encoding in encodings: + try: + with open(file_path, "r", encoding=encoding) as file: + return file.read() + except UnicodeDecodeError: + continue + raise UnicodeDecodeError( + f"Could not decode {file_path} with any of the specified encodings: {encodings}" + ) + def filter_file(directory: str, file: str, sweep_config: SweepConfig) -> bool: """ @@ -54,23 +67,25 @@ def filter_file(directory: str, file: str, sweep_config: SweepConfig) -> bool: if is_binary: return False f.close() - with open(file, "r") as f: - try: - lines = f.readlines() - except UnicodeDecodeError: - logger.warning(f"UnicodeDecodeError: {file}, skipping") - return False - line_count = len(lines) - data = "\n".join(lines) - # if average line length is greater than 200, then it is likely not human readable - if len(data)/line_count > 200: - return False - # check token density, if it is greater than 2, then it is likely not human readable - token_count = tiktoken_client.count(data) - if token_count == 0: - return False - if len(data)/token_count < 2: - return False + + + try: + # fetch file + data = read_file_with_fallback_encodings(file) + lines = data.split("\n") + except UnicodeDecodeError: + logger.warning(f"UnicodeDecodeError: {file}, skipping") + return False + line_count = len(lines) + # if average line length is greater than 200, then it is likely not human readable + if len(data)/line_count > 200: + return False + # check token density, if it is greater than 2, then it is likely not human readable + token_count = tiktoken_client.count(data) + if token_count == 0: + return False + if len(data)/token_count < 2: + return False return True diff --git a/sweepai/handlers/on_ticket.py b/sweepai/handlers/on_ticket.py index 1d259f98e4..bd8403c8b9 100644 --- a/sweepai/handlers/on_ticket.py +++ b/sweepai/handlers/on_ticket.py @@ -619,7 +619,6 @@ def on_ticket( ticket_progress.save() config_pr_url = None - user_settings = UserSettings.from_username(username=username) user_settings_message = user_settings.get_message() diff --git a/sweepai/utils/ticket_utils.py b/sweepai/utils/ticket_utils.py index 570e016a81..6fd8a66231 100644 --- a/sweepai/utils/ticket_utils.py +++ b/sweepai/utils/ticket_utils.py @@ -85,8 +85,8 @@ def prep_snippets( prefixes.append("/".join(snippet_path.split("/")[:idx]) + "/") prefixes.append(snippet_path) _, dir_obj = cloned_repo.list_directory_tree( - included_directories=prefixes, - included_files=snippet_paths, + included_directories=list(set(prefixes)), + included_files=list(set(snippet_paths)), ) repo_context_manager = RepoContextManager( dir_obj=dir_obj, diff --git a/sweepai/utils/tree_utils.py b/sweepai/utils/tree_utils.py index 38e2158e06..e2fbf9272d 100644 --- a/sweepai/utils/tree_utils.py +++ b/sweepai/utils/tree_utils.py @@ -39,6 +39,11 @@ class DirectoryTree: def __init__(self): self.original_lines: list[Line] = [] self.lines: list[Line] = [] + self.relevant_files: list[str] = [] # this is for __str__ method if the resulting string becomes too large + + def add_relevant_files(self, files: list[str]) -> None: + relevant_files = copy.deepcopy(files) + self.relevant_files += relevant_files def parse(self, input_str: str): stack: list[Line] = [] # To keep track of parent directories @@ -132,7 +137,31 @@ def __str__(self): for line in self.lines: line_text = line.text.split("/")[-2] + "/" if line.is_dir else line.text results.append((" " * line.indent_count) + line_text) - return "\n".join(results) + raw_str = "\n".join(results) + # if raw_str is too large (> 20k chars) we will use a truncated version + if len(raw_str) > 20000: + results = [] + logger.warning("While attempting to dump the directory tree, the string was too large. Outputting the truncated version instead...") + for line in self.lines: + # always print out directories + if line.is_dir: + line_text = line.text.split("/")[-2] + "/" if line.is_dir else line.text + results.append((" " * line.indent_count) + line_text) + continue + # if a file name doesn't appear as a file name in one fo the relevant files, don't print it + # instead print ... unless the previous item is already a ... + if (line.parent): + full_path_of_file = line.parent.full_path() + line.full_path() + else: + full_path_of_file = line.full_path() + if full_path_of_file in self.relevant_files: + line_text = line.text.split("/")[-2] + "/" if line.is_dir else line.text + results.append((" " * line.indent_count) + line_text) + elif len(results) > 0 and results[-1] != (" " * line.indent_count) + "...": + results.append((" " * line.indent_count) + "...") + raw_str = "\n".join(results) + + return raw_str @file_cache() From 5d1ed7797c83a6608df4834199fe963b1a7dd8ad Mon Sep 17 00:00:00 2001 From: martin ye Date: Fri, 15 Mar 2024 03:25:28 +0000 Subject: [PATCH 4/4] ruff fix --- sweepai/core/context_pruning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sweepai/core/context_pruning.py b/sweepai/core/context_pruning.py index 7dbc469880..5314a24c6d 100644 --- a/sweepai/core/context_pruning.py +++ b/sweepai/core/context_pruning.py @@ -33,7 +33,6 @@ from sweepai.utils.github_utils import ClonedRepo from sweepai.utils.progress import AssistantConversation, TicketProgress from sweepai.utils.tree_utils import DirectoryTree -from time import time if OPENAI_API_TYPE == "openai": client = OpenAI(api_key=OPENAI_API_KEY, timeout=90) if OPENAI_API_KEY else None