-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Gemini 2.0 support (#27) * merge server gitignore into global gitignore * reimplement llm in fastapi * add gemini api support to llm server * format llm_server.py * adjust llm_classes.py comments * add new llm server to docker & docker compose * add phrase-level yielding to llama model class * add phrase-level yielding to gemini2 class * add support for gemini2 to sserver.py * apply black formatting to server.py * add ollama installation to llm Dockerfile * bugfix: fix llm.py yielding binary strings to server * add gemini instructions to .env.prod * fix TTS server crash in prod The TTS serer throws an exception if the input text consists only of whitespace. This was fixed by adding a check in server.py to skip whitespace-only phrases in audio generation.
- Loading branch information
1 parent
02eb5b2
commit fab013f
Showing
15 changed files
with
997 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
.vscode/ | ||
*.egg-info/ | ||
build/ | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.venv/ | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.12 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,40 @@ | ||
FROM ollama/ollama:latest | ||
# implementation reference: https://github.com/astral-sh/uv-docker-example | ||
|
||
# # ./llm/entrypoint.sh:/entrypoint.sh # dev | ||
COPY ./entrypoint.sh / | ||
# Use a Python image with uv pre-installed | ||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim | ||
|
||
# overriding the deafult `ollama` entrypoint | ||
ENTRYPOINT [ "/bin/bash", "/entrypoint.sh" ] | ||
# install ollama | ||
RUN apt-get update && apt-get install -y curl | ||
RUN curl -fsSL https://ollama.com/install.sh | sh | ||
|
||
# Install the project into `/app` | ||
WORKDIR /app | ||
|
||
# Enable bytecode compilation | ||
ENV UV_COMPILE_BYTECODE=1 | ||
|
||
# Copy from the cache instead of linking since it's a mounted volume | ||
ENV UV_LINK_MODE=copy | ||
|
||
# Install the project's dependencies using the lockfile and settings | ||
RUN --mount=type=cache,target=/root/.cache/uv \ | ||
--mount=type=bind,source=uv.lock,target=uv.lock \ | ||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \ | ||
uv sync --frozen --no-install-project --no-dev | ||
|
||
# Then, add the rest of the project source code and install it | ||
# Installing separately from its dependencies allows optimal layer caching | ||
ADD . /app | ||
RUN --mount=type=cache,target=/root/.cache/uv \ | ||
uv sync --frozen --no-dev | ||
|
||
# Place executables in the environment at the front of the path | ||
ENV PATH="/app/.venv/bin:$PATH" | ||
|
||
# Reset the entrypoint, don't invoke `uv` | ||
ENTRYPOINT [] | ||
|
||
# Run the FastAPI application by default | ||
# Uses `fastapi dev` to enable hot-reloading when the `watch` sync occurs | ||
# Uses `--host 0.0.0.0` to allow access from outside the container | ||
CMD ["python", "./src/llm_server.py"] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[project] | ||
name = "llm" | ||
version = "0.1.0" | ||
description = "Add your description here" | ||
readme = "README.md" | ||
requires-python = ">=3.10" | ||
dependencies = [ | ||
"fastapi>=0.115.6", | ||
"google-generativeai>=0.8.3", | ||
"ollama>=0.4.4", | ||
"uvicorn>=0.34.0", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from abc import ABC, abstractmethod | ||
import ollama | ||
import subprocess | ||
import time | ||
import google.generativeai as genai | ||
import re | ||
|
||
|
||
class AbstractModel(ABC): | ||
def __init__(self): | ||
# match all punctuation followed by whitespace | ||
# src: https://www.freecodecamp.org/news/what-is-punct-in-regex-how-to-match-all-punctuation-marks-in-regular-expressions/ | ||
# I can't hardcode punctuation characters since the output won't | ||
# just be in English. | ||
# TODO: find a way to at least ignore commas and quotes in languages | ||
self.delimiter = r"[^\w\s]+" | ||
|
||
@abstractmethod | ||
def generate(self, prompt: str): | ||
""" | ||
Generates a streaming response for the given prompt. | ||
Yields one phrase at a time. A phrase is a sequence of words ending with | ||
a punctuation mark. | ||
""" | ||
pass | ||
|
||
|
||
class Llama(AbstractModel): | ||
def __init__(self, llama_name: str): | ||
super().__init__() | ||
|
||
# src: https://stackoverflow.com/a/78501628/14751074 | ||
# translated the answer's bash script logic to python | ||
|
||
# start the ollama server in the background | ||
print("Waiting for Ollama server to start") | ||
subprocess.Popen( | ||
["ollama", "serve"], | ||
start_new_session=True, | ||
) | ||
time.sleep(0.5) # ugly hack. | ||
|
||
# download the model if it's not here | ||
print("Downloading if'", llama_name, "'not already downloaded...", flush=True) | ||
ollama.pull(llama_name) | ||
|
||
self.model_name = llama_name | ||
|
||
def generate(self, prompt: str): | ||
stream = ollama.generate(model=self.model_name, prompt=prompt, stream=True) | ||
|
||
# llama outputs mostly in words, so I'm piecing them together and | ||
# returning when a phrase is complete. | ||
phrase_word_list = [] | ||
for chunk_obj in stream: | ||
chunk = chunk_obj["response"] | ||
phrase_word_list.append(chunk) | ||
|
||
if re.search(self.delimiter, chunk): | ||
phrase = "".join(phrase_word_list) | ||
phrase_word_list = [] | ||
yield phrase | ||
|
||
# if no more responses, return whatever's left | ||
if phrase_word_list: | ||
yield "".join(phrase_word_list) | ||
|
||
|
||
class Gemini2(AbstractModel): | ||
def __init__(self, model_name: str, api_key: str): | ||
super().__init__() | ||
|
||
# src: copied code from Google's AI studio | ||
genai.configure(api_key=api_key) | ||
|
||
# Create the model | ||
generation_config = { | ||
"temperature": 1, | ||
"top_p": 0.95, | ||
"top_k": 40, | ||
"max_output_tokens": 8192, | ||
"response_mime_type": "text/plain", | ||
} | ||
|
||
self.model = genai.GenerativeModel( | ||
model_name=model_name, | ||
generation_config=generation_config, | ||
) | ||
|
||
def generate(self, prompt: str): | ||
# src: https://github.com/google-gemini/generative-ai-python/blob/main/docs/api/google/generativeai/GenerativeModel.md#generate_content | ||
stream = self.model.generate_content(prompt, stream=True) | ||
|
||
# gemini's phrase size is between one sentence and one paragraph. | ||
# I'll have to split the content into words and iterate over them | ||
phrase_word_list = [] | ||
for chunk_obj in stream: | ||
chunk = chunk_obj.text | ||
|
||
# split phrase into words (including the whitespace) & append to list | ||
space_delimiter = r"\s" | ||
whitespaces = re.findall(space_delimiter, chunk) | ||
words = re.split(space_delimiter, chunk) | ||
for i in range(len(whitespaces)): | ||
words[i] += whitespaces[i] | ||
|
||
# iterate over the words, yield when a phrase is completed | ||
for word in words: | ||
|
||
phrase_word_list.append(word) | ||
if re.search(self.delimiter, word): | ||
phrase = "".join(phrase_word_list) | ||
phrase_word_list = [] | ||
yield phrase | ||
|
||
# if no more responses, return whatever's left | ||
if phrase_word_list: | ||
yield "".join(phrase_word_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from fastapi import FastAPI | ||
from fastapi.responses import StreamingResponse | ||
import uvicorn | ||
from contextlib import asynccontextmanager | ||
import os | ||
import llm_classes | ||
|
||
|
||
model = None | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
# startup and shutdown events | ||
# see https://fastapi.tiangolo.com/advanced/events/ | ||
|
||
global model | ||
|
||
# startup | ||
# get model name | ||
model_name = os.environ.get("MODEL_NAME") | ||
assert model_name is not None, "MODEL_NAME env variable not set!" | ||
|
||
if model_name.startswith("llama"): | ||
model = llm_classes.Llama(llama_name=model_name) | ||
elif model_name.startswith("gemini"): | ||
gemini_model_name, api_key = model_name.split("|") | ||
|
||
model = llm_classes.Gemini2(model_name=gemini_model_name, api_key=api_key) | ||
else: | ||
raise Exception("Invalid model name:", model_name) | ||
|
||
# app starts here | ||
yield | ||
|
||
# shutdown | ||
# empty for now... | ||
|
||
|
||
app = FastAPI(lifespan=lifespan) | ||
|
||
|
||
@app.get("/api/generate") | ||
def generate(prompt: str): | ||
def generate_stream(): | ||
for phrase in model.generate(prompt): | ||
phrase = phrase.encode() | ||
yield phrase | ||
|
||
return StreamingResponse(generate_stream()) | ||
|
||
|
||
if __name__ == "__main__": | ||
port = os.environ.get("LLM_PORT") | ||
assert port is not None, "LLM_PORT in llm-server is not set!" | ||
port = int(port) | ||
|
||
is_dev = os.environ.get("IS_DEV", "false").lower() == "true" | ||
|
||
uvicorn.run(__name__ + ":app", host="0.0.0.0", port=port, reload=is_dev) |
Oops, something went wrong.