Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin7-1 committed Dec 6, 2024
1 parent 3a5db3e commit 14772ce
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 192 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
package dev.langchain4j.community.model.dashscope;

import static dev.langchain4j.data.message.ChatMessageType.AI;
import static dev.langchain4j.data.message.ChatMessageType.SYSTEM;
import static dev.langchain4j.data.message.ChatMessageType.TOOL_EXECUTION_RESULT;
import static dev.langchain4j.data.message.ChatMessageType.USER;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper.toMap;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.util.stream.Collectors.toList;

import com.alibaba.dashscope.aigc.generation.GenerationOutput;
import com.alibaba.dashscope.aigc.generation.GenerationOutput.Choice;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
Expand Down Expand Up @@ -41,9 +54,6 @@
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Files;
Expand All @@ -62,27 +72,15 @@
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;

import static dev.langchain4j.data.message.ChatMessageType.AI;
import static dev.langchain4j.data.message.ChatMessageType.SYSTEM;
import static dev.langchain4j.data.message.ChatMessageType.TOOL_EXECUTION_RESULT;
import static dev.langchain4j.data.message.ChatMessageType.USER;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.model.chat.request.json.JsonSchemaElementHelper.toMap;
import static dev.langchain4j.model.output.FinishReason.LENGTH;
import static dev.langchain4j.model.output.FinishReason.STOP;
import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION;
import static java.util.stream.Collectors.toList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class QwenHelper {

private static final Logger log = LoggerFactory.getLogger(QwenHelper.class);

static List<Message> toQwenMessages(List<ChatMessage> messages) {
return sanitizeMessages(messages)
.stream()
return sanitizeMessages(messages).stream()
.map(QwenHelper::toQwenMessage)
.collect(toList());
}
Expand All @@ -105,12 +103,12 @@ static Message toQwenMessage(ChatMessage message) {

static String toSingleText(ChatMessage message) {
return switch (message.type()) {
case USER -> ((UserMessage) message).contents()
.stream()
.filter(TextContent.class::isInstance)
.map(TextContent.class::cast)
.map(TextContent::text)
.collect(Collectors.joining("\n"));
case USER -> ((UserMessage) message)
.contents().stream()
.filter(TextContent.class::isInstance)
.map(TextContent.class::cast)
.map(TextContent::text)
.collect(Collectors.joining("\n"));
case AI -> ((AiMessage) message).text();
case SYSTEM -> ((SystemMessage) message).text();
case TOOL_EXECUTION_RESULT -> ((ToolExecutionResultMessage) message).text();
Expand Down Expand Up @@ -141,9 +139,7 @@ static List<ToolCallBase> toolCallsFrom(ChatMessage message) {
}

static List<MultiModalMessage> toQwenMultiModalMessages(List<ChatMessage> messages) {
return messages.stream()
.map(QwenHelper::toQwenMultiModalMessage)
.collect(toList());
return messages.stream().map(QwenHelper::toQwenMultiModalMessage).collect(toList());
}

static MultiModalMessage toQwenMultiModalMessage(ChatMessage message) {
Expand All @@ -155,12 +151,9 @@ static MultiModalMessage toQwenMultiModalMessage(ChatMessage message) {

static List<Map<String, Object>> toMultiModalContents(ChatMessage message) {
return switch (message.type()) {
case USER -> ((UserMessage) message).contents()
.stream()
.map(QwenHelper::toMultiModalContent)
.collect(Collectors.toList());
case AI -> Collections.singletonList(
Collections.singletonMap("text", ((AiMessage) message).text()));
case USER -> ((UserMessage) message)
.contents().stream().map(QwenHelper::toMultiModalContent).collect(Collectors.toList());
case AI -> Collections.singletonList(Collections.singletonMap("text", ((AiMessage) message).text()));
case SYSTEM -> Collections.singletonList(
Collections.singletonMap("text", ((SystemMessage) message).text()));
case TOOL_EXECUTION_RESULT -> Collections.singletonList(
Expand Down Expand Up @@ -325,9 +318,8 @@ static TokenUsage tokenUsageFrom(MultiModalConversationResult result) {
static FinishReason finishReasonFrom(GenerationResult result) {
Choice choice = result.getOutput().getChoices().get(0);
// Upon observation, when tool_calls occur, the returned finish_reason may be null or "stop", not "tool_calls".
String finishReason = isNullOrEmpty(choice.getMessage().getToolCalls()) ?
choice.getFinishReason() :
"tool_calls";
String finishReason =
isNullOrEmpty(choice.getMessage().getToolCalls()) ? choice.getFinishReason() : "tool_calls";

return switch (finishReason) {
case "stop" -> STOP;
Expand Down Expand Up @@ -362,9 +354,7 @@ static List<ToolBase> toToolFunctions(Collection<ToolSpecification> toolSpecific
return Collections.emptyList();
}

return toolSpecifications.stream()
.map(QwenHelper::toToolFunction)
.collect(Collectors.toList());
return toolSpecifications.stream().map(QwenHelper::toToolFunction).collect(Collectors.toList());
}

static ToolBase toToolFunction(ToolSpecification toolSpecification) {
Expand All @@ -390,9 +380,9 @@ private static JsonObject toParameters(ToolSpecification toolSpecification) {
static AiMessage aiMessageFrom(GenerationResult result) {
if (isFunctionToolCalls(result)) {
String text = answerFrom(result);
return isNullOrBlank(text) ?
new AiMessage(toolExecutionRequestsFrom(result)) :
new AiMessage(text, toolExecutionRequestsFrom(result));
return isNullOrBlank(text)
? new AiMessage(toolExecutionRequestsFrom(result))
: new AiMessage(text, toolExecutionRequestsFrom(result));
} else {
return new AiMessage(answerFrom(result));
}
Expand Down Expand Up @@ -446,9 +436,7 @@ static boolean isFunctionToolCalls(GenerationResult result) {
}

private static List<ToolCallBase> toToolCalls(Collection<ToolExecutionRequest> toolExecutionRequests) {
return toolExecutionRequests.stream()
.map(QwenHelper::toToolCall)
.collect(toList());
return toolExecutionRequests.stream().map(QwenHelper::toToolCall).collect(toList());
}

private static ToolCallBase toToolCall(ToolExecutionRequest toolExecutionRequest) {
Expand All @@ -462,8 +450,8 @@ private static ToolCallBase toToolCall(ToolExecutionRequest toolExecutionRequest
}

static List<ChatMessage> sanitizeMessages(List<ChatMessage> messages) {
LinkedList<ChatMessage> sanitizedMessages = messages.stream()
.reduce(new LinkedList<>(), messageAccumulator(), messageCombiner());
LinkedList<ChatMessage> sanitizedMessages =
messages.stream().reduce(new LinkedList<>(), messageAccumulator(), messageCombiner());

// Ensure the last message is a user/tool_execution_result message
while (!sanitizedMessages.isEmpty() && !isInputMessageType(sanitizedMessages.getLast())) {
Expand Down Expand Up @@ -503,17 +491,23 @@ private static BiFunction<LinkedList<ChatMessage>, ChatMessage, LinkedList<ChatM
if (type == USER) {
while (acc.getLast().type() != SYSTEM && !isNormalAiType(acc.getLast())) {
ChatMessage removedMessage = acc.removeLast();
log.warn("Tool execution result should follow a tool execution request message. Drop duplicated message: {}", removedMessage);
log.warn(
"Tool execution result should follow a tool execution request message. Drop duplicated message: {}",
removedMessage);
}
} else if (type == TOOL_EXECUTION_RESULT) {
while (!isToolExecutionRequestsAiType(acc.getLast())) {
ChatMessage removedMessage = acc.removeLast();
log.warn("Tool execution result should follow a tool execution request message. Drop duplicated message: {}", removedMessage);
log.warn(
"Tool execution result should follow a tool execution request message. Drop duplicated message: {}",
removedMessage);
}
} else if (type == AI) {
while (!isInputMessageType(acc.getLast())) {
ChatMessage removedMessage = acc.removeLast();
log.warn("AI message should follow a user/tool_execution_result message. Drop duplicated message: {}", removedMessage);
log.warn(
"AI message should follow a user/tool_execution_result message. Drop duplicated message: {}",
removedMessage);
}
}

Expand Down Expand Up @@ -541,10 +535,10 @@ private static boolean isToolExecutionRequestsAiType(ChatMessage message) {
return message.type() == AI && ((AiMessage) message).hasToolExecutionRequests();
}

static ChatModelRequest createModelListenerRequest(GenerationParam request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
Double temperature = request.getTemperature() != null ? request.getTemperature().doubleValue() : null;
static ChatModelRequest createModelListenerRequest(
GenerationParam request, List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
Double temperature =
request.getTemperature() != null ? request.getTemperature().doubleValue() : null;
return ChatModelRequest.builder()
.model(request.getModel())
.temperature(temperature)
Expand All @@ -555,10 +549,12 @@ static ChatModelRequest createModelListenerRequest(GenerationParam request,
.build();
}

static ChatModelRequest createModelListenerRequest(MultiModalConversationParam request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
Double temperature = request.getTemperature() != null ? request.getTemperature().doubleValue() : null;
static ChatModelRequest createModelListenerRequest(
MultiModalConversationParam request,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
Double temperature =
request.getTemperature() != null ? request.getTemperature().doubleValue() : null;
return ChatModelRequest.builder()
.model(request.getModel())
.temperature(temperature)
Expand All @@ -569,9 +565,8 @@ static ChatModelRequest createModelListenerRequest(MultiModalConversationParam r
.build();
}

static ChatModelResponse createModelListenerResponse(String responseId,
String responseModel,
Response<AiMessage> response) {
static ChatModelResponse createModelListenerResponse(
String responseId, String responseModel, Response<AiMessage> response) {
if (response == null) {
return null;
}
Expand All @@ -585,9 +580,8 @@ static ChatModelResponse createModelListenerResponse(String responseId,
.build();
}

static void onListenRequest(List<ChatModelListener> listeners,
ChatModelRequest modelListenerRequest,
Map<Object, Object> attributes) {
static void onListenRequest(
List<ChatModelListener> listeners, ChatModelRequest modelListenerRequest, Map<Object, Object> attributes) {
ChatModelRequestContext context = new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
Expand All @@ -598,19 +592,17 @@ static void onListenRequest(List<ChatModelListener> listeners,
});
}

static void onListenResponse(List<ChatModelListener> listeners,
String responseId,
Response<AiMessage> response,
ChatModelRequest modelListenerRequest,
Map<Object, Object> attributes) {
ChatModelResponse modelListenerResponse = createModelListenerResponse(
responseId, modelListenerRequest.model(), response);
static void onListenResponse(
List<ChatModelListener> listeners,
String responseId,
Response<AiMessage> response,
ChatModelRequest modelListenerRequest,
Map<Object, Object> attributes) {
ChatModelResponse modelListenerResponse =
createModelListenerResponse(responseId, modelListenerRequest.model(), response);

ChatModelResponseContext context = new ChatModelResponseContext(
modelListenerResponse,
modelListenerRequest,
attributes
);
ChatModelResponseContext context =
new ChatModelResponseContext(modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onResponse(context);
Expand All @@ -620,19 +612,17 @@ static void onListenResponse(List<ChatModelListener> listeners,
});
}

static void onListenError(List<ChatModelListener> listeners,
String responseId,
Throwable error,
ChatModelRequest modelListenerRequest,
Response<AiMessage> partialResponse, Map<Object, Object> attributes) {
ChatModelResponse partialModelListenerResponse = createModelListenerResponse(
responseId, modelListenerRequest.model(), partialResponse);
ChatModelErrorContext context = new ChatModelErrorContext(
error,
modelListenerRequest,
partialModelListenerResponse,
attributes
);
static void onListenError(
List<ChatModelListener> listeners,
String responseId,
Throwable error,
ChatModelRequest modelListenerRequest,
Response<AiMessage> partialResponse,
Map<Object, Object> attributes) {
ChatModelResponse partialModelListenerResponse =
createModelListenerResponse(responseId, modelListenerRequest.model(), partialResponse);
ChatModelErrorContext context =
new ChatModelErrorContext(error, modelListenerRequest, partialModelListenerResponse, attributes);
listeners.forEach(listener -> {
try {
listener.onError(context);
Expand Down
Loading

0 comments on commit 14772ce

Please sign in to comment.