From 0a895375b8569e593f45ffac296bd4143c5be7fb Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Wed, 28 Aug 2024 14:08:00 -0700 Subject: [PATCH] Add processed function for remote inference input dataset parameters to convert it back to its orignal datatype (#2852) * Add processed function for remote inference input dataset parameters to convert it back to its orignal datatype Signed-off-by: b4sjoo * spotless Signed-off-by: b4sjoo * remove debugging print Signed-off-by: b4sjoo * Add UTs Signed-off-by: b4sjoo * Add UTs Signed-off-by: b4sjoo * Spotless Signed-off-by: b4sjoo --------- Signed-off-by: b4sjoo --- .../ml/common/utils/ModelInterfaceUtils.java | 5 +- .../TransportPredictionTaskAction.java | 9 ++- .../org/opensearch/ml/utils/MLNodeUtils.java | 34 ++++++++++ .../opensearch/ml/utils/MLNodeUtilsTests.java | 62 +++++++++++++++++++ 4 files changed, 104 insertions(+), 6 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java index 61bba1065d..5c5cc5fd99 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/ModelInterfaceUtils.java @@ -42,7 +42,10 @@ public class ModelInterfaceUtils { + " \"type\": \"object\",\n" + " \"properties\": {\n" + " \"texts\": {\n" - + " \"type\": \"string\"\n" + + " \"type\": \"array\",\n" + + " \"items\": {\n" + + " \"type\": \"string\"\n" + + " }\n" + " }\n" + " },\n" + " \"required\": [\n" diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 4cf957c499..a0e5018ad4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -242,11 +242,10 @@ public void validateInputSchema(String modelId, MLInput mlInput) { if (modelCacheHelper.getModelInterface(modelId) != null && modelCacheHelper.getModelInterface(modelId).get("input") != null) { String inputSchemaString = modelCacheHelper.getModelInterface(modelId).get("input"); try { - MLNodeUtils - .validateSchema( - inputSchemaString, - mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString() - ); + String InputString = mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString(); + // Process the parameters field in the input dataset to convert it back to its original datatype, instead of a string + String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(InputString); + MLNodeUtils.validateSchema(inputSchemaString, processedInputString); } catch (Exception e) { throw new OpenSearchStatusException("Error validating input schema: " + e.getMessage(), RestStatus.BAD_REQUEST); } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java index 86fbfb1605..3cbbc62ef5 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java @@ -30,6 +30,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; import com.networknt.schema.JsonSchema; import com.networknt.schema.JsonSchemaFactory; import com.networknt.schema.SpecVersion.VersionFlag; @@ -89,6 +90,39 @@ public static void validateSchema(String schemaString, String instanceString) th } } + /** + * This method processes the input JSON string and replaces the string values of the parameters with JSON objects if the string is a valid JSON. + * @param inputJson The input JSON string + * @return The processed JSON string + */ + public static String processRemoteInferenceInputDataSetParametersValue(String inputJson) throws IOException { + ObjectMapper mapper = new ObjectMapper(); + JsonNode rootNode = mapper.readTree(inputJson); + + if (rootNode.has("parameters") && rootNode.get("parameters").isObject()) { + ObjectNode parametersNode = (ObjectNode) rootNode.get("parameters"); + + parametersNode.fields().forEachRemaining(entry -> { + String key = entry.getKey(); + JsonNode value = entry.getValue(); + + if (value.isTextual()) { + String textValue = value.asText(); + try { + // Try to parse the string as JSON + JsonNode parsedValue = mapper.readTree(textValue); + // If successful, replace the string with the parsed JSON + parametersNode.set(key, parsedValue); + } catch (IOException e) { + // If parsing fails, it's not a valid JSON string, so keep it as is + parametersNode.set(key, value); + } + } + }); + } + return mapper.writeValueAsString(rootNode); + } + public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBreakerService, MLStats mlStats) { ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB(); if (openCircuitBreaker != null) { diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java index 7838308834..5b12e73d3c 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java @@ -26,6 +26,8 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.test.OpenSearchTestCase; +import com.fasterxml.jackson.core.JsonParseException; + public class MLNodeUtilsTests extends OpenSearchTestCase { public void testIsMLNode() { @@ -63,4 +65,64 @@ public void testValidateSchema() throws IOException { String json = "{\"key1\": \"foo\", \"key2\": 123}"; MLNodeUtils.validateSchema(schema, json); } + + @Test + public void testProcessRemoteInferenceInputDataSetParametersValueNoParameters() throws IOException { + String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true}"; + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + assertEquals(json, processedJson); + } + + @Test + public void testProcessRemoteInferenceInputDataSetInvalidJson() { + String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"a\"}}"; + assertThrows(JsonParseException.class, () -> MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json)); + } + + @Test + public void testProcessRemoteInferenceInputDataSetEmptyParameters() throws IOException { + String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{}}"; + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + assertEquals(json, processedJson); + } + + @Test + public void testProcessRemoteInferenceInputDataSetParametersValueParametersWrongType() throws IOException { + String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":[\"Hello\",\"world\"]}"; + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + assertEquals(json, processedJson); + } + + @Test + public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersProcessArray() throws IOException { + String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"texts\":\"[\\\"Hello\\\",\\\"world\\\"]\"}}"; + String expectedJson = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"texts\":[\"Hello\",\"world\"]}}"; + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + assertEquals(expectedJson, processedJson); + } + + @Test + public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersProcessObject() throws IOException { + String json = + "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":\"{\\\"role\\\":\\\"system\\\",\\\"foo\\\":\\\"{\\\\\\\"a\\\\\\\": \\\\\\\"b\\\\\\\"}\\\",\\\"content\\\":{\\\"a\\\":\\\"b\\\"}}\"}}}"; + String expectedJson = + "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":{\"role\":\"system\",\"foo\":\"{\\\"a\\\": \\\"b\\\"}\",\"content\":{\"a\":\"b\"}}}}"; + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + assertEquals(expectedJson, processedJson); + } + + @Test + public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersNoProcess() throws IOException { + String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"foo\",\"key2\":123,\"key3\":true}}"; + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + assertEquals(json, processedJson); + } + + @Test + public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersInvalidJson() throws IOException { + String json = + "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"texts\":\"[\\\"Hello\\\",\\\"world\\\"\"}}"; + String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json); + assertEquals(json, processedJson); + } }