-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a776a78
commit 0add927
Showing
4 changed files
with
183 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# MIT License | ||
# Copyright (c) 2023-2024 wunderwuzzi23 | ||
# Greetings from Seattle! | ||
|
||
from abc import ABC, abstractmethod | ||
from openai import OpenAI | ||
from groq import Groq | ||
from ollama import Client | ||
from openai import AzureOpenAI | ||
from anthropic import Anthropic | ||
import os | ||
|
||
class AIModel(ABC): | ||
@abstractmethod | ||
def chat(self, model, messages): | ||
pass | ||
|
||
@abstractmethod | ||
def moderate(self, message): | ||
pass | ||
|
||
@staticmethod | ||
def get_model_client(config): | ||
api_provider=config["api"] | ||
|
||
if api_provider == "" or api_provider==None: | ||
api_provider = "groq" | ||
|
||
if api_provider == "groq": | ||
return GroqModel(api_key=os.environ.get("GROQ_API_KEY")) | ||
|
||
elif api_provider == "openai": | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
if not api_key: | ||
api_key=config["openai_api_key"] | ||
if not api_key: #If statement to avoid "invalid filepath" error | ||
home_path = os.path.expanduser("~") | ||
api_key=open(os.path.join(home_path,".openai.apikey"), "r").readline().strip() | ||
api_key = api_key | ||
|
||
return OpenAIModel(api_key=os.environ.get("OPENAI_API_KEY")) | ||
|
||
elif api_provider == "azure": | ||
api_key = os.getenv("AZURE_OPENAI_API_KEY") | ||
if not api_key: | ||
api_key=config["azure_openai_api_key"] | ||
if not api_key: | ||
home_path = os.path.expanduser("~") | ||
api_key=open(os.path.join(home_path,".azureopenai.apikey"), "r").readline().strip() | ||
|
||
return AzureOpenAIModel( | ||
api_key=api_key, | ||
azure_endpoint=config["azure_endpoint"], | ||
api_version=config["azure_api_version"]) | ||
|
||
elif api_provider == "ollama": | ||
ollama_api = os.environ.get("OLLAMA_ENDPOINT", "http://localhost:11434") | ||
#ollama_model = os.environ.get("OLLAMA_MODEL", "llama3-8b-8192") | ||
return OllamaModel(ollama_api) | ||
|
||
if api_provider == "anthropic": | ||
api_key = os.getenv("ANTHROPIC_API_KEY") | ||
if not api_key: | ||
api_key=config["anthropic_api_key"] | ||
return AnthropicModel(api_key=api_key) | ||
else: | ||
raise ValueError(f"Invalid AI model provider: {api_provider}") | ||
|
||
class GroqModel(AIModel): | ||
def __init__(self, api_key): | ||
self.client = Groq(api_key=api_key) | ||
|
||
def chat(self, messages, model, temperature, max_tokens): | ||
resp = self.client.chat.completions.create(model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
max_tokens=max_tokens) | ||
return resp.choices[0].message.content | ||
|
||
def moderate(self, message): | ||
pass | ||
|
||
class OpenAIModel(AIModel): | ||
def __init__(self, api_key): | ||
self.client = OpenAI(api_key=api_key) | ||
|
||
def chat(self, messages, model, temperature, max_tokens): | ||
resp = self.client.chat.completions.create(model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
max_tokens=max_tokens) | ||
|
||
return resp.choices[0].message.content | ||
|
||
def moderate(self, message): | ||
return self.client.moderations.create(input=message) | ||
|
||
class OllamaModel(AIModel): | ||
def __init__(self, host): | ||
self.client = Client(host=host) | ||
|
||
def chat(self, messages, model, temperature, max_tokens): | ||
resp = self.client.chat(model=model, | ||
messages=messages) | ||
return resp["message"]["content"] | ||
|
||
def moderate(self, message): | ||
pass | ||
|
||
|
||
class AzureOpenAIModel(AIModel): | ||
def __init__(self, azure_endpoint, api_key, api_version): | ||
self.client = AzureOpenAI(azure_endpoint=azure_endpoint, api_key=api_key, api_version=api_version) | ||
|
||
def chat(self, messages, model, temperature, max_tokens): | ||
|
||
resp = self.client.chat.completions.create(model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
max_tokens=max_tokens) | ||
|
||
return resp.choices[0].message.content | ||
|
||
def moderate(self, message): | ||
return self.client.moderations.create(input=message) | ||
|
||
class AnthropicModel(AIModel): | ||
def __init__(self, api_key): | ||
self.client = Anthropic(api_key=api_key) | ||
|
||
def chat(self, messages, model, temperature, max_tokens): | ||
## Anthropic requires the system prompt to be passed separately | ||
## Hence extracting system prompt role from the messages | ||
## and then passing the messages without the system role | ||
## messages is not subscriptable, so we need to convert it to a list | ||
system_prompt = next((m.get("content", "") for m in messages if m.get("role") == "system"), "") | ||
|
||
# Remove system messages from the list | ||
user_messages = [m for m in messages if m.get("role") != "system"] | ||
resp = self.client.messages.create(model=model, | ||
system=system_prompt, | ||
messages=user_messages, | ||
temperature=temperature, | ||
max_tokens=max_tokens) | ||
|
||
return resp.content[0].text | ||
|
||
def moderate(self, message): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters