diff --git a/README.md b/README.md index 4ff2f9121..f093bfcfb 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,24 @@ Case No. | Case Type | Dataset Size | Filtering Rate | Results | Each case provides an in-depth examination of a vector database's abilities, providing you a comprehensive view of the database's performance. +#### Custom Dataset for Performance case + +Through the `/custom` page, users can customize their own performance case using local datasets. After saving, the corresponding case can be selected from the `/run_test` page to perform the test. + +![image](fig/custom_dataset.png) +![image](fig/custom_case_run_test.png) + +We have strict requirements for the data set format, please follow them. +- `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format. + - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. + - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. + - Ground truth file: The file must be named `neighbors.parquet` and should have two columns: `id` corresponding to query vectors and `neighbors_id` as an array of `int`. + +- `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files. + +- `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order. + + ## Goals Our goals of this benchmark are: ### Reproducibility & Usability diff --git a/fig/custom_case_run_test.png b/fig/custom_case_run_test.png new file mode 100644 index 000000000..8817b3439 Binary files /dev/null and b/fig/custom_case_run_test.png differ diff --git a/fig/custom_dataset.png b/fig/custom_dataset.png new file mode 100644 index 000000000..9d665891a Binary files /dev/null and b/fig/custom_dataset.png differ diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py index eca190832..c68ee854d 100644 --- a/vectordb_bench/__init__.py +++ b/vectordb_bench/__init__.py @@ -22,6 +22,7 @@ class config: NUM_CONCURRENCY = [1, 5, 10, 15, 20, 25, 30, 35] RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results") + CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json") CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h diff --git a/vectordb_bench/backend/assembler.py b/vectordb_bench/backend/assembler.py index 6b0e3c81d..e7da4d49f 100644 --- a/vectordb_bench/backend/assembler.py +++ b/vectordb_bench/backend/assembler.py @@ -14,7 +14,7 @@ class Assembler: def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner: c_cls = task.case_config.case_id.case_cls - c = c_cls() + c = c_cls(task.case_config.custom_case) if type(task.db_case_config) != EmptyDBCaseConfig: task.db_case_config.metric_type = c.dataset.data.metric_type diff --git a/vectordb_bench/backend/cases.py b/vectordb_bench/backend/cases.py index 7f40ae4f8..22e9dfbc7 100644 --- a/vectordb_bench/backend/cases.py +++ b/vectordb_bench/backend/cases.py @@ -3,9 +3,11 @@ from enum import Enum, auto from vectordb_bench import config +from vectordb_bench.backend.clients.api import MetricType from vectordb_bench.base import BaseModel +from vectordb_bench.frontend.components.custom.getCustomConfig import CustomDatasetConfig -from .dataset import Dataset, DatasetManager +from .dataset import CustomDataset, Dataset, DatasetManager log = logging.getLogger(__name__) @@ -43,23 +45,24 @@ class CaseType(Enum): Performance1536D5M99P = 15 Custom = 100 + PerformanceCustomDataset = 101 - @property def case_cls(self, custom_configs: dict | None = None) -> Case: - return type2case.get(self) + if custom_configs is None: + return type2case.get(self)() + else: + return type2case.get(self)(**custom_configs) - @property - def case_name(self) -> str: - c = self.case_cls + def case_name(self, custom_configs: dict | None = None) -> str: + c = self.case_cls(custom_configs) if c is not None: - return c().name + return c.name raise ValueError("Case unsupported") - @property - def case_description(self) -> str: - c = self.case_cls + def case_description(self, custom_configs: dict | None = None) -> str: + c = self.case_cls(custom_configs) if c is not None: - return c().description + return c.description raise ValueError("Case unsupported") @@ -115,6 +118,7 @@ class PerformanceCase(Case, BaseModel): load_timeout: float | int = config.LOAD_TIMEOUT_DEFAULT optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_DEFAULT + class CapacityDim960(CapacityCase): case_id: CaseType = CaseType.CapacityDim960 dataset: DatasetManager = Dataset.GIST.manager(100_000) @@ -238,6 +242,7 @@ class Performance1536D500K1P(PerformanceCase): load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K + class Performance1536D5M1P(PerformanceCase): case_id: CaseType = CaseType.Performance1536D5M1P filter_rate: float | int | None = 0.01 @@ -248,6 +253,7 @@ class Performance1536D5M1P(PerformanceCase): load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M + class Performance1536D500K99P(PerformanceCase): case_id: CaseType = CaseType.Performance1536D500K99P filter_rate: float | int | None = 0.99 @@ -258,6 +264,7 @@ class Performance1536D500K99P(PerformanceCase): load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K + class Performance1536D5M99P(PerformanceCase): case_id: CaseType = CaseType.Performance1536D5M99P filter_rate: float | int | None = 0.99 @@ -269,6 +276,40 @@ class Performance1536D5M99P(PerformanceCase): optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M +def metric_type_map(s: str) -> MetricType: + if s.lower() == 'cosine': + return MetricType.COSINE + if s.lower() == 'l2' or s.lower() == "euclidean": + return MetricType.L2 + if s.lower() == 'ip': + return MetricType.IP + err_msg = f"Not support metric_type: {s}" + log.error(err_msg) + raise RuntimeError(err_msg) + + +class PerformanceCustomDataset(PerformanceCase): + case_id: CaseType = CaseType.PerformanceCustomDataset + name: str = "Performance With Custom Dataset" + description: str = "" + dataset: DatasetManager + + def __init__(self, name, description, load_timeout, optimize_timeout, dataset_config, **kwargs): + dataset_config = CustomDatasetConfig(**dataset_config) + dataset = CustomDataset( + name=dataset_config.name, + size=dataset_config.size, + dim=dataset_config.dim, + metric_type=metric_type_map(dataset_config.metric_type), + use_shuffled=dataset_config.use_shuffled, + with_gt=dataset_config.with_gt, + dir=dataset_config.dir, + file_num=dataset_config.file_count, + ) + super().__init__(name=name, description=description, load_timeout=load_timeout, + optimize_timeout=optimize_timeout, dataset=DatasetManager(data=dataset)) + + type2case = { CaseType.CapacityDim960: CapacityDim960, CaseType.CapacityDim128: CapacityDim128, @@ -291,4 +332,5 @@ class Performance1536D5M99P(PerformanceCase): CaseType.Performance1536D500K99P: Performance1536D500K99P, CaseType.Performance1536D5M99P: Performance1536D5M99P, + CaseType.PerformanceCustomDataset: PerformanceCustomDataset, } diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py index 2b630eae3..d559eb6be 100644 --- a/vectordb_bench/backend/dataset.py +++ b/vectordb_bench/backend/dataset.py @@ -33,6 +33,7 @@ class BaseDataset(BaseModel): use_shuffled: bool with_gt: bool = False _size_label: dict[int, SizeLabel] = PrivateAttr() + isCustom: bool = False @validator("size") def verify_size(cls, v): @@ -52,7 +53,27 @@ def dir_name(self) -> str: def file_count(self) -> int: return self._size_label.get(self.size).file_count +class CustomDataset(BaseDataset): + dir: str + file_num: int + isCustom: bool = True + + @validator("size") + def verify_size(cls, v): + return v + + @property + def label(self) -> str: + return "Custom" + @property + def dir_name(self) -> str: + return self.dir + + @property + def file_count(self) -> int: + return self.file_num + class LAION(BaseDataset): name: str = "LAION" dim: int = 768 @@ -186,11 +207,12 @@ def prepare(self, gt_file, test_file = utils.compose_gt_file(filters), "test.parquet" all_files.extend([gt_file, test_file]) - source.reader().read( - dataset=self.data.dir_name.lower(), - files=all_files, - local_ds_root=self.data_dir, - ) + if not self.data.isCustom: + source.reader().read( + dataset=self.data.dir_name.lower(), + files=all_files, + local_ds_root=self.data_dir, + ) if gt_file is not None and test_file is not None: self.test_data = self._read_file(test_file) diff --git a/vectordb_bench/custom/custom_case.json b/vectordb_bench/custom/custom_case.json new file mode 100644 index 000000000..48ca8d8c4 --- /dev/null +++ b/vectordb_bench/custom/custom_case.json @@ -0,0 +1,18 @@ +[ + { + "name": "My Dataset (Performace Case)", + "description": "this is a customized dataset.", + "load_timeout": 36000, + "optimize_timeout": 36000, + "dataset_config": { + "name": "My Dataset", + "dir": "/my_dataset_path", + "size": 1000000, + "dim": 1024, + "metric_type": "L2", + "file_count": 1, + "use_shuffled": false, + "with_gt": true + } + } +] \ No newline at end of file diff --git a/vectordb_bench/frontend/components/check_results/charts.py b/vectordb_bench/frontend/components/check_results/charts.py index 7e28d1e66..c2b2813b8 100644 --- a/vectordb_bench/frontend/components/check_results/charts.py +++ b/vectordb_bench/frontend/components/check_results/charts.py @@ -1,19 +1,19 @@ from vectordb_bench.backend.cases import Case from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap -from vectordb_bench.frontend.const.styles import * +from vectordb_bench.frontend.config.styles import * from vectordb_bench.models import ResultLabel import plotly.express as px -def drawCharts(st, allData, failedTasks, cases: list[Case]): +def drawCharts(st, allData, failedTasks, caseNames: list[str]): initMainExpanderStyle(st) - for case in cases: - chartContainer = st.expander(case.name, True) - data = [data for data in allData if data["case_name"] == case.name] + for caseName in caseNames: + chartContainer = st.expander(caseName, True) + data = [data for data in allData if data["case_name"] == caseName] drawChart(data, chartContainer) - errorDBs = failedTasks[case.name] + errorDBs = failedTasks[caseName] showFailedDBs(chartContainer, errorDBs) diff --git a/vectordb_bench/frontend/components/check_results/data.py b/vectordb_bench/frontend/components/check_results/data.py index 10fa3f459..c8b10ee87 100644 --- a/vectordb_bench/frontend/components/check_results/data.py +++ b/vectordb_bench/frontend/components/check_results/data.py @@ -8,9 +8,9 @@ def getChartData( tasks: list[CaseResult], dbNames: list[str], - cases: list[Case], + caseNames: list[str], ): - filterTasks = getFilterTasks(tasks, dbNames, cases) + filterTasks = getFilterTasks(tasks, dbNames, caseNames) mergedTasks, failedTasks = mergeTasks(filterTasks) return mergedTasks, failedTasks @@ -18,14 +18,13 @@ def getChartData( def getFilterTasks( tasks: list[CaseResult], dbNames: list[str], - cases: list[Case], + caseNames: list[str], ) -> list[CaseResult]: - case_ids = [case.case_id for case in cases] filterTasks = [ task for task in tasks if task.task_config.db_name in dbNames - and task.task_config.case_config.case_id in case_ids + and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames ] return filterTasks @@ -36,16 +35,17 @@ def mergeTasks(tasks: list[CaseResult]): db_name = task.task_config.db_name db = task.task_config.db.value db_label = task.task_config.db_config.db_label or "" - case_id = task.task_config.case_config.case_id - dbCaseMetricsMap[db_name][case_id] = { + case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case) + dbCaseMetricsMap[db_name][case.name] = { "db": db, "db_label": db_label, "metrics": mergeMetrics( - dbCaseMetricsMap[db_name][case_id].get("metrics", {}), + dbCaseMetricsMap[db_name][case.name].get("metrics", {}), asdict(task.metrics), ), "label": getBetterLabel( - dbCaseMetricsMap[db_name][case_id].get("label", ResultLabel.FAILED), + dbCaseMetricsMap[db_name][case.name].get( + "label", ResultLabel.FAILED), task.label, ), } @@ -53,12 +53,11 @@ def mergeTasks(tasks: list[CaseResult]): mergedTasks = [] failedTasks = defaultdict(lambda: defaultdict(str)) for db_name, caseMetricsMap in dbCaseMetricsMap.items(): - for case_id, metricInfo in caseMetricsMap.items(): + for case_name, metricInfo in caseMetricsMap.items(): metrics = metricInfo["metrics"] db = metricInfo["db"] db_label = metricInfo["db_label"] label = metricInfo["label"] - case_name = case_id.case_name if label == ResultLabel.NORMAL: mergedTasks.append( { @@ -80,7 +79,8 @@ def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict: metrics = {**metrics_1} for key, value in metrics_2.items(): metrics[key] = ( - getBetterMetric(key, value, metrics[key]) if key in metrics else value + getBetterMetric( + key, value, metrics[key]) if key in metrics else value ) return metrics diff --git a/vectordb_bench/frontend/components/check_results/expanderStyle.py b/vectordb_bench/frontend/components/check_results/expanderStyle.py index 9496313e8..436eeec38 100644 --- a/vectordb_bench/frontend/components/check_results/expanderStyle.py +++ b/vectordb_bench/frontend/components/check_results/expanderStyle.py @@ -1,7 +1,7 @@ def initMainExpanderStyle(st): st.markdown( """""", + unsafe_allow_html=True, + ) \ No newline at end of file diff --git a/vectordb_bench/frontend/components/run_test/autoRefresh.py b/vectordb_bench/frontend/components/run_test/autoRefresh.py index fe31d8205..034ab5017 100644 --- a/vectordb_bench/frontend/components/run_test/autoRefresh.py +++ b/vectordb_bench/frontend/components/run_test/autoRefresh.py @@ -1,5 +1,5 @@ from streamlit_autorefresh import st_autorefresh -from vectordb_bench.frontend.const.styles import * +from vectordb_bench.frontend.config.styles import * def autoRefresh(): diff --git a/vectordb_bench/frontend/components/run_test/caseSelector.py b/vectordb_bench/frontend/components/run_test/caseSelector.py index 49b839163..58799deff 100644 --- a/vectordb_bench/frontend/components/run_test/caseSelector.py +++ b/vectordb_bench/frontend/components/run_test/caseSelector.py @@ -1,9 +1,13 @@ -from vectordb_bench.frontend.const.styles import * + +from vectordb_bench.frontend.config.styles import * from vectordb_bench.backend.cases import CaseType -from vectordb_bench.frontend.const.dbCaseConfigs import * +from vectordb_bench.frontend.config.dbCaseConfigs import * +from collections import defaultdict + +from vectordb_bench.frontend.utils import addHorizontalLine -def caseSelector(st, activedDbList): +def caseSelector(st, activedDbList: list[DB]): st.markdown( "
", unsafe_allow_html=True, @@ -14,41 +18,49 @@ def caseSelector(st, activedDbList): unsafe_allow_html=True, ) - caseIsActived = {case: False for case in CASE_LIST} - allCaseConfigs = {db: {case: {} for case in CASE_LIST} for db in DB_LIST} - for caseOrDivider in CASE_LIST_WITH_DIVIDER: - if caseOrDivider == DIVIDER: - caseItemContainer.markdown( - "
", - unsafe_allow_html=True, - ) + activedCaseList: list[CaseConfig] = [] + dbToCaseClusterConfigs = defaultdict(lambda: defaultdict(dict)) + dbToCaseConfigs = defaultdict(lambda: defaultdict(dict)) + caseClusters = UI_CASE_CLUSTERS + [get_custom_case_cluter()] + for caseCluster in caseClusters: + activedCaseList += caseClusterExpander( + st, caseCluster, dbToCaseClusterConfigs, activedDbList) + for db in dbToCaseClusterConfigs: + for uiCaseItem in dbToCaseClusterConfigs[db]: + for case in uiCaseItem.cases: + dbToCaseConfigs[db][case] = dbToCaseClusterConfigs[db][uiCaseItem] + + return activedCaseList, dbToCaseConfigs + + +def caseClusterExpander(st, caseCluster: UICaseItemCluster, dbToCaseClusterConfigs, activedDbList: list[DB]): + expander = st.expander(caseCluster.label, False) + activedCases: list[CaseConfig] = [] + for uiCaseItem in caseCluster.uiCaseItems: + if uiCaseItem.isLine: + addHorizontalLine(expander) else: - case = caseOrDivider - caseItemContainer = st.container() - caseIsActived[case] = caseItem( - caseItemContainer, allCaseConfigs, case, activedDbList - ) - activedCaseList = [case for case in CASE_LIST if caseIsActived[case]] - return activedCaseList, allCaseConfigs + activedCases += caseItemCheckbox(expander, + dbToCaseClusterConfigs, uiCaseItem, activedDbList) + return activedCases -def caseItem(st, allCaseConfigs, case: CaseType, activedDbList): - selected = st.checkbox(case.case_name) +def caseItemCheckbox(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]): + selected = st.checkbox(uiCaseItem.label) st.markdown( - f"
{case.case_description}
", + f"
{uiCaseItem.description}
", unsafe_allow_html=True, ) if selected: - caseConfigSettingContainer = st.container() caseConfigSetting( - caseConfigSettingContainer, allCaseConfigs, case, activedDbList + st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList ) - return selected + return uiCaseItem.cases if selected else [] -def caseConfigSetting(st, allCaseConfigs, case, activedDbList): +def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, activedDbList: list[DB]): for db in activedDbList: columns = st.columns(1 + CASE_CONFIG_SETTING_COLUMNS) # column 0 - title @@ -57,12 +69,12 @@ def caseConfigSetting(st, allCaseConfigs, case, activedDbList): f"
{db.name}
", unsafe_allow_html=True, ) - caseConfig = allCaseConfigs[db][case] k = 0 - for config in CASE_CONFIG_MAP.get(db, {}).get(case.case_cls().label, []): + caseConfig = dbToCaseClusterConfigs[db][uiCaseItem] + for config in CASE_CONFIG_MAP.get(db, {}).get(uiCaseItem.caseLabel, []): if config.isDisplayed(caseConfig): column = columns[1 + k % CASE_CONFIG_SETTING_COLUMNS] - key = "%s-%s-%s" % (db, case, config.label.value) + key = "%s-%s-%s" % (db, uiCaseItem.label, config.label.value) if config.inputType == InputType.Text: caseConfig[config.label] = column.text_input( config.displayLabel if config.displayLabel else config.label.value, diff --git a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py index ffd52721f..8f4f35c93 100644 --- a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py +++ b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py @@ -1,13 +1,9 @@ from pydantic import ValidationError -from vectordb_bench.frontend.const.styles import * +from vectordb_bench.frontend.config.styles import * from vectordb_bench.frontend.utils import inputIsPassword def dbConfigSettings(st, activedDbList): - st.markdown( - "", - unsafe_allow_html=True, - ) expander = st.expander("Configurations for the selected databases", True) dbConfigs = {} diff --git a/vectordb_bench/frontend/components/run_test/dbSelector.py b/vectordb_bench/frontend/components/run_test/dbSelector.py index 61db843f3..ad2f57a0f 100644 --- a/vectordb_bench/frontend/components/run_test/dbSelector.py +++ b/vectordb_bench/frontend/components/run_test/dbSelector.py @@ -1,5 +1,5 @@ -from vectordb_bench.frontend.const.styles import * -from vectordb_bench.frontend.const.dbCaseConfigs import DB_LIST +from vectordb_bench.frontend.config.styles import * +from vectordb_bench.frontend.config.dbCaseConfigs import DB_LIST def dbSelector(st): @@ -16,17 +16,6 @@ def dbSelector(st): dbContainerColumns = st.columns(DB_SELECTOR_COLUMNS, gap="small") dbIsActived = {db: False for db in DB_LIST} - # style - image; column gap; checkbox font; - st.markdown( - """ - - """, - unsafe_allow_html=True, - ) for i, db in enumerate(DB_LIST): column = dbContainerColumns[i % DB_SELECTOR_COLUMNS] dbIsActived[db] = column.checkbox(db.name) diff --git a/vectordb_bench/frontend/components/run_test/generateTasks.py b/vectordb_bench/frontend/components/run_test/generateTasks.py index 55f3c8399..828913f30 100644 --- a/vectordb_bench/frontend/components/run_test/generateTasks.py +++ b/vectordb_bench/frontend/components/run_test/generateTasks.py @@ -1,17 +1,15 @@ +from vectordb_bench.backend.clients import DB from vectordb_bench.models import CaseConfig, CaseConfigParamType, TaskConfig -def generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs): +def generate_tasks(activedDbList: list[DB], dbConfigs, activedCaseList: list[CaseConfig], allCaseConfigs): tasks = [] for db in activedDbList: for case in activedCaseList: task = TaskConfig( db=db.value, db_config=dbConfigs[db], - case_config=CaseConfig( - case_id=case.value, - custom_case={}, - ), + case_config=case, db_case_config=db.case_config_cls( allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None) )(**{key.value: value for key, value in allCaseConfigs[db][case].items()}), diff --git a/vectordb_bench/frontend/components/run_test/initStyle.py b/vectordb_bench/frontend/components/run_test/initStyle.py new file mode 100644 index 000000000..59dd438e1 --- /dev/null +++ b/vectordb_bench/frontend/components/run_test/initStyle.py @@ -0,0 +1,14 @@ +def initStyle(st): + st.markdown( + """""", + unsafe_allow_html=True, + ) \ No newline at end of file diff --git a/vectordb_bench/frontend/components/run_test/submitTask.py b/vectordb_bench/frontend/components/run_test/submitTask.py index 22d34b4f5..acf4b2c67 100644 --- a/vectordb_bench/frontend/components/run_test/submitTask.py +++ b/vectordb_bench/frontend/components/run_test/submitTask.py @@ -1,5 +1,5 @@ from datetime import datetime -from vectordb_bench.frontend.const.styles import * +from vectordb_bench.frontend.config.styles import * from vectordb_bench.interface import benchMarkRunner diff --git a/vectordb_bench/frontend/const/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py similarity index 82% rename from vectordb_bench/frontend/const/dbCaseConfigs.py rename to vectordb_bench/frontend/config/dbCaseConfigs.py index 1e69c57aa..136dbaa70 100644 --- a/vectordb_bench/frontend/const/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -1,42 +1,134 @@ -from enum import IntEnum +from enum import IntEnum, Enum import typing from pydantic import BaseModel from vectordb_bench.backend.cases import CaseLabel, CaseType from vectordb_bench.backend.clients import DB from vectordb_bench.backend.clients.api import IndexType +from vectordb_bench.frontend.components.custom.getCustomConfig import get_custom_configs -from vectordb_bench.models import CaseConfigParamType +from vectordb_bench.models import CaseConfig, CaseConfigParamType MAX_STREAMLIT_INT = (1 << 53) - 1 DB_LIST = [d for d in DB] -DIVIDER = "DIVIDER" -CASE_LIST_WITH_DIVIDER = [ + +class Delimiter(Enum): + Line = "line" + + +class BatchCaseConfig(BaseModel): + label: str = "" + description: str = "" + cases: list[CaseConfig] = [] + + +class UICaseItem(BaseModel): + isLine: bool = False + label: str = "" + description: str = "" + cases: list[CaseConfig] = [] + caseLabel: CaseLabel = CaseLabel.Performance + + def __init__(self, isLine: bool = False, case_id: CaseType = None, custom_case: dict = {}, cases: list[CaseConfig] = [], label: str = "", description: str = "", caseLabel: CaseLabel = CaseLabel.Performance): + if isLine is True: + super().__init__(isLine=True) + elif case_id is not None and isinstance(case_id, CaseType): + c = case_id.case_cls(custom_case) + super().__init__( + label=c.name, + description=c.description, + cases=[CaseConfig(case_id=case_id, custom_case=custom_case)], + caseLabel=c.label, + ) + else: + super().__init__( + label=label, + description=description, + cases=cases, + caseLabel=caseLabel, + ) + + def __hash__(self) -> int: + return hash(self.json()) + + +class UICaseItemCluster(BaseModel): + label: str = "" + uiCaseItems: list[UICaseItem] = [] + +def get_custom_case_items() -> list[UICaseItem]: + custom_configs = get_custom_configs() + return [ + UICaseItem(case_id=CaseType.PerformanceCustomDataset, + custom_case=custom_config.dict()) + for custom_config in custom_configs + ] + +def get_custom_case_cluter() -> UICaseItemCluster: + return UICaseItemCluster( + label="Custom Search Performance Test", + uiCaseItems=get_custom_case_items() + ) + + +UI_CASE_CLUSTERS: list[UICaseItemCluster] = [ + UICaseItemCluster( + label="Search Performance Test", + uiCaseItems=[ + UICaseItem(case_id=CaseType.Performance768D100M), + UICaseItem(case_id=CaseType.Performance768D10M), + UICaseItem(case_id=CaseType.Performance768D1M), + UICaseItem(isLine=True), + UICaseItem(case_id=CaseType.Performance1536D5M), + UICaseItem(case_id=CaseType.Performance1536D500K), + ] + ), + UICaseItemCluster( + label="Filter Search Performance Test", + uiCaseItems=[ + UICaseItem(case_id=CaseType.Performance768D10M1P), + UICaseItem(case_id=CaseType.Performance768D10M99P), + UICaseItem(case_id=CaseType.Performance768D1M1P), + UICaseItem(case_id=CaseType.Performance768D1M99P), + UICaseItem(isLine=True), + UICaseItem(case_id=CaseType.Performance1536D5M1P), + UICaseItem(case_id=CaseType.Performance1536D5M99P), + UICaseItem(case_id=CaseType.Performance1536D500K1P), + UICaseItem(case_id=CaseType.Performance1536D500K99P), + ] + ), + UICaseItemCluster( + label="Capacity Test", + uiCaseItems=[ + UICaseItem(case_id=CaseType.CapacityDim960), + UICaseItem(case_id=CaseType.CapacityDim128), + ] + ), +] + +# DIVIDER = "DIVIDER" +DISPLAY_CASE_ORDER = [ CaseType.Performance768D100M, CaseType.Performance768D10M, CaseType.Performance768D1M, - DIVIDER, CaseType.Performance1536D5M, CaseType.Performance1536D500K, - DIVIDER, CaseType.Performance768D10M1P, CaseType.Performance768D1M1P, - DIVIDER, CaseType.Performance1536D5M1P, CaseType.Performance1536D500K1P, - DIVIDER, CaseType.Performance768D10M99P, CaseType.Performance768D1M99P, - DIVIDER, CaseType.Performance1536D5M99P, CaseType.Performance1536D500K99P, - DIVIDER, CaseType.CapacityDim960, CaseType.CapacityDim128, ] +CASE_NAME_ORDER = [case.case_cls().name for case in DISPLAY_CASE_ORDER] -CASE_LIST = [item for item in CASE_LIST_WITH_DIVIDER if isinstance(item, CaseType)] +# CASE_LIST = [ +# item for item in CASE_LIST_WITH_DIVIDER if isinstance(item, CaseType)] class InputType(IntEnum): @@ -512,7 +604,8 @@ class CaseConfigInput(BaseModel): inputConfig={ "options": ["x4", "x8", "x16", "x32", "x64"], }, - isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) + isDisplayed=lambda config: config.get( + CaseConfigParamType.quantizationType, None) == "product" and config.get(CaseConfigParamType.IndexType, None) in [ IndexType.HNSW.value, @@ -574,7 +667,8 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_EF_Weaviate, ] -ESLoadingConfig = [CaseConfigParamInput_EFConstruction_ES, CaseConfigParamInput_M_ES] +ESLoadingConfig = [CaseConfigParamInput_EFConstruction_ES, + CaseConfigParamInput_M_ES] ESPerformanceConfig = [ CaseConfigParamInput_EFConstruction_ES, CaseConfigParamInput_M_ES, diff --git a/vectordb_bench/frontend/const/dbPrices.py b/vectordb_bench/frontend/config/dbPrices.py similarity index 100% rename from vectordb_bench/frontend/const/dbPrices.py rename to vectordb_bench/frontend/config/dbPrices.py diff --git a/vectordb_bench/frontend/const/styles.py b/vectordb_bench/frontend/config/styles.py similarity index 100% rename from vectordb_bench/frontend/const/styles.py rename to vectordb_bench/frontend/config/styles.py diff --git a/vectordb_bench/frontend/pages/custom.py b/vectordb_bench/frontend/pages/custom.py new file mode 100644 index 000000000..28c249f78 --- /dev/null +++ b/vectordb_bench/frontend/pages/custom.py @@ -0,0 +1,64 @@ +import streamlit as st +from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon +from vectordb_bench.frontend.components.custom.displayCustomCase import displayCustomCase +from vectordb_bench.frontend.components.custom.displaypPrams import displayParams +from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig, generate_custom_case, get_custom_configs, save_custom_configs +from vectordb_bench.frontend.components.custom.initStyle import initStyle +from vectordb_bench.frontend.config.styles import FAVICON, PAGE_TITLE + + +class CustomCaseManager(): + customCaseItems: list[CustomCaseConfig] + + def __init__(self): + self.customCaseItems = get_custom_configs() + + def addCase(self): + new_custom_case = generate_custom_case() + new_custom_case.dataset_config.name = f"{new_custom_case.dataset_config.name} {len(self.customCaseItems)}" + self.customCaseItems += [new_custom_case] + self.save() + + def deleteCase(self, idx: int): + self.customCaseItems.pop(idx) + self.save() + + def save(self): + save_custom_configs(self.customCaseItems) + + +def main(): + st.set_page_config( + page_title=PAGE_TITLE, + page_icon=FAVICON, + # layout="wide", + # initial_sidebar_state="collapsed", + ) + + # header + drawHeaderIcon(st) + + # init style + initStyle(st) + + st.title("Custom Dataset") + displayParams(st) + customCaseManager = CustomCaseManager() + + for idx, customCase in enumerate(customCaseManager.customCaseItems): + expander = st.expander(customCase.dataset_config.name, expanded=True) + key = f"custom_case_{idx}" + displayCustomCase(customCase, expander, key=key) + + columns = expander.columns(8) + columns[0].button( + "Save", key=f"{key}_", type="secondary", on_click=lambda: customCaseManager.save()) + columns[1].button(":red[Delete]", key=f"{key}_delete", type="secondary", + on_click=lambda: customCaseManager.deleteCase(idx)) + + st.button("\+ New Dataset", key=f"add_custom_configs", + type="primary", on_click=lambda: customCaseManager.addCase()) + + +if __name__ == "__main__": + main() diff --git a/vectordb_bench/frontend/pages/quries_per_dollar.py b/vectordb_bench/frontend/pages/quries_per_dollar.py index 10c1ac8f1..0bb05294b 100644 --- a/vectordb_bench/frontend/pages/quries_per_dollar.py +++ b/vectordb_bench/frontend/pages/quries_per_dollar.py @@ -8,7 +8,7 @@ from vectordb_bench.frontend.components.check_results.charts import drawMetricChart from vectordb_bench.frontend.components.check_results.filters import getshownData from vectordb_bench.frontend.components.get_results.saveAsImage import getResults -from vectordb_bench.frontend.const.styles import * +from vectordb_bench.frontend.config.styles import * from vectordb_bench.interface import benchMarkRunner from vectordb_bench.metric import QURIES_PER_DOLLAR_METRIC @@ -26,7 +26,7 @@ def main(): # results selector resultSelectorContainer = st.sidebar.container() - shownData, _, showCases = getshownData(allResults, resultSelectorContainer) + shownData, _, showCaseNames = getshownData(allResults, resultSelectorContainer) resultSelectorContainer.divider() @@ -45,8 +45,8 @@ def main(): priceMap = priceTable(priceTableContainer, shownData) # charts - for case in showCases: - data = [data for data in shownData if data["case_name"] == case.name] + for caseName in showCaseNames: + data = [data for data in shownData if data["case_name"] == caseName] dataWithMetric = [] metric = QURIES_PER_DOLLAR_METRIC for d in data: @@ -56,7 +56,7 @@ def main(): d[metric] = d["qps"] / price * 3.6 dataWithMetric.append(d) if len(dataWithMetric) > 0: - chartContainer = st.expander(case.name, True) + chartContainer = st.expander(caseName, True) drawMetricChart(data, metric, chartContainer) # footer diff --git a/vectordb_bench/frontend/pages/run_test.py b/vectordb_bench/frontend/pages/run_test.py index 0712bb6cc..1297743ae 100644 --- a/vectordb_bench/frontend/pages/run_test.py +++ b/vectordb_bench/frontend/pages/run_test.py @@ -5,6 +5,7 @@ from vectordb_bench.frontend.components.run_test.dbSelector import dbSelector from vectordb_bench.frontend.components.run_test.generateTasks import generate_tasks from vectordb_bench.frontend.components.run_test.hideSidebar import hideSidebar +from vectordb_bench.frontend.components.run_test.initStyle import initStyle from vectordb_bench.frontend.components.run_test.submitTask import submitTask from vectordb_bench.frontend.components.check_results.nav import NavToResults from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon @@ -15,6 +16,9 @@ def main(): # set page config initRunTestPageConfig(st) + # init style + initStyle(st) + # header drawHeaderIcon(st) diff --git a/vectordb_bench/frontend/utils.py b/vectordb_bench/frontend/utils.py index 139854af6..787b67d03 100644 --- a/vectordb_bench/frontend/utils.py +++ b/vectordb_bench/frontend/utils.py @@ -1,6 +1,22 @@ -from vectordb_bench.models import CaseType +import random +import string + passwordKeys = ["password", "api_key"] + + def inputIsPassword(key: str) -> bool: return key.lower() in passwordKeys + +def addHorizontalLine(st): + st.markdown( + "
", + unsafe_allow_html=True, + ) + + +def generate_random_string(length): + letters = string.ascii_letters + string.digits + result = ''.join(random.choice(letters) for _ in range(length)) + return result diff --git a/vectordb_bench/frontend/vdb_benchmark.py b/vectordb_bench/frontend/vdb_benchmark.py index 0be43470e..b859c68b8 100644 --- a/vectordb_bench/frontend/vdb_benchmark.py +++ b/vectordb_bench/frontend/vdb_benchmark.py @@ -6,7 +6,7 @@ from vectordb_bench.frontend.components.check_results.charts import drawCharts from vectordb_bench.frontend.components.check_results.filters import getshownData from vectordb_bench.frontend.components.get_results.saveAsImage import getResults -from vectordb_bench.frontend.const.styles import * +from vectordb_bench.frontend.config.styles import * from vectordb_bench.interface import benchMarkRunner @@ -24,7 +24,7 @@ def main(): # results selector and filter resultSelectorContainer = st.sidebar.container() - shownData, failedTasks, showCases = getshownData( + shownData, failedTasks, showCaseNames = getshownData( allResults, resultSelectorContainer ) @@ -40,7 +40,7 @@ def main(): getResults(resultesContainer, "vectordb_bench") # charts - drawCharts(st, shownData, failedTasks, showCases) + drawCharts(st, shownData, failedTasks, showCaseNames) # footer footer(st.container()) diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index ec1b610e1..0141f2f13 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -74,6 +74,9 @@ class CaseConfig(BaseModel): case_id: CaseType custom_case: dict | None = None + + def __hash__(self) -> int: + return hash(self.json()) class TaskConfig(BaseModel): diff --git a/vectordb_bench/results/getLeaderboardData.py b/vectordb_bench/results/getLeaderboardData.py index 50f458533..c6484514d 100644 --- a/vectordb_bench/results/getLeaderboardData.py +++ b/vectordb_bench/results/getLeaderboardData.py @@ -2,7 +2,7 @@ import ujson import pathlib from vectordb_bench.backend.cases import CaseType -from vectordb_bench.frontend.const.dbPrices import DB_DBLABEL_TO_PRICE +from vectordb_bench.frontend.config.dbPrices import DB_DBLABEL_TO_PRICE from vectordb_bench.interface import benchMarkRunner from vectordb_bench.models import CaseResult, ResultLabel, TestResult