Skip to content

Commit

Permalink
Fixed some integration bugs with remote-index-service, validated that…
Browse files Browse the repository at this point in the history
… index creation getting triggered from k-NN plugin

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Dec 21, 2024
1 parent 5606ac5 commit e088148
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 37 deletions.
35 changes: 31 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,10 @@ dependencies {
api group: 'com.google.guava', name: 'guava', version:'32.1.3-jre'
api group: 'commons-lang', name: 'commons-lang', version: '2.6'

api group: 'org.apache.httpcomponents', name: 'httpcore', version: "${versions.httpcore}"
api group: 'org.apache.httpcomponents', name: 'httpclient', version: "${versions.httpclient}"
api group: 'org.apache.httpcomponents', name: 'httpasyncclient', version: "${versions.httpasyncclient}"
implementation group: 'org.apache.httpcomponents', name: 'httpcore', version: "${versions.httpcore}"
implementation group: 'org.apache.httpcomponents', name: 'httpclient', version: "${versions.httpclient}"
implementation group: 'org.apache.httpcomponents', name: 'httpasyncclient', version: "${versions.httpasyncclient}"

testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
Expand All @@ -310,7 +311,6 @@ dependencies {

zipArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}"

// aws sdk v2 stack
api "software.amazon.awssdk:sdk-core:${versions.aws}"
api "software.amazon.awssdk:annotations:${versions.aws}"
api "software.amazon.awssdk:aws-core:${versions.aws}"
Expand All @@ -333,6 +333,33 @@ dependencies {
api "software.amazon.awssdk:aws-query-protocol:${versions.aws}"
api "software.amazon.awssdk:sts:${versions.aws}"
api "software.amazon.awssdk:netty-nio-client:${versions.aws}"

api "org.apache.httpcomponents:httpclient:${versions.httpclient}"
api "org.apache.httpcomponents:httpcore:${versions.httpcore}"
api "commons-logging:commons-logging:${versions.commonslogging}"
api "org.apache.logging.log4j:log4j-1.2-api:${versions.log4j}"
api "commons-codec:commons-codec:${versions.commonscodec}"
api "com.fasterxml.jackson.core:jackson-core:${versions.jackson}"
api "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}"
api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}"
api "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${versions.jackson}"
api "joda-time:joda-time:${versions.joda}"
api "org.slf4j:slf4j-api:${versions.slf4j}"

runtimeOnly "org.apache.logging.log4j:log4j-slf4j-impl:${versions.log4j}"

// network stack
api "io.netty:netty-buffer:${versions.netty}"
api "io.netty:netty-codec:${versions.netty}"
api "io.netty:netty-codec-http:${versions.netty}"
api "io.netty:netty-codec-http2:${versions.netty}"
api "io.netty:netty-common:${versions.netty}"
api "io.netty:netty-handler:${versions.netty}"
api "io.netty:netty-resolver:${versions.netty}"
api "io.netty:netty-transport:${versions.netty}"
api "io.netty:netty-transport-native-unix-common:${versions.netty}"
api "io.netty:netty-transport-classes-epoll:${versions.netty}"

}

task windowsPatches(type:Exec) {
Expand Down
7 changes: 2 additions & 5 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,7 @@ public List<Setting<?>> getSettings() {
KNN_FAISS_AVX512_DISABLED_SETTING,
QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING,
QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING,
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING,
KNN_S3_ACCESS_KEY_SETTING,
KNN_S3_SECRET_KEY_SETTING,
KNN_S3_TOKEN_KEY_SETTING
KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING
);
final List<Stream<Setting<?>>> streamList = Arrays.asList(
settings.stream(),
Expand Down Expand Up @@ -659,7 +656,7 @@ public static String getKnnS3Token() {
}

public static String getRemoteServiceEndpoint() {
return KNNSettings.state().getSettingValue(REMOTE_SERVICE_PORT);
return KNNSettings.state().getSettingValue(REMOTE_SERVICE_ENDPOINT);
}

public static Integer getRemoteServicePort() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,11 @@ private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() {
// mapperService is already checked for null or valid instance type at caller, hence we don't need
// addition isPresent check here.
int approximateThreshold = getApproximateThresholdValue();
final String indexUUID = mapperService.get().getIndexSettings().getIndex().getUUID();
return new NativeEngines990KnnVectorsFormat(
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()),
approximateThreshold
approximateThreshold,
indexUUID
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat {
private static FlatVectorsFormat flatVectorsFormat;
private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat";
private static int approximateThreshold;
private String indexUUID;

public NativeEngines990KnnVectorsFormat() {
this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()));
Expand All @@ -51,14 +52,21 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma
NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold;
}

public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int approximateThreshold, final String indexUUID) {
super(FORMAT_NAME);
NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat;
NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold;
this.indexUUID = indexUUID;
}

/**
* Returns a {@link KnnVectorsWriter} to write the vectors to the index.
*
* @param state {@link SegmentWriteState}
*/
@Override
public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException {
return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold);
return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold, indexUUID);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,32 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
private final Integer approximateThreshold;
private final S3Client s3Client;
private final IndexBuildServiceClient indexBuildServiceClient;
private final String indexUUID;

public NativeEngines990KnnVectorsWriter(
SegmentWriteState segmentWriteState,
FlatVectorsWriter flatVectorsWriter,
Integer approximateThreshold
) {
this(segmentWriteState, flatVectorsWriter, approximateThreshold, null);
}

public NativeEngines990KnnVectorsWriter(
SegmentWriteState segmentWriteState,
FlatVectorsWriter flatVectorsWriter,
Integer approximateThreshold,
String indexUUID
) {
this.segmentWriteState = segmentWriteState;
this.flatVectorsWriter = flatVectorsWriter;
this.approximateThreshold = approximateThreshold;
this.indexUUID = indexUUID;
try {
s3Client = S3Client.getInstance();
indexBuildServiceClient = IndexBuildServiceClient.getInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
indexBuildServiceClient = IndexBuildServiceClient.getInstance();
}

/**
Expand Down Expand Up @@ -133,7 +144,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {

uploadToS3(fieldInfo, knnVectorValuesSupplier);
log.info("Creating the IndexRequest...");
CreateIndexRequest createIndexRequest = buildCreateIndexRequest(fieldInfo);
CreateIndexRequest createIndexRequest = buildCreateIndexRequest(fieldInfo, totalLiveDocs);
log.info("Submitting request to remote indexbuildService");
try {
CreateIndexResponse response = indexBuildServiceClient.createIndex(createIndexRequest);
Expand Down Expand Up @@ -195,17 +206,15 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState

private void uploadToS3(final FieldInfo fieldInfo, final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier) {
// s3 uploader
String segmentName = segmentWriteState.segmentInfo.name;
String fieldName = fieldInfo.getName();
String s3Key = segmentName + "_" + fieldName + ".s3vec";
String s3Key = createObjectKey(fieldInfo);
try (InputStream vectorInputStream = new VectorValuesInputStream((KNNFloatVectorValues) knnVectorValuesSupplier.get())) {
StopWatch stopWatch = new StopWatch().start();
// Lets upload data to s3.
long totalBytesUploaded = s3Client.uploadWithProgress(vectorInputStream, s3Key);
long time_in_millis = stopWatch.stop().totalTime().millis();
log.info(
"Time taken to upload vector for segment : {}, field: {}, totalBytes: {}, dimension: {} is : {}ms",
segmentName,
segmentWriteState.segmentInfo.name,
fieldInfo.getName(),
totalBytesUploaded,
fieldInfo.getVectorDimension(),
Expand All @@ -217,12 +226,22 @@ private void uploadToS3(final FieldInfo fieldInfo, final Supplier<KNNVectorValue
}
}

private CreateIndexRequest buildCreateIndexRequest(final FieldInfo fieldInfo) {
private CreateIndexRequest buildCreateIndexRequest(final FieldInfo fieldInfo, int totalLiveDocs) {
String s3Key = createObjectKey(fieldInfo);
int dimension = fieldInfo.getVectorDimension();
return CreateIndexRequest.builder()
.bucketName(S3Client.BUCKET_NAME)
.objectLocation(s3Key)
.dimensions(dimension)
.numberOfVectors(totalLiveDocs)
.build();
}

private String createObjectKey(FieldInfo fieldInfo) {
String segmentName = segmentWriteState.segmentInfo.name;
String fieldName = fieldInfo.getName();
String s3Key = segmentName + "_" + fieldName + ".s3vec";
int dimension = fieldInfo.getVectorDimension();
return CreateIndexRequest.builder().bucketName(S3Client.BUCKET_NAME).objectLocation(s3Key).dimensions(dimension).build();
// shard information will also be needed to ensure that we can correct paths
return indexUUID + "_" + segmentName + "_" + fieldName + ".s3vec";
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;
Expand All @@ -25,8 +26,11 @@
import org.opensearch.knn.remote.index.model.CreateIndexRequest;
import org.opensearch.knn.remote.index.model.CreateIndexResponse;
import org.opensearch.knn.remote.index.s3.S3Client;
import org.opensearch.knn.remote.index.s3.SocketAccess;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;

/**
* Main class to class the IndexBuildServiceAPIs
Expand Down Expand Up @@ -62,8 +66,13 @@ private IndexBuildServiceClient() {
* @param createIndexRequest {@link CreateIndexRequest}
* @throws IOException Exception called if createIndex request is not successful
*/
public CreateIndexResponse createIndex(final CreateIndexRequest createIndexRequest) throws IOException {
HttpPost request = new HttpPost();
public CreateIndexResponse createIndex(final CreateIndexRequest createIndexRequest) throws IOException, URISyntaxException {
String host = KNNSettings.getRemoteServiceEndpoint();
int port = KNNSettings.getRemoteServicePort();

URI uri = new URIBuilder().setScheme("http").setHost(host).setPort(port).setPath("/create_index").build();

HttpPost request = new HttpPost(uri);
request.setHeader(CONTENT_TYPE, APPLICATION_JSON);
request.setHeader(ACCEPT, APPLICATION_JSON);
XContentBuilder builder = XContentFactory.jsonBuilder();
Expand All @@ -81,8 +90,8 @@ public void checkIndexBuildStatus() {

}

private HttpResponse makeHTTPRequest(final HttpRequest request) throws IOException {
HttpResponse response = httpClient.execute(httpHost, request);
private HttpResponse makeHTTPRequest(final HttpUriRequest request) throws IOException {
HttpResponse response = SocketAccess.doPrivilegedException(() -> httpClient.execute(request));
HttpEntity entity = response.getEntity();
int statusCode = response.getStatusLine().getStatusCode();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@Value
@Builder
public class CreateIndexResponse {
private static final ParseField INDEX_CREATION_REQUEST_ID = new ParseField("indexCreationRequestId");
private static final ParseField INDEX_CREATION_REQUEST_ID = new ParseField("job_id");
private static final ParseField STATUS = new ParseField("status");
String indexCreationRequestId;
String status;
Expand Down
12 changes: 7 additions & 5 deletions src/main/java/org/opensearch/knn/remote/index/s3/S3Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.profiles.ProfileFileSystemSetting;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest;
Expand Down Expand Up @@ -65,8 +66,8 @@ public static S3Client getInstance() throws IOException {
}

@SuppressForbidden(reason = "Need to provide this override to v2 SDK so that path does not default to home path")
private S3Client() throws IOException {
SocketAccess.doPrivilegedIOException(() -> {
private S3Client() {
SocketAccess.doPrivilegedException(() -> {
if (ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.getStringValue().isEmpty()) {
System.setProperty(
ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.property(),
Expand All @@ -92,6 +93,7 @@ private S3Client() throws IOException {

software.amazon.awssdk.services.s3.S3AsyncClientBuilder builder = software.amazon.awssdk.services.s3.S3AsyncClient.builder()
.region(REGION)
.httpClientBuilder(NettyNioAsyncHttpClient.builder())
.credentialsProvider(StaticCredentialsProvider.create(credentials))
.overrideConfiguration(ClientOverrideConfiguration.builder().defaultProfileFile(null).defaultProfileName(null).build());

Expand Down Expand Up @@ -122,7 +124,7 @@ public long uploadWithProgress(final InputStream inputStream, final String key)
.key(key)
.build();

CreateMultipartUploadResponse multipartUpload = SocketAccess.doPrivilegedIOException(
CreateMultipartUploadResponse multipartUpload = SocketAccess.doPrivilegedException(
() -> s3AsyncClient.createMultipartUpload(createMultipartUploadRequest).get()
);

Expand Down Expand Up @@ -150,7 +152,7 @@ public long uploadWithProgress(final InputStream inputStream, final String key)
.partNumber(partNumber)
.build();

CompletableFuture<UploadPartResponse> uploadPartResponse = SocketAccess.doPrivilegedIOException(
CompletableFuture<UploadPartResponse> uploadPartResponse = SocketAccess.doPrivilegedException(
() -> s3AsyncClient.uploadPart(uploadPartRequest, AsyncRequestBody.fromBytes(partData))
);
completableFutureList.add(uploadPartResponse);
Expand Down Expand Up @@ -205,7 +207,7 @@ public long uploadWithProgress(final InputStream inputStream, final String key)
.multipartUpload(completedMultipartUpload)
.build();

CompleteMultipartUploadResponse response = SocketAccess.doPrivilegedIOException(
CompleteMultipartUploadResponse response = SocketAccess.doPrivilegedException(
() -> s3AsyncClient.completeMultipartUpload(completeRequest).get()
);
log.debug("********** CompleteMultipartUploadResponse : {} **************", response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@

import org.opensearch.SpecialPermission;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;

public final class SocketAccess {
Expand All @@ -22,12 +20,12 @@ public static <T> T doPrivileged(PrivilegedAction<T> operation) {
return AccessController.doPrivileged(operation);
}

public static <T> T doPrivilegedIOException(PrivilegedExceptionAction<T> operation) throws IOException {
public static <T> T doPrivilegedException(PrivilegedExceptionAction<T> operation) {
SpecialPermission.check();
try {
return AccessController.doPrivileged(operation);
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/main/plugin-metadata/plugin-security.policy
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,9 @@ grant {
permission java.util.PropertyPermission "opensearch.path.conf", "read,write";
permission java.io.FilePermission "config", "read";

// For accessing the local remote build service
//permission java.net.SocketPermission "127.0.0.1:5005", "connect,resolve";
//permission java.net.SocketPermission "localhost:5005", "connect,resolve";

permission java.lang.RuntimePermission "accessDeclaredMembers";
};

0 comments on commit e088148

Please sign in to comment.