From 1bdc0ac5fc2a8eeed2bb1acaaf9a38da1a50228b Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 23 Sep 2024 11:00:11 -0700 Subject: [PATCH 1/5] add config field in MLToolSpec for static parameters Signed-off-by: Jing Zhang --- .../ml/common/agent/MLToolSpec.java | 29 ++++- .../ml/common/agent/MLAgentTest.java | 27 ++-- .../ml/common/agent/MLToolSpecTest.java | 119 +++++++++++++++++- .../agent/MLAgentGetResponseTest.java | 2 +- .../algorithms/agent/MLChatAgentRunner.java | 6 + .../MLConversationalFlowAgentRunner.java | 6 + .../algorithms/agent/MLFlowAgentRunner.java | 5 + .../agent/MLChatAgentRunnerTest.java | 49 ++++++++ .../agent/MLFlowAgentRunnerTest.java | 20 +++ .../agents/GetAgentTransportActionTests.java | 2 +- 10 files changed, 249 insertions(+), 16 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index 98f7e1f33c..f39f799cae 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -29,15 +29,24 @@ public class MLToolSpec implements ToXContentObject { public static final String DESCRIPTION_FIELD = "description"; public static final String PARAMETERS_FIELD = "parameters"; public static final String INCLUDE_OUTPUT_IN_AGENT_RESPONSE = "include_output_in_agent_response"; + public static final String CONFIG_FIELD = "config"; private String type; private String name; private String description; private Map parameters; private boolean includeOutputInAgentResponse; + private Map configMap; @Builder(toBuilder = true) - public MLToolSpec(String type, String name, String description, Map parameters, boolean includeOutputInAgentResponse) { + public MLToolSpec( + String type, + String name, + String description, + Map parameters, + boolean includeOutputInAgentResponse, + Map configMap + ) { if (type == null) { throw new IllegalArgumentException("tool type is null"); } @@ -46,6 +55,7 @@ public MLToolSpec(String type, String name, String description, Map 0) { + builder.field(CONFIG_FIELD, configMap); + } builder.endObject(); return builder; } @@ -97,6 +119,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException { String description = null; Map parameters = null; boolean includeOutputInAgentResponse = false; + Map configMap = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -119,6 +142,9 @@ public static MLToolSpec parse(XContentParser parser) throws IOException { case INCLUDE_OUTPUT_IN_AGENT_RESPONSE: includeOutputInAgentResponse = parser.booleanValue(); break; + case CONFIG_FIELD: + configMap = getParameterMap(parser.map()); + break; default: parser.skipChildren(); break; @@ -131,6 +157,7 @@ public static MLToolSpec parse(XContentParser parser) throws IOException { .description(description) .parameters(parameters) .includeOutputInAgentResponse(includeOutputInAgentResponse) + .configMap(configMap) .build(); } diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index b83758fc23..e83df00a6a 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -46,7 +46,7 @@ public void constructor_NullName() { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), null, null, Instant.EPOCH, @@ -66,7 +66,7 @@ public void constructor_NullType() { null, "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), null, null, Instant.EPOCH, @@ -86,7 +86,7 @@ public void constructor_NullLLMSpec() { MLAgentType.CONVERSATIONAL.name(), "test", null, - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), null, null, Instant.EPOCH, @@ -100,7 +100,14 @@ public void constructor_NullLLMSpec() { public void constructor_DuplicateTool() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Duplicate tool defined: test_tool_name"); - MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false); + MLToolSpec mlToolSpec = new MLToolSpec( + "test_tool_type", + "test_tool_name", + "test", + Collections.EMPTY_MAP, + false, + Collections.EMPTY_MAP + ); MLAgent agent = new MLAgent( "test_name", MLAgentType.CONVERSATIONAL.name(), @@ -123,7 +130,7 @@ public void writeTo() throws IOException { "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -150,7 +157,7 @@ public void writeTo_NullLLM() throws IOException { "FLOW", "test", null, - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -194,7 +201,7 @@ public void writeTo_NullParameters() throws IOException { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), null, new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -216,7 +223,7 @@ public void writeTo_NullMemory() throws IOException { "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), Map.of("test", "test"), null, Instant.EPOCH, @@ -238,7 +245,7 @@ public void toXContent() throws IOException { "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false)), + List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.EMPTY_MAP)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -294,7 +301,7 @@ public void fromStream() throws IOException { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java index 3d4d9a2ce5..12e641038a 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java @@ -22,7 +22,7 @@ public class MLToolSpecTest { @Test public void writeTo() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Map.of("configKey", "configValue")); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); @@ -32,11 +32,70 @@ public void writeTo() throws IOException { Assert.assertEquals(spec.getParameters(), spec1.getParameters()); Assert.assertEquals(spec.getDescription(), spec1.getDescription()); Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); + } + + @Test + public void writeToEmptyConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.EMPTY_MAP); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); + } + + @Test + public void writeToNullConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, null); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertNull(spec1.getConfigMap()); } @Test public void toXContent() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Map.of("configKey", "configValue")); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert + .assertEquals( + "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false,\"config\":{\"configKey\":\"configValue\"}}", + content + ); + } + + @Test + public void toXContentEmptyConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.EMPTY_MAP); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + spec.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert + .assertEquals( + "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}", + content + ); + } + + @Test + public void toXContentNullConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); spec.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); @@ -50,6 +109,28 @@ public void toXContent() throws IOException { @Test public void parse() throws IOException { + String jsonStr = + "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false,\"config\":{\"configKey\":\"configValue\"}}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLToolSpec spec = MLToolSpec.parse(parser); + + Assert.assertEquals(spec.getType(), "test"); + Assert.assertEquals(spec.getName(), "test"); + Assert.assertEquals(spec.getDescription(), "test"); + Assert.assertEquals(spec.getParameters(), Map.of("test", "test")); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), false); + Assert.assertEquals(spec.getConfigMap(), Map.of("configKey", "configValue")); + } + + @Test + public void parseEmptyConfigMap() throws IOException { String jsonStr = "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}"; XContentParser parser = XContentType.JSON @@ -67,11 +148,42 @@ public void parse() throws IOException { Assert.assertEquals(spec.getDescription(), "test"); Assert.assertEquals(spec.getParameters(), Map.of("test", "test")); Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), false); + Assert.assertEquals(spec.getConfigMap(), null); } @Test public void fromStream() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false); + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Map.of("configKey", "configValue")); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); + } + + @Test + public void fromStreamEmptyConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.EMPTY_MAP); + BytesStreamOutput output = new BytesStreamOutput(); + spec.writeTo(output); + MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); + + Assert.assertEquals(spec.getType(), spec1.getType()); + Assert.assertEquals(spec.getName(), spec1.getName()); + Assert.assertEquals(spec.getParameters(), spec1.getParameters()); + Assert.assertEquals(spec.getDescription(), spec1.getDescription()); + Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); + } + + @Test + public void fromStreamNullConfigMap() throws IOException { + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, null); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); @@ -81,5 +193,6 @@ public void fromStream() throws IOException { Assert.assertEquals(spec.getParameters(), spec1.getParameters()); Assert.assertEquals(spec.getDescription(), spec1.getDescription()); Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), spec1.isIncludeOutputInAgentResponse()); + Assert.assertEquals(spec.getConfigMap(), spec1.getConfigMap()); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index cad3794134..9065a7c58b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -76,7 +76,7 @@ public void writeTo() throws IOException { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 4b14f1af17..a053a96c8f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -470,12 +470,18 @@ private static void runTool( Map llmToolTmpParameters = new HashMap<>(); llmToolTmpParameters.putAll(tmpParameters); llmToolTmpParameters.putAll(toolSpecMap.get(action).getParameters()); + if (toolSpecMap.get(action).getConfigMap() != null) { + llmToolTmpParameters.putAll(toolSpecMap.get(action).getConfigMap()); + } llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput); tools.get(action).run(llmToolTmpParameters, toolListener); // run tool } else { Map parameters = new HashMap<>(); parameters.putAll(tmpParameters); parameters.putAll(toolParams); + if (toolSpecMap.get(action).getConfigMap() != null) { + parameters.putAll(toolSpecMap.get(action).getConfigMap()); + } tools.get(action).run(parameters, toolListener); // run tool } } catch (Exception e) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index 672890c030..fdc4fb86d5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -434,6 +434,12 @@ Map getToolExecuteParams(MLToolSpec toolSpec, Map getToolExecuteParams(MLToolSpec toolSpec, Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); + params.put("question", "raw input"); + doReturn(true).when(firstTool).useOriginalInput(); + + // Run the MLChatAgentRunner. + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify that the tool's run method was called. + verify(firstTool).run(any(), any()); + // Verify the size of parameters passed in the tool run method. + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); + verify(firstTool).run((Map) argumentCaptor.capture(), any()); + assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + // The value of input should be "config_value", and not be "raw input". + assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); + + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + assertNotNull(modelTensorOutput); + } + @Test public void testSaveLastTraceFailure() { // Mock tool validation to return true. @@ -838,6 +868,25 @@ private MLAgent createMLAgentWithTools() { .build(); } + private MLAgent createMLAgentWithToolsConfig() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLToolSpec firstToolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .type(FIRST_TOOL) + .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) + .configMap(ImmutableMap.of("input", "config_value")) + .build(); + return MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .tools(Arrays.asList(firstToolSpec)) + .memory(mlMemorySpec) + .llm(llmSpec) + .build(); + } + private Map createAgentParamsWithAction(String action, String actionInput) { Map params = new HashMap<>(); params.put("action", action); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java index 609609438a..b0225abc49 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java @@ -300,6 +300,26 @@ public void testGetToolExecuteParams() { assertFalse(result.containsKey("toolType.param2")); } + @Test + public void testGetToolExecuteParamsWithConfig() { + MLToolSpec toolSpec = mock(MLToolSpec.class); + when(toolSpec.getParameters()).thenReturn(Map.of("param1", "value1", "tool_key", "value_from_parameters")); + when(toolSpec.getConfigMap()).thenReturn(Map.of("tool_key", "tool_config_value")); + when(toolSpec.getType()).thenReturn("toolType"); + when(toolSpec.getName()).thenReturn("toolName"); + + Map params = Map + .of("toolType.param2", "value2", "toolName.param3", "value3", "param4", "value4", "toolName.tool_key", "dynamic value"); + + Map result = mlFlowAgentRunner.getToolExecuteParams(toolSpec, params); + + assertEquals("value1", result.get("param1")); + assertEquals("value3", result.get("param3")); + assertEquals("value4", result.get("param4")); + assertFalse(result.containsKey("toolType.param2")); + assertEquals("tool_config_value", result.get("tool_key")); + } + @Test public void testGetToolExecuteParamsWithInputSubstitution() { // Setup ToolSpec with parameters diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java index 8a0ab62168..d197afe387 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java @@ -282,7 +282,7 @@ public GetResponse prepareMLAgent(String agentId, boolean isHidden) throws IOExc MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), + List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, From 1eb2fd213ba7143c9e1a2975ed8add95bd83ae58 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 23 Sep 2024 13:38:23 -0700 Subject: [PATCH 2/5] add version control Signed-off-by: Jing Zhang --- .../opensearch/ml/common/agent/MLToolSpec.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index f39f799cae..2286f3545a 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -11,11 +11,13 @@ import java.io.IOException; import java.util.Map; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.CommonValue; import lombok.Builder; import lombok.EqualsAndHashCode; @@ -24,6 +26,8 @@ @EqualsAndHashCode @Getter public class MLToolSpec implements ToXContentObject { + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG = CommonValue.VERSION_2_17_0; + public static final String TOOL_TYPE_FIELD = "type"; public static final String TOOL_NAME_FIELD = "name"; public static final String DESCRIPTION_FIELD = "description"; @@ -66,7 +70,7 @@ public MLToolSpec(StreamInput input) throws IOException { parameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString); } includeOutputInAgentResponse = input.readBoolean(); - if (input.readBoolean()) { + if (input.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG) && input.readBoolean()) { configMap = input.readMap(StreamInput::readOptionalString, StreamInput::readOptionalString); } } @@ -82,11 +86,13 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeBoolean(includeOutputInAgentResponse); - if (configMap != null) { - out.writeBoolean(true); - out.writeMap(configMap, StreamOutput::writeOptionalString, StreamOutput::writeOptionalString); - } else { - out.writeBoolean(false); + if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG)) { + if (configMap != null) { + out.writeBoolean(true); + out.writeMap(configMap, StreamOutput::writeOptionalString, StreamOutput::writeOptionalString); + } else { + out.writeBoolean(false); + } } } From 9176b93049bca3172c27d0e413e47ddb0803bccb Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 23 Sep 2024 16:31:49 -0700 Subject: [PATCH 3/5] address comments I Signed-off-by: Jing Zhang --- .../ml/common/agent/MLToolSpec.java | 6 ++--- .../ml/common/agent/MLAgentTest.java | 22 +++++++++---------- .../ml/common/agent/MLToolSpecTest.java | 6 ++--- .../agent/MLAgentGetResponseTest.java | 2 +- .../agents/GetAgentTransportActionTests.java | 2 +- 5 files changed, 19 insertions(+), 19 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index 2286f3545a..6da7602df3 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -71,7 +71,7 @@ public MLToolSpec(StreamInput input) throws IOException { } includeOutputInAgentResponse = input.readBoolean(); if (input.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG) && input.readBoolean()) { - configMap = input.readMap(StreamInput::readOptionalString, StreamInput::readOptionalString); + configMap = input.readMap(StreamInput::readString, StreamInput::readOptionalString); } } @@ -89,7 +89,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG)) { if (configMap != null) { out.writeBoolean(true); - out.writeMap(configMap, StreamOutput::writeOptionalString, StreamOutput::writeOptionalString); + out.writeMap(configMap, StreamOutput::writeString, StreamOutput::writeOptionalString); } else { out.writeBoolean(false); } @@ -112,7 +112,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(PARAMETERS_FIELD, parameters); } builder.field(INCLUDE_OUTPUT_IN_AGENT_RESPONSE, includeOutputInAgentResponse); - if (configMap != null && configMap.size() > 0) { + if (configMap != null && !configMap.isEmpty()) { builder.field(CONFIG_FIELD, configMap); } builder.endObject(); diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index e83df00a6a..c72da18a30 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -46,7 +46,7 @@ public void constructor_NullName() { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), null, null, Instant.EPOCH, @@ -66,7 +66,7 @@ public void constructor_NullType() { null, "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), null, null, Instant.EPOCH, @@ -86,7 +86,7 @@ public void constructor_NullLLMSpec() { MLAgentType.CONVERSATIONAL.name(), "test", null, - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), null, null, Instant.EPOCH, @@ -104,9 +104,9 @@ public void constructor_DuplicateTool() { "test_tool_type", "test_tool_name", "test", - Collections.EMPTY_MAP, + Collections.emptyMap(), false, - Collections.EMPTY_MAP + Collections.emptyMap() ); MLAgent agent = new MLAgent( "test_name", @@ -130,7 +130,7 @@ public void writeTo() throws IOException { "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -157,7 +157,7 @@ public void writeTo_NullLLM() throws IOException { "FLOW", "test", null, - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -201,7 +201,7 @@ public void writeTo_NullParameters() throws IOException { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), null, new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -223,7 +223,7 @@ public void writeTo_NullMemory() throws IOException { "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), null, Instant.EPOCH, @@ -245,7 +245,7 @@ public void toXContent() throws IOException { "CONVERSATIONAL", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, @@ -301,7 +301,7 @@ public void fromStream() throws IOException { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java index 12e641038a..88218f80a9 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java @@ -37,7 +37,7 @@ public void writeTo() throws IOException { @Test public void writeToEmptyConfigMap() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.EMPTY_MAP); + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap()); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); @@ -81,7 +81,7 @@ public void toXContent() throws IOException { @Test public void toXContentEmptyConfigMap() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.EMPTY_MAP); + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap()); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); spec.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); @@ -168,7 +168,7 @@ public void fromStream() throws IOException { @Test public void fromStreamEmptyConfigMap() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.EMPTY_MAP); + MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap()); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index 9065a7c58b..50acb7f927 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -76,7 +76,7 @@ public void writeTo() throws IOException { MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java index d197afe387..d5e2d40e50 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java @@ -282,7 +282,7 @@ public GetResponse prepareMLAgent(String agentId, boolean isHidden) throws IOExc MLAgentType.CONVERSATIONAL.name(), "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), - List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false, Collections.EMPTY_MAP)), + List.of(new MLToolSpec("test", "test", "test", Collections.emptyMap(), false, Collections.emptyMap())), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, From cf680596d65341daf65c767cfc05195bc7d234f4 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 14 Oct 2024 09:42:49 -0700 Subject: [PATCH 4/5] address commits II Signed-off-by: Jing Zhang --- .../org/opensearch/ml/common/CommonValue.java | 1 + .../ml/common/agent/MLToolSpec.java | 2 +- .../ml/common/agent/MLToolSpecTest.java | 86 ++++++++++++++----- .../engine/algorithms/agent/AgentUtils.java | 9 ++ .../algorithms/agent/MLChatAgentRunner.java | 6 -- .../MLConversationalFlowAgentRunner.java | 10 +-- .../algorithms/agent/MLFlowAgentRunner.java | 8 +- .../agent/MLChatAgentRunnerTest.java | 40 +++++++-- 8 files changed, 119 insertions(+), 43 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 6a48cd5081..99c0e86390 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -581,4 +581,5 @@ public class CommonValue { public static final Version VERSION_2_15_0 = Version.fromString("2.15.0"); public static final Version VERSION_2_16_0 = Version.fromString("2.16.0"); public static final Version VERSION_2_17_0 = Version.fromString("2.17.0"); + public static final Version VERSION_2_18_0 = Version.fromString("2.18.0"); } diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java index 6da7602df3..c144d5cda9 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLToolSpec.java @@ -26,7 +26,7 @@ @EqualsAndHashCode @Getter public class MLToolSpec implements ToXContentObject { - public static final Version MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG = CommonValue.VERSION_2_17_0; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_TOOL_CONFIG = CommonValue.VERSION_2_18_0; public static final String TOOL_TYPE_FIELD = "type"; public static final String TOOL_NAME_FIELD = "name"; diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java index 88218f80a9..ecbf4d0ba1 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLToolSpecTest.java @@ -22,7 +22,14 @@ public class MLToolSpecTest { @Test public void writeTo() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Map.of("configKey", "configValue")); + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Map.of("configKey", "configValue") + ); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); @@ -37,7 +44,14 @@ public void writeTo() throws IOException { @Test public void writeToEmptyConfigMap() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap()); + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Collections.emptyMap() + ); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); @@ -52,7 +66,7 @@ public void writeToEmptyConfigMap() throws IOException { @Test public void writeToNullConfigMap() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, null); + MLToolSpec spec = new MLToolSpec("test_type", "test_name", "test_desc", Map.of("test_key", "test_value"), false, null); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = new MLToolSpec(output.bytes().streamInput()); @@ -67,42 +81,56 @@ public void writeToNullConfigMap() throws IOException { @Test public void toXContent() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Map.of("configKey", "configValue")); + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Map.of("configKey", "configValue") + ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); spec.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); Assert .assertEquals( - "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false,\"config\":{\"configKey\":\"configValue\"}}", + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false,\"config\":{\"configKey\":\"configValue\"}}", content ); } @Test public void toXContentEmptyConfigMap() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap()); + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Collections.emptyMap() + ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); spec.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); Assert .assertEquals( - "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}", + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false}", content ); } @Test public void toXContentNullConfigMap() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, null); + MLToolSpec spec = new MLToolSpec("test_type", "test_name", "test_desc", Map.of("test_key", "test_value"), false, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); spec.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); Assert .assertEquals( - "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}", + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false}", content ); } @@ -110,7 +138,7 @@ public void toXContentNullConfigMap() throws IOException { @Test public void parse() throws IOException { String jsonStr = - "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false,\"config\":{\"configKey\":\"configValue\"}}"; + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false,\"config\":{\"configKey\":\"configValue\"}}"; XContentParser parser = XContentType.JSON .xContent() .createParser( @@ -121,10 +149,10 @@ public void parse() throws IOException { parser.nextToken(); MLToolSpec spec = MLToolSpec.parse(parser); - Assert.assertEquals(spec.getType(), "test"); - Assert.assertEquals(spec.getName(), "test"); - Assert.assertEquals(spec.getDescription(), "test"); - Assert.assertEquals(spec.getParameters(), Map.of("test", "test")); + Assert.assertEquals(spec.getType(), "test_type"); + Assert.assertEquals(spec.getName(), "test_name"); + Assert.assertEquals(spec.getDescription(), "test_desc"); + Assert.assertEquals(spec.getParameters(), Map.of("test_key", "test_value")); Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), false); Assert.assertEquals(spec.getConfigMap(), Map.of("configKey", "configValue")); } @@ -132,7 +160,7 @@ public void parse() throws IOException { @Test public void parseEmptyConfigMap() throws IOException { String jsonStr = - "{\"type\":\"test\",\"name\":\"test\",\"description\":\"test\",\"parameters\":{\"test\":\"test\"},\"include_output_in_agent_response\":false}"; + "{\"type\":\"test_type\",\"name\":\"test_name\",\"description\":\"test_desc\",\"parameters\":{\"test_key\":\"test_value\"},\"include_output_in_agent_response\":false}"; XContentParser parser = XContentType.JSON .xContent() .createParser( @@ -143,17 +171,24 @@ public void parseEmptyConfigMap() throws IOException { parser.nextToken(); MLToolSpec spec = MLToolSpec.parse(parser); - Assert.assertEquals(spec.getType(), "test"); - Assert.assertEquals(spec.getName(), "test"); - Assert.assertEquals(spec.getDescription(), "test"); - Assert.assertEquals(spec.getParameters(), Map.of("test", "test")); + Assert.assertEquals(spec.getType(), "test_type"); + Assert.assertEquals(spec.getName(), "test_name"); + Assert.assertEquals(spec.getDescription(), "test_desc"); + Assert.assertEquals(spec.getParameters(), Map.of("test_key", "test_value")); Assert.assertEquals(spec.isIncludeOutputInAgentResponse(), false); Assert.assertEquals(spec.getConfigMap(), null); } @Test public void fromStream() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Map.of("configKey", "configValue")); + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Map.of("configKey", "configValue") + ); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); @@ -168,7 +203,14 @@ public void fromStream() throws IOException { @Test public void fromStreamEmptyConfigMap() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, Collections.emptyMap()); + MLToolSpec spec = new MLToolSpec( + "test_type", + "test_name", + "test_desc", + Map.of("test_key", "test_value"), + false, + Collections.emptyMap() + ); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); @@ -183,7 +225,7 @@ public void fromStreamEmptyConfigMap() throws IOException { @Test public void fromStreamNullConfigMap() throws IOException { - MLToolSpec spec = new MLToolSpec("test", "test", "test", Map.of("test", "test"), false, null); + MLToolSpec spec = new MLToolSpec("test_type", "test_name", "test_desc", Map.of("test_key", "test_value"), false, null); BytesStreamOutput output = new BytesStreamOutput(); spec.writeTo(output); MLToolSpec spec1 = MLToolSpec.fromStream(output.bytes().streamInput()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index f424b3f624..bbe199ca69 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -465,12 +465,21 @@ public static Map constructToolParams( ) { Map toolParams = new HashMap<>(); Map toolSpecParams = toolSpecMap.get(action).getParameters(); + Map toolSpecConfigMap = toolSpecMap.get(action).getConfigMap(); if (toolSpecParams != null) { toolParams.putAll(toolSpecParams); } + if (toolSpecConfigMap != null) { + toolParams.putAll(toolSpecConfigMap); + } if (tools.get(action).useOriginalInput()) { toolParams.put("input", question); lastActionInput.set(question); + } else if (toolSpecConfigMap != null && toolSpecConfigMap.containsKey("input")) { + String input = toolSpecConfigMap.get("input"); + StringSubstitutor substitutor = new StringSubstitutor(toolParams, "${parameters.", "}"); + input = substitutor.replace(input); + toolParams.put("input", input); } else { toolParams.put("input", actionInput); if (isJson(actionInput)) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index a053a96c8f..4b14f1af17 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -470,18 +470,12 @@ private static void runTool( Map llmToolTmpParameters = new HashMap<>(); llmToolTmpParameters.putAll(tmpParameters); llmToolTmpParameters.putAll(toolSpecMap.get(action).getParameters()); - if (toolSpecMap.get(action).getConfigMap() != null) { - llmToolTmpParameters.putAll(toolSpecMap.get(action).getConfigMap()); - } llmToolTmpParameters.put(MLAgentExecutor.QUESTION, actionInput); tools.get(action).run(llmToolTmpParameters, toolListener); // run tool } else { Map parameters = new HashMap<>(); parameters.putAll(tmpParameters); parameters.putAll(toolParams); - if (toolSpecMap.get(action).getConfigMap() != null) { - parameters.putAll(toolSpecMap.get(action).getConfigMap()); - } tools.get(action).run(parameters, toolListener); // run tool } } catch (Exception e) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index fdc4fb86d5..3891caf8e7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -428,6 +428,11 @@ Map getToolExecuteParams(MLToolSpec toolSpec, Map getToolExecuteParams(MLToolSpec toolSpec, Map getToolExecuteParams(MLToolSpec toolSpec, Map getToolExecuteParams(MLToolSpec toolSpec, Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); params.put("question", "raw input"); - doReturn(true).when(firstTool).useOriginalInput(); + doReturn(false).when(firstTool).useOriginalInput(); // Run the MLChatAgentRunner. mlChatAgentRunner.run(mlAgent, params, agentActionListener); @@ -764,7 +764,7 @@ public void testToolConfig() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); assertEquals(15, ((Map) argumentCaptor.getValue()).size()); - // The value of input should be "config_value", and not be "raw input". + // The value of input should be "config_value". assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -772,6 +772,36 @@ public void testToolConfig() { assertNotNull(modelTensorOutput); } + @Test + public void testToolConfigWithInputPlaceholder() { + // Mock tool validation to return false. + when(firstTool.validate(any())).thenReturn(true); + + // Create an MLAgent with a tool including two parameters. + MLAgent mlAgent = createMLAgentWithToolsConfig(ImmutableMap.of("input", "${parameters.key2}")); + + // Create parameters for the agent. + Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); + params.put("question", "raw input"); + doReturn(false).when(firstTool).useOriginalInput(); + + // Run the MLChatAgentRunner. + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify that the tool's run method was called. + verify(firstTool).run(any(), any()); + // Verify the size of parameters passed in the tool run method. + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); + verify(firstTool).run((Map) argumentCaptor.capture(), any()); + assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + // The value of input should be replaced with the value associated with the key "key2" of the first tool. + assertEquals("value2", ((Map) argumentCaptor.getValue()).get("input")); + + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); + assertNotNull(modelTensorOutput); + } + @Test public void testSaveLastTraceFailure() { // Mock tool validation to return true. @@ -868,14 +898,14 @@ private MLAgent createMLAgentWithTools() { .build(); } - private MLAgent createMLAgentWithToolsConfig() { + private MLAgent createMLAgentWithToolsConfig(Map configMap) { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); MLToolSpec firstToolSpec = MLToolSpec .builder() .name(FIRST_TOOL) .type(FIRST_TOOL) .parameters(ImmutableMap.of("key1", "value1", "key2", "value2")) - .configMap(ImmutableMap.of("input", "config_value")) + .configMap(configMap) .build(); return MLAgent .builder() From 8bca2181a2cea9391314d4ef2eea88a17b6d2356 Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Mon, 21 Oct 2024 11:54:07 -0700 Subject: [PATCH 5/5] address comments III Signed-off-by: Jing Zhang --- .../org/opensearch/ml/engine/algorithms/agent/AgentUtils.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index bbe199ca69..d8f8d6da94 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -480,6 +480,10 @@ public static Map constructToolParams( StringSubstitutor substitutor = new StringSubstitutor(toolParams, "${parameters.", "}"); input = substitutor.replace(input); toolParams.put("input", input); + if (isJson(input)) { + Map params = getParameterMap(gson.fromJson(input, Map.class)); + toolParams.putAll(params); + } } else { toolParams.put("input", actionInput); if (isJson(actionInput)) {