Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support openai embeddings endpoints in ai proxy #1934

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion go/ai-proxy/api/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ package bedrock

type Endpoint string

const EndpointChat = "/chat/bedrock"
const (
EndpointChat = "/chat/bedrock"
EndpointEmbeddings = "/embeddings/bedrock"
)
1 change: 1 addition & 0 deletions go/ai-proxy/api/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Endpoint string
const (
EndpointChat = "/openai/v1/chat/completions"
EndpointChatCompletions = "/v1/chat/completions"
EndpointEmbeddings = "/v1/embeddings"
)

type ChatCompletionRequest struct {
Expand Down
1 change: 1 addition & 0 deletions go/ai-proxy/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/sashabaranov/go-openai v1.37.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.11.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go/ai-proxy/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc=
github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU=
github.com/sashabaranov/go-openai v1.37.0 h1:hQQowgYm4OXJ1Z/wTrE+XZaO20BYsL0R3uRPSpfNZkY=
github.com/sashabaranov/go-openai v1.37.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
14 changes: 14 additions & 0 deletions go/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,30 @@ func main() {
os.Exit(1)
}

eop, err := proxy.NewOpenAIEmbeddingsProxy(api.ProviderOpenAI, args.ProviderHost(), args.ProviderCredentials())
if err != nil {
klog.ErrorS(err, "Could not create proxy")
os.Exit(1)
}

bp, err := proxy.NewBedrockProxy(api.ProviderBedrock, args.ProviderCredentials())
if err != nil {
klog.ErrorS(err, "Could not create proxy")
os.Exit(1)
}

ebp, err := proxy.NewBedrockEmbeddingsProxy(api.ProviderBedrock, args.ProviderCredentials())
if err != nil {
klog.ErrorS(err, "Could not create proxy")
os.Exit(1)
}

router := mux.NewRouter()
router.HandleFunc(ollama.EndpointChat, p.Proxy())
router.HandleFunc(openai.EndpointChat, op.Proxy())
router.HandleFunc(openai.EndpointEmbeddings, eop.Proxy())
router.HandleFunc(bedrock.EndpointChat, bp.Proxy())
router.HandleFunc(bedrock.EndpointEmbeddings, ebp.Proxy())

klog.V(log.LogLevelMinimal).InfoS("Listening and serving HTTP", "address", args.Address())
if err := http.ListenAndServe(args.Address(), router); err != nil {
Expand Down
137 changes: 137 additions & 0 deletions go/ai-proxy/proxy/bedrock/embeddings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package bedrock

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/ollama/ollama/openai"
"k8s.io/klog/v2"

"github.com/pluralsh/console/go/ai-proxy/api"
)

const (
Titan = "titan"
Cohere = "cohere"
)

type titanEmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}

type cohereEmbeddingResponse struct {
Embedding []float32 `json:"embeddings"`
}

type BedrockEmbeddingsProxy struct {
bedrockClient *bedrockruntime.Client
}

func NewBedrockEmbeddingsProxy(region string) (api.OpenAIProxy, error) {
ctx := context.Background()

var loadOptions []func(options *config.LoadOptions) error
if region != "" {
loadOptions = append(loadOptions, config.WithRegion(region))
}

sdkConfig, err := config.LoadDefaultConfig(ctx, loadOptions...)
if err != nil {
klog.ErrorS(err, "Couldn't load default configuration. Have you set up your AWS account?")
return nil, err
}
bedrockClient := bedrockruntime.NewFromConfig(sdkConfig)
return &BedrockEmbeddingsProxy{
bedrockClient: bedrockClient,
}, nil
}

func (b *BedrockEmbeddingsProxy) Proxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var openAIReq openai.EmbedRequest
if err := json.NewDecoder(r.Body).Decode(&openAIReq); err != nil {
http.Error(w, "failed to parse openai request", http.StatusBadRequest)
return
}
b.handleEmbeddingBedrock(w, &openAIReq)
}
}

func (b *BedrockEmbeddingsProxy) handleEmbeddingBedrock(
w http.ResponseWriter,
req *openai.EmbedRequest,
) {
input := map[string]interface{}{}

switch {
case strings.Contains(strings.ToLower(req.Model), Titan):
input["inputText"] = req.Input
case strings.Contains(strings.ToLower(req.Model), Cohere):
input["texts"] = []string{req.Input.(string)}
default:
klog.Errorf("model doesn't support embedding at this time %s", req.Model)
return
}

payloadBytes, err := json.Marshal(input)
if err != nil {
klog.ErrorS(err, "failed to convert to bedrock request")
return
}

output, err := b.bedrockClient.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{
ModelId: aws.String(req.Model),
Body: payloadBytes,
})
if err != nil {
klog.ErrorS(err, "request to bedrock failed")
return
}

response, err := convertBedrockEmbeddingToOpenAI(output, req.Model)
if err != nil {
klog.ErrorS(err, "failed to convert bedrock response to openai format")
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
klog.Errorf("Error encoding response: %v", err)
return
}
}

func convertBedrockEmbeddingToOpenAI(output *bedrockruntime.InvokeModelOutput, model string) (*openai.EmbeddingList, error) {
switch {
case strings.Contains(model, Titan):
var embed titanEmbeddingResponse
if err := json.Unmarshal(output.Body, &embed); err != nil {
return nil, fmt.Errorf("failed to unmarshal Titan embedding response: %v", err)
}
var embedding openai.Embedding
embedding.Embedding = embed.Embedding
return &openai.EmbeddingList{
Model: model,
Data: []openai.Embedding{embedding},
}, nil
case strings.Contains(model, Cohere):
var embed cohereEmbeddingResponse
if err := json.Unmarshal(output.Body, &embed); err != nil {
return nil, fmt.Errorf("failed to unmarshal Cohere embedding response: %v", err)
}
var embedding openai.Embedding
embedding.Embedding = embed.Embedding
return &openai.EmbeddingList{
Model: model,
Data: []openai.Embedding{embedding},
}, nil
default:
return nil, fmt.Errorf("model doesn't support embedding at this time %s", model)
}
}
50 changes: 50 additions & 0 deletions go/ai-proxy/proxy/openai/embeddings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package openai

import (
"fmt"
"net/http/httputil"
"net/url"

"k8s.io/klog/v2"

"github.com/pluralsh/console/go/ai-proxy/api"
"github.com/pluralsh/console/go/ai-proxy/api/openai"
"github.com/pluralsh/console/go/ai-proxy/internal/log"
)

func NewOpenAIEmbeddingsProxy(host, token string) (api.OpenAIProxy, error) {
parsedURL, err := url.Parse(host)
if err != nil {
return nil, err
}

reverse := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.Out.Header.Set("Authorization", "Bearer "+token)

r.SetXForwarded()

targetURL, err := url.Parse(openai.EndpointEmbeddings)
if err != nil {
klog.ErrorS(err, "failed to parse target url")
return
}

r.Out.URL.Scheme = parsedURL.Scheme
r.Out.URL.Host = parsedURL.Host
r.Out.Host = parsedURL.Host
r.Out.URL.Path = targetURL.Path

klog.V(log.LogLevelDebug).InfoS(
"proxying request",
"from", fmt.Sprintf("%s %s", r.In.Method, r.In.URL.Path),
"to", r.Out.URL.String(),
)
},
}

return &OpenAIProxy{
proxy: reverse,
token: token,
}, nil
}
16 changes: 16 additions & 0 deletions go/ai-proxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,26 @@ func NewOpenAIProxy(p api.Provider, host, token string) (api.OpenAIProxy, error)
return nil, fmt.Errorf("invalid provider: %s", p)
}

func NewOpenAIEmbeddingsProxy(p api.Provider, host, token string) (api.OpenAIProxy, error) {
switch p {
case api.ProviderOpenAI:
return openai.NewOpenAIEmbeddingsProxy(host, token)
}
return nil, fmt.Errorf("invalid provider: %s", p)
}

func NewBedrockProxy(p api.Provider, region string) (api.OpenAIProxy, error) {
switch p {
case api.ProviderBedrock:
return bedrock.NewBedrockProxy(region)
}
return nil, fmt.Errorf("invalid provider: %s", p)
}

func NewBedrockEmbeddingsProxy(p api.Provider, region string) (api.OpenAIProxy, error) {
switch p {
case api.ProviderBedrock:
return bedrock.NewBedrockEmbeddingsProxy(region)
}
return nil, fmt.Errorf("invalid provider: %s", p)
}
48 changes: 48 additions & 0 deletions go/ai-proxy/test/bedrock/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,51 @@ func TestBedrockProxy_Streaming(t *testing.T) {
}
})
}

func TestBedrockEmbeddingsProxy(t *testing.T) {
cases := []helpers.TestStruct[any, any]{
{
Name: "embeddings request should return correct openai response",
Method: "POST",
Endpoint: bedrock.EndpointEmbeddings,
Request: openai.EmbedRequest{
Model: "amazon.titan-embed-text-v2:0",
Input: "Hello from Titan embeddings test.",
},
WantData: openai.EmbeddingList{
Model: "amazon.titan-embed-text-v2:0",
Data: []openai.Embedding{
{
Embedding: make([]float32, 5)},
},
},
WantErr: nil,
WantStatus: http.StatusOK,
},
}

for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
wantDataBytes, err := json.Marshal(tc.WantData)
if err != nil {
t.Fatal(err)
}

mockResponseFunc := helpers.MockResponse(tc.Endpoint, wantDataBytes, tc.WantErr, tc.WantStatus)
err = mockResponseFunc(handlers)
if err != nil {
t.Fatal(err)
}

requestFunc := helpers.CreateRequest(tc.Method, tc.Endpoint, tc.Request)
res, err := requestFunc(server, providerServer)
if !errors.Is(err, tc.WantErr) {
t.Fatalf("\nwant:\n%v\ngot:\n%v", tc.WantErr, err)
}

if !bytes.Equal(wantDataBytes, res) {
t.Errorf("\nwant:\n%s\ngot:\n%s", tc.WantData, res)
}
})
}
}
16 changes: 16 additions & 0 deletions go/ai-proxy/test/helpers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import (
"net"
"net/http"
"net/http/httptest"
"os"
"strings"

"github.com/gorilla/mux"
"k8s.io/klog/v2"

"github.com/pluralsh/console/go/ai-proxy/api"
"github.com/pluralsh/console/go/ai-proxy/api/bedrock"
Expand All @@ -32,15 +34,29 @@ func SetupServer() (*httptest.Server, error) {
return nil, err
}

eop, err := proxy.NewOpenAIEmbeddingsProxy(api.ProviderOpenAI, args.ProviderHost(), args.ProviderCredentials())
if err != nil {
klog.ErrorS(err, "Could not create proxy")
os.Exit(1)
}

bp, err := proxy.NewBedrockProxy(api.ProviderBedrock, args.ProviderCredentials())
if err != nil {
return nil, err
}

ebp, err := proxy.NewBedrockEmbeddingsProxy(api.ProviderBedrock, args.ProviderCredentials())
if err != nil {
klog.ErrorS(err, "Could not create proxy")
os.Exit(1)
}

router := mux.NewRouter()
router.HandleFunc(ollama.EndpointChat, p.Proxy())
router.HandleFunc(openai.EndpointChat, op.Proxy())
router.HandleFunc(openai.EndpointEmbeddings, eop.Proxy())
router.HandleFunc(bedrock.EndpointChat, bp.Proxy())
router.HandleFunc(bedrock.EndpointEmbeddings, ebp.Proxy())

return httptest.NewServer(router), nil
}
Expand Down
Loading
Loading