Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/small fix #3305

Merged
merged 8 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 1 addition & 15 deletions sweepai/agents/assistant_function_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sweepai.utils.diff import generate_diff
from sweepai.utils.progress import AssistantConversation, TicketProgress
from sweepai.utils.utils import check_code, chunk_code
from sweepai.core.repo_parsing_utils import read_file_with_fallback_encodings

# Pre-amble using ideas from https://github.com/paul-gauthier/aider/blob/main/aider/coders/udiff_prompts.py
# Doesn't regress on the benchmark but improves average code generated and avoids empty comments.
Expand Down Expand Up @@ -84,21 +85,6 @@ def ensure_additional_messages_length(additional_messages: list[Message]):
)
return additional_messages


def read_file_with_fallback_encodings(
file_path, encodings=["utf-8", "windows-1252", "iso-8859-1"]
):
for encoding in encodings:
try:
with open(file_path, "r", encoding=encoding) as file:
return file.read()
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(
f"Could not decode {file_path} with any of the specified encodings: {encodings}"
)


def build_keyword_search_match_results(
match_indices: list[int], chunks: list[str], keyword: str, success_message
) -> str:
Expand Down
4 changes: 3 additions & 1 deletion sweepai/core/context_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,9 @@ def get_relevant_context(
repo_context_manager = build_import_trees(repo_context_manager, import_graph)
# for any code file mentioned in the query add it to the top relevant snippets
repo_context_manager = add_relevant_files_to_top_snippets(repo_context_manager)
# check to see if there are any files that are mentioned in the query
# add relevant files to dir_obj inside repo_context_manager, this is in case dir_obj is too large when as a string
repo_context_manager.dir_obj.add_relevant_files(repo_context_manager.relevant_file_paths)

user_prompt = repo_context_manager.format_context(
unformatted_user_prompt=unformatted_user_prompt,
query=query,
Expand Down
49 changes: 32 additions & 17 deletions sweepai/core/repo_parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@

tiktoken_client = Tiktoken()

def read_file_with_fallback_encodings(
file_path, encodings=["utf-8", "windows-1252", "iso-8859-1"]
):
for encoding in encodings:
try:
with open(file_path, "r", encoding=encoding) as file:
return file.read()
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(
f"Could not decode {file_path} with any of the specified encodings: {encodings}"
)


def filter_file(directory: str, file: str, sweep_config: SweepConfig) -> bool:
"""
Expand Down Expand Up @@ -54,23 +67,25 @@ def filter_file(directory: str, file: str, sweep_config: SweepConfig) -> bool:
if is_binary:
return False
f.close()
with open(file, "r") as f:
try:
lines = f.readlines()
except UnicodeDecodeError:
logger.warning(f"UnicodeDecodeError: {file}, skipping")
return False
line_count = len(lines)
data = "\n".join(lines)
# if average line length is greater than 200, then it is likely not human readable
if len(data)/line_count > 200:
return False
# check token density, if it is greater than 2, then it is likely not human readable
token_count = tiktoken_client.count(data)
if token_count == 0:
return False
if len(data)/token_count < 2:
return False


try:
# fetch file
data = read_file_with_fallback_encodings(file)
lines = data.split("\n")
except UnicodeDecodeError:
logger.warning(f"UnicodeDecodeError: {file}, skipping")
return False
line_count = len(lines)
# if average line length is greater than 200, then it is likely not human readable
if len(data)/line_count > 200:
return False
# check token density, if it is greater than 2, then it is likely not human readable
token_count = tiktoken_client.count(data)
if token_count == 0:
return False
if len(data)/token_count < 2:
return False
return True


Expand Down
1 change: 0 additions & 1 deletion sweepai/handlers/on_ticket.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,6 @@ def on_ticket(
ticket_progress.save()

config_pr_url = None

user_settings = UserSettings.from_username(username=username)
user_settings_message = user_settings.get_message()

Expand Down
4 changes: 2 additions & 2 deletions sweepai/utils/ticket_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def prep_snippets(
prefixes.append("/".join(snippet_path.split("/")[:idx]) + "/")
prefixes.append(snippet_path)
_, dir_obj = cloned_repo.list_directory_tree(
included_directories=prefixes,
included_files=snippet_paths,
included_directories=list(set(prefixes)),
included_files=list(set(snippet_paths)),
)
repo_context_manager = RepoContextManager(
dir_obj=dir_obj,
Expand Down
31 changes: 30 additions & 1 deletion sweepai/utils/tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class DirectoryTree:
def __init__(self):
self.original_lines: list[Line] = []
self.lines: list[Line] = []
self.relevant_files: list[str] = [] # this is for __str__ method if the resulting string becomes too large

def add_relevant_files(self, files: list[str]) -> None:
relevant_files = copy.deepcopy(files)
self.relevant_files += relevant_files

def parse(self, input_str: str):
stack: list[Line] = [] # To keep track of parent directories
Expand Down Expand Up @@ -132,7 +137,31 @@ def __str__(self):
for line in self.lines:
line_text = line.text.split("/")[-2] + "/" if line.is_dir else line.text
results.append((" " * line.indent_count) + line_text)
return "\n".join(results)
raw_str = "\n".join(results)
# if raw_str is too large (> 20k chars) we will use a truncated version
if len(raw_str) > 20000:
results = []
logger.warning("While attempting to dump the directory tree, the string was too large. Outputting the truncated version instead...")
for line in self.lines:
# always print out directories
if line.is_dir:
line_text = line.text.split("/")[-2] + "/" if line.is_dir else line.text
results.append((" " * line.indent_count) + line_text)
continue
# if a file name doesn't appear as a file name in one fo the relevant files, don't print it
# instead print ... unless the previous item is already a ...
if (line.parent):
full_path_of_file = line.parent.full_path() + line.full_path()
else:
full_path_of_file = line.full_path()
if full_path_of_file in self.relevant_files:
line_text = line.text.split("/")[-2] + "/" if line.is_dir else line.text
results.append((" " * line.indent_count) + line_text)
elif len(results) > 0 and results[-1] != (" " * line.indent_count) + "...":
results.append((" " * line.indent_count) + "...")
raw_str = "\n".join(results)

return raw_str


@file_cache()
Expand Down
Loading