Skip to content

Commit

Permalink
Merge pull request #447 from spo0nman/pkaushal/vectordb-chroma 
Browse files Browse the repository at this point in the history
feat: Add ChromaDB integration for vector storage
  • Loading branch information
LarFii authored Dec 11, 2024
2 parents 7fbd9aa + 504a3c2 commit aee622f
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 0 deletions.
172 changes: 172 additions & 0 deletions lightrag/kg/chroma_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import asyncio
from dataclasses import dataclass
from typing import Union
import numpy as np
from chromadb import HttpClient
from chromadb.config import Settings
from lightrag.base import BaseVectorStorage
from lightrag.utils import logger


@dataclass
class ChromaVectorDBStorage(BaseVectorStorage):
"""ChromaDB vector storage implementation."""

cosine_better_than_threshold: float = 0.2

def __post_init__(self):
try:
# Use global config value if specified, otherwise use default
self.cosine_better_than_threshold = self.global_config.get(
"cosine_better_than_threshold", self.cosine_better_than_threshold
)

config = self.global_config.get("vector_db_storage_cls_kwargs", {})
user_collection_settings = config.get("collection_settings", {})
# Default HNSW index settings for ChromaDB
default_collection_settings = {
# Distance metric used for similarity search (cosine similarity)
"hnsw:space": "cosine",
# Number of nearest neighbors to explore during index construction
# Higher values = better recall but slower indexing
"hnsw:construction_ef": 128,
# Number of nearest neighbors to explore during search
# Higher values = better recall but slower search
"hnsw:search_ef": 128,
# Number of connections per node in the HNSW graph
# Higher values = better recall but more memory usage
"hnsw:M": 16,
# Number of vectors to process in one batch during indexing
"hnsw:batch_size": 100,
# Number of updates before forcing index synchronization
# Lower values = more frequent syncs but slower indexing
"hnsw:sync_threshold": 1000,
}
collection_settings = {
**default_collection_settings,
**user_collection_settings,
}

auth_provider = config.get(
"auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
auth_credentials = config.get("auth_token", "secret-token")
headers = {}

if "token_authn" in auth_provider:
headers = {
config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
}
elif "basic_authn" in auth_provider:
auth_credentials = config.get("auth_credentials", "admin:admin")

self._client = HttpClient(
host=config.get("host", "localhost"),
port=config.get("port", 8000),
headers=headers,
settings=Settings(
chroma_api_impl="rest",
chroma_client_auth_provider=auth_provider,
chroma_client_auth_credentials=auth_credentials,
allow_reset=True,
anonymized_telemetry=False,
),
)

self._collection = self._client.get_or_create_collection(
name=self.namespace,
metadata={
**collection_settings,
"dimension": self.embedding_func.embedding_dim,
},
)
# Use batch size from collection settings if specified
self._max_batch_size = self.global_config.get(
"embedding_batch_num", collection_settings.get("hnsw:batch_size", 32)
)
except Exception as e:
logger.error(f"ChromaDB initialization failed: {str(e)}")
raise

async def upsert(self, data: dict[str, dict]):
if not data:
logger.warning("Empty data provided to vector DB")
return []

try:
ids = list(data.keys())
documents = [v["content"] for v in data.values()]
metadatas = [
{k: v for k, v in item.items() if k in self.meta_fields}
or {"_default": "true"}
for item in data.values()
]

# Process in batches
batches = [
documents[i : i + self._max_batch_size]
for i in range(0, len(documents), self._max_batch_size)
]

embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = []

# Pre-allocate embeddings_list with known size
embeddings_list = [None] * len(embedding_tasks)

# Use asyncio.gather instead of as_completed if order doesn't matter
embeddings_results = await asyncio.gather(*embedding_tasks)
embeddings_list = list(embeddings_results)

embeddings = np.concatenate(embeddings_list)

# Upsert in batches
for i in range(0, len(ids), self._max_batch_size):
batch_slice = slice(i, i + self._max_batch_size)

self._collection.upsert(
ids=ids[batch_slice],
embeddings=embeddings[batch_slice].tolist(),
documents=documents[batch_slice],
metadatas=metadatas[batch_slice],
)

return ids

except Exception as e:
logger.error(f"Error during ChromaDB upsert: {str(e)}")
raise

async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
try:
embedding = await self.embedding_func([query])

results = self._collection.query(
query_embeddings=embedding.tolist(),
n_results=top_k * 2, # Request more results to allow for filtering
include=["metadatas", "distances", "documents"],
)

# Filter results by cosine similarity threshold and take top k
# We request 2x results initially to have enough after filtering
# ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal)
# We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity)
# Only keep results with distance below threshold, then take top k
return [
{
"id": results["ids"][0][i],
"distance": 1 - results["distances"][0][i],
"content": results["documents"][0][i],
**results["metadatas"][0][i],
}
for i in range(len(results["ids"][0]))
if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold
][:top_k]

except Exception as e:
logger.error(f"Error during ChromaDB query: {str(e)}")
raise

async def index_done_callback(self):
# ChromaDB handles persistence automatically
pass
2 changes: 2 additions & 0 deletions lightrag/lightrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def import_class(*args, **kwargs):
OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")


def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
Expand Down Expand Up @@ -263,6 +264,7 @@ def _get_storage_class(self) -> Type[BaseGraphStorage]:
"NanoVectorDBStorage": NanoVectorDBStorage,
"OracleVectorDBStorage": OracleVectorDBStorage,
"MilvusVectorDBStorge": MilvusVectorDBStorge,
"ChromaVectorDBStorage": ChromaVectorDBStorage,
# graph storage
"NetworkXStorage": NetworkXStorage,
"Neo4JStorage": Neo4JStorage,
Expand Down
113 changes: 113 additions & 0 deletions test_chromadb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np

#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./chromadb_test_dir"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)

# ChromaDB Configuration
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
CHROMADB_AUTH_PROVIDER = os.environ.get(
"CHROMADB_AUTH_PROVIDER", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
CHROMADB_AUTH_HEADER = os.environ.get("CHROMADB_AUTH_HEADER", "X-Chroma-Token")

# Embedding Configuration and Functions
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))

# ChromaDB requires knowing the dimension of embeddings upfront when
# creating a collection. The embedding dimension is model-specific
# (e.g. text-embedding-3-large uses 3072 dimensions)
# we dynamically determine it by running a test embedding
# and then pass it to the ChromaDBStorage class


async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model=EMBEDDING_MODEL,
)


async def get_embedding_dimension():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
return embedding.shape[1]


async def create_embedding_function_instance():
# Get embedding dimension
embedding_dimension = await get_embedding_dimension()
# Create embedding function instance
return EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
)


async def initialize_rag():
embedding_func_instance = await create_embedding_function_instance()

return LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST,
"port": CHROMADB_PORT,
"auth_token": CHROMADB_AUTH_TOKEN,
"auth_provider": CHROMADB_AUTH_PROVIDER,
"auth_header_name": CHROMADB_AUTH_HEADER,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
)


# Run the initialization
rag = asyncio.run(initialize_rag())

# with open("./dickens/book.txt", "r", encoding="utf-8") as f:
# rag.insert(f.read())

# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)

# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)

# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)

# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)

0 comments on commit aee622f

Please sign in to comment.