Skip to content

Commit

Permalink
Add explain command (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
shruti222patel authored Aug 22, 2023
1 parent 22b8f87 commit d786231
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 19 deletions.
Binary file modified .repo_gpt/code_embeddings.pkl
Binary file not shown.
18 changes: 15 additions & 3 deletions src/repo_gpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def print_help(*args):
default=CODE_EMBEDDING_FILE_PATH,
)

# Sub-command to explain a file
explain_code = subparsers.add_parser("explain", help="Explain a code snippet")
explain_code.add_argument(
"--language", default="", type=str, help="Language of the code"
)
explain_code.add_argument("--code", type=str, help="Code you want to explain")

# Sub-command to analyze a file
add_test = subparsers.add_parser("add-test", help="Add tests for existing function")
add_test.add_argument(
Expand Down Expand Up @@ -108,8 +115,8 @@ def print_help(*args):
openai_service = OpenAIService()

search_service = (
SearchService(args.pickle_path, openai_service)
if args.command != "setup"
SearchService(openai_service, args.pickle_path)
if args.command not in ["setup", "explain"]
else None
)

Expand All @@ -125,6 +132,9 @@ def print_help(*args):
search_service.question_answer(args.question)
elif args.command == "analyze":
search_service.analyze_file(args.file_path)
elif args.command == "explain":
search_service = SearchService(openai_service, language=args.language)
return search_service.explain(args.code)
elif args.command == "add-test":
code_manager = CodeManager(args.pickle_path)
# Look for the function name in the embedding file
Expand Down Expand Up @@ -202,4 +212,6 @@ def add_tests(


if __name__ == "__main__":
main()
result = main()
if result != None:
print(result)
28 changes: 15 additions & 13 deletions src/repo_gpt/openai_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Set your OpenAI API key as an environment variable
import os
import time
from typing import Any, Callable

import numpy as np
import openai as openai
Expand Down Expand Up @@ -71,7 +69,7 @@ def num_tokens_from_messages(messages, model="gpt-3.5-turbo"):

class OpenAIService:
GENERAL_SYSTEM_PROMPT = "You are a world-class software engineer and technical writer specializing in understanding code + architecture + tradeoffs and explaining them clearly and in detail. You are helpful and answer questions the user asks. You organize your explanations in markdown-formatted, bulleted lists."
ANALYSIS_SYSTEM_PROMPT = "You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You carefully explain code with great detail and accuracy. You organize your explanations in markdown-formatted, bulleted lists."
ANALYSIS_SYSTEM_PROMPT = "You are a world-class developer with an eagle eye for unintended bugs and edge cases. You carefully explain code with great detail and accuracy. You organize your explanations in markdown-formatted, bulleted lists."

@retry(wait=wait_random_exponential(min=0.2, max=60), stop=stop_after_attempt(6))
def get_answer(
Expand All @@ -96,16 +94,20 @@ def get_answer(
)
return response.choices[0]["message"]["content"]

def _retry_on_exception(self, func: Callable[..., Any], *args, **kwargs):
for retry in range(MAX_RETRIES):
try:
return func(*args, **kwargs)
except Exception:
if retry < MAX_RETRIES - 1: # if it's not the last retry
sleep_time = 2**retry # exponential backoff
time.sleep(sleep_time)
else: # if it's the last retry, re-raise the exception
raise
@retry(wait=wait_random_exponential(min=0.2, max=60), stop=stop_after_attempt(6))
def query(self, query: str, system_prompt: str = GENERAL_SYSTEM_PROMPT):
response = openai.ChatCompletion.create(
messages=[
{
"role": "system",
"content": system_prompt,
},
{"role": "user", "content": query},
],
model=GPT_MODEL,
temperature=TEMPERATURE,
)
return response.choices[0]["message"]["content"]

@retry(wait=wait_random_exponential(min=0.2, max=60), stop=stop_after_attempt(6))
def get_embedding(self, text: str):
Expand Down
22 changes: 20 additions & 2 deletions src/repo_gpt/search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@

class SearchService:
def __init__(
self, pickle_path: Path, openai_service: OpenAIService, language: str = "python"
self,
openai_service: OpenAIService,
pickle_path: Path = None,
language: str = "python",
):
self.pickle_path = pickle_path
self.refresh_df()
if pickle_path is not None:
self.refresh_df()
self.openai_service = openai_service
self.language = language

Expand Down Expand Up @@ -105,3 +109,17 @@ def analyze_file(self, file_path: str):
ans_md = Markdown(ans)
console.print("🤖 Answer from `GPT3.5` 🤖")
console.print(ans_md)

def explain(self, code: str):
try:
explanation = self.openai_service.query(
f"""Please explain the following {self.language} function.
```{self.language}
{code}
```""",
system_prompt=f"""You are a world-class {self.language} developer and you are explaining the following code to a junior developer. Organize your explanation as a markdown-format.""",
)
return explanation
except Exception as e:
return e.message
2 changes: 1 addition & 1 deletion src/repo_gpt/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def print_message_delta(self, delta) -> None:
def get_explanation_of_function(self) -> str:
self.create_gpt_message(
"system",
"You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. ...",
f"You are a world-class {self.language} developer with an eagle eye for unintended bugs and edge cases. ...",
)
self.create_gpt_message(
"user",
Expand Down

0 comments on commit d786231

Please sign in to comment.