Skip to content

Commit

Permalink
Update: Open API method.
Browse files Browse the repository at this point in the history
  • Loading branch information
cuiwei4j committed Jan 11, 2025
1 parent 73acfc8 commit b3fb9e4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import java.time.Duration;
import java.util.List;

/**
* <a href="https://open.bigmodel.cn/dev/api/Agent_Platform/agent">QingLiu Agent</a>
*/
public class ZhipuAiAssistant {

private final String appId;
Expand Down Expand Up @@ -50,36 +53,51 @@ public static ZhipuAiAssistantBuilder builder() {
return new ZhipuAiAssistantBuilder();
}

/**
* Obtain the input parameters of the intelligent agent (application).
*/
public List<AssistantKeyValuePair> variables() {
return client.variables(appId);
}

public ConversationId conversation() {
return client.conversation(appId);
/**
* Create a new session
*/
public String getConversationId() {
ConversationId conversation = client.conversation(appId);
return conversation.getConversationId();
}

public void generate(
String conversationId,
List<AssistantKeyValuePair> keyValuePairs,
StreamingResponseHandler<AiMessage> handler) {
final ConversationRequest request = ConversationRequest.builder()
/**
* Create session request.
* @param conversationId Conversation ID
* @param keyValuePairs input parameters
*/
public String getRequestId(String conversationId, List<AssistantKeyValuePair> keyValuePairs) {
ConversationRequest request = ConversationRequest.builder()
.appId(appId)
.conversationId(conversationId)
.keyValuePairs(keyValuePairs)
.build();
final ConversationId reqId = client.generate(request);
this.generate(reqId, handler);
return client.generate(request).getId();
}

public void generate(
String conversationId,
List<AssistantKeyValuePair> keyValuePairs,
StreamingResponseHandler<AiMessage> handler) {
String requestId = getRequestId(conversationId, keyValuePairs);
this.generate(requestId, handler);
}

public void generate(ConversationId request, StreamingResponseHandler<AiMessage> handler) {
client.sseInvoke(request, handler);
public void generate(String requestId, StreamingResponseHandler<AiMessage> handler) {
client.sseInvoke(requestId, handler);
}

/**
* Recommended questions
*
* @param conversationId Conversation ID
* @return Problems
*/
public Problems sessionRecord(String conversationId) {
return client.sessionRecord(appId, conversationId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ public List<AssistantKeyValuePair> variables(String appId) {
if (Objects.nonNull(body)) {
if (!body.isSuccess()) {
log.error("获取智能体输入参数失败,原因为:【{}】", body.getMessage());
throw new ZhipuAiException(body.getCode() + "", body.getMessage());
}
return body.getData();
}
Expand All @@ -107,6 +108,7 @@ public ConversationId conversation(String appId) {
if (Objects.nonNull(body)) {
if (!body.isSuccess()) {
log.error("创建新会话失败,原因为:【{}】", body.getMessage());
throw new ZhipuAiException(body.getCode() + "", body.getMessage());
}
return body.getData();
}
Expand All @@ -128,6 +130,7 @@ public ConversationId generate(ConversationRequest request) {
if (Objects.nonNull(body)) {
if (!body.isSuccess()) {
log.error("创建对话或创作请求失败,原因为:【{}】", body.getMessage());
throw new ZhipuAiException(body.getCode() + "", body.getMessage());
}
return body.getData();
}
Expand All @@ -140,7 +143,7 @@ public ConversationId generate(ConversationRequest request) {
}
}

void sseInvoke(ConversationId request, StreamingResponseHandler<AiMessage> handler) {
void sseInvoke(String requestId, StreamingResponseHandler<AiMessage> handler) {
EventSourceListener eventSourceListener = new EventSourceListener() {
final StringBuffer contentBuilder = new StringBuffer();
TokenUsage tokenUsage;
Expand Down Expand Up @@ -229,7 +232,7 @@ public void onClosed(@NotNull EventSource eventSource) {
}
};
EventSources.createFactory(this.okHttpClient)
.newEventSource(zhipuAiApi.sseInvoke(request.getId()).request(), eventSourceListener);
.newEventSource(zhipuAiApi.sseInvoke(requestId).request(), eventSourceListener);
}

public Problems sessionRecord(String appId, String conversationId) {
Expand All @@ -241,6 +244,7 @@ public Problems sessionRecord(String appId, String conversationId) {
if (Objects.nonNull(body)) {
if (!body.isSuccess()) {
log.error("获取推荐问题失败,原因为:【{}】", body.getMessage());
throw new ZhipuAiException(body.getCode() + "", body.getMessage());
}
return body.getData();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import static org.assertj.core.api.Assertions.fail;

import dev.langchain4j.community.model.zhipu.assistant.AssistantKeyValuePair;
import dev.langchain4j.community.model.zhipu.assistant.conversation.ConversationId;
import dev.langchain4j.community.model.zhipu.assistant.problem.Problems;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils;
Expand Down Expand Up @@ -85,7 +84,6 @@ void recommend_problems() {
* @return conversationId
*/
public String getConversationId() {
ConversationId conversationId = chatModel.conversation();
return conversationId.getConversationId();
return chatModel.getConversationId();
}
}

0 comments on commit b3fb9e4

Please sign in to comment.