Skip to content

Commit

Permalink
Add approximate filters (#607)
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Roblin <[email protected]>
  • Loading branch information
finnroblin authored Aug 30, 2024
1 parent 8ff3896 commit d039711
Show file tree
Hide file tree
Showing 7 changed files with 591 additions and 187 deletions.
5 changes: 5 additions & 0 deletions osbenchmark/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Context(Enum):
MAX_DISTANCE_NEIGHBORS = 4
MIN_SCORE_NEIGHBORS = 5
PARENTS = 6
ATTRIBUTES = 7


class DataSet(ABC):
Expand Down Expand Up @@ -133,6 +134,7 @@ def size(self):
def reset(self):
self.current = self.BEGINNING

# pylint: disable=R0911
@staticmethod
def parse_context(context: Context) -> str:
if context == Context.NEIGHBORS:
Expand All @@ -152,6 +154,9 @@ def parse_context(context: Context) -> str:
if context == Context.MIN_SCORE_NEIGHBORS:
return "min_score_neighbors"

if context == Context.ATTRIBUTES:
return "attributes"

raise Exception("Unsupported context")


Expand Down
16 changes: 15 additions & 1 deletion osbenchmark/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def parse_string_parameter(key: str, params: dict, default: str = None) -> str:

def parse_int_parameter(key: str, params: dict, default: int = None) -> int:
if key not in params:
if default:
if default is not None:
return default
raise ConfigurationError(
"Value cannot be None for param {}".format(key)
Expand All @@ -46,3 +46,17 @@ def parse_float_parameter(key: str, params: dict, default: float = None) -> floa
return params[key]

raise ConfigurationError("Value must be a float for param {}".format(key))


def parse_bool_parameter(key: str, params: dict, default: bool = None) -> bool:
if key not in params:
if default is not None:
return default
raise ConfigurationError(
"Value cannot be None for param {}".format(key)
)

if isinstance(params[key], bool):
return params[key]

raise ConfigurationError("Value must be a bool for param {}".format(key))
26 changes: 25 additions & 1 deletion osbenchmark/worker_coordinator/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,17 @@ def _get_field_value(content, field_name):
return _get_field_value(content["_source"], field_name)
return None

def binary_search_for_last_negative_1(neighbors):
low = 0
high = len(neighbors)
while low < high:
mid = (low + high) // 2
if neighbors[mid] == "-1":
high = mid
else:
low = mid + 1
return low - 1

def calculate_topk_search_recall(predictions, neighbors, top_k):
"""
Calculates the recall by comparing top_k neighbors with predictions.
Expand All @@ -1270,7 +1281,20 @@ def calculate_topk_search_recall(predictions, neighbors, top_k):
self.logger.info("No neighbors are provided for recall calculation")
return 0.0
min_num_of_results = min(top_k, len(neighbors))
last_neighbor_is_negative_1 = int(neighbors[min_num_of_results-1]) == -1
truth_set = neighbors[:min_num_of_results]
if last_neighbor_is_negative_1:
self.logger.debug("Last neighbor is -1")
last_neighbor_idx = binary_search_for_last_negative_1(truth_set)

# Note: we do - 1 since list indexing is inclusive, and we want to ignore the first '-1' in neighbors.
truth_set = truth_set[:last_neighbor_idx-1]
if not truth_set:
self.logger.info("No true neighbors after filtering, returning recall = 1.\n"
"Total neighbors in prediction: [%d].", len(predictions))
return 1.0


for j in range(min_num_of_results):
if j >= len(predictions):
self.logger.info("No more neighbors in prediction to compare against ground truth.\n"
Expand All @@ -1280,7 +1304,7 @@ def calculate_topk_search_recall(predictions, neighbors, top_k):
if predictions[j] in truth_set:
correct += 1.0

return correct / min_num_of_results
return correct / len(truth_set)

def calculate_radial_search_recall(predictions, neighbors, enable_top_1_recall=False):
"""
Expand Down
Loading

0 comments on commit d039711

Please sign in to comment.