Skip to content

Commit

Permalink
feat: Go 支持 rerank (#900)
Browse files Browse the repository at this point in the history
* feat:go提供reranker

* feat:go提供reranker

* 更新 Java 版本到 0.1.6 (#899)

---------

Co-authored-by: zhangxin <[email protected]>
Co-authored-by: Dobiichi-Origami <[email protected]>
  • Loading branch information
3 people authored Feb 21, 2025
1 parent cf94b38 commit d5d16ba
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 1 deletion.
7 changes: 7 additions & 0 deletions go/qianfan/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@ const (
DefaultChatCompletionModel = "ERNIE-Lite-8K"
DefaultCompletionModel = "ERNIE-Lite-8K"
DefaultEmbeddingModel = "Embedding-V1"
DefaultReRankerModel = "bce-reranker-base_v1"
DefaultText2ImageModel = "Stable-Diffusion-XL"
)

// v2可以使用的预制模型
const (
DefaultReRankerV2Model = "bce-reranker-base"
)

// API 错误码
const (
NoErrorErrCode = 0
Expand Down Expand Up @@ -68,4 +74,5 @@ const (
const (
ChatV2API = "/v2/chat/completions"
EmbeddingV2API = "/v2/embeddings"
ReRankerV2API = "/v2/rerankers"
)
1 change: 1 addition & 0 deletions go/qianfan/model_endpoint_retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func getModelEndpointRetriever() *modelEndpointRetriever {
"chat": ChatModelEndpoint,
"completions": CompletionModelEndpoint,
"embeddings": EmbeddingEndpoint,
"reranker": ReRankerEndpoint,
"text2image": Text2ImageEndpoint,
"image2text": make(map[string]string),
}
Expand Down
140 changes: 140 additions & 0 deletions go/qianfan/reranker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package qianfan

import (
"context"
)

type ReRanker struct {
BaseModel
}

type ReRankerRequest struct {
BaseRequestBody `mapstructure:"-"`
Query string `mapstructure:"query"` //查询文本
Documents []string `mapstructure:"documents"` //需要重排序的文本
TopN int `mapstructure:"top_n,omitempty"` //返回的最相关文本的数量
UserID string `mapstructure:"user_id,omitempty"` // 表示最终用户的唯一标识符
}
type ReRankerResponse struct {
Id string `json:"id"` //本轮对话的id
Object string `json:"object"` //回包类型
Created int `json:"created"` //创建时间
Results []ReRankerData `json:"results"` //重排序结果,按相似性得分倒序
Usage ModelUsage `json:"usage"` // token统计信息
ModelAPIError
baseResponse
}

type ReRankerData struct {
Document string `json:"document"` //文档内容
RelevanceScore float64 `json:"relevance_score"` //相似性得分
Index int `json:"index"` //排序
}

func NewReRanker(optionList ...Option) *ReRanker {
options := makeOptions(optionList...)
return newReRanker(options)
}

// 内部根据 options 创建 Embedding 实例
func newReRanker(options *Options) *ReRanker {
reRanker := &ReRanker{
BaseModel{
Model: DefaultReRankerModel,
Endpoint: "",
Requestor: newRequestor(options),
},
}
if options.Model != nil {
reRanker.Model = *options.Model
}
if options.Endpoint != nil {
reRanker.Endpoint = *options.Endpoint
}
return reRanker
}

var ReRankerEndpoint = map[string]string{
"bce-reranker-base_v1": "/reranker/bce_reranker_base",
}

func (r *ReRanker) Do(ctx context.Context, request *ReRankerRequest) (*ReRankerResponse, error) {
var resp *ReRankerResponse
var err error

if request.TopN <= 0 {
request.TopN = len(request.Documents)
}
runErr := runWithContext(ctx, func() {
resp, err = r.do(ctx, request)
})
if runErr != nil {
return nil, runErr
}
return resp, err
}

func (r *ReRanker) do(ctx context.Context, request *ReRankerRequest) (*ReRankerResponse, error) {
do := func() (*ReRankerResponse, error) {
url, err := r.realEndpoint(ctx)
if err != nil {
return nil, err
}
req, err := NewModelRequest("POST", url, request)
if err != nil {
return nil, err
}
resp := &ReRankerResponse{}

err = r.requestResource(ctx, req, resp)
if err != nil {
return nil, err
}

return resp, nil
}
resp, err := do()
if err != nil {
if r.Endpoint == "" && isUnsupportedModelError(err) {
// 根据 model 获得的 endpoint 错误,刷新模型列表后重试
refreshErr := getModelEndpointRetriever().Refresh(ctx)
if refreshErr != nil {
logger.Errorf("refresh endpoint failed: %s", refreshErr)
return resp, err
}
return do()
}
return resp, err
}
return resp, err
}

// ModelList 获取 ReRanker 支持的模型列表
func (r *ReRanker) ModelList() []string {
models := getModelEndpointRetriever().GetModelList(context.TODO(), "reranker")
list := make([]string, len(models))
i := 0
for k := range EmbeddingEndpoint {
list[i] = k
i++
}
return list
}

func (r *ReRanker) realEndpoint(ctx context.Context) (string, error) {
url := modelAPIPrefix
if r.Endpoint == "" {
endpoint := getModelEndpointRetriever().GetEndpoint(ctx, "reranker", r.Model)
if endpoint == "" {
endpoint = getModelEndpointRetriever().GetEndpointWithRefresh(ctx, "reranker", r.Model)
if endpoint == "" {
return "", &ModelNotSupportedError{Model: r.Model}
}
}
url += endpoint
} else {
url += "/reranker/" + r.Endpoint
}
logger.Debugf("requesting endpoint: %s", url)
return url, nil
}
98 changes: 98 additions & 0 deletions go/qianfan/reranker_v2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package qianfan

import (
"context"
"fmt"
)

type ReRankerV2 struct {
BaseModel
}

type ReRankerV2Request struct {
BaseRequestBody `mapstructure:"-"`
Query string `mapstructure:"query"` //查询文本
Documents []string `mapstructure:"documents"` //需要重排序的文本
TopN int `mapstructure:"top_n,omitempty"` //返回的最相关文本的数量
User string `mapstructure:"user,omitempty"` //用户标识
Model string `mapstructure:"model"` //模型名称
}

type ReRankerV2Data struct {
*ReRankerData
}

type ReRankerV2Response struct {
baseResponse // 基础的响应字段
Id string `json:"id"` // 请求的id
Object string `json:"object"` // 回包类型,固定值“ReRanker_list”
Created int `json:"created"` // 创建时间
Results []ReRankerV2Data `json:"results"` // 嵌入向量数据列表
Usage *ModelUsage `json:"usage"` // token统计信息
Error *ChatCompletionV2Error `json:"error"` // 错误信息
}

func NewReRankerV2(optionList ...Option) *ReRankerV2 {
options := makeOptions(optionList...)
return newReRankerV2(options)
}

func newReRankerV2(options *Options) *ReRankerV2 {
reRankerV2 := &ReRankerV2{
BaseModel{
Model: DefaultReRankerV2Model,
Requestor: newRequestor(options),
},
}
if options.Model != nil {
reRankerV2.Model = *options.Model
}

return reRankerV2
}

func (c *ReRankerV2) Do(ctx context.Context, request *ReRankerV2Request) (*ReRankerV2Response, error) {
var resp *ReRankerV2Response
var err error
if request.TopN <= 0 {
request.TopN = len(request.Documents)
}
runErr := runWithContext(ctx, func() {
resp, err = c.do(ctx, request)
})
if runErr != nil {
return nil, runErr
}
return resp, err
}

func (c *ReRankerV2) do(ctx context.Context, request *ReRankerV2Request) (*ReRankerV2Response, error) {
do := func() (*ReRankerV2Response, error) {
req, err := NewBearerTokenRequest("POST", ReRankerV2API, request)
if err != nil {
return nil, err
}
resp := &ReRankerV2Response{}

err = c.Requestor.request(ctx, req, resp)
if err != nil {
return nil, err
}

if resp.Error != nil {
return nil, fmt.Errorf(
"code: %s, type: %s, message: %s",
resp.Error.Code,
resp.Error.Type,
resp.Error.Message,
)
}

return resp, nil
}
return do()
}

func (c *ReRankerV2Response) GetErrorCode() string {
return c.Error.Message
}
2 changes: 1 addition & 1 deletion go/qianfan/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@
package qianfan

// SDK 版本
const Version = "v0.0.14"
const Version = "v0.0.15"
const versionIndicator = "qianfan_go_sdk_" + Version

0 comments on commit d5d16ba

Please sign in to comment.