Skip to content

Commit

Permalink
Add processed function for remote inference input dataset parameters …
Browse files Browse the repository at this point in the history
…to convert it back to its orignal datatype (opensearch-project#2852)

* Add processed function for remote inference input dataset parameters to convert it back to its orignal datatype

Signed-off-by: b4sjoo <[email protected]>

* spotless

Signed-off-by: b4sjoo <[email protected]>

* remove debugging print

Signed-off-by: b4sjoo <[email protected]>

* Add UTs

Signed-off-by: b4sjoo <[email protected]>

* Add UTs

Signed-off-by: b4sjoo <[email protected]>

* Spotless

Signed-off-by: b4sjoo <[email protected]>

---------

Signed-off-by: b4sjoo <[email protected]>
  • Loading branch information
b4sjoo authored Aug 28, 2024
1 parent 05eb53f commit 0a89537
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ public class ModelInterfaceUtils {
+ " \"type\": \"object\",\n"
+ " \"properties\": {\n"
+ " \"texts\": {\n"
+ " \"type\": \"string\"\n"
+ " \"type\": \"array\",\n"
+ " \"items\": {\n"
+ " \"type\": \"string\"\n"
+ " }\n"
+ " }\n"
+ " },\n"
+ " \"required\": [\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,10 @@ public void validateInputSchema(String modelId, MLInput mlInput) {
if (modelCacheHelper.getModelInterface(modelId) != null && modelCacheHelper.getModelInterface(modelId).get("input") != null) {
String inputSchemaString = modelCacheHelper.getModelInterface(modelId).get("input");
try {
MLNodeUtils
.validateSchema(
inputSchemaString,
mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()
);
String InputString = mlInput.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString();
// Process the parameters field in the input dataset to convert it back to its original datatype, instead of a string
String processedInputString = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(InputString);
MLNodeUtils.validateSchema(inputSchemaString, processedInputString);
} catch (Exception e) {
throw new OpenSearchStatusException("Error validating input schema: " + e.getMessage(), RestStatus.BAD_REQUEST);
}
Expand Down
34 changes: 34 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.networknt.schema.JsonSchema;
import com.networknt.schema.JsonSchemaFactory;
import com.networknt.schema.SpecVersion.VersionFlag;
Expand Down Expand Up @@ -89,6 +90,39 @@ public static void validateSchema(String schemaString, String instanceString) th
}
}

/**
* This method processes the input JSON string and replaces the string values of the parameters with JSON objects if the string is a valid JSON.
* @param inputJson The input JSON string
* @return The processed JSON string
*/
public static String processRemoteInferenceInputDataSetParametersValue(String inputJson) throws IOException {
ObjectMapper mapper = new ObjectMapper();
JsonNode rootNode = mapper.readTree(inputJson);

if (rootNode.has("parameters") && rootNode.get("parameters").isObject()) {
ObjectNode parametersNode = (ObjectNode) rootNode.get("parameters");

parametersNode.fields().forEachRemaining(entry -> {
String key = entry.getKey();
JsonNode value = entry.getValue();

if (value.isTextual()) {
String textValue = value.asText();
try {
// Try to parse the string as JSON
JsonNode parsedValue = mapper.readTree(textValue);
// If successful, replace the string with the parsed JSON
parametersNode.set(key, parsedValue);
} catch (IOException e) {
// If parsing fails, it's not a valid JSON string, so keep it as is
parametersNode.set(key, value);
}
}
});
}
return mapper.writeValueAsString(rootNode);
}

public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBreakerService, MLStats mlStats) {
ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB();
if (openCircuitBreaker != null) {
Expand Down
62 changes: 62 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.opensearch.ml.common.MLTask;
import org.opensearch.test.OpenSearchTestCase;

import com.fasterxml.jackson.core.JsonParseException;

public class MLNodeUtilsTests extends OpenSearchTestCase {

public void testIsMLNode() {
Expand Down Expand Up @@ -63,4 +65,64 @@ public void testValidateSchema() throws IOException {
String json = "{\"key1\": \"foo\", \"key2\": 123}";
MLNodeUtils.validateSchema(schema, json);
}

@Test
public void testProcessRemoteInferenceInputDataSetParametersValueNoParameters() throws IOException {
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true}";
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
assertEquals(json, processedJson);
}

@Test
public void testProcessRemoteInferenceInputDataSetInvalidJson() {
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"a\"}}";
assertThrows(JsonParseException.class, () -> MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json));
}

@Test
public void testProcessRemoteInferenceInputDataSetEmptyParameters() throws IOException {
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{}}";
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
assertEquals(json, processedJson);
}

@Test
public void testProcessRemoteInferenceInputDataSetParametersValueParametersWrongType() throws IOException {
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":[\"Hello\",\"world\"]}";
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
assertEquals(json, processedJson);
}

@Test
public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersProcessArray() throws IOException {
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"texts\":\"[\\\"Hello\\\",\\\"world\\\"]\"}}";
String expectedJson = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"texts\":[\"Hello\",\"world\"]}}";
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
assertEquals(expectedJson, processedJson);
}

@Test
public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersProcessObject() throws IOException {
String json =
"{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":\"{\\\"role\\\":\\\"system\\\",\\\"foo\\\":\\\"{\\\\\\\"a\\\\\\\": \\\\\\\"b\\\\\\\"}\\\",\\\"content\\\":{\\\"a\\\":\\\"b\\\"}}\"}}}";
String expectedJson =
"{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"messages\":{\"role\":\"system\",\"foo\":\"{\\\"a\\\": \\\"b\\\"}\",\"content\":{\"a\":\"b\"}}}}";
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
assertEquals(expectedJson, processedJson);
}

@Test
public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersNoProcess() throws IOException {
String json = "{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"foo\",\"key2\":123,\"key3\":true}}";
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
assertEquals(json, processedJson);
}

@Test
public void testProcessRemoteInferenceInputDataSetParametersValueWithParametersInvalidJson() throws IOException {
String json =
"{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"parameters\":{\"key1\":\"foo\",\"key2\":123,\"key3\":true,\"texts\":\"[\\\"Hello\\\",\\\"world\\\"\"}}";
String processedJson = MLNodeUtils.processRemoteInferenceInputDataSetParametersValue(json);
assertEquals(json, processedJson);
}
}

0 comments on commit 0a89537

Please sign in to comment.