From 2cfab29f22069d17ad5f613bdcad968638bc5db5 Mon Sep 17 00:00:00 2001 From: Vivek Narang Date: Mon, 11 Nov 2024 09:58:22 -0500 Subject: [PATCH] make getTopK configurable --- .../main/java/com/nvidia/cuvs/ExampleApp.java | 1 + .../com/nvidia/cuvs/cagra/CagraIndex.java | 6 ++--- .../java/com/nvidia/cuvs/cagra/CuVSQuery.java | 25 +++++++++++++++++-- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/ExampleApp.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/ExampleApp.java index 8f2e5ec2e..aa8f3dbd9 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/ExampleApp.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/ExampleApp.java @@ -54,6 +54,7 @@ public static void main(String[] args) throws Throwable { // Query CuVSQuery query = new CuVSQuery.Builder() + .withTopK(1) .withSearchParams(cagraSearchParams) .withQueryVectors(queries) .withMapping(map) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/cagra/CagraIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/cagra/CagraIndex.java index 440ada990..5b8557d20 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/cagra/CagraIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/cagra/CagraIndex.java @@ -168,10 +168,10 @@ public SearchResult search(CuVSQuery query) throws Throwable { MemoryLayout rvML = linker.canonicalLayouts().get("int"); MemorySegment rvMS = arena.allocate(rvML); - searchMH.invokeExact(ref.indexMemorySegment, getMemorySegment(query.queryVectors), 2, 4L, 2L, res.getResource(), - neighborsMS, distancesMS, rvMS, query.searchParams.cagraSearchParamsMS); + searchMH.invokeExact(ref.indexMemorySegment, getMemorySegment(query.getQueries()), query.getTopK(), 4L, 2L, res.getResource(), + neighborsMS, distancesMS, rvMS, query.getSearchParams().cagraSearchParamsMS); - return new SearchResult(neighborsSL, distancesSL, neighborsMS, distancesMS, 2, query.mapping); + return new SearchResult(neighborsSL, distancesSL, neighborsMS, distancesMS, query.getTopK(), query.getMapping()); } /** diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/cagra/CuVSQuery.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/cagra/CuVSQuery.java index 027ceebe8..b99314503 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/cagra/CuVSQuery.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/cagra/CuVSQuery.java @@ -9,14 +9,16 @@ public class CuVSQuery { PreFilter preFilter; float[][] queryVectors; public Map mapping; + int topK; public CuVSQuery(CagraSearchParams searchParams, PreFilter preFilter, float[][] queryVectors, - Map mapping) { + Map mapping, int topK) { super(); this.searchParams = searchParams; this.preFilter = preFilter; this.queryVectors = queryVectors; this.mapping = mapping; + this.topK = topK; } @Override @@ -37,11 +39,20 @@ public float[][] getQueries() { return queryVectors; } + public Map getMapping() { + return mapping; + } + + public int getTopK() { + return topK; + } + public static class Builder { CagraSearchParams searchParams; PreFilter preFilter; float[][] queryVectors; Map mapping; + int topK = 2; /** * @@ -89,6 +100,16 @@ public Builder withMapping(Map mapping) { this.mapping = mapping; return this; } + + /** + * + * @param topK + * @return + */ + public Builder withTopK(int topK) { + this.topK = topK; + return this; + } /** * @@ -96,7 +117,7 @@ public Builder withMapping(Map mapping) { * @throws Throwable */ public CuVSQuery build() throws Throwable { - return new CuVSQuery(searchParams, preFilter, queryVectors, mapping); + return new CuVSQuery(searchParams, preFilter, queryVectors, mapping, topK); } }