Skip to content

Commit

Permalink
Fix/small fix (#3305)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinlu1248 authored Mar 15, 2024
2 parents 8d00a15 + bbf565c commit fd15ff8
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 37 deletions.
16 changes: 1 addition & 15 deletions sweepai/agents/assistant_function_modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sweepai.utils.diff import generate_diff
from sweepai.utils.progress import AssistantConversation, TicketProgress
from sweepai.utils.utils import check_code, chunk_code
from sweepai.core.repo_parsing_utils import read_file_with_fallback_encodings

# Pre-amble using ideas from https://github.com/paul-gauthier/aider/blob/main/aider/coders/udiff_prompts.py
# Doesn't regress on the benchmark but improves average code generated and avoids empty comments.
Expand Down Expand Up @@ -84,21 +85,6 @@ def ensure_additional_messages_length(additional_messages: list[Message]):
)
return additional_messages


def read_file_with_fallback_encodings(
file_path, encodings=["utf-8", "windows-1252", "iso-8859-1"]
):
for encoding in encodings:
try:
with open(file_path, "r", encoding=encoding) as file:
return file.read()
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(
f"Could not decode {file_path} with any of the specified encodings: {encodings}"
)


def build_keyword_search_match_results(
match_indices: list[int], chunks: list[str], keyword: str, success_message
) -> str:
Expand Down
4 changes: 3 additions & 1 deletion sweepai/core/context_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,9 @@ def get_relevant_context(
repo_context_manager = build_import_trees(repo_context_manager, import_graph)
# for any code file mentioned in the query add it to the top relevant snippets
repo_context_manager = add_relevant_files_to_top_snippets(repo_context_manager)
# check to see if there are any files that are mentioned in the query
# add relevant files to dir_obj inside repo_context_manager, this is in case dir_obj is too large when as a string
repo_context_manager.dir_obj.add_relevant_files(repo_context_manager.relevant_file_paths)

user_prompt = repo_context_manager.format_context(
unformatted_user_prompt=unformatted_user_prompt,
query=query,
Expand Down
49 changes: 32 additions & 17 deletions sweepai/core/repo_parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@

tiktoken_client = Tiktoken()

def read_file_with_fallback_encodings(
file_path, encodings=["utf-8", "windows-1252", "iso-8859-1"]
):
for encoding in encodings:
try:
with open(file_path, "r", encoding=encoding) as file:
return file.read()
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(
f"Could not decode {file_path} with any of the specified encodings: {encodings}"
)


def filter_file(directory: str, file: str, sweep_config: SweepConfig) -> bool:
"""
Expand Down Expand Up @@ -54,23 +67,25 @@ def filter_file(directory: str, file: str, sweep_config: SweepConfig) -> bool:
if is_binary:
return False
f.close()
with open(file, "r") as f:
try:
lines = f.readlines()
except UnicodeDecodeError:
logger.warning(f"UnicodeDecodeError: {file}, skipping")
return False
line_count = len(lines)
data = "\n".join(lines)
# if average line length is greater than 200, then it is likely not human readable
if len(data)/line_count > 200:
return False
# check token density, if it is greater than 2, then it is likely not human readable
token_count = tiktoken_client.count(data)
if token_count == 0:
return False
if len(data)/token_count < 2:
return False


try:
# fetch file
data = read_file_with_fallback_encodings(file)
lines = data.split("\n")
except UnicodeDecodeError:
logger.warning(f"UnicodeDecodeError: {file}, skipping")
return False
line_count = len(lines)
# if average line length is greater than 200, then it is likely not human readable
if len(data)/line_count > 200:
return False
# check token density, if it is greater than 2, then it is likely not human readable
token_count = tiktoken_client.count(data)
if token_count == 0:
return False
if len(data)/token_count < 2:
return False
return True


Expand Down
1 change: 0 additions & 1 deletion sweepai/handlers/on_ticket.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,6 @@ def on_ticket(
ticket_progress.save()

config_pr_url = None

user_settings = UserSettings.from_username(username=username)
user_settings_message = user_settings.get_message()

Expand Down
4 changes: 2 additions & 2 deletions sweepai/utils/ticket_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def prep_snippets(
prefixes.append("/".join(snippet_path.split("/")[:idx]) + "/")
prefixes.append(snippet_path)
_, dir_obj = cloned_repo.list_directory_tree(
included_directories=prefixes,
included_files=snippet_paths,
included_directories=list(set(prefixes)),
included_files=list(set(snippet_paths)),
)
repo_context_manager = RepoContextManager(
dir_obj=dir_obj,
Expand Down
31 changes: 30 additions & 1 deletion sweepai/utils/tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class DirectoryTree:
def __init__(self):
self.original_lines: list[Line] = []
self.lines: list[Line] = []
self.relevant_files: list[str] = [] # this is for __str__ method if the resulting string becomes too large

def add_relevant_files(self, files: list[str]) -> None:
relevant_files = copy.deepcopy(files)
self.relevant_files += relevant_files

def parse(self, input_str: str):
stack: list[Line] = [] # To keep track of parent directories
Expand Down Expand Up @@ -132,7 +137,31 @@ def __str__(self):
for line in self.lines:
line_text = line.text.split("/")[-2] + "/" if line.is_dir else line.text
results.append((" " * line.indent_count) + line_text)
return "\n".join(results)
raw_str = "\n".join(results)
# if raw_str is too large (> 20k chars) we will use a truncated version
if len(raw_str) > 20000:
results = []
logger.warning("While attempting to dump the directory tree, the string was too large. Outputting the truncated version instead...")
for line in self.lines:
# always print out directories
if line.is_dir:
line_text = line.text.split("/")[-2] + "/" if line.is_dir else line.text
results.append((" " * line.indent_count) + line_text)
continue
# if a file name doesn't appear as a file name in one fo the relevant files, don't print it
# instead print ... unless the previous item is already a ...
if (line.parent):
full_path_of_file = line.parent.full_path() + line.full_path()
else:
full_path_of_file = line.full_path()
if full_path_of_file in self.relevant_files:
line_text = line.text.split("/")[-2] + "/" if line.is_dir else line.text
results.append((" " * line.indent_count) + line_text)
elif len(results) > 0 and results[-1] != (" " * line.indent_count) + "...":
results.append((" " * line.indent_count) + "...")
raw_str = "\n".join(results)

return raw_str


@file_cache()
Expand Down

0 comments on commit fd15ff8

Please sign in to comment.