diff --git a/pyproject.toml b/pyproject.toml index 1311cceef..c89459919 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ elastic = [ "elasticsearch" ] pgvector = [ "psycopg", "psycopg-binary", "pgvector" ] pgvecto_rs = [ "psycopg2" ] redis = [ "redis" ] +memorydb = [ "memorydb" ] chromadb = [ "chromadb" ] zilliz_cloud = [] diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index 6a99661c2..6abaf32fe 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -31,6 +31,7 @@ class DB(Enum): PgVector = "PgVector" PgVectoRS = "PgVectoRS" Redis = "Redis" + MemoryDB = "MemoryDB" Chroma = "Chroma" Test = "test" @@ -73,6 +74,10 @@ def init_cls(self) -> Type[VectorDB]: if self == DB.Redis: from .redis.redis import Redis return Redis + + if self == DB.MemoryDB: + from .memorydb.memorydb import MemoryDB + return MemoryDB if self == DB.Chroma: from .chroma.chroma import ChromaClient @@ -116,6 +121,10 @@ def config_cls(self) -> Type[DBConfig]: if self == DB.Redis: from .redis.config import RedisConfig return RedisConfig + + if self == DB.MemoryDB: + from .memorydb.config import MemoryDBConfig + return MemoryDBConfig if self == DB.Chroma: from .chroma.config import ChromaConfig diff --git a/vectordb_bench/backend/clients/memorydb/cli.py b/vectordb_bench/backend/clients/memorydb/cli.py new file mode 100644 index 000000000..29a812e96 --- /dev/null +++ b/vectordb_bench/backend/clients/memorydb/cli.py @@ -0,0 +1,80 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor2, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB + + +class MemoryDBTypedDict(TypedDict): + host: Annotated[ + str, click.option("--host", type=str, help="Db host", required=True) + ] + password: Annotated[str, click.option("--password", type=str, help="Db password")] + port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")] + ssl: Annotated[ + bool, + click.option( + "--ssl/--no-ssl", + is_flag=True, + show_default=True, + default=True, + help="Enable or disable SSL for Redis", + ), + ] + ssl_ca_certs: Annotated[ + str, + click.option( + "--ssl-ca-certs", + show_default=True, + help="Path to certificate authority file to use for SSL", + ), + ] + cmd: Annotated[ + bool, + click.option( + "--cmd", + is_flag=True, + show_default=True, + default=False, + help="Cluster Mode Disabled (CMD) for Redis doesn't use Cluster conn", + ), + ] + + +class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2): + ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MemoryDBHNSWTypedDict) +def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]): + from .config import MemoryDBConfig, MemoryDBHNSWConfig + + run( + db=DB.MemoryDB, + db_config=MemoryDBConfig( + db_label=parameters["db_label"], + password=SecretStr(parameters["password"]) + if parameters["password"] + else None, + host=SecretStr(parameters["host"]), + port=parameters["port"], + ssl=parameters["ssl"], + ssl_ca_certs=parameters["ssl_ca_certs"], + cmd=parameters["cmd"], + ), + db_case_config=MemoryDBHNSWConfig( + M=parameters["m"], + ef_construction=parameters["ef_construction"], + ef_runtime=parameters["ef_runtime"], + ), + **parameters, + ) \ No newline at end of file diff --git a/vectordb_bench/backend/clients/memorydb/config.py b/vectordb_bench/backend/clients/memorydb/config.py new file mode 100644 index 000000000..94f2478b1 --- /dev/null +++ b/vectordb_bench/backend/clients/memorydb/config.py @@ -0,0 +1,54 @@ +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class MemoryDBConfig(DBConfig): + host: SecretStr + password: SecretStr | None = None + port: int | None = None + ssl: bool | None = None + cmd: bool | None = None + ssl_ca_certs: str | None = None + + def to_dict(self) -> dict: + return { + "host": self.host.get_secret_value(), + "port": self.port, + "password": self.password.get_secret_value() if self.password else None, + "ssl": self.ssl, + "cmd": self.cmd, + "ssl_ca_certs": self.ssl_ca_certs, + } + + +class MemoryDBIndexConfig(BaseModel, DBCaseConfig): + metric_type: MetricType | None = None + insert_batch_size: int | None = 10 # Adjust this as needed, but don't make too big + + def parse_metric(self) -> str: + if self.metric_type == MetricType.L2: + return "l2" + elif self.metric_type == MetricType.IP: + return "ip" + return "cosine" + + +class MemoryDBHNSWConfig(MemoryDBIndexConfig): + M: int | None = 16 + ef_construction: int | None = 64 + ef_runtime: int | None = 10 + index: IndexType = IndexType.HNSW + + def index_param(self) -> dict: + return { + "metric": self.parse_metric(), + "index_type": self.index.value, + "m": self.M, + "ef_construction": self.ef_construction, + } + + def search_param(self) -> dict: + return { + "ef_runtime": self.ef_runtime, + } \ No newline at end of file diff --git a/vectordb_bench/backend/clients/memorydb/memorydb.py b/vectordb_bench/backend/clients/memorydb/memorydb.py new file mode 100644 index 000000000..fea34dca2 --- /dev/null +++ b/vectordb_bench/backend/clients/memorydb/memorydb.py @@ -0,0 +1,256 @@ +import logging, time +from contextlib import contextmanager +from typing import Any, Generator, Optional, Tuple, Type +from ..api import VectorDB, DBCaseConfig, IndexType +from .config import MemoryDBIndexConfig +import redis +from redis import Redis +from redis.cluster import RedisCluster +from redis.commands.search.field import TagField, VectorField, NumericField +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import Query +import numpy as np + + +log = logging.getLogger(__name__) +INDEX_NAME = "index" # Vector Index Name + +class MemoryDB(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: MemoryDBIndexConfig, + drop_old: bool = False, + **kwargs + ): + + self.db_config = db_config + self.case_config = db_case_config + self.collection_name = INDEX_NAME + self.target_nodes = RedisCluster.RANDOM if not self.db_config["cmd"] else None + self.insert_batch_size = db_case_config.insert_batch_size or 1 + self.dbsize = kwargs.get("num_rows") + + # Create a redis connection, if db has password configured, add it to the connection here and in init(): + log.info(f"Redis establishing connection to: {self.db_config}") + conn = self.get_client(primary=True) + log.info(f"Connection established: {conn}") + log.info(conn.execute_command("INFO server")) + + if drop_old: + try: + log.info(f"Redis client getting info for: {INDEX_NAME}") + info = conn.ft(INDEX_NAME).info() + log.info(f"Index info: {info}") + except redis.exceptions.ResponseError as e: + log.error(e) + drop_old = False + log.info(f"Redis client drop_old collection: {self.collection_name}") + + log.info("Executing FLUSHALL") + conn.flushall() + + # Since the default behaviour of FLUSHALL is asynchronous, wait for db to be empty + self.wait_until(self.wait_for_empty_db, 3, "", conn) + if not self.db_config["cmd"]: + replica_clients = self.get_client(replicas=True) + for rc, host in replica_clients: + self.wait_until(self.wait_for_empty_db, 3, "", rc) + log.debug(f"Flushall done in the host: {host}") + rc.close() + + self.make_index(dim, conn) + conn.close() + conn = None + + def make_index(self, vector_dimensions: int, conn: redis.Redis): + try: + # check to see if index exists + conn.ft(INDEX_NAME).info() + except Exception as e: + log.warn(f"Error getting info for index '{INDEX_NAME}': {e}") + index_param = self.case_config.index_param() + search_param = self.case_config.search_param() + vector_parameters = { # Vector Index Type: FLAT or HNSW + "TYPE": "FLOAT32", # FLOAT32 or FLOAT64 + "DIM": vector_dimensions, # Number of Vector Dimensions + "DISTANCE_METRIC": index_param[ + "metric" + ], # Vector Search Distance Metric + } + if index_param["m"]: + vector_parameters["M"] = index_param["m"] + if index_param["ef_construction"]: + vector_parameters["EF_CONSTRUCTION"] = index_param["ef_construction"] + if search_param["ef_runtime"]: + vector_parameters["EF_RUNTIME"] = search_param["ef_runtime"] + + schema = ( + TagField("id"), + NumericField("metadata"), + VectorField("vector", # Vector Field Name + "HNSW", vector_parameters + ), + ) + + definition = IndexDefinition(index_type=IndexType.HASH) + rs = conn.ft(INDEX_NAME) + rs.create_index(schema, definition=definition) + + def get_client(self, **kwargs): + """ + Gets either cluster connection or normal redis connection based on `cmd` flag. + CMD stands for Cluster Mode Disabled and is a "mode" for Redis. + """ + if not self.db_config["cmd"]: + # Cluster mode enabled + + client = RedisCluster( + host=self.db_config["host"], + port=self.db_config["port"], + ssl=self.db_config["ssl"], + password=self.db_config["password"], + ssl_ca_certs=self.db_config["ssl_ca_certs"], + ssl_cert_reqs=None, + ) + + # Ping all nodes to create a connection + client.execute_command("PING", target_nodes=RedisCluster.ALL_NODES) + replicas = client.get_replicas() + + if len(replicas) > 0: + # FT.SEARCH is a keyless command, use READONLY for replica connections + client.execute_command("READONLY", target_nodes=RedisCluster.REPLICAS) + + if kwargs.get("primary", False): + client = client.get_primaries()[0].redis_connection + + if kwargs.get("replicas", False): + # Return client and host name for each replica + return [(c.redis_connection, c.host) for c in replicas] + + else: + client = Redis( + host=self.db_config["host"], + port=self.db_config["port"], + db=0, + ssl=self.db_config["ssl"], + password=self.db_config["password"], + ssl_ca_certs=self.db_config["ssl_ca_certs"], + ssl_cert_reqs=None, + ) + client.execute_command("PING") + return client + + @contextmanager + def init(self) -> Generator[None, None, None]: + """ create and destory connections to database. + + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + """ + self.conn = self.get_client() + search_param = self.case_config.search_param() + if search_param["ef_runtime"]: + self.ef_runtime_str = f'EF_RUNTIME {search_param["ef_runtime"]}' + else: + self.ef_runtime_str = "" + yield + self.conn.close() + self.conn = None + + def ready_to_load(self) -> bool: + pass + + def optimize(self) -> None: + self._post_insert() + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + """Insert embeddings into the database. + Should call self.init() first. + """ + + try: + with self.conn.pipeline(transaction=False) as pipe: + for i, embedding in enumerate(embeddings): + embedding = np.array(embedding).astype(np.float32) + pipe.hset(metadata[i], mapping = { + "id": str(metadata[i]), + "metadata": metadata[i], + "vector": embedding.tobytes(), + }) + # Execute the pipe so we don't keep too much in memory at once + if (i + 1) % self.insert_batch_size == 0: + pipe.execute() + + pipe.execute() + result_len = i + 1 + except Exception as e: + return 0, e + + return result_len, None + + def _post_insert(self): + """Wait for indexing to finish""" + client = self.get_client(primary=True) + log.info("Waiting for background indexing to finish") + args = (self.wait_for_no_activity, 5, "", client) + self.wait_until(*args) + if not self.db_config["cmd"]: + replica_clients = self.get_client(replicas=True) + for rc, host_name in replica_clients: + args = (self.wait_for_no_activity, 5, "", rc) + self.wait_until(*args) + log.debug(f"Background indexing completed in the host: {host_name}") + rc.close() + + def wait_until( + self, condition, interval=5, message="Operation took too long", *args + ): + while not condition(*args): + time.sleep(interval) + + def wait_for_no_activity(self, client: redis.RedisCluster | redis.Redis): + return ( + client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY" + ) + + def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis): + return client.execute_command("DBSIZE") == 0 + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> (list[int]): + assert self.conn is not None + + query_vector = np.array(query).astype(np.float32).tobytes() + query_obj = Query(f"*=>[KNN {k} @vector $vec as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2) + query_params = {"vec": query_vector} + + if filters: + # benchmark test filters of format: {'metadata': '>=10000', 'id': 10000} + # gets exact match for id, and range for metadata if they exist in filters + id_value = filters.get("id") + metadata_value = filters.get("metadata") + if id_value and metadata_value: + query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2) + elif id_value: + #gets exact match for id + query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2) + else: #metadata only case, greater than or equal to metadata value + query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2) + res = self.conn.ft(INDEX_NAME).search(query_obj, query_params) + # doc in res of format {'id': '9831', 'payload': None, 'score': '1.19209289551e-07'} + return [int(doc["id"]) for doc in res.docs] \ No newline at end of file diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index 396909cd5..390b6d9eb 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,5 +1,6 @@ from ..backend.clients.pgvector.cli import PgVectorHNSW from ..backend.clients.redis.cli import Redis +from ..backend.clients.memorydb.cli import MemoryDB from ..backend.clients.test.cli import Test from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex @@ -10,6 +11,7 @@ cli.add_command(PgVectorHNSW) cli.add_command(Redis) +cli.add_command(MemoryDB) cli.add_command(Weaviate) cli.add_command(Test) cli.add_command(ZillizAutoIndex)