Skip to content

Commit

Permalink
yolo v0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
wunderwuzzi23 committed Jun 30, 2024
1 parent a776a78 commit 0add927
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 79 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

![Animated GIF](https://github.com/wunderwuzzi23/blog/raw/master/static/images/2023/yolo-shell-anim-gif.gif)

# Update Yolo v0.5 - Support for Claude and other providers

* Added Claude support. Can an API key from Anthropic, current model `claude-3-5-sonnet-20240620`.
* ai_model.py to abstract model usage and allow adding new providers more easily
* Rewrote some logic to simplify and generalize support for various new APIs (like Ollama, Claude)

# Update Yolo v0.4 - Support for Groq

* Added groq support. You can get an API key at `https://console.groq.com` and set mode to for instance `llama3-8b-8192`. groq is lightning fast.
Expand Down
149 changes: 149 additions & 0 deletions ai_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# MIT License
# Copyright (c) 2023-2024 wunderwuzzi23
# Greetings from Seattle!

from abc import ABC, abstractmethod
from openai import OpenAI
from groq import Groq
from ollama import Client
from openai import AzureOpenAI
from anthropic import Anthropic
import os

class AIModel(ABC):
@abstractmethod
def chat(self, model, messages):
pass

@abstractmethod
def moderate(self, message):
pass

@staticmethod
def get_model_client(config):
api_provider=config["api"]

if api_provider == "" or api_provider==None:
api_provider = "groq"

if api_provider == "groq":
return GroqModel(api_key=os.environ.get("GROQ_API_KEY"))

elif api_provider == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
api_key=config["openai_api_key"]
if not api_key: #If statement to avoid "invalid filepath" error
home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip()
api_key = api_key

return OpenAIModel(api_key=os.environ.get("OPENAI_API_KEY"))

elif api_provider == "azure":
api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not api_key:
api_key=config["azure_openai_api_key"]
if not api_key:
home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip()

return AzureOpenAIModel(
api_key=api_key,
azure_endpoint=config["azure_endpoint"],
api_version=config["azure_api_version"])

elif api_provider == "ollama":
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434")
#ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192")
return OllamaModel(ollama_api)

if api_provider == "anthropic":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
api_key=config["anthropic_api_key"]
return AnthropicModel(api_key=api_key)
else:
raise ValueError(f"Invalid AI model provider: {api_provider}")

class GroqModel(AIModel):
def __init__(self, api_key):
self.client = Groq(api_key=api_key)

def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens)
return resp.choices[0].message.content

def moderate(self, message):
pass

class OpenAIModel(AIModel):
def __init__(self, api_key):
self.client = OpenAI(api_key=api_key)

def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat.completions.create(model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens)

return resp.choices[0].message.content

def moderate(self, message):
return self.client.moderations.create(input=message)

class OllamaModel(AIModel):
def __init__(self, host):
self.client = Client(host=host)

def chat(self, messages, model, temperature, max_tokens):
resp = self.client.chat(model=model,
messages=messages)
return resp["message"]["content"]

def moderate(self, message):
pass


class AzureOpenAIModel(AIModel):
def __init__(self, azure_endpoint, api_key, api_version):
self.client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version)

def chat(self, messages, model, temperature, max_tokens):

resp = self.client.chat.completions.create(model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens)

return resp.choices[0].message.content

def moderate(self, message):
return self.client.moderations.create(input=message)

class AnthropicModel(AIModel):
def __init__(self, api_key):
self.client = Anthropic(api_key=api_key)

def chat(self, messages, model, temperature, max_tokens):
## Anthropic requires the system prompt to be passed separately
## Hence extracting system prompt role from the messages
## and then passing the messages without the system role
## messages is not subscriptable, so we need to convert it to a list
system_prompt = next((m.get("content", "") for m in messages if m.get("role") == "system"), "")

# Remove system messages from the list
user_messages = [m for m in messages if m.get("role") != "system"]
resp = self.client.messages.create(model=model,
system=system_prompt,
messages=user_messages,
temperature=temperature,
max_tokens=max_tokens)

return resp.content[0].text

def moderate(self, message):
pass
98 changes: 23 additions & 75 deletions yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,17 @@

import os
import platform
from openai import OpenAI
from openai import AzureOpenAI
from groq import Groq
from ai_model import AIModel, GroqModel, OpenAIModel, OllamaModel, AnthropicModel, AzureOpenAIModel
import sys
import subprocess
import dotenv
import distro
import yaml
import pyperclip

from termcolor import colored
from colorama import init

def read_config() -> any:

def read_config():
## Find the executing directory (e.g. in case an alias is set)
## So we can find the config file
yolo_path = os.path.abspath(__file__)
Expand All @@ -30,29 +26,27 @@ def read_config() -> any:
with open(config_file, 'r') as file:
return yaml.safe_load(file)

# Construct the prompt
def get_full_prompt(user_prompt, shell):

def get_system_prompt(shell):
## Find the executing directory (e.g. in case an alias is set)
## So we can find the prompt.txt file
yolo_path = os.path.abspath(__file__)
prompt_path = os.path.dirname(yolo_path)

## Load the prompt and prep it
prompt_file = os.path.join(prompt_path, "prompt.txt")
pre_prompt = open(prompt_file,"r").read()
pre_prompt = pre_prompt.replace("{shell}", shell)
pre_prompt = pre_prompt.replace("{os}", get_os_friendly_name())
prompt = pre_prompt + user_prompt

# be nice and make it a question
system_prompt = open(prompt_file,"r").read()
system_prompt = system_prompt.replace("{shell}", shell)
system_prompt = system_prompt.replace("{os}", get_os_friendly_name())

return system_prompt

def ensure_prompt_is_question(prompt):
if prompt[-1:] != "?" and prompt[-1:] != ".":
prompt+="?"

return prompt

def print_usage(config):
print("Yolo v0.4 - by @wunderwuzzi23 (June 1, 2024)")
print("Yolo v0.5 - by @wunderwuzzi23 (June 29, 2024)")
print()
print("Usage: yolo [-a] list the current directory information")
print("Argument: -a: Prompt the user before running the command (only useful when safety is off)")
Expand All @@ -66,7 +60,6 @@ def print_usage(config):
print("* Safety : " + str(bool(config["safety"])))
print("* Command Color: " + str(config["suggested_command_color"]))


def get_os_friendly_name():
os_name = platform.system()

Expand All @@ -79,64 +72,23 @@ def get_os_friendly_name():
else:
return os_name

def create_client(config):

dotenv.load_dotenv()

if config["api"] == "azure_openai":
api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not api_key: api_key=config["azure_openai_api_key"]
if not api_key:
home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip()

return AzureOpenAI(
azure_endpoint=config["azure_endpoint"],
api_key=api_key,
api_version=config["azure_api_version"]
)

if config["api"] == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key: api_key=config["openai_api_key"]
if not api_key: #If statement to avoid "invalid filepath" error
home_path = os.path.expanduser("~")
api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip()

api_key = api_key
return OpenAI(api_key=api_key)

if config["api"] == "groq":
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
api_key=config["groq_api_key"]

return Groq(api_key=api_key)

def chat_completion(client, query, config, shell):
# do we have a prompt from the user?
if query == "":
print ("No user prompt specified.")
sys.exit(-1)

# Load prompt based on Shell and OS and append the user's prompt
prompt = get_full_prompt(query, shell)

# Make the first line the system prompt
system_prompt = prompt.split('\n')[0]
#print(prompt)
system_prompt = get_system_prompt(shell)

# Call the Model API
response = client.chat.completions.create(
response = client.chat(
model=config["model"],
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
{"role": "user", "content": query}
],
temperature=config["temperature"],
max_tokens=config["max_tokens"])

return response.choices[0].message.content.strip()
return response

def check_for_issue(response):
prefixes = ("sorry", "i'm sorry", "the question is not clear", "i'm", "i am")
Expand All @@ -145,7 +97,6 @@ def check_for_issue(response):
sys.exit(-1)

def check_for_markdown(response):
# odd corner case, sometimes ChatCompletion returns markdown
if response.count("```",2):
print(colored("The proposed command contains markdown, so I did not execute the response directly: \n", 'red')+response)
sys.exit(-1)
Expand Down Expand Up @@ -189,7 +140,7 @@ def eval_user_intent_and_execute(client, config, user_input, command, shell, ask
if bool(config["modify"]) and user_input.upper() == "M":
print("Modify prompt: ", end = '')
modded_query = input()
modded_response = call_open_ai(client, modded_query, config, shell)
modded_response = chat_completion(client, modded_query, config, shell)
check_for_issue(modded_response)
check_for_markdown(modded_response)
user_intent = prompt_user_for_action(config, ask_flag, modded_response)
Expand All @@ -203,14 +154,12 @@ def eval_user_intent_and_execute(client, config, user_input, command, shell, ask
pyperclip.copy(command)
print("Copied command to clipboard.")



def main():
#Enable color output on Windows using colorama
init()
init() #Enable color output on Windows using colorama
dotenv.load_dotenv()

config = read_config()
client = create_client(config)
client = AIModel.get_model_client(config)

# Unix based SHELL (/bin/bash, /bin/zsh), otherwise assuming it's Windows
shell = os.environ.get("SHELL", "powershell.exe")
Expand All @@ -230,9 +179,8 @@ def main():
ask_flag = True
command_start_idx = 2

# To allow easy/natural use we don't require the input to be a
# single string. So, the user can just type yolo what is my name?
# without having to put the question between ''
# To allow easy/natural use we don't require the input to be a single string.
# User can just type yolo what is my name? without having to put the question between ''
arguments = sys.argv[command_start_idx:]
user_prompt = " ".join(arguments)

Expand All @@ -245,6 +193,6 @@ def main():
print()
eval_user_intent_and_execute(client, config, users_intent, result, shell, ask_flag)


if __name__ == "__main__":
main()

9 changes: 5 additions & 4 deletions yolo.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
api: openai # openai, azure_openai, groq
model: gpt-4o # if azure_openai this is the deployment name, for groq, e.g llama3-8b-8192
api: openai # openai, azure, groq, ollama, anthropic
model: gpt-4o # if azure this is the deployment name
# other options: gpt-4o, llama3-8b-8192, or claude-3-5-sonnet-20240620

# Azure specific (only needed if api: azure-openai)
azure_endpoint: https://<name>.openai.azure.com
Expand All @@ -14,8 +15,8 @@ modify: False # Enable prompt modify feature
suggested_command_color: blue # Suggested Command Color

# API Keys (optional): Preferred to use environment variables
# OPENAI_API_KEY, AZURE_OPENAI_API_KEY or GROQ_API_KEY (.env file is also supported)
# OPENAI_API_KEY, AZURE_OPENAI_API_KEY, ANTHROPIC_API_KEY or GROQ_API_KEY (.env file is also supported)
azure_openai_api_key:
openai_api_key:
groq_api_key:

anthropic_api_key:

0 comments on commit 0add927

Please sign in to comment.