Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial loading implementation for FAISS HNSW #2405

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Bug Fixes
### Infrastructure
* Removed JDK 11 and 17 version from CI runs [#1921](https://github.com/opensearch-project/k-NN/pull/1921)
* Added initial implementation of partial loading [#2405](https://github.com/opensearch-project/k-NN/pull/2405)
### Documentation
### Maintenance
### Refactoring
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import org.apache.lucene.util.VectorUtil;

public enum KNNVectorDistanceFunction {
EUCLIDEAN {
@Override
public float distance(float[] vec1, float[] vec2) {
return VectorUtil.squareDistance(vec1, vec2);
}

@Override
public float distance(byte[] vec1, byte[] vec2) {
return VectorUtil.squareDistance(vec1, vec2);
}
},
DOT_PRODUCT {
@Override
public float distance(float[] vec1, float[] vec2) {
return -VectorUtil.dotProduct(vec1, vec2);
}

@Override
public float distance(byte[] vec1, byte[] vec2) {
return -VectorUtil.dotProduct(vec1, vec2);
}
},
COSINE {
@Override
public float distance(float[] vec1, float[] vec2) {
return VectorUtil.cosine(vec1, vec2);
}

@Override
public float distance(byte[] vec1, byte[] vec2) {
return VectorUtil.cosine(vec1, vec2);
}
};

public abstract float distance(float[] vec1, float[] vec2);

public abstract float distance(byte[] vec1, byte[] vec2);
}
19 changes: 19 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.EUCLIDEAN;
}

@Override
public KNNVectorDistanceFunction getKnnVectorDistanceFunction() {
return KNNVectorDistanceFunction.EUCLIDEAN;
}

@Override
public float scoreToDistanceTranslation(float score) {
if (score == 0) {
Expand Down Expand Up @@ -82,6 +87,11 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.COSINE;
}

@Override
public KNNVectorDistanceFunction getKnnVectorDistanceFunction() {
return KNNVectorDistanceFunction.COSINE;
}

@Override
public void validateVector(byte[] vector) {
if (isZeroVector(vector)) {
Expand Down Expand Up @@ -133,6 +143,11 @@ public float scoreTranslation(float rawScore) {
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
}

@Override
public KNNVectorDistanceFunction getKnnVectorDistanceFunction() {
return KNNVectorDistanceFunction.DOT_PRODUCT;
}
},
HAMMING("hamming") {
@Override
Expand Down Expand Up @@ -177,6 +192,10 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {

public abstract float scoreTranslation(float rawScore);

public KNNVectorDistanceFunction getKnnVectorDistanceFunction() {
throw new UnsupportedOperationException(String.format("Space [%s] does not have a knn vector distance function", getValue()));
}

/**
* Get KNNVectorSimilarityFunction that maps to this SpaceType
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
// Check only after quantization state writer finish writing its state, since it is required
// even if there are no graph files in segment, which will be later used by exact search
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
if (false /*TMP*/ && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is temp code. Will revert it back before merging.

log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush",
fieldInfo.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.partialloading.PartialLoadingContext;

import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.locks.ReadWriteLock;
Expand Down Expand Up @@ -100,6 +102,10 @@ default boolean decRef() {
return true;
}

default PartialLoadingContext getPartialLoadingContext() {
return null;
}

/**
* Represents native indices loaded into memory. Because these indices are backed by files, they should be
* freed when file is deleted.
Expand All @@ -121,6 +127,27 @@ class IndexAllocation implements NativeMemoryAllocation {
@Getter
private final boolean isBinaryIndex;
private final RefCountedReleasable<IndexAllocation> refCounted;
@Getter
private final PartialLoadingContext partialLoadingContext;

/**
* Constructor
*
* @param executorService Executor service used to close the allocation
* @param knnEngine KNNEngine associated with the index allocation
* @param vectorFileName Vector file name. Ex: _0_165_my_field.faiss
* @param openSearchIndexName Name of OpenSearch index this index is associated with
*/
IndexAllocation(
ExecutorService executorService,
KNNEngine knnEngine,
String vectorFileName,
String openSearchIndexName,
boolean isBinaryIndex,
PartialLoadingContext partialLoadingContext
) {
this(executorService, 0, 0, knnEngine, vectorFileName, openSearchIndexName, null, isBinaryIndex, partialLoadingContext);
}

/**
* Constructor
Expand All @@ -140,7 +167,7 @@ class IndexAllocation implements NativeMemoryAllocation {
String vectorFileName,
String openSearchIndexName
) {
this(executorService, memoryAddress, sizeKb, knnEngine, vectorFileName, openSearchIndexName, null, false);
this(executorService, memoryAddress, sizeKb, knnEngine, vectorFileName, openSearchIndexName, null, false, null);
}

/**
Expand All @@ -163,6 +190,41 @@ class IndexAllocation implements NativeMemoryAllocation {
String openSearchIndexName,
SharedIndexState sharedIndexState,
boolean isBinaryIndex
) {
this(
executorService,
memoryAddress,
sizeKb,
knnEngine,
vectorFileName,
openSearchIndexName,
sharedIndexState,
isBinaryIndex,
null
);
}

/**
* Constructor
*
* @param executorService Executor service used to close the allocation
* @param memoryAddress Pointer in memory to the index
* @param sizeKb Size this index consumes in kilobytes
* @param knnEngine KNNEngine associated with the index allocation
* @param vectorFileName Vector file name. Ex: _0_165_my_field.faiss
* @param openSearchIndexName Name of OpenSearch index this index is associated with
* @param sharedIndexState Shared index state. If not shared state present, pass null.
*/
IndexAllocation(
ExecutorService executorService,
long memoryAddress,
int sizeKb,
KNNEngine knnEngine,
String vectorFileName,
String openSearchIndexName,
SharedIndexState sharedIndexState,
boolean isBinaryIndex,
PartialLoadingContext partialLoadingContext
) {
this.executor = executorService;
this.closed = false;
Expand All @@ -175,6 +237,7 @@ class IndexAllocation implements NativeMemoryAllocation {
this.sharedIndexState = sharedIndexState;
this.isBinaryIndex = isBinaryIndex;
this.refCounted = new RefCountedReleasable<>("IndexAllocation-Reference", this, this::closeInternal);
this.partialLoadingContext = partialLoadingContext;
}

protected void closeInternal() {
Expand Down Expand Up @@ -218,6 +281,14 @@ private void cleanup() {
if (sharedIndexState != null) {
SharedIndexStateManager.getInstance().release(sharedIndexState);
}

if (partialLoadingContext != null) {
try {
partialLoadingContext.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.partialloading.PartialLoadingContext;
import org.opensearch.knn.partialloading.faiss.FaissIndex;
import org.opensearch.knn.partialloading.search.PartialLoadingMode;
import org.opensearch.knn.training.TrainingDataConsumer;
import org.opensearch.knn.training.VectorReader;

import java.io.Closeable;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

Expand Down Expand Up @@ -87,6 +91,15 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde
final Directory directory = indexEntryContext.getDirectory();
final int indexSizeKb = Math.toIntExact(directory.fileLength(vectorFileName) / 1024);

// TMP
final PartialLoadingMode partialLoadingMode = PartialLoadingMode.DISABLED;
// final PartialLoadingMode partialLoadingMode = PartialLoadingMode.MEMORY_EFFICIENT;
// TMP

if (partialLoadingMode != PartialLoadingMode.DISABLED) {
return createPartialLoadedIndexAllocation(directory, indexEntryContext, knnEngine, vectorFileName, partialLoadingMode);
}

// Try to open an index input then pass it down to native engine for loading an index.
try (IndexInput readStream = directory.openInput(vectorFileName, IOContext.READONCE)) {
final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(readStream);
Expand All @@ -96,6 +109,45 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde
}
}

private NativeMemoryAllocation.IndexAllocation createPartialLoadedIndexAllocation(
Directory directory,
NativeMemoryEntryContext.IndexEntryContext indexEntryContext,
KNNEngine knnEngine,
String vectorFileName,
PartialLoadingMode partialLoadingMode
) throws IOException {
validatePartialLoadingSupported(indexEntryContext, knnEngine);

// Try to open an index input then pass it down to native engine for loading an index.
FaissIndex faissIndex = null;
try (IndexInput input = directory.openInput(vectorFileName, IOContext.READONCE)) {
faissIndex = FaissIndex.partiallyLoad(input);
}

// Create partial loading context.
final PartialLoadingContext partialLoadingContext = new PartialLoadingContext(faissIndex, vectorFileName, partialLoadingMode);

return new NativeMemoryAllocation.IndexAllocation(
executor,
knnEngine,
vectorFileName,
indexEntryContext.getOpenSearchIndexName(),
IndexUtil.isBinaryIndex(knnEngine, indexEntryContext.getParameters()),
partialLoadingContext
);
}

private void validatePartialLoadingSupported(NativeMemoryEntryContext.IndexEntryContext indexEntryContext, KNNEngine knnEngine)
throws UnsupportedEncodingException {
if (IndexUtil.isBinaryIndex(knnEngine, indexEntryContext.getParameters())) {
throw new UnsupportedEncodingException("Partial loading search does not support binary index.");
}

if (IndexUtil.isByteIndex(indexEntryContext.getParameters())) {
throw new UnsupportedEncodingException("Partial loading search does not support byte index.");
}
}

private NativeMemoryAllocation.IndexAllocation createIndexAllocation(
final NativeMemoryEntryContext.IndexEntryContext indexEntryContext,
final KNNEngine knnEngine,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
System.out.println(" +++++++++++++++++++++++++++++++++ parentFilter = context.getParentFilter(), " + parentFilter);
}

if (parentFilter == null && expandNested) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
* Place holder for the score of the document
*/
public class KNNQueryResult {
private final int id;
private final float score;
private int id;
private float score;

public KNNQueryResult(final int id, final float score) {
this.id = id;
Expand All @@ -24,4 +24,9 @@ public int getId() {
public float getScore() {
return this.score;
}

public void reset(final int id, final float score) {
this.id = id;
this.score = score;
}
}
Loading
Loading