Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/sweep chat modify improvements 0621 #4099

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sweep_chat/components/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,10 @@ const UserMessageDisplay = ({
e.preventDefault()
}}
variant="default"
className="ml-2 bg-slate-600 text-white hover:bg-slate-700"
className="ml-2 bg-blue-900 text-white hover:bg-blue-800"
>
Send
<FaPaperPlane />
&nbsp;&nbsp;Send
</Button>
</>
)}
Expand Down
21 changes: 7 additions & 14 deletions sweepai/chat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,17 +538,9 @@ def chat_codebase_stream(
content=snippets_message,
role="user"
),
*messages[:-1],
*messages[:-1]
]

if len(messages) <= 2:
chat_gpt.messages.append(
Message(
content=openai_format_message if use_openai else anthropic_format_message,
role="user",
)
)

def stream_state(
initial_user_message: str,
snippets: list[Snippet],
Expand All @@ -559,7 +551,8 @@ def stream_state(
use_openai: bool,
k: int = DEFAULT_K
):
user_message = initial_user_message
# this is where the format and query are joined
user_message = initial_user_message + "\n" + anthropic_format_message

fetched_snippets = snippets
new_messages = [
Expand Down Expand Up @@ -838,10 +831,10 @@ async def autofix(
access_token: str = Depends(get_token_header)
):# -> dict[str, Any] | StreamingResponse:
# for debugging with rerun_chat_modify_direct.py
# from dataclasses import asdict
# data = [asdict(query) for query in code_suggestions]
# with open("code_suggestions.json", "w") as file:
# json.dump(data, file, indent=4)
from dataclasses import asdict
data = [asdict(query) for query in code_suggestions]
with open("code_suggestions.json", "w") as file:
json.dump(data, file, indent=4)
with Timer() as timer:
g = get_authenticated_github_client(repo_name, access_token)
logger.debug(f"Getting authenticated GitHub client took {timer.time_elapsed} seconds")
Expand Down
1 change: 0 additions & 1 deletion sweepai/chat/search_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,3 @@ def area(self):
Just respond with the search query, nothing else."""

query_optimizer_user_prompt = """Question: {query}"""

4 changes: 4 additions & 0 deletions sweepai/core/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def to_openai(self) -> str:
obj["name"] = self.name
return obj

def __repr__(self):
# take the first 100 and last 100 characters of the message if it's too long
truncated_message_content = self.content[:100] + "..." + self.content[-100:] if len(self.content) > 200 else self.content
return f"START OF MESSAGE\n\n{truncated_message_content}\n\nROLE: {self.role} FUNCTION_CALL: {self.function_call} NAME: {self.name} ANNOTATIONS: {self.annotations if self.annotations else ''} KEY: {self.key}\n\nEND OF MESSAGE\n\n"

class Function(BaseModel):
class Parameters(BaseModel):
Expand Down
15 changes: 9 additions & 6 deletions sweepai/utils/github_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def clone_url(self):
)

def clone(self):
os.environ['GIT_LFS_SKIP_SMUDGE'] = '1'
if not os.path.exists(self.cached_dir):
logger.info("Cloning repo...")
if self.branch:
Expand All @@ -505,6 +506,7 @@ def clone(self):
repo = git.Repo(self.cached_dir)
repo.git.remote("set-url", "origin", self.clone_url)
repo.git.clean('-fd')
self.handle_checkout_failures(git_repo=repo)
repo.git.pull()
logger.info("Pull repo succeeded")
except Exception as e:
Expand All @@ -517,6 +519,7 @@ def clone(self):
)
else:
repo = git.Repo.clone_from(self.clone_url, self.cached_dir)
self.handle_checkout_failures(git_repo=repo)
logger.info("Copying repo...")
shutil.copytree(
self.cached_dir, self.repo_dir, symlinks=True, copy_function=shutil.copy
Expand All @@ -542,16 +545,15 @@ def __post_init__(self):
try:
self.git_repo.git.checkout(self.branch)
except Exception as e:
self.handle_checkout_failures()
os.environ['GIT_LFS_SKIP_SMUDGE'] = '1'
self.handle_checkout_failures(self.git_repo)
self.git_repo.git.checkout(self.branch)

def handle_checkout_failures(self):
untracked_files = self.git_repo.untracked_files
def handle_checkout_failures(self, git_repo):
untracked_files = git_repo.untracked_files
if untracked_files:
logger.info(f"Untracked files found: {', '.join(untracked_files)}")
for file in untracked_files:
file_path = os.path.join(self.git_repo.working_dir, file)
file_path = os.path.join(git_repo.working_dir, file)
if os.path.isfile(file_path):
logger.info(f"Removing untracked file: {file}")
os.remove(file_path)
Expand All @@ -562,7 +564,7 @@ def handle_checkout_failures(self):
logger.info("No untracked files found")

logger.info("Cleaning untracked files")
self.git_repo.git.clean('-fd')
git_repo.git.clean('-fd')


def __del__(self):
Expand All @@ -575,6 +577,7 @@ def __del__(self):
def pull(self):
if self.git_repo:
self.git_repo.git.remote("set-url", "origin", self.clone_url)
self.handle_checkout_failures(self.git_repo)
self.git_repo.git.pull()

def list_directory_tree(
Expand Down