Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrency optimization for native graph loading #2345

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
- Add concurrency optimizations with native memory graph loading and force eviction (#2265) [https://github.com/opensearch-project/k-NN/pull/2345]

### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;

/**
* Manages native memory allocations made by JNI.
Expand All @@ -56,6 +58,7 @@ public class NativeMemoryCacheManager implements Closeable {

private Cache<String, NativeMemoryAllocation> cache;
private Deque<String> accessRecencyQueue;
private final ConcurrentHashMap<String, ReentrantLock> indexLocks = new ConcurrentHashMap<>();
private final ExecutorService executor;
private AtomicBoolean cacheCapacityReached;
private long maxWeight;
Expand Down Expand Up @@ -345,7 +348,22 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryC

// Cache Miss
// Evict before put
// open the graph file before proceeding to load the graph into memory
ReentrantLock indexFileLock = indexLocks.computeIfAbsent(key, k -> new ReentrantLock());
indexFileLock.lock();
nativeMemoryEntryContext.openVectorIndex();
Gankris96 marked this conversation as resolved.
Show resolved Hide resolved
indexFileLock.unlock();
Comment on lines +352 to +355
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please, have a private method openIndex() here so this is taken care as an when the code changes?

if (!indexFileLock.hasQueuedThreads()) {
indexLocks.remove(key, indexFileLock);
}
synchronized (this) {
// recheck if another thread already loaded this entry into the cache
result = cache.getIfPresent(key);
if (result != null) {
accessRecencyQueue.remove(key);
accessRecencyQueue.addLast(key);
return result;
}
Comment on lines +360 to +366
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private method for this as well? There will be additional null check but everytime get returns accessRecency should be updated.

if (getCacheSizeInKilobytes() + nativeMemoryEntryContext.calculateSizeInKB() >= maxWeight) {
Iterator<String> lruIterator = accessRecencyQueue.iterator();
while (lruIterator.hasNext()
Expand All @@ -367,7 +385,15 @@ public NativeMemoryAllocation get(NativeMemoryEntryContext<?> nativeMemoryEntryC
return result;
}
} else {
return cache.get(nativeMemoryEntryContext.getKey(), nativeMemoryEntryContext::load);
// open graphFile before load
try (nativeMemoryEntryContext) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be a case where Multiple threads trigger eviction and graph loading concurrently, leading to temporary spikes in memory usage. Can we think of using bounded concurrency for eviction and graph loading tasks with thread pools?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will take it up in a separate issue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fair callout. I think we need to improve on our cache operations in general.
I think the problem we are going through right now is that the cache operations can be async in nature (cleanup, eviction) where as we use it as a 1:1 reference for the off heap memory in use.
We can create a tracking issue and deal with this separately.

String key = nativeMemoryEntryContext.getKey();
ReentrantLock indexFileLock = indexLocks.computeIfAbsent(key, k -> new ReentrantLock());
indexFileLock.lock();
nativeMemoryEntryContext.openVectorIndex();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid this case when graph is partially loaded or an error occurs during loading, which endup cache being an inconsistent state . Can we ensure automaticity in graph loading and only put in cache if it is successful.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is error in graph loading then the entry will not be in cache. What would be the scenario where cache ends up in inconsistent state ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Gankris96 Can we wrap this call behind the same lock based logic above?
Just to make sure we do not open the same index files concurrently in two different threads?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapping this within a lock still seems to fail some bwc search tests where we endup getting incorrect results. Even doing so would not really help coz we don't solve the eventual problem of multiple graph files getting loaded at the same time because the load is not synchronized anymore.
This probably requires revisiting in a new separate issue where we refactor the whole cache strategy imo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create an issue so that we can track it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the bwc failure was a different issue unrelated to this. I did add back the locking logic for this as well. It seems to work fine so we can keep this in.

indexFileLock.unlock();
return cache.get(key, nativeMemoryEntryContext::load);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
package org.opensearch.knn.index.memory;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.store.IndexInputWithBuffer;

import java.io.IOException;
import java.util.Map;
Expand All @@ -26,7 +30,7 @@
/**
* Encapsulates all information needed to load a component into native memory.
*/
public abstract class NativeMemoryEntryContext<T extends NativeMemoryAllocation> {
public abstract class NativeMemoryEntryContext<T extends NativeMemoryAllocation> implements AutoCloseable {

protected final String key;

Expand Down Expand Up @@ -55,13 +59,27 @@ public String getKey() {
*/
public abstract Integer calculateSizeInKB();

/**
* Opens the graph file by opening the corresponding indexInput so
* that it is available for graph loading
*/

public void openVectorIndex() {}

/**
* Provides the capability to close the closable objects in the {@link NativeMemoryEntryContext}
*/
@Override
public void close() {}

/**
* Loads entry into memory.
*
* @return NativeMemoryAllocation associated with NativeMemoryEntryContext
*/
public abstract T load() throws IOException;

@Log4j2
public static class IndexEntryContext extends NativeMemoryEntryContext<NativeMemoryAllocation.IndexAllocation> {

@Getter
Expand All @@ -75,6 +93,17 @@ public static class IndexEntryContext extends NativeMemoryEntryContext<NativeMem
@Getter
private final String modelId;

@Getter
Gankris96 marked this conversation as resolved.
Show resolved Hide resolved
private boolean indexGraphFileOpened = false;
@Getter
private int indexSizeKb;

@Getter
private IndexInput readStream;

@Getter
IndexInputWithBuffer indexInputWithBuffer;

/**
* Constructor
*
Expand Down Expand Up @@ -131,10 +160,61 @@ public Integer calculateSizeInKB() {
}
}

@Override
public void openVectorIndex() {
// if graph file is already opened for index, do nothing
if (isIndexGraphFileOpened()) {
return;
}
// Extract vector file name from the given cache key.
// Ex: _0_165_my_field.faiss@1vaqiupVUwvkXAG4Qc/RPg==
final String cacheKey = this.getKey();
final String vectorFileName = NativeMemoryCacheKeyHelper.extractVectorIndexFileName(cacheKey);
if (vectorFileName == null) {
throw new IllegalStateException(
"Invalid cache key was given. The key [" + cacheKey + "] does not contain the corresponding vector file name."
);
}

// Prepare for opening index input from directory.
final Directory directory = this.getDirectory();

// Try to open an index input then pass it down to native engine for loading an index.
try {
indexSizeKb = Math.toIntExact(directory.fileLength(vectorFileName) / 1024);
readStream = directory.openInput(vectorFileName, IOContext.READONCE);
readStream.seek(0);
Gankris96 marked this conversation as resolved.
Show resolved Hide resolved
Gankris96 marked this conversation as resolved.
Show resolved Hide resolved
indexInputWithBuffer = new IndexInputWithBuffer(readStream);
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
indexGraphFileOpened = true;
log.debug("[KNN] NativeMemoryCacheManager openVectorIndex successful");
} catch (IOException e) {
throw new RuntimeException("Failed to openVectorIndex the index " + openSearchIndexName);
}
}

@Override
public NativeMemoryAllocation.IndexAllocation load() throws IOException {
if (!isIndexGraphFileOpened()) {
throw new IllegalStateException("Index graph file is not open");
}
return indexLoadStrategy.load(this);
}

// close the indexInput
@Override
public void close() {
if (readStream != null) {
Gankris96 marked this conversation as resolved.
Show resolved Hide resolved
try {
readStream.close();
indexGraphFileOpened = false;
} catch (IOException e) {
throw new RuntimeException(
"Exception while closing the indexInput index [" + openSearchIndexName + "] for loading the graph file.",
e
);
}
}
}
}

public static class TrainingDataEntryContext extends NativeMemoryEntryContext<NativeMemoryAllocation.TrainingDataAllocation> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.core.action.ActionListener;
import org.opensearch.knn.index.codec.util.NativeMemoryCacheKeyHelper;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.store.IndexInputWithBuffer;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.engine.KNNEngine;
Expand Down Expand Up @@ -88,10 +85,16 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde
final int indexSizeKb = Math.toIntExact(directory.fileLength(vectorFileName) / 1024);

// Try to open an index input then pass it down to native engine for loading an index.
try (IndexInput readStream = directory.openInput(vectorFileName, IOContext.READONCE)) {
final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(readStream);
final long indexAddress = JNIService.loadIndex(indexInputWithBuffer, indexEntryContext.getParameters(), knnEngine);

// openVectorIndex takes care of opening the indexInput file
if (!indexEntryContext.isIndexGraphFileOpened()) {
throw new IllegalStateException("Index [" + indexEntryContext.getOpenSearchIndexName() + "] is not preloaded");
}
try (indexEntryContext) {
final long indexAddress = JNIService.loadIndex(
indexEntryContext.indexInputWithBuffer,
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
indexEntryContext.getParameters(),
knnEngine
);
return createIndexAllocation(indexEntryContext, knnEngine, indexAddress, indexSizeKb, vectorFileName);
}
}
Expand Down
Loading
Loading