Skip to content

Commit

Permalink
Release v1.2.0 (#29)
Browse files Browse the repository at this point in the history
* Gemini 2.0 support (#27)

* merge server gitignore into global gitignore

* reimplement llm in fastapi

* add gemini api support to llm server

* format llm_server.py

* adjust llm_classes.py comments

* add new llm server to docker & docker compose

* add phrase-level yielding to llama model class

* add phrase-level yielding to gemini2 class

* add support for gemini2 to sserver.py

* apply black formatting to server.py

* add ollama installation to llm Dockerfile

* bugfix: fix llm.py yielding binary strings to server

* add gemini instructions to .env.prod

* fix TTS server crash in prod

The TTS serer throws an exception if the input text consists only of
whitespace. This was fixed by adding a check in server.py to skip
whitespace-only phrases in audio generation.
  • Loading branch information
AshkanArabim authored Dec 20, 2024
1 parent 02eb5b2 commit fab013f
Show file tree
Hide file tree
Showing 15 changed files with 997 additions and 96 deletions.
7 changes: 5 additions & 2 deletions .env.prod
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ports
LLM_PORT=11434 # see ollama docs
LLM_PORT=11000 # avoid port 11434; it's used by the internal ollama server
TTS_PORT=5002
DB_PORT=5432
SERVER_PORT=5000
Expand All @@ -12,6 +12,9 @@ POSTGRES_DB=news_briefer
DB_USERNAME=api_handler
DB_PASSWORD=temp

# LLM configs
MODEL_NAME=llama3.2
# MODEL_NAME=gemini-2.0-flash-exp|<your_api_key>

# other stuff
JWT_SECRET_KEY=MyRandomSecretKey!
MODEL_NAME=llama3.2
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
.vscode/
*.egg-info/
build/
__pycache__/
4 changes: 4 additions & 0 deletions docker-compose.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ services:
- ./tts/tts-server.py:/app/tts-server.py

llm:
environment:
- IS_DEV=true
ports:
- ${LLM_PORT}:${LLM_PORT}
volumes:
- ./llm/src/:/app/src/

volumes:
client_dev_node_modules:
2 changes: 1 addition & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ services:
- JWT_SECRET_KEY=${JWT_SECRET_KEY}
- DB_USERNAME=${DB_USERNAME}
- DB_PASSWORD=${DB_PASSWORD}
- MODEL_NAME=${MODEL_NAME}
- SERVER_PORT=${SERVER_PORT}
restart: unless-stopped

Expand Down Expand Up @@ -63,6 +62,7 @@ services:
- llm_vol:/root/.ollama
environment:
- MODEL_NAME=${MODEL_NAME}
- LLM_PORT=${LLM_PORT}
deploy:
resources:
reservations:
Expand Down
2 changes: 2 additions & 0 deletions llm/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.venv/
__pycache__/
1 change: 1 addition & 0 deletions llm/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
43 changes: 38 additions & 5 deletions llm/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,40 @@
FROM ollama/ollama:latest
# implementation reference: https://github.com/astral-sh/uv-docker-example

# # ./llm/entrypoint.sh:/entrypoint.sh # dev
COPY ./entrypoint.sh /
# Use a Python image with uv pre-installed
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim

# overriding the deafult `ollama` entrypoint
ENTRYPOINT [ "/bin/bash", "/entrypoint.sh" ]
# install ollama
RUN apt-get update && apt-get install -y curl
RUN curl -fsSL https://ollama.com/install.sh | sh

# Install the project into `/app`
WORKDIR /app

# Enable bytecode compilation
ENV UV_COMPILE_BYTECODE=1

# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy

# Install the project's dependencies using the lockfile and settings
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
uv sync --frozen --no-install-project --no-dev

# Then, add the rest of the project source code and install it
# Installing separately from its dependencies allows optimal layer caching
ADD . /app
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --no-dev

# Place executables in the environment at the front of the path
ENV PATH="/app/.venv/bin:$PATH"

# Reset the entrypoint, don't invoke `uv`
ENTRYPOINT []

# Run the FastAPI application by default
# Uses `fastapi dev` to enable hot-reloading when the `watch` sync occurs
# Uses `--host 0.0.0.0` to allow access from outside the container
CMD ["python", "./src/llm_server.py"]
18 changes: 0 additions & 18 deletions llm/entrypoint.sh

This file was deleted.

12 changes: 12 additions & 0 deletions llm/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[project]
name = "llm"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"fastapi>=0.115.6",
"google-generativeai>=0.8.3",
"ollama>=0.4.4",
"uvicorn>=0.34.0",
]
119 changes: 119 additions & 0 deletions llm/src/llm_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from abc import ABC, abstractmethod
import ollama
import subprocess
import time
import google.generativeai as genai
import re


class AbstractModel(ABC):
def __init__(self):
# match all punctuation followed by whitespace
# src: https://www.freecodecamp.org/news/what-is-punct-in-regex-how-to-match-all-punctuation-marks-in-regular-expressions/
# I can't hardcode punctuation characters since the output won't
# just be in English.
# TODO: find a way to at least ignore commas and quotes in languages
self.delimiter = r"[^\w\s]+"

@abstractmethod
def generate(self, prompt: str):
"""
Generates a streaming response for the given prompt.
Yields one phrase at a time. A phrase is a sequence of words ending with
a punctuation mark.
"""
pass


class Llama(AbstractModel):
def __init__(self, llama_name: str):
super().__init__()

# src: https://stackoverflow.com/a/78501628/14751074
# translated the answer's bash script logic to python

# start the ollama server in the background
print("Waiting for Ollama server to start")
subprocess.Popen(
["ollama", "serve"],
start_new_session=True,
)
time.sleep(0.5) # ugly hack.

# download the model if it's not here
print("Downloading if'", llama_name, "'not already downloaded...", flush=True)
ollama.pull(llama_name)

self.model_name = llama_name

def generate(self, prompt: str):
stream = ollama.generate(model=self.model_name, prompt=prompt, stream=True)

# llama outputs mostly in words, so I'm piecing them together and
# returning when a phrase is complete.
phrase_word_list = []
for chunk_obj in stream:
chunk = chunk_obj["response"]
phrase_word_list.append(chunk)

if re.search(self.delimiter, chunk):
phrase = "".join(phrase_word_list)
phrase_word_list = []
yield phrase

# if no more responses, return whatever's left
if phrase_word_list:
yield "".join(phrase_word_list)


class Gemini2(AbstractModel):
def __init__(self, model_name: str, api_key: str):
super().__init__()

# src: copied code from Google's AI studio
genai.configure(api_key=api_key)

# Create the model
generation_config = {
"temperature": 1,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 8192,
"response_mime_type": "text/plain",
}

self.model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
)

def generate(self, prompt: str):
# src: https://github.com/google-gemini/generative-ai-python/blob/main/docs/api/google/generativeai/GenerativeModel.md#generate_content
stream = self.model.generate_content(prompt, stream=True)

# gemini's phrase size is between one sentence and one paragraph.
# I'll have to split the content into words and iterate over them
phrase_word_list = []
for chunk_obj in stream:
chunk = chunk_obj.text

# split phrase into words (including the whitespace) & append to list
space_delimiter = r"\s"
whitespaces = re.findall(space_delimiter, chunk)
words = re.split(space_delimiter, chunk)
for i in range(len(whitespaces)):
words[i] += whitespaces[i]

# iterate over the words, yield when a phrase is completed
for word in words:

phrase_word_list.append(word)
if re.search(self.delimiter, word):
phrase = "".join(phrase_word_list)
phrase_word_list = []
yield phrase

# if no more responses, return whatever's left
if phrase_word_list:
yield "".join(phrase_word_list)
60 changes: 60 additions & 0 deletions llm/src/llm_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import uvicorn
from contextlib import asynccontextmanager
import os
import llm_classes


model = None


@asynccontextmanager
async def lifespan(app: FastAPI):
# startup and shutdown events
# see https://fastapi.tiangolo.com/advanced/events/

global model

# startup
# get model name
model_name = os.environ.get("MODEL_NAME")
assert model_name is not None, "MODEL_NAME env variable not set!"

if model_name.startswith("llama"):
model = llm_classes.Llama(llama_name=model_name)
elif model_name.startswith("gemini"):
gemini_model_name, api_key = model_name.split("|")

model = llm_classes.Gemini2(model_name=gemini_model_name, api_key=api_key)
else:
raise Exception("Invalid model name:", model_name)

# app starts here
yield

# shutdown
# empty for now...


app = FastAPI(lifespan=lifespan)


@app.get("/api/generate")
def generate(prompt: str):
def generate_stream():
for phrase in model.generate(prompt):
phrase = phrase.encode()
yield phrase

return StreamingResponse(generate_stream())


if __name__ == "__main__":
port = os.environ.get("LLM_PORT")
assert port is not None, "LLM_PORT in llm-server is not set!"
port = int(port)

is_dev = os.environ.get("IS_DEV", "false").lower() == "true"

uvicorn.run(__name__ + ":app", host="0.0.0.0", port=port, reload=is_dev)
Loading

0 comments on commit fab013f

Please sign in to comment.