Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support ak/sk auth #269

Merged
merged 14 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions go/qianfan/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package qianfan

import (
"fmt"
"sync"
"time"

"github.com/mitchellh/mapstructure"
)

type AccessTokenRequest struct {
GrantType string `mapstructure:"grant_type"`
ClientId string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
}

func newAccessTokenRequest(ak, sk string) *AccessTokenRequest {
return &AccessTokenRequest{
GrantType: "client_credentials",
ClientId: ak,
ClientSecret: sk,
}
}

type AccessTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
SessionKey string `json:"session_key"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
SessionSecret string `json:"session_secret"`
baseResponse
}

func (r *AccessTokenResponse) GetErrorCode() string {
return r.Error
}

type credential struct {
AK string
SK string
}

type accessToken struct {
token string
lastUpateTime time.Time
}

type AuthManager struct {
tokenMap map[credential]*accessToken
lock sync.Mutex
*Requestor
}

func maskAk(ak string) string {
unmaskLen := 6
if len(ak) < unmaskLen {
return ak
}
return fmt.Sprintf("%s******", ak[:unmaskLen])
}

var _authManager *AuthManager

func GetAuthManager() *AuthManager {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有没有场景,AuthManager 的 Requestor 需要一些特殊设置?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

貌似没有,python目前没遇到,需要设置的参数都走GetConfig()获取

if _authManager == nil {
_authManager = &AuthManager{
tokenMap: make(map[credential]*accessToken),
lock: sync.Mutex{},
Requestor: newRequestor(makeOptions()),
}
}
return _authManager
}

func (m *AuthManager) GetAccessToken(ak, sk string) (string, error) {
token, ok := func() (*accessToken, bool) {
m.lock.Lock()
defer m.lock.Unlock()
token, ok := m.tokenMap[credential{ak, sk}]
return token, ok
}()
if ok {
return token.token, nil
}
logger.Infof("Access token of ak `%s` not found, tring to refresh it...", maskAk(ak))
return m.GetAccessTokenWithRefresh(ak, sk)
}

func (m *AuthManager) GetAccessTokenWithRefresh(ak, sk string) (string, error) {
m.lock.Lock()
defer m.lock.Unlock()

token, ok := m.tokenMap[credential{ak, sk}]
if ok {
lastUpdate := token.lastUpateTime
current := time.Now()
// 最近更新时间小于最小刷新间隔,则直接返回
// 避免多个请求同时刷新,导致token被刷新多次
if current.Sub(lastUpdate) < time.Duration(GetConfig().AccessTokenRefreshMinInterval)*time.Second {
logger.Debugf("Access token of ak `%s` was freshed %s ago, skip refreshing", maskAk(ak), current.Sub(lastUpdate))
return token.token, nil
}
}

resp := AccessTokenResponse{}
req, err := newAuthRequest("POST", authAPIPrefix, nil)
if err != nil {
return "", err
}
params := newAccessTokenRequest(ak, sk)
paramsMap := make(map[string]string)
err = mapstructure.Decode(params, &paramsMap)
if err != nil {
return "", err
}
req.Params = paramsMap
err = m.Requestor.request(req, &resp)
if err != nil {
return "", err
}
if resp.Error != "" {
logger.Errorf("refresh access token of ak `%s` failed with error: %s", maskAk(ak), resp.ErrorDescription)
return "", &APIError{Msg: resp.ErrorDescription}
}
logger.Infof("Access token of ak `%s` was refreshed", maskAk(ak))
m.tokenMap[credential{ak, sk}] = &accessToken{
token: resp.AccessToken,
lastUpateTime: time.Now(),
}
return resp.AccessToken, nil
}
195 changes: 195 additions & 0 deletions go/qianfan/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
// Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package qianfan

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func fakeAccessToken(ak, sk string) string {
return fmt.Sprintf("%s.%s", ak, sk)
}

func resetAuthManager() {
_authManager = nil
}

func setAccessTokenExpired(ak, sk string) {
GetAuthManager().tokenMap[credential{ak, sk}] = &accessToken{
token: "expired_token",
lastUpateTime: time.Now().Add(-100 * time.Hour), // 100s 过期
}
}

func TestAuth(t *testing.T) {
resetAuthManager()
ak, sk := "ak_33", "sk_4235"
// 第一次获取前,缓存里应当没有
_, ok := GetAuthManager().tokenMap[credential{ak, sk}]
assert.False(t, ok)

accessTok, err := GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, fakeAccessToken(ak, sk))
updateTime := GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime
// 再测试一次,应当从缓存里获取,更新时间不变
accessTok, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, fakeAccessToken(ak, sk))
assert.Equal(
t,
updateTime,
GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime,
)
// 模拟过期
ak, sk = "ak_95411", "sk_87135"
setAccessTokenExpired(ak, sk)
// 设置一个附近的更新时间,用来测试是否会忽略刚更新过的 token
GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime = time.Now()

accessTok, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, "expired_token") // 直接获取还是从缓存获取

accessTok, err = GetAuthManager().GetAccessTokenWithRefresh(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, "expired_token") // 刷新后,由于 lastUpdateTime 太接近,依旧使用缓存
setAccessTokenExpired(ak, sk)

accessTok, err = GetAuthManager().GetAccessTokenWithRefresh(ak, sk)
assert.NoError(t, err)
assert.Equal(t, accessTok, fakeAccessToken(ak, sk)) // 应当刷新
elaplsed := time.Since(GetAuthManager().tokenMap[credential{ak, sk}].lastUpateTime)
assert.Less(t, elaplsed, 10*time.Second) // 刷新后,lastUpdateTime 应当更新
}

func TestAuthFailed(t *testing.T) {
ak, sk := "bad_ak", "bad_sk"
_, err := GetAuthManager().GetAccessToken(ak, sk)
assert.Error(t, err)
var target *APIError
assert.ErrorAs(t, err, &target)
assert.Contains(t, err.Error(), target.Msg)
assert.Equal(t, target.Msg, "Client authentication failed")
}

func TestAuthWhenUsing(t *testing.T) {
defer resetTestEnv()
_authManager = nil
GetConfig().AccessKey = "access_key_484913"
GetConfig().SecretKey = "secret_key_48135"
GetConfig().AK = ""
GetConfig().SK = ""
// 未设置 AKSK,所以用 IAM 鉴权
chat := NewChatCompletion()
resp, err := chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")},
},
)
assert.NoError(t, err)
signedKey, ok := resp.RawResponse.Request.Header["Authorization"]
assert.True(t, ok)
assert.Contains(t, signedKey[0], GetConfig().AccessKey)
assert.NotContains(t, resp.RawResponse.Request.URL.RawQuery, "access_token")
// 设置了 AKSK,所以用 AKSK 鉴权
GetConfig().AK = "ak_48915684"
GetConfig().SK = "sk_78941813"
resp, err = chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")},
},
)
assert.NoError(t, err)
_, ok = resp.RawResponse.Request.Header["Authorization"]
assert.False(t, ok)
assert.Contains(t, resp.RawResponse.Request.URL.RawQuery, "access_token")
assert.Equal(
t,
resp.RawResponse.Request.URL.Query().Get("access_token"),
fakeAccessToken(GetConfig().AK, GetConfig().SK),
)
// 如果只设置了部分鉴权信息,则报错
GetConfig().AK = ""
GetConfig().AccessKey = ""
_, err = chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage("你好")},
},
)
assert.Error(t, err)
var target *CredentialNotFoundError
assert.ErrorAs(t, err, &target)
}

func TestAccessTokenExpired(t *testing.T) {
defer resetTestEnv()
_authManager = nil
ak, sk := "ak_48915684", "sk_78941813"
GetConfig().AK = ak
GetConfig().SK = sk
setAccessTokenExpired(ak, sk)
token, err := GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Contains(t, token, "expired")
prompt := "你好"
chat := NewChatCompletion()
resp, err := chat.Do(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage(prompt)},
},
)
assert.NoError(t, err)
assert.Contains(t, resp.RawResponse.Request.URL.Query().Get("access_token"), fakeAccessToken(ak, sk))
assert.Contains(t, resp.Result, prompt)
token, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, token, fakeAccessToken(ak, sk))

// 测试流式请求的刷新 token
setAccessTokenExpired(ak, sk)
token, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Contains(t, token, "expired")
stream, err := chat.Stream(
context.Background(),
&ChatCompletionRequest{
Messages: []ChatCompletionMessage{ChatCompletionUserMessage(prompt)},
},
)
assert.NoError(t, err)

for {
r, err := stream.Recv()
assert.NoError(t, err)
token, err = GetAuthManager().GetAccessToken(ak, sk)
assert.NoError(t, err)
assert.Equal(t, token, fakeAccessToken(ak, sk))
assert.Contains(t, r.RawResponse.Request.URL.Query().Get("access_token"), fakeAccessToken(ak, sk))
if r.IsEnd {
break
}
}

}
Loading
Loading