Skip to content

Commit

Permalink
Add support for MemoryDB
Browse files Browse the repository at this point in the history
  • Loading branch information
Baswanth Vegunta committed Jul 2, 2024
1 parent 09306a0 commit d8b1407
Show file tree
Hide file tree
Showing 6 changed files with 402 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ elastic = [ "elasticsearch" ]
pgvector = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvecto_rs = [ "psycopg2" ]
redis = [ "redis" ]
memorydb = [ "memorydb" ]
chromadb = [ "chromadb" ]
zilliz_cloud = []

Expand Down
9 changes: 9 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class DB(Enum):
PgVector = "PgVector"
PgVectoRS = "PgVectoRS"
Redis = "Redis"
MemoryDB = "MemoryDB"
Chroma = "Chroma"
Test = "test"

Expand Down Expand Up @@ -73,6 +74,10 @@ def init_cls(self) -> Type[VectorDB]:
if self == DB.Redis:
from .redis.redis import Redis
return Redis

if self == DB.MemoryDB:
from .memorydb.memorydb import MemoryDB
return MemoryDB

if self == DB.Chroma:
from .chroma.chroma import ChromaClient
Expand Down Expand Up @@ -116,6 +121,10 @@ def config_cls(self) -> Type[DBConfig]:
if self == DB.Redis:
from .redis.config import RedisConfig
return RedisConfig

if self == DB.MemoryDB:
from .memorydb.config import MemoryDBConfig
return MemoryDBConfig

if self == DB.Chroma:
from .chroma.config import ChromaConfig
Expand Down
80 changes: 80 additions & 0 deletions vectordb_bench/backend/clients/memorydb/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Annotated, TypedDict, Unpack

import click
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
HNSWFlavor2,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from .. import DB


class MemoryDBTypedDict(TypedDict):
host: Annotated[
str, click.option("--host", type=str, help="Db host", required=True)
]
password: Annotated[str, click.option("--password", type=str, help="Db password")]
port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
ssl: Annotated[
bool,
click.option(
"--ssl/--no-ssl",
is_flag=True,
show_default=True,
default=True,
help="Enable or disable SSL for Redis",
),
]
ssl_ca_certs: Annotated[
str,
click.option(
"--ssl-ca-certs",
show_default=True,
help="Path to certificate authority file to use for SSL",
),
]
cmd: Annotated[
bool,
click.option(
"--cmd",
is_flag=True,
show_default=True,
default=False,
help="Cluster Mode Disabled (CMD) for Redis doesn't use Cluster conn",
),
]


class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2):
...


@cli.command()
@click_parameter_decorators_from_typed_dict(MemoryDBHNSWTypedDict)
def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]):
from .config import MemoryDBConfig, MemoryDBHNSWConfig

run(
db=DB.MemoryDB,
db_config=MemoryDBConfig(
db_label=parameters["db_label"],
password=SecretStr(parameters["password"])
if parameters["password"]
else None,
host=SecretStr(parameters["host"]),
port=parameters["port"],
ssl=parameters["ssl"],
ssl_ca_certs=parameters["ssl_ca_certs"],
cmd=parameters["cmd"],
),
db_case_config=MemoryDBHNSWConfig(
M=parameters["m"],
ef_construction=parameters["ef_construction"],
ef_runtime=parameters["ef_runtime"],
),
**parameters,
)
54 changes: 54 additions & 0 deletions vectordb_bench/backend/clients/memorydb/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from pydantic import BaseModel, SecretStr

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType


class MemoryDBConfig(DBConfig):
host: SecretStr
password: SecretStr | None = None
port: int | None = None
ssl: bool | None = None
cmd: bool | None = None
ssl_ca_certs: str | None = None

def to_dict(self) -> dict:
return {
"host": self.host.get_secret_value(),
"port": self.port,
"password": self.password.get_secret_value() if self.password else None,
"ssl": self.ssl,
"cmd": self.cmd,
"ssl_ca_certs": self.ssl_ca_certs,
}


class MemoryDBIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
insert_batch_size: int | None = 10 # Adjust this as needed, but don't make too big

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2"
elif self.metric_type == MetricType.IP:
return "ip"
return "cosine"


class MemoryDBHNSWConfig(MemoryDBIndexConfig):
M: int | None = 16
ef_construction: int | None = 64
ef_runtime: int | None = 10
index: IndexType = IndexType.HNSW

def index_param(self) -> dict:
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
"m": self.M,
"ef_construction": self.ef_construction,
}

def search_param(self) -> dict:
return {
"ef_runtime": self.ef_runtime,
}
Loading

0 comments on commit d8b1407

Please sign in to comment.