diff --git a/README.md b/README.md
index 4ff2f9121..f093bfcfb 100644
--- a/README.md
+++ b/README.md
@@ -172,6 +172,24 @@ Case No. | Case Type | Dataset Size | Filtering Rate | Results |
Each case provides an in-depth examination of a vector database's abilities, providing you a comprehensive view of the database's performance.
+#### Custom Dataset for Performance case
+
+Through the `/custom` page, users can customize their own performance case using local datasets. After saving, the corresponding case can be selected from the `/run_test` page to perform the test.
+
+![image](fig/custom_dataset.png)
+![image](fig/custom_case_run_test.png)
+
+We have strict requirements for the data set format, please follow them.
+- `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
+ - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
+ - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
+ - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`.
+
+- `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.
+
+- `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
+
+
## Goals
Our goals of this benchmark are:
### Reproducibility & Usability
diff --git a/fig/custom_case_run_test.png b/fig/custom_case_run_test.png
new file mode 100644
index 000000000..8817b3439
Binary files /dev/null and b/fig/custom_case_run_test.png differ
diff --git a/fig/custom_dataset.png b/fig/custom_dataset.png
new file mode 100644
index 000000000..9d665891a
Binary files /dev/null and b/fig/custom_dataset.png differ
diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py
index eca190832..c68ee854d 100644
--- a/vectordb_bench/__init__.py
+++ b/vectordb_bench/__init__.py
@@ -22,6 +22,7 @@ class config:
NUM_CONCURRENCY = [1, 5, 10, 15, 20, 25, 30, 35]
RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results")
+ CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h
diff --git a/vectordb_bench/backend/assembler.py b/vectordb_bench/backend/assembler.py
index 6b0e3c81d..e7da4d49f 100644
--- a/vectordb_bench/backend/assembler.py
+++ b/vectordb_bench/backend/assembler.py
@@ -14,7 +14,7 @@ class Assembler:
def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
c_cls = task.case_config.case_id.case_cls
- c = c_cls()
+ c = c_cls(task.case_config.custom_case)
if type(task.db_case_config) != EmptyDBCaseConfig:
task.db_case_config.metric_type = c.dataset.data.metric_type
diff --git a/vectordb_bench/backend/cases.py b/vectordb_bench/backend/cases.py
index 7f40ae4f8..22e9dfbc7 100644
--- a/vectordb_bench/backend/cases.py
+++ b/vectordb_bench/backend/cases.py
@@ -3,9 +3,11 @@
from enum import Enum, auto
from vectordb_bench import config
+from vectordb_bench.backend.clients.api import MetricType
from vectordb_bench.base import BaseModel
+from vectordb_bench.frontend.components.custom.getCustomConfig import CustomDatasetConfig
-from .dataset import Dataset, DatasetManager
+from .dataset import CustomDataset, Dataset, DatasetManager
log = logging.getLogger(__name__)
@@ -43,23 +45,24 @@ class CaseType(Enum):
Performance1536D5M99P = 15
Custom = 100
+ PerformanceCustomDataset = 101
- @property
def case_cls(self, custom_configs: dict | None = None) -> Case:
- return type2case.get(self)
+ if custom_configs is None:
+ return type2case.get(self)()
+ else:
+ return type2case.get(self)(**custom_configs)
- @property
- def case_name(self) -> str:
- c = self.case_cls
+ def case_name(self, custom_configs: dict | None = None) -> str:
+ c = self.case_cls(custom_configs)
if c is not None:
- return c().name
+ return c.name
raise ValueError("Case unsupported")
- @property
- def case_description(self) -> str:
- c = self.case_cls
+ def case_description(self, custom_configs: dict | None = None) -> str:
+ c = self.case_cls(custom_configs)
if c is not None:
- return c().description
+ return c.description
raise ValueError("Case unsupported")
@@ -115,6 +118,7 @@ class PerformanceCase(Case, BaseModel):
load_timeout: float | int = config.LOAD_TIMEOUT_DEFAULT
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_DEFAULT
+
class CapacityDim960(CapacityCase):
case_id: CaseType = CaseType.CapacityDim960
dataset: DatasetManager = Dataset.GIST.manager(100_000)
@@ -238,6 +242,7 @@ class Performance1536D500K1P(PerformanceCase):
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K
+
class Performance1536D5M1P(PerformanceCase):
case_id: CaseType = CaseType.Performance1536D5M1P
filter_rate: float | int | None = 0.01
@@ -248,6 +253,7 @@ class Performance1536D5M1P(PerformanceCase):
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M
+
class Performance1536D500K99P(PerformanceCase):
case_id: CaseType = CaseType.Performance1536D500K99P
filter_rate: float | int | None = 0.99
@@ -258,6 +264,7 @@ class Performance1536D500K99P(PerformanceCase):
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K
+
class Performance1536D5M99P(PerformanceCase):
case_id: CaseType = CaseType.Performance1536D5M99P
filter_rate: float | int | None = 0.99
@@ -269,6 +276,40 @@ class Performance1536D5M99P(PerformanceCase):
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M
+def metric_type_map(s: str) -> MetricType:
+ if s.lower() == 'cosine':
+ return MetricType.COSINE
+ if s.lower() == 'l2' or s.lower() == "euclidean":
+ return MetricType.L2
+ if s.lower() == 'ip':
+ return MetricType.IP
+ err_msg = f"Not support metric_type: {s}"
+ log.error(err_msg)
+ raise RuntimeError(err_msg)
+
+
+class PerformanceCustomDataset(PerformanceCase):
+ case_id: CaseType = CaseType.PerformanceCustomDataset
+ name: str = "Performance With Custom Dataset"
+ description: str = ""
+ dataset: DatasetManager
+
+ def __init__(self, name, description, load_timeout, optimize_timeout, dataset_config, **kwargs):
+ dataset_config = CustomDatasetConfig(**dataset_config)
+ dataset = CustomDataset(
+ name=dataset_config.name,
+ size=dataset_config.size,
+ dim=dataset_config.dim,
+ metric_type=metric_type_map(dataset_config.metric_type),
+ use_shuffled=dataset_config.use_shuffled,
+ with_gt=dataset_config.with_gt,
+ dir=dataset_config.dir,
+ file_num=dataset_config.file_count,
+ )
+ super().__init__(name=name, description=description, load_timeout=load_timeout,
+ optimize_timeout=optimize_timeout, dataset=DatasetManager(data=dataset))
+
+
type2case = {
CaseType.CapacityDim960: CapacityDim960,
CaseType.CapacityDim128: CapacityDim128,
@@ -291,4 +332,5 @@ class Performance1536D5M99P(PerformanceCase):
CaseType.Performance1536D500K99P: Performance1536D500K99P,
CaseType.Performance1536D5M99P: Performance1536D5M99P,
+ CaseType.PerformanceCustomDataset: PerformanceCustomDataset,
}
diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py
index 2b630eae3..d559eb6be 100644
--- a/vectordb_bench/backend/dataset.py
+++ b/vectordb_bench/backend/dataset.py
@@ -33,6 +33,7 @@ class BaseDataset(BaseModel):
use_shuffled: bool
with_gt: bool = False
_size_label: dict[int, SizeLabel] = PrivateAttr()
+ isCustom: bool = False
@validator("size")
def verify_size(cls, v):
@@ -52,7 +53,27 @@ def dir_name(self) -> str:
def file_count(self) -> int:
return self._size_label.get(self.size).file_count
+class CustomDataset(BaseDataset):
+ dir: str
+ file_num: int
+ isCustom: bool = True
+
+ @validator("size")
+ def verify_size(cls, v):
+ return v
+
+ @property
+ def label(self) -> str:
+ return "Custom"
+ @property
+ def dir_name(self) -> str:
+ return self.dir
+
+ @property
+ def file_count(self) -> int:
+ return self.file_num
+
class LAION(BaseDataset):
name: str = "LAION"
dim: int = 768
@@ -186,11 +207,12 @@ def prepare(self,
gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
all_files.extend([gt_file, test_file])
- source.reader().read(
- dataset=self.data.dir_name.lower(),
- files=all_files,
- local_ds_root=self.data_dir,
- )
+ if not self.data.isCustom:
+ source.reader().read(
+ dataset=self.data.dir_name.lower(),
+ files=all_files,
+ local_ds_root=self.data_dir,
+ )
if gt_file is not None and test_file is not None:
self.test_data = self._read_file(test_file)
diff --git a/vectordb_bench/custom/custom_case.json b/vectordb_bench/custom/custom_case.json
new file mode 100644
index 000000000..48ca8d8c4
--- /dev/null
+++ b/vectordb_bench/custom/custom_case.json
@@ -0,0 +1,18 @@
+[
+ {
+ "name": "My Dataset (Performace Case)",
+ "description": "this is a customized dataset.",
+ "load_timeout": 36000,
+ "optimize_timeout": 36000,
+ "dataset_config": {
+ "name": "My Dataset",
+ "dir": "/my_dataset_path",
+ "size": 1000000,
+ "dim": 1024,
+ "metric_type": "L2",
+ "file_count": 1,
+ "use_shuffled": false,
+ "with_gt": true
+ }
+ }
+]
\ No newline at end of file
diff --git a/vectordb_bench/frontend/components/check_results/charts.py b/vectordb_bench/frontend/components/check_results/charts.py
index 7e28d1e66..c2b2813b8 100644
--- a/vectordb_bench/frontend/components/check_results/charts.py
+++ b/vectordb_bench/frontend/components/check_results/charts.py
@@ -1,19 +1,19 @@
from vectordb_bench.backend.cases import Case
from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle
from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.models import ResultLabel
import plotly.express as px
-def drawCharts(st, allData, failedTasks, cases: list[Case]):
+def drawCharts(st, allData, failedTasks, caseNames: list[str]):
initMainExpanderStyle(st)
- for case in cases:
- chartContainer = st.expander(case.name, True)
- data = [data for data in allData if data["case_name"] == case.name]
+ for caseName in caseNames:
+ chartContainer = st.expander(caseName, True)
+ data = [data for data in allData if data["case_name"] == caseName]
drawChart(data, chartContainer)
- errorDBs = failedTasks[case.name]
+ errorDBs = failedTasks[caseName]
showFailedDBs(chartContainer, errorDBs)
diff --git a/vectordb_bench/frontend/components/check_results/data.py b/vectordb_bench/frontend/components/check_results/data.py
index 10fa3f459..c8b10ee87 100644
--- a/vectordb_bench/frontend/components/check_results/data.py
+++ b/vectordb_bench/frontend/components/check_results/data.py
@@ -8,9 +8,9 @@
def getChartData(
tasks: list[CaseResult],
dbNames: list[str],
- cases: list[Case],
+ caseNames: list[str],
):
- filterTasks = getFilterTasks(tasks, dbNames, cases)
+ filterTasks = getFilterTasks(tasks, dbNames, caseNames)
mergedTasks, failedTasks = mergeTasks(filterTasks)
return mergedTasks, failedTasks
@@ -18,14 +18,13 @@ def getChartData(
def getFilterTasks(
tasks: list[CaseResult],
dbNames: list[str],
- cases: list[Case],
+ caseNames: list[str],
) -> list[CaseResult]:
- case_ids = [case.case_id for case in cases]
filterTasks = [
task
for task in tasks
if task.task_config.db_name in dbNames
- and task.task_config.case_config.case_id in case_ids
+ and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
]
return filterTasks
@@ -36,16 +35,17 @@ def mergeTasks(tasks: list[CaseResult]):
db_name = task.task_config.db_name
db = task.task_config.db.value
db_label = task.task_config.db_config.db_label or ""
- case_id = task.task_config.case_config.case_id
- dbCaseMetricsMap[db_name][case_id] = {
+ case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
+ dbCaseMetricsMap[db_name][case.name] = {
"db": db,
"db_label": db_label,
"metrics": mergeMetrics(
- dbCaseMetricsMap[db_name][case_id].get("metrics", {}),
+ dbCaseMetricsMap[db_name][case.name].get("metrics", {}),
asdict(task.metrics),
),
"label": getBetterLabel(
- dbCaseMetricsMap[db_name][case_id].get("label", ResultLabel.FAILED),
+ dbCaseMetricsMap[db_name][case.name].get(
+ "label", ResultLabel.FAILED),
task.label,
),
}
@@ -53,12 +53,11 @@ def mergeTasks(tasks: list[CaseResult]):
mergedTasks = []
failedTasks = defaultdict(lambda: defaultdict(str))
for db_name, caseMetricsMap in dbCaseMetricsMap.items():
- for case_id, metricInfo in caseMetricsMap.items():
+ for case_name, metricInfo in caseMetricsMap.items():
metrics = metricInfo["metrics"]
db = metricInfo["db"]
db_label = metricInfo["db_label"]
label = metricInfo["label"]
- case_name = case_id.case_name
if label == ResultLabel.NORMAL:
mergedTasks.append(
{
@@ -80,7 +79,8 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
metrics = {**metrics_1}
for key, value in metrics_2.items():
metrics[key] = (
- getBetterMetric(key, value, metrics[key]) if key in metrics else value
+ getBetterMetric(
+ key, value, metrics[key]) if key in metrics else value
)
return metrics
diff --git a/vectordb_bench/frontend/components/check_results/expanderStyle.py b/vectordb_bench/frontend/components/check_results/expanderStyle.py
index 9496313e8..436eeec38 100644
--- a/vectordb_bench/frontend/components/check_results/expanderStyle.py
+++ b/vectordb_bench/frontend/components/check_results/expanderStyle.py
@@ -1,7 +1,7 @@
def initMainExpanderStyle(st):
st.markdown(
"""""",
+ unsafe_allow_html=True,
+ )
\ No newline at end of file
diff --git a/vectordb_bench/frontend/components/run_test/autoRefresh.py b/vectordb_bench/frontend/components/run_test/autoRefresh.py
index fe31d8205..034ab5017 100644
--- a/vectordb_bench/frontend/components/run_test/autoRefresh.py
+++ b/vectordb_bench/frontend/components/run_test/autoRefresh.py
@@ -1,5 +1,5 @@
from streamlit_autorefresh import st_autorefresh
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
def autoRefresh():
diff --git a/vectordb_bench/frontend/components/run_test/caseSelector.py b/vectordb_bench/frontend/components/run_test/caseSelector.py
index 49b839163..58799deff 100644
--- a/vectordb_bench/frontend/components/run_test/caseSelector.py
+++ b/vectordb_bench/frontend/components/run_test/caseSelector.py
@@ -1,9 +1,13 @@
-from vectordb_bench.frontend.const.styles import *
+
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.backend.cases import CaseType
-from vectordb_bench.frontend.const.dbCaseConfigs import *
+from vectordb_bench.frontend.config.dbCaseConfigs import *
+from collections import defaultdict
+
+from vectordb_bench.frontend.utils import addHorizontalLine
-def caseSelector(st, activedDbList):
+def caseSelector(st, activedDbList: list[DB]):
st.markdown(
"
",
unsafe_allow_html=True,
@@ -14,41 +18,49 @@ def caseSelector(st, activedDbList):
unsafe_allow_html=True,
)
- caseIsActived = {case: False for case in CASE_LIST}
- allCaseConfigs = {db: {case: {} for case in CASE_LIST} for db in DB_LIST}
- for caseOrDivider in CASE_LIST_WITH_DIVIDER:
- if caseOrDivider == DIVIDER:
- caseItemContainer.markdown(
- "",
- unsafe_allow_html=True,
- )
+ activedCaseList: list[CaseConfig] = []
+ dbToCaseClusterConfigs = defaultdict(lambda: defaultdict(dict))
+ dbToCaseConfigs = defaultdict(lambda: defaultdict(dict))
+ caseClusters = UI_CASE_CLUSTERS + [get_custom_case_cluter()]
+ for caseCluster in caseClusters:
+ activedCaseList += caseClusterExpander(
+ st, caseCluster, dbToCaseClusterConfigs, activedDbList)
+ for db in dbToCaseClusterConfigs:
+ for uiCaseItem in dbToCaseClusterConfigs[db]:
+ for case in uiCaseItem.cases:
+ dbToCaseConfigs[db][case] = dbToCaseClusterConfigs[db][uiCaseItem]
+
+ return activedCaseList, dbToCaseConfigs
+
+
+def caseClusterExpander(st, caseCluster: UICaseItemCluster, dbToCaseClusterConfigs, activedDbList: list[DB]):
+ expander = st.expander(caseCluster.label, False)
+ activedCases: list[CaseConfig] = []
+ for uiCaseItem in caseCluster.uiCaseItems:
+ if uiCaseItem.isLine:
+ addHorizontalLine(expander)
else:
- case = caseOrDivider
- caseItemContainer = st.container()
- caseIsActived[case] = caseItem(
- caseItemContainer, allCaseConfigs, case, activedDbList
- )
- activedCaseList = [case for case in CASE_LIST if caseIsActived[case]]
- return activedCaseList, allCaseConfigs
+ activedCases += caseItemCheckbox(expander,
+ dbToCaseClusterConfigs, uiCaseItem, activedDbList)
+ return activedCases
-def caseItem(st, allCaseConfigs, case: CaseType, activedDbList):
- selected = st.checkbox(case.case_name)
+def caseItemCheckbox(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
+ selected = st.checkbox(uiCaseItem.label)
st.markdown(
- f"{case.case_description}
",
+ f"{uiCaseItem.description}
",
unsafe_allow_html=True,
)
if selected:
- caseConfigSettingContainer = st.container()
caseConfigSetting(
- caseConfigSettingContainer, allCaseConfigs, case, activedDbList
+ st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList
)
- return selected
+ return uiCaseItem.cases if selected else []
-def caseConfigSetting(st, allCaseConfigs, case, activedDbList):
+def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]):
for db in activedDbList:
columns = st.columns(1 + CASE_CONFIG_SETTING_COLUMNS)
# column 0 - title
@@ -57,12 +69,12 @@ def caseConfigSetting(st, allCaseConfigs, case, activedDbList):
f"{db.name}
",
unsafe_allow_html=True,
)
- caseConfig = allCaseConfigs[db][case]
k = 0
- for config in CASE_CONFIG_MAP.get(db, {}).get(case.case_cls().label, []):
+ caseConfig = dbToCaseClusterConfigs[db][uiCaseItem]
+ for config in CASE_CONFIG_MAP.get(db, {}).get(uiCaseItem.caseLabel, []):
if config.isDisplayed(caseConfig):
column = columns[1 + k % CASE_CONFIG_SETTING_COLUMNS]
- key = "%s-%s-%s" % (db, case, config.label.value)
+ key = "%s-%s-%s" % (db, uiCaseItem.label, config.label.value)
if config.inputType == InputType.Text:
caseConfig[config.label] = column.text_input(
config.displayLabel if config.displayLabel else config.label.value,
diff --git a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py
index ffd52721f..8f4f35c93 100644
--- a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py
+++ b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py
@@ -1,13 +1,9 @@
from pydantic import ValidationError
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.frontend.utils import inputIsPassword
def dbConfigSettings(st, activedDbList):
- st.markdown(
- "",
- unsafe_allow_html=True,
- )
expander = st.expander("Configurations for the selected databases", True)
dbConfigs = {}
diff --git a/vectordb_bench/frontend/components/run_test/dbSelector.py b/vectordb_bench/frontend/components/run_test/dbSelector.py
index 61db843f3..ad2f57a0f 100644
--- a/vectordb_bench/frontend/components/run_test/dbSelector.py
+++ b/vectordb_bench/frontend/components/run_test/dbSelector.py
@@ -1,5 +1,5 @@
-from vectordb_bench.frontend.const.styles import *
-from vectordb_bench.frontend.const.dbCaseConfigs import DB_LIST
+from vectordb_bench.frontend.config.styles import *
+from vectordb_bench.frontend.config.dbCaseConfigs import DB_LIST
def dbSelector(st):
@@ -16,17 +16,6 @@ def dbSelector(st):
dbContainerColumns = st.columns(DB_SELECTOR_COLUMNS, gap="small")
dbIsActived = {db: False for db in DB_LIST}
- # style - image; column gap; checkbox font;
- st.markdown(
- """
-
- """,
- unsafe_allow_html=True,
- )
for i, db in enumerate(DB_LIST):
column = dbContainerColumns[i % DB_SELECTOR_COLUMNS]
dbIsActived[db] = column.checkbox(db.name)
diff --git a/vectordb_bench/frontend/components/run_test/generateTasks.py b/vectordb_bench/frontend/components/run_test/generateTasks.py
index 55f3c8399..828913f30 100644
--- a/vectordb_bench/frontend/components/run_test/generateTasks.py
+++ b/vectordb_bench/frontend/components/run_test/generateTasks.py
@@ -1,17 +1,15 @@
+from vectordb_bench.backend.clients import DB
from vectordb_bench.models import CaseConfig, CaseConfigParamType, TaskConfig
-def generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs):
+def generate_tasks(activedDbList: list[DB], dbConfigs, activedCaseList: list[CaseConfig], allCaseConfigs):
tasks = []
for db in activedDbList:
for case in activedCaseList:
task = TaskConfig(
db=db.value,
db_config=dbConfigs[db],
- case_config=CaseConfig(
- case_id=case.value,
- custom_case={},
- ),
+ case_config=case,
db_case_config=db.case_config_cls(
allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None)
)(**{key.value: value for key, value in allCaseConfigs[db][case].items()}),
diff --git a/vectordb_bench/frontend/components/run_test/initStyle.py b/vectordb_bench/frontend/components/run_test/initStyle.py
new file mode 100644
index 000000000..59dd438e1
--- /dev/null
+++ b/vectordb_bench/frontend/components/run_test/initStyle.py
@@ -0,0 +1,14 @@
+def initStyle(st):
+ st.markdown(
+ """""",
+ unsafe_allow_html=True,
+ )
\ No newline at end of file
diff --git a/vectordb_bench/frontend/components/run_test/submitTask.py b/vectordb_bench/frontend/components/run_test/submitTask.py
index 22d34b4f5..acf4b2c67 100644
--- a/vectordb_bench/frontend/components/run_test/submitTask.py
+++ b/vectordb_bench/frontend/components/run_test/submitTask.py
@@ -1,5 +1,5 @@
from datetime import datetime
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.interface import benchMarkRunner
diff --git a/vectordb_bench/frontend/const/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py
similarity index 82%
rename from vectordb_bench/frontend/const/dbCaseConfigs.py
rename to vectordb_bench/frontend/config/dbCaseConfigs.py
index 1e69c57aa..136dbaa70 100644
--- a/vectordb_bench/frontend/const/dbCaseConfigs.py
+++ b/vectordb_bench/frontend/config/dbCaseConfigs.py
@@ -1,42 +1,134 @@
-from enum import IntEnum
+from enum import IntEnum, Enum
import typing
from pydantic import BaseModel
from vectordb_bench.backend.cases import CaseLabel, CaseType
from vectordb_bench.backend.clients import DB
from vectordb_bench.backend.clients.api import IndexType
+from vectordb_bench.frontend.components.custom.getCustomConfig import get_custom_configs
-from vectordb_bench.models import CaseConfigParamType
+from vectordb_bench.models import CaseConfig, CaseConfigParamType
MAX_STREAMLIT_INT = (1 << 53) - 1
DB_LIST = [d for d in DB]
-DIVIDER = "DIVIDER"
-CASE_LIST_WITH_DIVIDER = [
+
+class Delimiter(Enum):
+ Line = "line"
+
+
+class BatchCaseConfig(BaseModel):
+ label: str = ""
+ description: str = ""
+ cases: list[CaseConfig] = []
+
+
+class UICaseItem(BaseModel):
+ isLine: bool = False
+ label: str = ""
+ description: str = ""
+ cases: list[CaseConfig] = []
+ caseLabel: CaseLabel = CaseLabel.Performance
+
+ def __init__(self, isLine: bool = False, case_id: CaseType = None, custom_case: dict = {}, cases: list[CaseConfig] = [], label: str = "", description: str = "", caseLabel: CaseLabel = CaseLabel.Performance):
+ if isLine is True:
+ super().__init__(isLine=True)
+ elif case_id is not None and isinstance(case_id, CaseType):
+ c = case_id.case_cls(custom_case)
+ super().__init__(
+ label=c.name,
+ description=c.description,
+ cases=[CaseConfig(case_id=case_id, custom_case=custom_case)],
+ caseLabel=c.label,
+ )
+ else:
+ super().__init__(
+ label=label,
+ description=description,
+ cases=cases,
+ caseLabel=caseLabel,
+ )
+
+ def __hash__(self) -> int:
+ return hash(self.json())
+
+
+class UICaseItemCluster(BaseModel):
+ label: str = ""
+ uiCaseItems: list[UICaseItem] = []
+
+def get_custom_case_items() -> list[UICaseItem]:
+ custom_configs = get_custom_configs()
+ return [
+ UICaseItem(case_id=CaseType.PerformanceCustomDataset,
+ custom_case=custom_config.dict())
+ for custom_config in custom_configs
+ ]
+
+def get_custom_case_cluter() -> UICaseItemCluster:
+ return UICaseItemCluster(
+ label="Custom Search Performance Test",
+ uiCaseItems=get_custom_case_items()
+ )
+
+
+UI_CASE_CLUSTERS: list[UICaseItemCluster] = [
+ UICaseItemCluster(
+ label="Search Performance Test",
+ uiCaseItems=[
+ UICaseItem(case_id=CaseType.Performance768D100M),
+ UICaseItem(case_id=CaseType.Performance768D10M),
+ UICaseItem(case_id=CaseType.Performance768D1M),
+ UICaseItem(isLine=True),
+ UICaseItem(case_id=CaseType.Performance1536D5M),
+ UICaseItem(case_id=CaseType.Performance1536D500K),
+ ]
+ ),
+ UICaseItemCluster(
+ label="Filter Search Performance Test",
+ uiCaseItems=[
+ UICaseItem(case_id=CaseType.Performance768D10M1P),
+ UICaseItem(case_id=CaseType.Performance768D10M99P),
+ UICaseItem(case_id=CaseType.Performance768D1M1P),
+ UICaseItem(case_id=CaseType.Performance768D1M99P),
+ UICaseItem(isLine=True),
+ UICaseItem(case_id=CaseType.Performance1536D5M1P),
+ UICaseItem(case_id=CaseType.Performance1536D5M99P),
+ UICaseItem(case_id=CaseType.Performance1536D500K1P),
+ UICaseItem(case_id=CaseType.Performance1536D500K99P),
+ ]
+ ),
+ UICaseItemCluster(
+ label="Capacity Test",
+ uiCaseItems=[
+ UICaseItem(case_id=CaseType.CapacityDim960),
+ UICaseItem(case_id=CaseType.CapacityDim128),
+ ]
+ ),
+]
+
+# DIVIDER = "DIVIDER"
+DISPLAY_CASE_ORDER = [
CaseType.Performance768D100M,
CaseType.Performance768D10M,
CaseType.Performance768D1M,
- DIVIDER,
CaseType.Performance1536D5M,
CaseType.Performance1536D500K,
- DIVIDER,
CaseType.Performance768D10M1P,
CaseType.Performance768D1M1P,
- DIVIDER,
CaseType.Performance1536D5M1P,
CaseType.Performance1536D500K1P,
- DIVIDER,
CaseType.Performance768D10M99P,
CaseType.Performance768D1M99P,
- DIVIDER,
CaseType.Performance1536D5M99P,
CaseType.Performance1536D500K99P,
- DIVIDER,
CaseType.CapacityDim960,
CaseType.CapacityDim128,
]
+CASE_NAME_ORDER = [case.case_cls().name for case in DISPLAY_CASE_ORDER]
-CASE_LIST = [item for item in CASE_LIST_WITH_DIVIDER if isinstance(item, CaseType)]
+# CASE_LIST = [
+# item for item in CASE_LIST_WITH_DIVIDER if isinstance(item, CaseType)]
class InputType(IntEnum):
@@ -512,7 +604,8 @@ class CaseConfigInput(BaseModel):
inputConfig={
"options": ["x4", "x8", "x16", "x32", "x64"],
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None)
+ isDisplayed=lambda config: config.get(
+ CaseConfigParamType.quantizationType, None)
== "product" and config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
@@ -574,7 +667,8 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_EF_Weaviate,
]
-ESLoadingConfig = [CaseConfigParamInput_EFConstruction_ES, CaseConfigParamInput_M_ES]
+ESLoadingConfig = [CaseConfigParamInput_EFConstruction_ES,
+ CaseConfigParamInput_M_ES]
ESPerformanceConfig = [
CaseConfigParamInput_EFConstruction_ES,
CaseConfigParamInput_M_ES,
diff --git a/vectordb_bench/frontend/const/dbPrices.py b/vectordb_bench/frontend/config/dbPrices.py
similarity index 100%
rename from vectordb_bench/frontend/const/dbPrices.py
rename to vectordb_bench/frontend/config/dbPrices.py
diff --git a/vectordb_bench/frontend/const/styles.py b/vectordb_bench/frontend/config/styles.py
similarity index 100%
rename from vectordb_bench/frontend/const/styles.py
rename to vectordb_bench/frontend/config/styles.py
diff --git a/vectordb_bench/frontend/pages/custom.py b/vectordb_bench/frontend/pages/custom.py
new file mode 100644
index 000000000..28c249f78
--- /dev/null
+++ b/vectordb_bench/frontend/pages/custom.py
@@ -0,0 +1,64 @@
+import streamlit as st
+from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon
+from vectordb_bench.frontend.components.custom.displayCustomCase import displayCustomCase
+from vectordb_bench.frontend.components.custom.displaypPrams import displayParams
+from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig, generate_custom_case, get_custom_configs, save_custom_configs
+from vectordb_bench.frontend.components.custom.initStyle import initStyle
+from vectordb_bench.frontend.config.styles import FAVICON, PAGE_TITLE
+
+
+class CustomCaseManager():
+ customCaseItems: list[CustomCaseConfig]
+
+ def __init__(self):
+ self.customCaseItems = get_custom_configs()
+
+ def addCase(self):
+ new_custom_case = generate_custom_case()
+ new_custom_case.dataset_config.name = f"{new_custom_case.dataset_config.name} {len(self.customCaseItems)}"
+ self.customCaseItems += [new_custom_case]
+ self.save()
+
+ def deleteCase(self, idx: int):
+ self.customCaseItems.pop(idx)
+ self.save()
+
+ def save(self):
+ save_custom_configs(self.customCaseItems)
+
+
+def main():
+ st.set_page_config(
+ page_title=PAGE_TITLE,
+ page_icon=FAVICON,
+ # layout="wide",
+ # initial_sidebar_state="collapsed",
+ )
+
+ # header
+ drawHeaderIcon(st)
+
+ # init style
+ initStyle(st)
+
+ st.title("Custom Dataset")
+ displayParams(st)
+ customCaseManager = CustomCaseManager()
+
+ for idx, customCase in enumerate(customCaseManager.customCaseItems):
+ expander = st.expander(customCase.dataset_config.name, expanded=True)
+ key = f"custom_case_{idx}"
+ displayCustomCase(customCase, expander, key=key)
+
+ columns = expander.columns(8)
+ columns[0].button(
+ "Save", key=f"{key}_", type="secondary", on_click=lambda: customCaseManager.save())
+ columns[1].button(":red[Delete]", key=f"{key}_delete", type="secondary",
+ on_click=lambda: customCaseManager.deleteCase(idx))
+
+ st.button("\+ New Dataset", key=f"add_custom_configs",
+ type="primary", on_click=lambda: customCaseManager.addCase())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/vectordb_bench/frontend/pages/quries_per_dollar.py b/vectordb_bench/frontend/pages/quries_per_dollar.py
index 10c1ac8f1..0bb05294b 100644
--- a/vectordb_bench/frontend/pages/quries_per_dollar.py
+++ b/vectordb_bench/frontend/pages/quries_per_dollar.py
@@ -8,7 +8,7 @@
from vectordb_bench.frontend.components.check_results.charts import drawMetricChart
from vectordb_bench.frontend.components.check_results.filters import getshownData
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.interface import benchMarkRunner
from vectordb_bench.metric import QURIES_PER_DOLLAR_METRIC
@@ -26,7 +26,7 @@ def main():
# results selector
resultSelectorContainer = st.sidebar.container()
- shownData, _, showCases = getshownData(allResults, resultSelectorContainer)
+ shownData, _, showCaseNames = getshownData(allResults, resultSelectorContainer)
resultSelectorContainer.divider()
@@ -45,8 +45,8 @@ def main():
priceMap = priceTable(priceTableContainer, shownData)
# charts
- for case in showCases:
- data = [data for data in shownData if data["case_name"] == case.name]
+ for caseName in showCaseNames:
+ data = [data for data in shownData if data["case_name"] == caseName]
dataWithMetric = []
metric = QURIES_PER_DOLLAR_METRIC
for d in data:
@@ -56,7 +56,7 @@ def main():
d[metric] = d["qps"] / price * 3.6
dataWithMetric.append(d)
if len(dataWithMetric) > 0:
- chartContainer = st.expander(case.name, True)
+ chartContainer = st.expander(caseName, True)
drawMetricChart(data, metric, chartContainer)
# footer
diff --git a/vectordb_bench/frontend/pages/run_test.py b/vectordb_bench/frontend/pages/run_test.py
index 0712bb6cc..1297743ae 100644
--- a/vectordb_bench/frontend/pages/run_test.py
+++ b/vectordb_bench/frontend/pages/run_test.py
@@ -5,6 +5,7 @@
from vectordb_bench.frontend.components.run_test.dbSelector import dbSelector
from vectordb_bench.frontend.components.run_test.generateTasks import generate_tasks
from vectordb_bench.frontend.components.run_test.hideSidebar import hideSidebar
+from vectordb_bench.frontend.components.run_test.initStyle import initStyle
from vectordb_bench.frontend.components.run_test.submitTask import submitTask
from vectordb_bench.frontend.components.check_results.nav import NavToResults
from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon
@@ -15,6 +16,9 @@ def main():
# set page config
initRunTestPageConfig(st)
+ # init style
+ initStyle(st)
+
# header
drawHeaderIcon(st)
diff --git a/vectordb_bench/frontend/utils.py b/vectordb_bench/frontend/utils.py
index 139854af6..787b67d03 100644
--- a/vectordb_bench/frontend/utils.py
+++ b/vectordb_bench/frontend/utils.py
@@ -1,6 +1,22 @@
-from vectordb_bench.models import CaseType
+import random
+import string
+
passwordKeys = ["password", "api_key"]
+
+
def inputIsPassword(key: str) -> bool:
return key.lower() in passwordKeys
+
+def addHorizontalLine(st):
+ st.markdown(
+ "",
+ unsafe_allow_html=True,
+ )
+
+
+def generate_random_string(length):
+ letters = string.ascii_letters + string.digits
+ result = ''.join(random.choice(letters) for _ in range(length))
+ return result
diff --git a/vectordb_bench/frontend/vdb_benchmark.py b/vectordb_bench/frontend/vdb_benchmark.py
index 0be43470e..b859c68b8 100644
--- a/vectordb_bench/frontend/vdb_benchmark.py
+++ b/vectordb_bench/frontend/vdb_benchmark.py
@@ -6,7 +6,7 @@
from vectordb_bench.frontend.components.check_results.charts import drawCharts
from vectordb_bench.frontend.components.check_results.filters import getshownData
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults
-from vectordb_bench.frontend.const.styles import *
+from vectordb_bench.frontend.config.styles import *
from vectordb_bench.interface import benchMarkRunner
@@ -24,7 +24,7 @@ def main():
# results selector and filter
resultSelectorContainer = st.sidebar.container()
- shownData, failedTasks, showCases = getshownData(
+ shownData, failedTasks, showCaseNames = getshownData(
allResults, resultSelectorContainer
)
@@ -40,7 +40,7 @@ def main():
getResults(resultesContainer, "vectordb_bench")
# charts
- drawCharts(st, shownData, failedTasks, showCases)
+ drawCharts(st, shownData, failedTasks, showCaseNames)
# footer
footer(st.container())
diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py
index ec1b610e1..0141f2f13 100644
--- a/vectordb_bench/models.py
+++ b/vectordb_bench/models.py
@@ -74,6 +74,9 @@ class CaseConfig(BaseModel):
case_id: CaseType
custom_case: dict | None = None
+
+ def __hash__(self) -> int:
+ return hash(self.json())
class TaskConfig(BaseModel):
diff --git a/vectordb_bench/results/getLeaderboardData.py b/vectordb_bench/results/getLeaderboardData.py
index 50f458533..c6484514d 100644
--- a/vectordb_bench/results/getLeaderboardData.py
+++ b/vectordb_bench/results/getLeaderboardData.py
@@ -2,7 +2,7 @@
import ujson
import pathlib
from vectordb_bench.backend.cases import CaseType
-from vectordb_bench.frontend.const.dbPrices import DB_DBLABEL_TO_PRICE
+from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE
from vectordb_bench.interface import benchMarkRunner
from vectordb_bench.models import CaseResult, ResultLabel, TestResult