From e088148b12671d35e1bcb5a91c6df0855d81dc5e Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Fri, 20 Dec 2024 19:46:51 -0800 Subject: [PATCH] Fixed some integration bugs with remote-index-service, validated that index creation getting triggered from k-NN plugin Signed-off-by: Navneet Verma --- build.gradle | 35 +++++++++++++++-- .../org/opensearch/knn/index/KNNSettings.java | 7 +--- .../codec/BasePerFieldKnnVectorsFormat.java | 4 +- .../NativeEngines990KnnVectorsFormat.java | 10 ++++- .../NativeEngines990KnnVectorsWriter.java | 39 ++++++++++++++----- .../index/client/IndexBuildServiceClient.java | 19 ++++++--- .../index/model/CreateIndexResponse.java | 2 +- .../knn/remote/index/s3/S3Client.java | 12 +++--- .../knn/remote/index/s3/SocketAccess.java | 8 ++-- .../plugin-metadata/plugin-security.policy | 4 ++ 10 files changed, 103 insertions(+), 37 deletions(-) diff --git a/build.gradle b/build.gradle index 83cee83d8..ff49f061d 100644 --- a/build.gradle +++ b/build.gradle @@ -295,9 +295,10 @@ dependencies { api group: 'com.google.guava', name: 'guava', version:'32.1.3-jre' api group: 'commons-lang', name: 'commons-lang', version: '2.6' - api group: 'org.apache.httpcomponents', name: 'httpcore', version: "${versions.httpcore}" - api group: 'org.apache.httpcomponents', name: 'httpclient', version: "${versions.httpclient}" - api group: 'org.apache.httpcomponents', name: 'httpasyncclient', version: "${versions.httpasyncclient}" + implementation group: 'org.apache.httpcomponents', name: 'httpcore', version: "${versions.httpcore}" + implementation group: 'org.apache.httpcomponents', name: 'httpclient', version: "${versions.httpclient}" + implementation group: 'org.apache.httpcomponents', name: 'httpasyncclient', version: "${versions.httpasyncclient}" + testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}" testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10' testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3' @@ -310,7 +311,6 @@ dependencies { zipArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}" - // aws sdk v2 stack api "software.amazon.awssdk:sdk-core:${versions.aws}" api "software.amazon.awssdk:annotations:${versions.aws}" api "software.amazon.awssdk:aws-core:${versions.aws}" @@ -333,6 +333,33 @@ dependencies { api "software.amazon.awssdk:aws-query-protocol:${versions.aws}" api "software.amazon.awssdk:sts:${versions.aws}" api "software.amazon.awssdk:netty-nio-client:${versions.aws}" + + api "org.apache.httpcomponents:httpclient:${versions.httpclient}" + api "org.apache.httpcomponents:httpcore:${versions.httpcore}" + api "commons-logging:commons-logging:${versions.commonslogging}" + api "org.apache.logging.log4j:log4j-1.2-api:${versions.log4j}" + api "commons-codec:commons-codec:${versions.commonscodec}" + api "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" + api "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}" + api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" + api "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${versions.jackson}" + api "joda-time:joda-time:${versions.joda}" + api "org.slf4j:slf4j-api:${versions.slf4j}" + + runtimeOnly "org.apache.logging.log4j:log4j-slf4j-impl:${versions.log4j}" + + // network stack + api "io.netty:netty-buffer:${versions.netty}" + api "io.netty:netty-codec:${versions.netty}" + api "io.netty:netty-codec-http:${versions.netty}" + api "io.netty:netty-codec-http2:${versions.netty}" + api "io.netty:netty-common:${versions.netty}" + api "io.netty:netty-handler:${versions.netty}" + api "io.netty:netty-resolver:${versions.netty}" + api "io.netty:netty-transport:${versions.netty}" + api "io.netty:netty-transport-native-unix-common:${versions.netty}" + api "io.netty:netty-transport-classes-epoll:${versions.netty}" + } task windowsPatches(type:Exec) { diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index c553d4073..1c2532d68 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -575,10 +575,7 @@ public List> getSettings() { KNN_FAISS_AVX512_DISABLED_SETTING, QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, - KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING, - KNN_S3_ACCESS_KEY_SETTING, - KNN_S3_SECRET_KEY_SETTING, - KNN_S3_TOKEN_KEY_SETTING + KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED_SETTING ); final List>> streamList = Arrays.asList( settings.stream(), @@ -659,7 +656,7 @@ public static String getKnnS3Token() { } public static String getRemoteServiceEndpoint() { - return KNNSettings.state().getSettingValue(REMOTE_SERVICE_PORT); + return KNNSettings.state().getSettingValue(REMOTE_SERVICE_ENDPOINT); } public static Integer getRemoteServicePort() { diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 72187516f..0c3374a38 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -134,9 +134,11 @@ private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() { // mapperService is already checked for null or valid instance type at caller, hence we don't need // addition isPresent check here. int approximateThreshold = getApproximateThresholdValue(); + final String indexUUID = mapperService.get().getIndexSettings().getIndex().getUUID(); return new NativeEngines990KnnVectorsFormat( new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()), - approximateThreshold + approximateThreshold, + indexUUID ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java index dd326123e..79504fda2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormat.java @@ -32,6 +32,7 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat { private static FlatVectorsFormat flatVectorsFormat; private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat"; private static int approximateThreshold; + private String indexUUID; public NativeEngines990KnnVectorsFormat() { this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer())); @@ -51,6 +52,13 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold; } + public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int approximateThreshold, final String indexUUID) { + super(FORMAT_NAME); + NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat; + NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold; + this.indexUUID = indexUUID; + } + /** * Returns a {@link KnnVectorsWriter} to write the vectors to the index. * @@ -58,7 +66,7 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsForma */ @Override public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException { - return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold); + return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold, indexUUID); } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index d1ce0bb94..fcef39e95 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -63,21 +63,32 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private final Integer approximateThreshold; private final S3Client s3Client; private final IndexBuildServiceClient indexBuildServiceClient; + private final String indexUUID; public NativeEngines990KnnVectorsWriter( SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter, Integer approximateThreshold + ) { + this(segmentWriteState, flatVectorsWriter, approximateThreshold, null); + } + + public NativeEngines990KnnVectorsWriter( + SegmentWriteState segmentWriteState, + FlatVectorsWriter flatVectorsWriter, + Integer approximateThreshold, + String indexUUID ) { this.segmentWriteState = segmentWriteState; this.flatVectorsWriter = flatVectorsWriter; this.approximateThreshold = approximateThreshold; + this.indexUUID = indexUUID; try { s3Client = S3Client.getInstance(); + indexBuildServiceClient = IndexBuildServiceClient.getInstance(); } catch (Exception e) { throw new RuntimeException(e); } - indexBuildServiceClient = IndexBuildServiceClient.getInstance(); } /** @@ -133,7 +144,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { uploadToS3(fieldInfo, knnVectorValuesSupplier); log.info("Creating the IndexRequest..."); - CreateIndexRequest createIndexRequest = buildCreateIndexRequest(fieldInfo); + CreateIndexRequest createIndexRequest = buildCreateIndexRequest(fieldInfo, totalLiveDocs); log.info("Submitting request to remote indexbuildService"); try { CreateIndexResponse response = indexBuildServiceClient.createIndex(createIndexRequest); @@ -195,9 +206,7 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState private void uploadToS3(final FieldInfo fieldInfo, final Supplier> knnVectorValuesSupplier) { // s3 uploader - String segmentName = segmentWriteState.segmentInfo.name; - String fieldName = fieldInfo.getName(); - String s3Key = segmentName + "_" + fieldName + ".s3vec"; + String s3Key = createObjectKey(fieldInfo); try (InputStream vectorInputStream = new VectorValuesInputStream((KNNFloatVectorValues) knnVectorValuesSupplier.get())) { StopWatch stopWatch = new StopWatch().start(); // Lets upload data to s3. @@ -205,7 +214,7 @@ private void uploadToS3(final FieldInfo fieldInfo, final Supplier httpClient.execute(request)); HttpEntity entity = response.getEntity(); int statusCode = response.getStatusLine().getStatusCode(); diff --git a/src/main/java/org/opensearch/knn/remote/index/model/CreateIndexResponse.java b/src/main/java/org/opensearch/knn/remote/index/model/CreateIndexResponse.java index 83f297453..532c7b7be 100644 --- a/src/main/java/org/opensearch/knn/remote/index/model/CreateIndexResponse.java +++ b/src/main/java/org/opensearch/knn/remote/index/model/CreateIndexResponse.java @@ -15,7 +15,7 @@ @Value @Builder public class CreateIndexResponse { - private static final ParseField INDEX_CREATION_REQUEST_ID = new ParseField("indexCreationRequestId"); + private static final ParseField INDEX_CREATION_REQUEST_ID = new ParseField("job_id"); private static final ParseField STATUS = new ParseField("status"); String indexCreationRequestId; String status; diff --git a/src/main/java/org/opensearch/knn/remote/index/s3/S3Client.java b/src/main/java/org/opensearch/knn/remote/index/s3/S3Client.java index 8a32912ee..d37c12da8 100644 --- a/src/main/java/org/opensearch/knn/remote/index/s3/S3Client.java +++ b/src/main/java/org/opensearch/knn/remote/index/s3/S3Client.java @@ -15,6 +15,7 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.profiles.ProfileFileSystemSetting; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; @@ -65,8 +66,8 @@ public static S3Client getInstance() throws IOException { } @SuppressForbidden(reason = "Need to provide this override to v2 SDK so that path does not default to home path") - private S3Client() throws IOException { - SocketAccess.doPrivilegedIOException(() -> { + private S3Client() { + SocketAccess.doPrivilegedException(() -> { if (ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.getStringValue().isEmpty()) { System.setProperty( ProfileFileSystemSetting.AWS_SHARED_CREDENTIALS_FILE.property(), @@ -92,6 +93,7 @@ private S3Client() throws IOException { software.amazon.awssdk.services.s3.S3AsyncClientBuilder builder = software.amazon.awssdk.services.s3.S3AsyncClient.builder() .region(REGION) + .httpClientBuilder(NettyNioAsyncHttpClient.builder()) .credentialsProvider(StaticCredentialsProvider.create(credentials)) .overrideConfiguration(ClientOverrideConfiguration.builder().defaultProfileFile(null).defaultProfileName(null).build()); @@ -122,7 +124,7 @@ public long uploadWithProgress(final InputStream inputStream, final String key) .key(key) .build(); - CreateMultipartUploadResponse multipartUpload = SocketAccess.doPrivilegedIOException( + CreateMultipartUploadResponse multipartUpload = SocketAccess.doPrivilegedException( () -> s3AsyncClient.createMultipartUpload(createMultipartUploadRequest).get() ); @@ -150,7 +152,7 @@ public long uploadWithProgress(final InputStream inputStream, final String key) .partNumber(partNumber) .build(); - CompletableFuture uploadPartResponse = SocketAccess.doPrivilegedIOException( + CompletableFuture uploadPartResponse = SocketAccess.doPrivilegedException( () -> s3AsyncClient.uploadPart(uploadPartRequest, AsyncRequestBody.fromBytes(partData)) ); completableFutureList.add(uploadPartResponse); @@ -205,7 +207,7 @@ public long uploadWithProgress(final InputStream inputStream, final String key) .multipartUpload(completedMultipartUpload) .build(); - CompleteMultipartUploadResponse response = SocketAccess.doPrivilegedIOException( + CompleteMultipartUploadResponse response = SocketAccess.doPrivilegedException( () -> s3AsyncClient.completeMultipartUpload(completeRequest).get() ); log.debug("********** CompleteMultipartUploadResponse : {} **************", response); diff --git a/src/main/java/org/opensearch/knn/remote/index/s3/SocketAccess.java b/src/main/java/org/opensearch/knn/remote/index/s3/SocketAccess.java index 02d68cdd4..b46b6462a 100644 --- a/src/main/java/org/opensearch/knn/remote/index/s3/SocketAccess.java +++ b/src/main/java/org/opensearch/knn/remote/index/s3/SocketAccess.java @@ -7,10 +7,8 @@ import org.opensearch.SpecialPermission; -import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; -import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; public final class SocketAccess { @@ -22,12 +20,12 @@ public static T doPrivileged(PrivilegedAction operation) { return AccessController.doPrivileged(operation); } - public static T doPrivilegedIOException(PrivilegedExceptionAction operation) throws IOException { + public static T doPrivilegedException(PrivilegedExceptionAction operation) { SpecialPermission.check(); try { return AccessController.doPrivileged(operation); - } catch (PrivilegedActionException e) { - throw (IOException) e.getCause(); + } catch (Exception e) { + throw new RuntimeException(e); } } diff --git a/src/main/plugin-metadata/plugin-security.policy b/src/main/plugin-metadata/plugin-security.policy index 0d2d67718..8b0b68fef 100644 --- a/src/main/plugin-metadata/plugin-security.policy +++ b/src/main/plugin-metadata/plugin-security.policy @@ -38,5 +38,9 @@ grant { permission java.util.PropertyPermission "opensearch.path.conf", "read,write"; permission java.io.FilePermission "config", "read"; + // For accessing the local remote build service + //permission java.net.SocketPermission "127.0.0.1:5005", "connect,resolve"; + //permission java.net.SocketPermission "localhost:5005", "connect,resolve"; + permission java.lang.RuntimePermission "accessDeclaredMembers"; };