Skip to content

Commit

Permalink
Bug Fix: Fix for rag processor throwing NPE when optional parameters …
Browse files Browse the repository at this point in the history
…are not provided (opensearch-project#3057)

* fix (rag npe): optional and empty fields are handled appropriately

Signed-off-by: Pavan Yekbote <[email protected]>

* fix: test cases

Signed-off-by: Pavan Yekbote <[email protected]>

* fix: format violations

Signed-off-by: Pavan Yekbote <[email protected]>

* tests: adding empty params test case

Signed-off-by: Pavan Yekbote <[email protected]>

* fix: remove wildcard import

Signed-off-by: Pavan Yekbote <[email protected]>

---------

Signed-off-by: Pavan Yekbote <[email protected]>
  • Loading branch information
pyek-bot authored Oct 4, 2024
1 parent 4850254 commit b7a0d78
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative.ext;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.opensearch.core.ParseField;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ObjectParser;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;

import com.google.common.base.Preconditions;

import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
Expand All @@ -48,60 +48,44 @@
@NoArgsConstructor
public class GenerativeQAParameters implements Writeable, ToXContentObject {

private static final ObjectParser<GenerativeQAParameters, Void> PARSER;

// Optional parameter; if provided, conversational memory will be used for RAG
// and the current interaction will be saved in the conversation referenced by this id.
private static final ParseField CONVERSATION_ID = new ParseField("memory_id");
private static final String CONVERSATION_ID = "memory_id";

// Optional parameter; if an LLM model is not set at the search pipeline level, one must be
// provided at the search request level.
private static final ParseField LLM_MODEL = new ParseField("llm_model");
private static final String LLM_MODEL = "llm_model";

// Required parameter; this is sent to LLMs as part of the user prompt.
// TODO support question rewriting when chat history is not used (conversation_id is not provided).
private static final ParseField LLM_QUESTION = new ParseField("llm_question");
private static final String LLM_QUESTION = "llm_question";

// Optional parameter; this parameter controls the number of search results ("contexts") to
// include in the user prompt.
private static final ParseField CONTEXT_SIZE = new ParseField("context_size");
private static final String CONTEXT_SIZE = "context_size";

// Optional parameter; this parameter controls the number of the interactions to include
// in the user prompt.
private static final ParseField INTERACTION_SIZE = new ParseField("message_size");
private static final String INTERACTION_SIZE = "message_size";

// Optional parameter; this parameter controls how long the search pipeline waits for a response
// from a remote inference endpoint before timing out the request.
private static final ParseField TIMEOUT = new ParseField("timeout");
private static final String TIMEOUT = "timeout";

// Optional parameter: this parameter allows request-level customization of the "system" (role) prompt.
private static final ParseField SYSTEM_PROMPT = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_SYSTEM_PROMPT);
private static final String SYSTEM_PROMPT = "system_prompt";

// Optional parameter: this parameter allows request-level customization of the "user" (role) prompt.
private static final ParseField USER_INSTRUCTIONS = new ParseField(GenerativeQAProcessorConstants.CONFIG_NAME_USER_INSTRUCTIONS);
private static final String USER_INSTRUCTIONS = "user_instructions";

// Optional parameter; this parameter indicates the name of the field in the LLM response
// that contains the chat completion text, i.e. "answer".
private static final ParseField LLM_RESPONSE_FIELD = new ParseField("llm_response_field");
private static final String LLM_RESPONSE_FIELD = "llm_response_field";

private static final ParseField LLM_MESSAGES_FIELD = new ParseField("llm_messages");
private static final String LLM_MESSAGES_FIELD = "llm_messages";

public static final int SIZE_NULL_VALUE = -1;

static {
PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new);
PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID);
PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL);
PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION);
PARSER.declareStringOrNull(GenerativeQAParameters::setSystemPrompt, SYSTEM_PROMPT);
PARSER.declareStringOrNull(GenerativeQAParameters::setUserInstructions, USER_INSTRUCTIONS);
PARSER.declareIntOrNull(GenerativeQAParameters::setContextSize, SIZE_NULL_VALUE, CONTEXT_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setInteractionSize, SIZE_NULL_VALUE, INTERACTION_SIZE);
PARSER.declareIntOrNull(GenerativeQAParameters::setTimeout, SIZE_NULL_VALUE, TIMEOUT);
PARSER.declareStringOrNull(GenerativeQAParameters::setLlmResponseField, LLM_RESPONSE_FIELD);
PARSER.declareObjectArray(GenerativeQAParameters::setMessageBlock, (p, c) -> MessageBlock.fromXContent(p), LLM_MESSAGES_FIELD);
}

@Setter
@Getter
private String conversationId;
Expand Down Expand Up @@ -167,6 +151,7 @@ public GenerativeQAParameters(
);
}

@Builder(toBuilder = true)
public GenerativeQAParameters(
String conversationId,
String llmModel,
Expand All @@ -184,7 +169,7 @@ public GenerativeQAParameters(

// TODO: keep this requirement until we can extract the question from the query or from the request processor parameters
// for question rewriting.
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION.getPreferredName() + " must be provided.");
Preconditions.checkArgument(!Strings.isNullOrEmpty(llmQuestion), LLM_QUESTION + " must be provided.");
this.llmQuestion = llmQuestion;
this.systemPrompt = systemPrompt;
this.userInstructions = userInstructions;
Expand Down Expand Up @@ -212,17 +197,49 @@ public GenerativeQAParameters(StreamInput input) throws IOException {

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return xContentBuilder
.field(CONVERSATION_ID.getPreferredName(), this.conversationId)
.field(LLM_MODEL.getPreferredName(), this.llmModel)
.field(LLM_QUESTION.getPreferredName(), this.llmQuestion)
.field(SYSTEM_PROMPT.getPreferredName(), this.systemPrompt)
.field(USER_INSTRUCTIONS.getPreferredName(), this.userInstructions)
.field(CONTEXT_SIZE.getPreferredName(), this.contextSize)
.field(INTERACTION_SIZE.getPreferredName(), this.interactionSize)
.field(TIMEOUT.getPreferredName(), this.timeout)
.field(LLM_RESPONSE_FIELD.getPreferredName(), this.llmResponseField)
.field(LLM_MESSAGES_FIELD.getPreferredName(), this.llmMessages);
xContentBuilder.startObject();
if (this.conversationId != null) {
xContentBuilder.field(CONVERSATION_ID, this.conversationId);
}

if (this.llmModel != null) {
xContentBuilder.field(LLM_MODEL, this.llmModel);
}

if (this.llmQuestion != null) {
xContentBuilder.field(LLM_QUESTION, this.llmQuestion);
}

if (this.systemPrompt != null) {
xContentBuilder.field(SYSTEM_PROMPT, this.systemPrompt);
}

if (this.userInstructions != null) {
xContentBuilder.field(USER_INSTRUCTIONS, this.userInstructions);
}

if (this.contextSize != null) {
xContentBuilder.field(CONTEXT_SIZE, this.contextSize);
}

if (this.interactionSize != null) {
xContentBuilder.field(INTERACTION_SIZE, this.interactionSize);
}

if (this.timeout != null) {
xContentBuilder.field(TIMEOUT, this.timeout);
}

if (this.llmResponseField != null) {
xContentBuilder.field(LLM_RESPONSE_FIELD, this.llmResponseField);
}

if (this.llmMessages != null && !this.llmMessages.isEmpty()) {
xContentBuilder.field(LLM_MESSAGES_FIELD, this.llmMessages);
}

xContentBuilder.endObject();
return xContentBuilder;
}

@Override
Expand All @@ -242,7 +259,76 @@ public void writeTo(StreamOutput out) throws IOException {
}

public static GenerativeQAParameters parse(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
String conversationId = null;
String llmModel = null;
String llmQuestion = null;
String systemPrompt = null;
String userInstructions = null;
Integer contextSize = null;
Integer interactionSize = null;
Integer timeout = null;
String llmResponseField = null;
List<MessageBlock> llmMessages = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String field = parser.currentName();
parser.nextToken();

switch (field) {
case CONVERSATION_ID:
conversationId = parser.text();
break;
case LLM_MODEL:
llmModel = parser.text();
break;
case LLM_QUESTION:
llmQuestion = parser.text();
break;
case SYSTEM_PROMPT:
systemPrompt = parser.text();
break;
case USER_INSTRUCTIONS:
userInstructions = parser.text();
break;
case CONTEXT_SIZE:
contextSize = parser.intValue();
break;
case INTERACTION_SIZE:
interactionSize = parser.intValue();
break;
case TIMEOUT:
timeout = parser.intValue();
break;
case LLM_RESPONSE_FIELD:
llmResponseField = parser.text();
break;
case LLM_MESSAGES_FIELD:
llmMessages = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
llmMessages.add(MessageBlock.fromXContent(parser));
}
break;
default:
parser.skipChildren();
break;
}
}

return GenerativeQAParameters
.builder()
.conversationId(conversationId)
.llmModel(llmModel)
.llmQuestion(llmQuestion)
.systemPrompt(systemPrompt)
.userInstructions(userInstructions)
.contextSize(contextSize)
.interactionSize(interactionSize)
.timeout(timeout)
.llmResponseField(llmResponseField)
.llmMessages(llmMessages)
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,25 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

import java.io.EOFException;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.junit.Assert;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentHelper;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.search.SearchModule;
import org.opensearch.searchpipelines.questionanswering.generative.llm.MessageBlock;
import org.opensearch.test.OpenSearchTestCase;

Expand Down Expand Up @@ -121,21 +126,38 @@ public void testMiscMethods() throws IOException {
}

public void testParse() throws IOException {
XContentParser xcParser = mock(XContentParser.class);
when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT);
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser);
String requiredJsonStr = "{\"llm_question\":\"this is test llm question\"}";

XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
requiredJsonStr
);

parser.nextToken();
GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(parser);
assertNotNull(builder);
assertNotNull(builder.getParams());
GenerativeQAParameters params = builder.getParams();
Assert.assertEquals("this is test llm question", params.getLlmQuestion());
}

public void testXContentRoundTrip() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", null, null, null, null, messageList);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);

XContentType xContentType = randomFrom(XContentType.values());
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
BytesReference serialized = BytesReference.bytes(builder);

XContentParser parser = createParser(xContentType.xContent(), serialized);
parser.nextToken();
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);

assertEquals(extBuilder, deserialized);
GenerativeQAParameters parameters = deserialized.getParams();
assertTrue(GenerativeQAParameters.SIZE_NULL_VALUE == parameters.getContextSize());
Expand All @@ -147,10 +169,16 @@ public void testXContentRoundTripAllValues() throws IOException {
GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c", "s", "u", 1, 2, 3, null);
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(param1);

XContentType xContentType = randomFrom(XContentType.values());
BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true);
XContentBuilder builder = XContentBuilder.builder(xContentType.xContent());
builder = extBuilder.toXContent(builder, EMPTY_PARAMS);
BytesReference serialized = BytesReference.bytes(builder);

XContentParser parser = createParser(xContentType.xContent(), serialized);
parser.nextToken();
GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser);

assertEquals(extBuilder, deserialized);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,18 @@ public void testToXConent() throws IOException {
assertNotNull(parameters.toXContent(builder, null));
}

public void testToXConentAllOptionalParameters() throws IOException {
public void testToXContentEmptyParams() throws IOException {
GenerativeQAParameters parameters = new GenerativeQAParameters();
XContent xc = mock(XContent.class);
OutputStream os = mock(OutputStream.class);
XContentGenerator generator = mock(XContentGenerator.class);
when(xc.createGenerator(any(), any(), any())).thenReturn(generator);
XContentBuilder builder = new XContentBuilder(xc, os);
parameters.toXContent(builder, null);
assertNotNull(parameters.toXContent(builder, null));
}

public void testToXContentAllOptionalParameters() throws IOException {
String conversationId = "a";
String llmModel = "b";
String llmQuestion = "c";
Expand Down

0 comments on commit b7a0d78

Please sign in to comment.