Skip to content

Commit

Permalink
Merge pull request #115 from wandb/feature/groq-support
Browse files Browse the repository at this point in the history
Add support for Groq!
  • Loading branch information
vanpelt authored May 13, 2024
2 parents 3fc829b + 6f4a26d commit 815d7cc
Show file tree
Hide file tree
Showing 10 changed files with 422 additions and 250 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.DS_Store
nohup.out
nohup.out
.cache/
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,21 @@ cd openui/backend
pip install .
# This must be set to use OpenAI models, find your api key here: https://platform.openai.com/api-keys
export OPENAI_API_KEY=xxx
# You may change the base url to use an OpenAI-compatible api by setting the OPENAI_BASE_URL environment variable
# export OPENAI_BASE_URL=https://api.myopenai.com/v1
python -m openui
```

## Groq

To use the super fast [Groq](https://groq.com) models, set `GROQ_API_KEY` to your Groq api key which you can [find here](https://console.groq.com/keys).

You can also change the default base url used for Groq (if necessary), i.e.

```bash
export GROQ_BASE_URL=https://api.groq.com/openai/v1
```

### Docker Compose

> DISCLAIMER: This is likely going to be very slow. If you have a GPU you may need to change the tag of the `ollama` container to one that supports it. If you're running on a Mac, follow the instructions above and run Ollama natively to take advantage of the M1/M2.
Expand All @@ -47,7 +59,7 @@ You can build and run the docker file manually from the `/backend` directory:

```bash
docker build . -t wandb/openui --load
docker run -p 7878:7878 -e OPENAI_API_KEY wandb/openui
docker run -p 7878:7878 -e OPENAI_API_KEY -e GROQ_API_KEY wandb/openui
```

Now you can goto [http://localhost:7878](http://localhost:7878)
Expand Down
4 changes: 4 additions & 0 deletions backend/openui/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,7 @@ class Env(Enum):
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
BUCKET_NAME = os.getenv("BUCKET_NAME", "openui")
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
69 changes: 65 additions & 4 deletions backend/openui/server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from contextlib import asynccontextmanager
from fastapi.responses import (
StreamingResponse,
Expand Down Expand Up @@ -64,7 +65,17 @@ async def lifespan(app: FastAPI):
description="API for proxying LLM requests to different services",
)

openai = AsyncOpenAI() # AsyncOpenAI(base_url="http://127.0.0.1:11434/v1")
openai = AsyncOpenAI(
base_url=config.OPENAI_BASE_URL,
api_key=config.OPENAI_API_KEY)

if config.GROQ_API_KEY is not None:
groq = AsyncOpenAI(
base_url=config.GROQ_BASE_URL,
api_key=config.GROQ_API_KEY)
else:
groq = None

ollama = AsyncClient()
ollama_openai = AsyncOpenAI(base_url=os.getenv("OLLAMA_HOST", "http://127.0.0.1:11434") + "/v1")
router = APIRouter()
Expand Down Expand Up @@ -106,6 +117,8 @@ async def chat_completions(
input_tokens = count_tokens(data["messages"])
# TODO: we always assume 4096 max tokens (random fudge factor here)
data["max_tokens"] = 4096 - input_tokens - 20
# TODO: refactor all these blocks into one once Ollama supports vision
# OpenAI Models
if data.get("model").startswith("gpt"):
if data["model"] == "gpt-4" or data["model"] == "gpt-4-32k":
raise HTTPException(status=400, data="Model not supported")
Expand All @@ -120,6 +133,21 @@ async def chat_completions(
openai_stream_generator(response, input_tokens, user_id, multiplier),
media_type="text/event-stream",
)
# Groq Models
elif data.get("model").startswith("groq/"):
data["model"] = data["model"].replace("groq/", "")
if groq is None:
raise HTTPException(status=500, detail="Groq API key is not set.")
response: AsyncStream[ChatCompletionChunk] = (
await groq.chat.completions.create(
**data,
)
)
return StreamingResponse(
openai_stream_generator(response, input_tokens, user_id, 1),
media_type="text/event-stream",
)
# Ollama Time
elif data.get("model").startswith("ollama/"):
data["model"] = data["model"].replace("ollama/", "")
data.pop("max_tokens")
Expand Down Expand Up @@ -310,10 +338,43 @@ async def create_share(id: str, payload: ShareRequest):
async def get_share(id: str):
return Response(storage.download(f"{id}.json"), media_type="application/json")

async def get_openai_models():
try:
await openai.models.list()
# We only support 3.5 and 4 for now
return ["gpt-3.5-turbo", "gpt-4-turbo"]
except Exception:
return []

async def get_ollama_models():
try:
return (await ollama.list())["models"]
except Exception as e:
logger.exception("Ollama Error: %s", e)
return []

@router.get("/v1/ollama/tags", tags="openui/ollama/tags")
async def ollama_models():
return await ollama.list()
async def get_groq_models():
try:
return (await groq.models.list()).data
except Exception as e:
logger.exception("Groq Error: %s", e)
return []

@router.get("/v1/models", tags="openui/models")
async def models():
tasks = [
get_openai_models(),
get_groq_models(),
get_ollama_models(),
]
openai_models, groq_models, ollama_models = await asyncio.gather(*tasks)
return {
"models": {
"openai": openai_models,
"groq": groq_models,
"ollama": ollama_models,
}
}


@router.get(
Expand Down
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"boto3>=1.34.67",
]
name = "openui"
version = "0.2"
version = "0.3"
description = "A backend service for generating HTML components with Ollama or OpenAI models"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
1 change: 1 addition & 0 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"@uiw/react-codemirror": "^4.22.0",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"groq-sdk": "^0.3.3",
"jotai": "^2.8.0",
"js-cookie": "^3.0.5",
"litellm": "^0.12.0",
Expand Down
Loading

0 comments on commit 815d7cc

Please sign in to comment.