-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #447 from spo0nman/pkaushal/vectordb-chroma
feat: Add ChromaDB integration for vector storage
- Loading branch information
Showing
3 changed files
with
287 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import asyncio | ||
from dataclasses import dataclass | ||
from typing import Union | ||
import numpy as np | ||
from chromadb import HttpClient | ||
from chromadb.config import Settings | ||
from lightrag.base import BaseVectorStorage | ||
from lightrag.utils import logger | ||
|
||
|
||
@dataclass | ||
class ChromaVectorDBStorage(BaseVectorStorage): | ||
"""ChromaDB vector storage implementation.""" | ||
|
||
cosine_better_than_threshold: float = 0.2 | ||
|
||
def __post_init__(self): | ||
try: | ||
# Use global config value if specified, otherwise use default | ||
self.cosine_better_than_threshold = self.global_config.get( | ||
"cosine_better_than_threshold", self.cosine_better_than_threshold | ||
) | ||
|
||
config = self.global_config.get("vector_db_storage_cls_kwargs", {}) | ||
user_collection_settings = config.get("collection_settings", {}) | ||
# Default HNSW index settings for ChromaDB | ||
default_collection_settings = { | ||
# Distance metric used for similarity search (cosine similarity) | ||
"hnsw:space": "cosine", | ||
# Number of nearest neighbors to explore during index construction | ||
# Higher values = better recall but slower indexing | ||
"hnsw:construction_ef": 128, | ||
# Number of nearest neighbors to explore during search | ||
# Higher values = better recall but slower search | ||
"hnsw:search_ef": 128, | ||
# Number of connections per node in the HNSW graph | ||
# Higher values = better recall but more memory usage | ||
"hnsw:M": 16, | ||
# Number of vectors to process in one batch during indexing | ||
"hnsw:batch_size": 100, | ||
# Number of updates before forcing index synchronization | ||
# Lower values = more frequent syncs but slower indexing | ||
"hnsw:sync_threshold": 1000, | ||
} | ||
collection_settings = { | ||
**default_collection_settings, | ||
**user_collection_settings, | ||
} | ||
|
||
auth_provider = config.get( | ||
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider" | ||
) | ||
auth_credentials = config.get("auth_token", "secret-token") | ||
headers = {} | ||
|
||
if "token_authn" in auth_provider: | ||
headers = { | ||
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials | ||
} | ||
elif "basic_authn" in auth_provider: | ||
auth_credentials = config.get("auth_credentials", "admin:admin") | ||
|
||
self._client = HttpClient( | ||
host=config.get("host", "localhost"), | ||
port=config.get("port", 8000), | ||
headers=headers, | ||
settings=Settings( | ||
chroma_api_impl="rest", | ||
chroma_client_auth_provider=auth_provider, | ||
chroma_client_auth_credentials=auth_credentials, | ||
allow_reset=True, | ||
anonymized_telemetry=False, | ||
), | ||
) | ||
|
||
self._collection = self._client.get_or_create_collection( | ||
name=self.namespace, | ||
metadata={ | ||
**collection_settings, | ||
"dimension": self.embedding_func.embedding_dim, | ||
}, | ||
) | ||
# Use batch size from collection settings if specified | ||
self._max_batch_size = self.global_config.get( | ||
"embedding_batch_num", collection_settings.get("hnsw:batch_size", 32) | ||
) | ||
except Exception as e: | ||
logger.error(f"ChromaDB initialization failed: {str(e)}") | ||
raise | ||
|
||
async def upsert(self, data: dict[str, dict]): | ||
if not data: | ||
logger.warning("Empty data provided to vector DB") | ||
return [] | ||
|
||
try: | ||
ids = list(data.keys()) | ||
documents = [v["content"] for v in data.values()] | ||
metadatas = [ | ||
{k: v for k, v in item.items() if k in self.meta_fields} | ||
or {"_default": "true"} | ||
for item in data.values() | ||
] | ||
|
||
# Process in batches | ||
batches = [ | ||
documents[i : i + self._max_batch_size] | ||
for i in range(0, len(documents), self._max_batch_size) | ||
] | ||
|
||
embedding_tasks = [self.embedding_func(batch) for batch in batches] | ||
embeddings_list = [] | ||
|
||
# Pre-allocate embeddings_list with known size | ||
embeddings_list = [None] * len(embedding_tasks) | ||
|
||
# Use asyncio.gather instead of as_completed if order doesn't matter | ||
embeddings_results = await asyncio.gather(*embedding_tasks) | ||
embeddings_list = list(embeddings_results) | ||
|
||
embeddings = np.concatenate(embeddings_list) | ||
|
||
# Upsert in batches | ||
for i in range(0, len(ids), self._max_batch_size): | ||
batch_slice = slice(i, i + self._max_batch_size) | ||
|
||
self._collection.upsert( | ||
ids=ids[batch_slice], | ||
embeddings=embeddings[batch_slice].tolist(), | ||
documents=documents[batch_slice], | ||
metadatas=metadatas[batch_slice], | ||
) | ||
|
||
return ids | ||
|
||
except Exception as e: | ||
logger.error(f"Error during ChromaDB upsert: {str(e)}") | ||
raise | ||
|
||
async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]: | ||
try: | ||
embedding = await self.embedding_func([query]) | ||
|
||
results = self._collection.query( | ||
query_embeddings=embedding.tolist(), | ||
n_results=top_k * 2, # Request more results to allow for filtering | ||
include=["metadatas", "distances", "documents"], | ||
) | ||
|
||
# Filter results by cosine similarity threshold and take top k | ||
# We request 2x results initially to have enough after filtering | ||
# ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal) | ||
# We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity) | ||
# Only keep results with distance below threshold, then take top k | ||
return [ | ||
{ | ||
"id": results["ids"][0][i], | ||
"distance": 1 - results["distances"][0][i], | ||
"content": results["documents"][0][i], | ||
**results["metadatas"][0][i], | ||
} | ||
for i in range(len(results["ids"][0])) | ||
if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold | ||
][:top_k] | ||
|
||
except Exception as e: | ||
logger.error(f"Error during ChromaDB query: {str(e)}") | ||
raise | ||
|
||
async def index_done_callback(self): | ||
# ChromaDB handles persistence automatically | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import os | ||
import asyncio | ||
from lightrag import LightRAG, QueryParam | ||
from lightrag.llm import gpt_4o_mini_complete, openai_embedding | ||
from lightrag.utils import EmbeddingFunc | ||
import numpy as np | ||
|
||
######### | ||
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert() | ||
# import nest_asyncio | ||
# nest_asyncio.apply() | ||
######### | ||
WORKING_DIR = "./chromadb_test_dir" | ||
if not os.path.exists(WORKING_DIR): | ||
os.mkdir(WORKING_DIR) | ||
|
||
# ChromaDB Configuration | ||
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost") | ||
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000)) | ||
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token") | ||
CHROMADB_AUTH_PROVIDER = os.environ.get( | ||
"CHROMADB_AUTH_PROVIDER", "chromadb.auth.token_authn.TokenAuthClientProvider" | ||
) | ||
CHROMADB_AUTH_HEADER = os.environ.get("CHROMADB_AUTH_HEADER", "X-Chroma-Token") | ||
|
||
# Embedding Configuration and Functions | ||
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large") | ||
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192)) | ||
|
||
# ChromaDB requires knowing the dimension of embeddings upfront when | ||
# creating a collection. The embedding dimension is model-specific | ||
# (e.g. text-embedding-3-large uses 3072 dimensions) | ||
# we dynamically determine it by running a test embedding | ||
# and then pass it to the ChromaDBStorage class | ||
|
||
|
||
async def embedding_func(texts: list[str]) -> np.ndarray: | ||
return await openai_embedding( | ||
texts, | ||
model=EMBEDDING_MODEL, | ||
) | ||
|
||
|
||
async def get_embedding_dimension(): | ||
test_text = ["This is a test sentence."] | ||
embedding = await embedding_func(test_text) | ||
return embedding.shape[1] | ||
|
||
|
||
async def create_embedding_function_instance(): | ||
# Get embedding dimension | ||
embedding_dimension = await get_embedding_dimension() | ||
# Create embedding function instance | ||
return EmbeddingFunc( | ||
embedding_dim=embedding_dimension, | ||
max_token_size=EMBEDDING_MAX_TOKEN_SIZE, | ||
func=embedding_func, | ||
) | ||
|
||
|
||
async def initialize_rag(): | ||
embedding_func_instance = await create_embedding_function_instance() | ||
|
||
return LightRAG( | ||
working_dir=WORKING_DIR, | ||
llm_model_func=gpt_4o_mini_complete, | ||
embedding_func=embedding_func_instance, | ||
vector_storage="ChromaVectorDBStorage", | ||
log_level="DEBUG", | ||
embedding_batch_num=32, | ||
vector_db_storage_cls_kwargs={ | ||
"host": CHROMADB_HOST, | ||
"port": CHROMADB_PORT, | ||
"auth_token": CHROMADB_AUTH_TOKEN, | ||
"auth_provider": CHROMADB_AUTH_PROVIDER, | ||
"auth_header_name": CHROMADB_AUTH_HEADER, | ||
"collection_settings": { | ||
"hnsw:space": "cosine", | ||
"hnsw:construction_ef": 128, | ||
"hnsw:search_ef": 128, | ||
"hnsw:M": 16, | ||
"hnsw:batch_size": 100, | ||
"hnsw:sync_threshold": 1000, | ||
}, | ||
}, | ||
) | ||
|
||
|
||
# Run the initialization | ||
rag = asyncio.run(initialize_rag()) | ||
|
||
# with open("./dickens/book.txt", "r", encoding="utf-8") as f: | ||
# rag.insert(f.read()) | ||
|
||
# Perform naive search | ||
print( | ||
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) | ||
) | ||
|
||
# Perform local search | ||
print( | ||
rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) | ||
) | ||
|
||
# Perform global search | ||
print( | ||
rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) | ||
) | ||
|
||
# Perform hybrid search | ||
print( | ||
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) | ||
) |