From f06b6cf2e095c415f68c394efe54cacd7ce0e3be Mon Sep 17 00:00:00 2001 From: Azure99 <961523404@qq.com> Date: Tue, 28 May 2024 17:35:45 +0800 Subject: [PATCH 1/2] feat: add reranker support --- .../java/com/baidubce/qianfan/Qianfan.java | 16 ++-- .../qianfan/core/builder/RerankBuilder.java | 69 +++++++++++++++++ .../qianfan/model/constant/ModelType.java | 1 + .../qianfan/model/rerank/RerankData.java | 61 +++++++++++++++ .../qianfan/model/rerank/RerankRequest.java | 75 +++++++++++++++++++ .../qianfan/model/rerank/RerankResponse.java | 51 +++++++++++++ .../qianfan/model/rerank/RerankUsage.java | 47 ++++++++++++ 7 files changed, 315 insertions(+), 5 deletions(-) create mode 100644 java/src/main/java/com/baidubce/qianfan/core/builder/RerankBuilder.java create mode 100644 java/src/main/java/com/baidubce/qianfan/model/rerank/RerankData.java create mode 100644 java/src/main/java/com/baidubce/qianfan/model/rerank/RerankRequest.java create mode 100644 java/src/main/java/com/baidubce/qianfan/model/rerank/RerankResponse.java create mode 100644 java/src/main/java/com/baidubce/qianfan/model/rerank/RerankUsage.java diff --git a/java/src/main/java/com/baidubce/qianfan/Qianfan.java b/java/src/main/java/com/baidubce/qianfan/Qianfan.java index a6c9e6a1..ae129d03 100644 --- a/java/src/main/java/com/baidubce/qianfan/Qianfan.java +++ b/java/src/main/java/com/baidubce/qianfan/Qianfan.java @@ -16,11 +16,7 @@ package com.baidubce.qianfan; -import com.baidubce.qianfan.core.builder.ChatBuilder; -import com.baidubce.qianfan.core.builder.CompletionBuilder; -import com.baidubce.qianfan.core.builder.EmbeddingBuilder; -import com.baidubce.qianfan.core.builder.Image2TextBuilder; -import com.baidubce.qianfan.core.builder.Text2ImageBuilder; +import com.baidubce.qianfan.core.builder.*; import com.baidubce.qianfan.model.BaseRequest; import com.baidubce.qianfan.model.BaseResponse; import com.baidubce.qianfan.model.RateLimitConfig; @@ -35,6 +31,8 @@ import com.baidubce.qianfan.model.image.Image2TextResponse; import com.baidubce.qianfan.model.image.Text2ImageRequest; import com.baidubce.qianfan.model.image.Text2ImageResponse; +import com.baidubce.qianfan.model.rerank.RerankRequest; +import com.baidubce.qianfan.model.rerank.RerankResponse; import java.util.Iterator; @@ -118,6 +116,14 @@ public Iterator image2TextStream(Image2TextRequest request) return requestStream(request, Image2TextResponse.class); } + public RerankBuilder rerank() { + return new RerankBuilder(this); + } + + public RerankResponse rerank(RerankRequest request) { + return request(request, RerankResponse.class); + } + public , U extends BaseRequest> T request(BaseRequest request, Class responseClass) { return client.request(request, responseClass); } diff --git a/java/src/main/java/com/baidubce/qianfan/core/builder/RerankBuilder.java b/java/src/main/java/com/baidubce/qianfan/core/builder/RerankBuilder.java new file mode 100644 index 00000000..13a084c5 --- /dev/null +++ b/java/src/main/java/com/baidubce/qianfan/core/builder/RerankBuilder.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce.qianfan.core.builder; + +import com.baidubce.qianfan.Qianfan; +import com.baidubce.qianfan.model.rerank.RerankRequest; +import com.baidubce.qianfan.model.rerank.RerankResponse; + +import java.util.List; + +public class RerankBuilder extends BaseBuilder { + private String query; + + private List documents; + + private Integer topN; + + public RerankBuilder() { + super(); + } + + public RerankBuilder(Qianfan qianfan) { + super(qianfan); + } + + public RerankBuilder query(String query) { + this.query = query; + return this; + } + + public RerankBuilder documents(List documents) { + this.documents = documents; + return this; + } + + public RerankBuilder topN(Integer topN) { + this.topN = topN; + return this; + } + + public RerankRequest build() { + return new RerankRequest() + .setQuery(query) + .setDocuments(documents) + .setTopN(topN) + .setModel(super.getModel()) + .setEndpoint(super.getEndpoint()) + .setUserId(super.getUserId()) + .setExtraParameters(super.getExtraParameters()); + } + + public RerankResponse execute() { + return super.getQianfan().rerank(build()); + } +} diff --git a/java/src/main/java/com/baidubce/qianfan/model/constant/ModelType.java b/java/src/main/java/com/baidubce/qianfan/model/constant/ModelType.java index 12dace8a..853ca721 100644 --- a/java/src/main/java/com/baidubce/qianfan/model/constant/ModelType.java +++ b/java/src/main/java/com/baidubce/qianfan/model/constant/ModelType.java @@ -22,6 +22,7 @@ public class ModelType { public static final String EMBEDDINGS = "embeddings"; public static final String TEXT_2_IMAGE = "text2image"; public static final String IMAGE_2_TEXT = "image2text"; + public static final String RERANK = "reranker"; private ModelType() { } diff --git a/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankData.java b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankData.java new file mode 100644 index 00000000..5052fdb5 --- /dev/null +++ b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankData.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce.qianfan.model.rerank; + +public class RerankData { + /** + * 文本内容 + */ + private String document; + + /** + * 相似性得分 + */ + private Double relevanceScore; + + /** + * 序号 + */ + private Integer index; + + public String getDocument() { + return document; + } + + public RerankData setDocument(String document) { + this.document = document; + return this; + } + + public Double getRelevanceScore() { + return relevanceScore; + } + + public RerankData setRelevanceScore(Double relevanceScore) { + this.relevanceScore = relevanceScore; + return this; + } + + public Integer getIndex() { + return index; + } + + public RerankData setIndex(Integer index) { + this.index = index; + return this; + } +} diff --git a/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankRequest.java b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankRequest.java new file mode 100644 index 00000000..41de83c4 --- /dev/null +++ b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankRequest.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce.qianfan.model.rerank; + +import com.baidubce.qianfan.model.BaseRequest; +import com.baidubce.qianfan.model.constant.ModelType; + +import java.util.List; + +public class RerankRequest extends BaseRequest { + /** + * 查询文本,长度不超过1600个字符,token数若超过400做截断 + */ + private String query; + + /** + * 需要重排序的文本,说明: + * (1)不能为空List,List的每个成员不能为空字符串 + * (2)文本数量不超过64 + * (3)每条document文本长度不超过4096个字符,token数若超过1024做截断 + */ + private List documents; + + /** + * 返回的最相关文本的数量,默认为document的数量 + */ + private Integer topN; + + + @Override + public String getType() { + return ModelType.RERANK; + } + + public String getQuery() { + return query; + } + + public RerankRequest setQuery(String query) { + this.query = query; + return this; + } + + public List getDocuments() { + return documents; + } + + public RerankRequest setDocuments(List documents) { + this.documents = documents; + return this; + } + + public Integer getTopN() { + return topN; + } + + public RerankRequest setTopN(Integer topN) { + this.topN = topN; + return this; + } +} diff --git a/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankResponse.java b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankResponse.java new file mode 100644 index 00000000..9f4ff1be --- /dev/null +++ b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankResponse.java @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce.qianfan.model.rerank; + +import com.baidubce.qianfan.model.BaseResponse; + +import java.util.List; + +public class RerankResponse extends BaseResponse { + /** + * 重排序结果,按相似性得分倒序 + */ + private List results; + + /** + * token统计信息,token数 = 汉字数+单词数*1.3 (仅为估算逻辑) + */ + private RerankUsage usage; + + public List getResults() { + return results; + } + + public RerankResponse setResults(List results) { + this.results = results; + return this; + } + + public RerankUsage getUsage() { + return usage; + } + + public RerankResponse setUsage(RerankUsage usage) { + this.usage = usage; + return this; + } +} diff --git a/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankUsage.java b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankUsage.java new file mode 100644 index 00000000..57f86fbc --- /dev/null +++ b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankUsage.java @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce.qianfan.model.rerank; + +public class RerankUsage { + /** + * 问题tokens数 + */ + private Integer promptTokens; + + /** + * tokens总数 + */ + private Integer totalTokens; + + public Integer getPromptTokens() { + return promptTokens; + } + + public RerankUsage setPromptTokens(Integer promptTokens) { + this.promptTokens = promptTokens; + return this; + } + + public Integer getTotalTokens() { + return totalTokens; + } + + public RerankUsage setTotalTokens(Integer totalTokens) { + this.totalTokens = totalTokens; + return this; + } +} From 8b1e1b333848595b9958f469fe596fa156af5715 Mon Sep 17 00:00:00 2001 From: Azure99 <961523404@qq.com> Date: Tue, 28 May 2024 17:39:05 +0800 Subject: [PATCH 2/2] feat: add reranker support --- .../com/baidubce/qianfan/core/ModelEndpointRetriever.java | 8 +++++++- .../com/baidubce/qianfan/model/constant/ModelType.java | 2 +- .../com/baidubce/qianfan/model/rerank/RerankRequest.java | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/java/src/main/java/com/baidubce/qianfan/core/ModelEndpointRetriever.java b/java/src/main/java/com/baidubce/qianfan/core/ModelEndpointRetriever.java index 75b4ec35..2fca20a8 100644 --- a/java/src/main/java/com/baidubce/qianfan/core/ModelEndpointRetriever.java +++ b/java/src/main/java/com/baidubce/qianfan/core/ModelEndpointRetriever.java @@ -39,11 +39,15 @@ public class ModelEndpointRetriever { private static final String DEFAULT_EMBEDDING_MODEL = "embedding-v1"; private static final String DEFAULT_TEXT_2_IMAGE_MODEL = "stable-diffusion-xl"; private static final String DEFAULT_IMAGE_2_TEXT_MODEL = "fuyu-8b"; + private static final String DEFAULT_RERANKER_MODEL = "bce-reranker-base_v1"; private static final String LIST_MODEL_SERVICE_URL = "%s/wenxinworkshop/service/list"; private static final String ENDPOINT_TEMPLATE = "/%s/%s"; private static final int DYNAMIC_MAP_REFRESH_INTERVAL = 3600; - private static final String[] MODEL_TYPES = {ModelType.CHAT, ModelType.COMPLETIONS, ModelType.EMBEDDINGS, ModelType.TEXT_2_IMAGE, ModelType.IMAGE_2_TEXT}; + private static final String[] MODEL_TYPES = { + ModelType.CHAT, ModelType.COMPLETIONS, ModelType.EMBEDDINGS, + ModelType.TEXT_2_IMAGE, ModelType.IMAGE_2_TEXT, ModelType.RERANKER + }; // type -> (model -> endpoint) private final Map> typeModelEndpointMap = new HashMap<>(); @@ -62,6 +66,7 @@ public ModelEndpointRetriever(IAuth auth) { defaultTypeModelMap.put(ModelType.EMBEDDINGS, DEFAULT_EMBEDDING_MODEL); defaultTypeModelMap.put(ModelType.TEXT_2_IMAGE, DEFAULT_TEXT_2_IMAGE_MODEL); defaultTypeModelMap.put(ModelType.IMAGE_2_TEXT, DEFAULT_IMAGE_2_TEXT_MODEL); + defaultTypeModelMap.put(ModelType.RERANKER, DEFAULT_RERANKER_MODEL); for (String type : MODEL_TYPES) { typeModelEndpointMap.put(type, new HashMap<>()); @@ -113,6 +118,7 @@ public ModelEndpointRetriever(IAuth auth) { typeModelEndpointMap.get(ModelType.EMBEDDINGS).put("tao-8k", "tao_8k"); typeModelEndpointMap.get(ModelType.TEXT_2_IMAGE).put("stable-diffusion-xl", "sd_xl"); typeModelEndpointMap.get(ModelType.IMAGE_2_TEXT).put("fuyu-8b", "fuyu_8b"); + typeModelEndpointMap.get(ModelType.RERANKER).put("bce-reranker-base_v1", "bce_reranker_base"); // Compatibility for old model names typeModelEndpointMap.get(ModelType.CHAT).put("ernie-bot-turbo", "eb-instant"); typeModelEndpointMap.get(ModelType.CHAT).put("ernie-bot", "completions"); diff --git a/java/src/main/java/com/baidubce/qianfan/model/constant/ModelType.java b/java/src/main/java/com/baidubce/qianfan/model/constant/ModelType.java index 853ca721..d7a09838 100644 --- a/java/src/main/java/com/baidubce/qianfan/model/constant/ModelType.java +++ b/java/src/main/java/com/baidubce/qianfan/model/constant/ModelType.java @@ -22,7 +22,7 @@ public class ModelType { public static final String EMBEDDINGS = "embeddings"; public static final String TEXT_2_IMAGE = "text2image"; public static final String IMAGE_2_TEXT = "image2text"; - public static final String RERANK = "reranker"; + public static final String RERANKER = "reranker"; private ModelType() { } diff --git a/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankRequest.java b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankRequest.java index 41de83c4..2e5c5d3e 100644 --- a/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankRequest.java +++ b/java/src/main/java/com/baidubce/qianfan/model/rerank/RerankRequest.java @@ -43,7 +43,7 @@ public class RerankRequest extends BaseRequest { @Override public String getType() { - return ModelType.RERANK; + return ModelType.RERANKER; } public String getQuery() {