From a5bf64f596fff0679352b0af19ad10cfe32e478f Mon Sep 17 00:00:00 2001 From: John Blackwell Date: Thu, 20 Feb 2025 14:40:08 -0500 Subject: [PATCH 1/5] Support openai embeddings endpoints in ai proxy --- go/ai-proxy/api/bedrock/bedrock.go | 5 +- go/ai-proxy/api/openai/openai.go | 1 + go/ai-proxy/go.mod | 1 + go/ai-proxy/go.sum | 2 + go/ai-proxy/main.go | 14 ++ go/ai-proxy/proxy/bedrock/embeddings.go | 133 ++++++++++++++++++ go/ai-proxy/proxy/openai/embeddings.go | 50 +++++++ go/ai-proxy/proxy/proxy.go | 16 +++ go/ai-proxy/test/bedrock/bedrock_test.go | 48 +++++++ go/ai-proxy/test/helpers/common.go | 16 +++ .../openai_standard/openai_standard_test.go | 57 +++++++- 11 files changed, 339 insertions(+), 4 deletions(-) create mode 100644 go/ai-proxy/proxy/bedrock/embeddings.go create mode 100644 go/ai-proxy/proxy/openai/embeddings.go diff --git a/go/ai-proxy/api/bedrock/bedrock.go b/go/ai-proxy/api/bedrock/bedrock.go index c0d50bf5cf..83a0be778e 100644 --- a/go/ai-proxy/api/bedrock/bedrock.go +++ b/go/ai-proxy/api/bedrock/bedrock.go @@ -2,4 +2,7 @@ package bedrock type Endpoint string -const EndpointChat = "/chat/bedrock" +const ( + EndpointChat = "/chat/bedrock" + EndpointEmbeddings = "/embeddings/bedrock" +) diff --git a/go/ai-proxy/api/openai/openai.go b/go/ai-proxy/api/openai/openai.go index 3847b830fd..bd83808527 100644 --- a/go/ai-proxy/api/openai/openai.go +++ b/go/ai-proxy/api/openai/openai.go @@ -15,6 +15,7 @@ type Endpoint string const ( EndpointChat = "/openai/v1/chat/completions" EndpointChatCompletions = "/v1/chat/completions" + EndpointEmbeddings = "/v1/embeddings" ) type ChatCompletionRequest struct { diff --git a/go/ai-proxy/go.mod b/go/ai-proxy/go.mod index 71c141f4b7..fa98f6bd51 100644 --- a/go/ai-proxy/go.mod +++ b/go/ai-proxy/go.mod @@ -53,6 +53,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/rogpeppe/go-internal v1.8.0 // indirect + github.com/sashabaranov/go-openai v1.37.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.11.0 // indirect diff --git a/go/ai-proxy/go.sum b/go/ai-proxy/go.sum index 56cd75f8c1..b2f6d22293 100644 --- a/go/ai-proxy/go.sum +++ b/go/ai-proxy/go.sum @@ -109,6 +109,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc= github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU= +github.com/sashabaranov/go-openai v1.37.0 h1:hQQowgYm4OXJ1Z/wTrE+XZaO20BYsL0R3uRPSpfNZkY= +github.com/sashabaranov/go-openai v1.37.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/go/ai-proxy/main.go b/go/ai-proxy/main.go index cf774102b9..b4fe2d79bc 100644 --- a/go/ai-proxy/main.go +++ b/go/ai-proxy/main.go @@ -32,16 +32,30 @@ func main() { os.Exit(1) } + eop, err := proxy.NewOpenAIEmbeddingsProxy(api.ProviderOpenAI, args.ProviderHost(), args.ProviderCredentials()) + if err != nil { + klog.ErrorS(err, "Could not create proxy") + os.Exit(1) + } + bp, err := proxy.NewBedrockProxy(api.ProviderBedrock, args.ProviderCredentials()) if err != nil { klog.ErrorS(err, "Could not create proxy") os.Exit(1) } + ebp, err := proxy.NewBedrockEmbeddingsProxy(api.ProviderBedrock, args.ProviderCredentials()) + if err != nil { + klog.ErrorS(err, "Could not create proxy") + os.Exit(1) + } + router := mux.NewRouter() router.HandleFunc(ollama.EndpointChat, p.Proxy()) router.HandleFunc(openai.EndpointChat, op.Proxy()) + router.HandleFunc(openai.EndpointEmbeddings, eop.Proxy()) router.HandleFunc(bedrock.EndpointChat, bp.Proxy()) + router.HandleFunc(bedrock.EndpointEmbeddings, ebp.Proxy()) klog.V(log.LogLevelMinimal).InfoS("Listening and serving HTTP", "address", args.Address()) if err := http.ListenAndServe(args.Address(), router); err != nil { diff --git a/go/ai-proxy/proxy/bedrock/embeddings.go b/go/ai-proxy/proxy/bedrock/embeddings.go new file mode 100644 index 0000000000..ba0e03b4e5 --- /dev/null +++ b/go/ai-proxy/proxy/bedrock/embeddings.go @@ -0,0 +1,133 @@ +package bedrock + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/ollama/ollama/openai" + "k8s.io/klog/v2" + + "github.com/pluralsh/console/go/ai-proxy/api" +) + +const ( + Titan = "titan" + Cohere = "cohere" +) + +type titanEmbeddingResponse struct { + Embedding []float32 `json:"embedding"` +} + +type cohereEmbeddingResponse struct { + Embedding []float32 `json:"embeddings"` +} + +type BedrockEmbeddingsProxy struct { + bedrockClient *bedrockruntime.Client +} + +func NewBedrockEmbeddingsProxy(region string) (api.OpenAIProxy, error) { + ctx := context.Background() + + var loadOptions []func(options *config.LoadOptions) error + if region != "" { + loadOptions = append(loadOptions, config.WithRegion(region)) + } + + sdkConfig, err := config.LoadDefaultConfig(ctx, loadOptions...) + if err != nil { + klog.ErrorS(err, "Couldn't load default configuration. Have you set up your AWS account?") + return nil, err + } + bedrockClient := bedrockruntime.NewFromConfig(sdkConfig) + return &BedrockEmbeddingsProxy{ + bedrockClient: bedrockClient, + }, nil +} + +func (b *BedrockEmbeddingsProxy) Proxy() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var openAIReq openai.EmbedRequest + if err := json.NewDecoder(r.Body).Decode(&openAIReq); err != nil { + http.Error(w, "failed to parse openai request", http.StatusBadRequest) + return + } + b.handleEmbeddingBedrock(w, &openAIReq) + } +} + +func (b *BedrockEmbeddingsProxy) handleEmbeddingBedrock( + w http.ResponseWriter, + req *openai.EmbedRequest, +) { + input := map[string]interface{}{} + + switch { + case strings.Contains(strings.ToLower(req.Model), Titan): + input["inputText"] = req.Input + case strings.Contains(strings.ToLower(req.Model), Cohere): + input["texts"] = []string{req.Input.(string)} + default: + klog.Errorf("model doesn't support embedding at this time %s", req.Model) + return + } + + payloadBytes, err := json.Marshal(input) + if err != nil { + klog.ErrorS(err, "failed to convert to bedrock request") + return + } + + output, err := b.bedrockClient.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(req.Model), + Body: payloadBytes, + }) + + response, err := convertBedrockEmbeddingToOpenAI(output, req.Model) + if err != nil { + klog.ErrorS(err, "failed to convert bedrock response to openai format") + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + klog.Errorf("Error encoding response: %v", err) + return + } +} + +func convertBedrockEmbeddingToOpenAI(output *bedrockruntime.InvokeModelOutput, model string) (*openai.EmbeddingList, error) { + switch { + case strings.Contains(model, Titan): + var embed titanEmbeddingResponse + if err := json.Unmarshal(output.Body, &embed); err != nil { + return nil, fmt.Errorf("failed to unmarshal Titan embedding response: %v", err) + } + var embedding openai.Embedding + embedding.Embedding = embed.Embedding + return &openai.EmbeddingList{ + Model: model, + Data: []openai.Embedding{embedding}, + }, nil + case strings.Contains(model, Cohere): + var embed cohereEmbeddingResponse + if err := json.Unmarshal(output.Body, &embed); err != nil { + return nil, fmt.Errorf("failed to unmarshal Cohere embedding response: %v", err) + } + var embedding openai.Embedding + embedding.Embedding = embed.Embedding + return &openai.EmbeddingList{ + Model: model, + Data: []openai.Embedding{embedding}, + }, nil + default: + return nil, fmt.Errorf("model doesn't support embedding at this time %s", model) + } +} diff --git a/go/ai-proxy/proxy/openai/embeddings.go b/go/ai-proxy/proxy/openai/embeddings.go new file mode 100644 index 0000000000..6ec27ab290 --- /dev/null +++ b/go/ai-proxy/proxy/openai/embeddings.go @@ -0,0 +1,50 @@ +package openai + +import ( + "fmt" + "net/http/httputil" + "net/url" + + "k8s.io/klog/v2" + + "github.com/pluralsh/console/go/ai-proxy/api" + "github.com/pluralsh/console/go/ai-proxy/api/openai" + "github.com/pluralsh/console/go/ai-proxy/internal/log" +) + +func NewOpenAIEmbeddingsProxy(host, token string) (api.OpenAIProxy, error) { + parsedURL, err := url.Parse(host) + if err != nil { + return nil, err + } + + reverse := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.Out.Header.Set("Authorization", "Bearer "+token) + + r.SetXForwarded() + + targetURL, err := url.Parse(openai.EndpointEmbeddings) + if err != nil { + klog.ErrorS(err, "failed to parse target url") + return + } + + r.Out.URL.Scheme = parsedURL.Scheme + r.Out.URL.Host = parsedURL.Host + r.Out.Host = parsedURL.Host + r.Out.URL.Path = targetURL.Path + + klog.V(log.LogLevelDebug).InfoS( + "proxying request", + "from", fmt.Sprintf("%s %s", r.In.Method, r.In.URL.Path), + "to", r.Out.URL.String(), + ) + }, + } + + return &OpenAIProxy{ + proxy: reverse, + token: token, + }, nil +} diff --git a/go/ai-proxy/proxy/proxy.go b/go/ai-proxy/proxy/proxy.go index 57d9bf3a65..4958c12594 100644 --- a/go/ai-proxy/proxy/proxy.go +++ b/go/ai-proxy/proxy/proxy.go @@ -32,6 +32,14 @@ func NewOpenAIProxy(p api.Provider, host, token string) (api.OpenAIProxy, error) return nil, fmt.Errorf("invalid provider: %s", p) } +func NewOpenAIEmbeddingsProxy(p api.Provider, host, token string) (api.OpenAIProxy, error) { + switch p { + case api.ProviderOpenAI: + return openai.NewOpenAIEmbeddingsProxy(host, token) + } + return nil, fmt.Errorf("invalid provider: %s", p) +} + func NewBedrockProxy(p api.Provider, region string) (api.OpenAIProxy, error) { switch p { case api.ProviderBedrock: @@ -39,3 +47,11 @@ func NewBedrockProxy(p api.Provider, region string) (api.OpenAIProxy, error) { } return nil, fmt.Errorf("invalid provider: %s", p) } + +func NewBedrockEmbeddingsProxy(p api.Provider, region string) (api.OpenAIProxy, error) { + switch p { + case api.ProviderBedrock: + return bedrock.NewBedrockEmbeddingsProxy(region) + } + return nil, fmt.Errorf("invalid provider: %s", p) +} diff --git a/go/ai-proxy/test/bedrock/bedrock_test.go b/go/ai-proxy/test/bedrock/bedrock_test.go index 29532cacb2..5a7c3330ce 100644 --- a/go/ai-proxy/test/bedrock/bedrock_test.go +++ b/go/ai-proxy/test/bedrock/bedrock_test.go @@ -147,3 +147,51 @@ func TestBedrockProxy_Streaming(t *testing.T) { } }) } + +func TestBedrockEmbeddingsProxy(t *testing.T) { + cases := []helpers.TestStruct[any, any]{ + { + Name: "embeddings request should return correct openai response", + Method: "POST", + Endpoint: bedrock.EndpointEmbeddings, + Request: openai.EmbedRequest{ + Model: "amazon.titan-embed-text-v2:0", + Input: "Hello from Titan embeddings test.", + }, + WantData: openai.EmbeddingList{ + Model: "amazon.titan-embed-text-v2:0", + Data: []openai.Embedding{ + { + Embedding: make([]float32, 5)}, + }, + }, + WantErr: nil, + WantStatus: http.StatusOK, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + wantDataBytes, err := json.Marshal(tc.WantData) + if err != nil { + t.Fatal(err) + } + + mockResponseFunc := helpers.MockResponse(tc.Endpoint, wantDataBytes, tc.WantErr, tc.WantStatus) + err = mockResponseFunc(handlers) + if err != nil { + t.Fatal(err) + } + + requestFunc := helpers.CreateRequest(tc.Method, tc.Endpoint, tc.Request) + res, err := requestFunc(server, providerServer) + if !errors.Is(err, tc.WantErr) { + t.Fatalf("\nwant:\n%v\ngot:\n%v", tc.WantErr, err) + } + + if !bytes.Equal(wantDataBytes, res) { + t.Errorf("\nwant:\n%s\ngot:\n%s", tc.WantData, res) + } + }) + } +} diff --git a/go/ai-proxy/test/helpers/common.go b/go/ai-proxy/test/helpers/common.go index cb4072bb61..58f961f8b7 100644 --- a/go/ai-proxy/test/helpers/common.go +++ b/go/ai-proxy/test/helpers/common.go @@ -8,9 +8,11 @@ import ( "net" "net/http" "net/http/httptest" + "os" "strings" "github.com/gorilla/mux" + "k8s.io/klog/v2" "github.com/pluralsh/console/go/ai-proxy/api" "github.com/pluralsh/console/go/ai-proxy/api/bedrock" @@ -32,15 +34,29 @@ func SetupServer() (*httptest.Server, error) { return nil, err } + eop, err := proxy.NewOpenAIEmbeddingsProxy(api.ProviderOpenAI, args.ProviderHost(), args.ProviderCredentials()) + if err != nil { + klog.ErrorS(err, "Could not create proxy") + os.Exit(1) + } + bp, err := proxy.NewBedrockProxy(api.ProviderBedrock, args.ProviderCredentials()) if err != nil { return nil, err } + ebp, err := proxy.NewBedrockEmbeddingsProxy(api.ProviderBedrock, args.ProviderCredentials()) + if err != nil { + klog.ErrorS(err, "Could not create proxy") + os.Exit(1) + } + router := mux.NewRouter() router.HandleFunc(ollama.EndpointChat, p.Proxy()) router.HandleFunc(openai.EndpointChat, op.Proxy()) + router.HandleFunc(openai.EndpointEmbeddings, eop.Proxy()) router.HandleFunc(bedrock.EndpointChat, bp.Proxy()) + router.HandleFunc(bedrock.EndpointEmbeddings, ebp.Proxy()) return httptest.NewServer(router), nil } diff --git a/go/ai-proxy/test/openai_standard/openai_standard_test.go b/go/ai-proxy/test/openai_standard/openai_standard_test.go index a3b0fa98b8..101e46da23 100644 --- a/go/ai-proxy/test/openai_standard/openai_standard_test.go +++ b/go/ai-proxy/test/openai_standard/openai_standard_test.go @@ -29,7 +29,10 @@ var ( handlers = make(map[string]http.HandlerFunc) ) -const endpoint = "/openai/chat/completions" +const ( + endpointChat = "/openai/chat/completions" + endpointEmbeddings = "/embeddings/bedrock" +) func TestMain(m *testing.M) { var err error @@ -51,7 +54,7 @@ func TestOpenAIStandardProxy(t *testing.T) { { Name: "chat request should return correct openai response", Method: "POST", - Endpoint: endpoint, + Endpoint: endpointChat, Request: openai.ChatCompletionRequest{ Model: "testmodel", Messages: []openai.Message{{ @@ -102,7 +105,7 @@ func TestOpenAIStandardProxy_Streaming(t *testing.T) { streamTest := helpers.TestStruct[openai.ChatCompletionRequest, any]{ Name: "chat request with streaming should return SSE headers", Method: "POST", - Endpoint: endpoint, + Endpoint: endpointChat, Request: openai.ChatCompletionRequest{ Model: "testmodel", Stream: true, @@ -147,3 +150,51 @@ func TestOpenAIStandardProxy_Streaming(t *testing.T) { } }) } + +func TestOpenAIEmbeddingsProxy(t *testing.T) { + cases := []helpers.TestStruct[any, any]{ + { + Name: "chat request should return correct openai response", + Method: "POST", + Endpoint: endpointEmbeddings, + Request: openai.EmbedRequest{ + Model: "openai-embedding-model", + Input: "Hello from embeddings test.", + }, + WantData: openai.EmbeddingList{ + Model: "openai-embedding-model", + Data: []openai.Embedding{ + { + Embedding: make([]float32, 5)}, + }, + }, + WantErr: nil, + WantStatus: http.StatusOK, + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + wantDataBytes, err := json.Marshal(tc.WantData) + if err != nil { + t.Fatal(err) + } + + mockResponseFunc := helpers.MockResponse(tc.Endpoint, wantDataBytes, tc.WantErr, tc.WantStatus) + err = mockResponseFunc(handlers) + if err != nil { + t.Fatal(err) + } + + requestFunc := helpers.CreateRequest(tc.Method, tc.Endpoint, tc.Request) + res, err := requestFunc(server, providerServer) + if !errors.Is(err, tc.WantErr) { + t.Fatalf("\nwant:\n%v\ngot:\n%v", tc.WantErr, err) + } + + if !bytes.Equal(wantDataBytes, res) { + t.Errorf("\nwant:\n%s\ngot:\n%s", tc.WantData, res) + } + }) + } +} From 23e2fd0740d49a5b2cc7e7f84d52a664aa024111 Mon Sep 17 00:00:00 2001 From: John Blackwell Date: Thu, 20 Feb 2025 14:47:19 -0500 Subject: [PATCH 2/5] fix ineffectual error handling --- go/ai-proxy/proxy/bedrock/embeddings.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/go/ai-proxy/proxy/bedrock/embeddings.go b/go/ai-proxy/proxy/bedrock/embeddings.go index ba0e03b4e5..3c9aa00d3e 100644 --- a/go/ai-proxy/proxy/bedrock/embeddings.go +++ b/go/ai-proxy/proxy/bedrock/embeddings.go @@ -89,6 +89,10 @@ func (b *BedrockEmbeddingsProxy) handleEmbeddingBedrock( ModelId: aws.String(req.Model), Body: payloadBytes, }) + if err != nil { + klog.ErrorS(err, "request to bedrock failed") + return + } response, err := convertBedrockEmbeddingToOpenAI(output, req.Model) if err != nil { From fb8f403bbe7c366eb68658eef6e0ca01a1eda191 Mon Sep 17 00:00:00 2001 From: John Blackwell Date: Mon, 24 Feb 2025 16:04:37 -0500 Subject: [PATCH 3/5] cleanup routes --- go/ai-proxy/api/bedrock/bedrock.go | 7 --- go/ai-proxy/api/provider.go | 2 + go/ai-proxy/args/args.go | 4 ++ go/ai-proxy/main.go | 54 ++++++++----------- go/ai-proxy/proxy/proxy.go | 26 +++------ go/ai-proxy/test/bedrock/bedrock_test.go | 7 ++- go/ai-proxy/test/helpers/common.go | 50 ++++++++--------- .../openai_standard/openai_standard_test.go | 5 +- 8 files changed, 61 insertions(+), 94 deletions(-) diff --git a/go/ai-proxy/api/bedrock/bedrock.go b/go/ai-proxy/api/bedrock/bedrock.go index 83a0be778e..8a0a1908d5 100644 --- a/go/ai-proxy/api/bedrock/bedrock.go +++ b/go/ai-proxy/api/bedrock/bedrock.go @@ -1,8 +1 @@ package bedrock - -type Endpoint string - -const ( - EndpointChat = "/chat/bedrock" - EndpointEmbeddings = "/embeddings/bedrock" -) diff --git a/go/ai-proxy/api/provider.go b/go/ai-proxy/api/provider.go index a51f9f31a4..393f8d874c 100644 --- a/go/ai-proxy/api/provider.go +++ b/go/ai-proxy/api/provider.go @@ -24,6 +24,8 @@ func ToProvider(s string) (Provider, error) { return ProviderAnthropic, nil case ProviderVertex.String(): return ProviderVertex, nil + case ProviderBedrock.String(): + return ProviderBedrock, nil } return "", fmt.Errorf("invalid provider: %s", s) diff --git a/go/ai-proxy/args/args.go b/go/ai-proxy/args/args.go index 881558820b..46281c063c 100644 --- a/go/ai-proxy/args/args.go +++ b/go/ai-proxy/args/args.go @@ -65,6 +65,10 @@ func Provider() api.Provider { } func ProviderHost() string { + if Provider() == api.ProviderBedrock { + return "" + } + if len(*argProviderHost) == 0 { panic(fmt.Errorf("provider host is required")) } diff --git a/go/ai-proxy/main.go b/go/ai-proxy/main.go index b4fe2d79bc..402c653a66 100644 --- a/go/ai-proxy/main.go +++ b/go/ai-proxy/main.go @@ -8,7 +8,6 @@ import ( "k8s.io/klog/v2" "github.com/pluralsh/console/go/ai-proxy/api" - "github.com/pluralsh/console/go/ai-proxy/api/bedrock" "github.com/pluralsh/console/go/ai-proxy/api/ollama" "github.com/pluralsh/console/go/ai-proxy/api/openai" "github.com/pluralsh/console/go/ai-proxy/args" @@ -20,43 +19,34 @@ import ( func main() { klog.V(log.LogLevelMinimal).InfoS("Starting AI Proxy", "provider", args.Provider(), "version", environment.Version, "commit", environment.Commit) + router := mux.NewRouter() p, err := proxy.NewOllamaTranslationProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials()) if err != nil { - klog.ErrorS(err, "Could not create proxy") - os.Exit(1) - } - - op, err := proxy.NewOpenAIProxy(api.ProviderOpenAI, args.ProviderHost(), args.ProviderCredentials()) - if err != nil { - klog.ErrorS(err, "Could not create proxy") - os.Exit(1) - } - - eop, err := proxy.NewOpenAIEmbeddingsProxy(api.ProviderOpenAI, args.ProviderHost(), args.ProviderCredentials()) - if err != nil { - klog.ErrorS(err, "Could not create proxy") - os.Exit(1) + if args.Provider() == api.ProviderBedrock { + + } else { + klog.ErrorS(err, "Could not create proxy") + os.Exit(1) + } + } else { + router.HandleFunc(ollama.EndpointChat, p.Proxy()) } - bp, err := proxy.NewBedrockProxy(api.ProviderBedrock, args.ProviderCredentials()) - if err != nil { - klog.ErrorS(err, "Could not create proxy") - os.Exit(1) + if args.Provider() == api.ProviderOpenAI || args.Provider() == api.ProviderBedrock { + op, err := proxy.NewOpenAIProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials()) + if err != nil { + klog.ErrorS(err, "Could not create proxy") + os.Exit(1) + } + ep, err := proxy.NewOpenAIEmbeddingsProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials()) + if err != nil { + klog.ErrorS(err, "Could not create embedding proxy") + os.Exit(1) + } + router.HandleFunc(openai.EndpointChat, op.Proxy()) + router.HandleFunc(openai.EndpointEmbeddings, ep.Proxy()) } - ebp, err := proxy.NewBedrockEmbeddingsProxy(api.ProviderBedrock, args.ProviderCredentials()) - if err != nil { - klog.ErrorS(err, "Could not create proxy") - os.Exit(1) - } - - router := mux.NewRouter() - router.HandleFunc(ollama.EndpointChat, p.Proxy()) - router.HandleFunc(openai.EndpointChat, op.Proxy()) - router.HandleFunc(openai.EndpointEmbeddings, eop.Proxy()) - router.HandleFunc(bedrock.EndpointChat, bp.Proxy()) - router.HandleFunc(bedrock.EndpointEmbeddings, ebp.Proxy()) - klog.V(log.LogLevelMinimal).InfoS("Listening and serving HTTP", "address", args.Address()) if err := http.ListenAndServe(args.Address(), router); err != nil { klog.ErrorS(err, "Could not run the router") diff --git a/go/ai-proxy/proxy/proxy.go b/go/ai-proxy/proxy/proxy.go index 4958c12594..61387b793e 100644 --- a/go/ai-proxy/proxy/proxy.go +++ b/go/ai-proxy/proxy/proxy.go @@ -24,34 +24,22 @@ func NewOllamaTranslationProxy(p api.Provider, host string, credentials string) return nil, fmt.Errorf("invalid provider: %s", p) } -func NewOpenAIProxy(p api.Provider, host, token string) (api.OpenAIProxy, error) { +func NewOpenAIProxy(p api.Provider, host, credentials string) (api.OpenAIProxy, error) { switch p { case api.ProviderOpenAI: - return openai.NewOpenAIProxy(host, token) - } - return nil, fmt.Errorf("invalid provider: %s", p) -} - -func NewOpenAIEmbeddingsProxy(p api.Provider, host, token string) (api.OpenAIProxy, error) { - switch p { - case api.ProviderOpenAI: - return openai.NewOpenAIEmbeddingsProxy(host, token) - } - return nil, fmt.Errorf("invalid provider: %s", p) -} - -func NewBedrockProxy(p api.Provider, region string) (api.OpenAIProxy, error) { - switch p { + return openai.NewOpenAIProxy(host, credentials) case api.ProviderBedrock: - return bedrock.NewBedrockProxy(region) + return bedrock.NewBedrockProxy(credentials) } return nil, fmt.Errorf("invalid provider: %s", p) } -func NewBedrockEmbeddingsProxy(p api.Provider, region string) (api.OpenAIProxy, error) { +func NewOpenAIEmbeddingsProxy(p api.Provider, host, credentials string) (api.OpenAIProxy, error) { switch p { + case api.ProviderOpenAI: + return openai.NewOpenAIEmbeddingsProxy(host, credentials) case api.ProviderBedrock: - return bedrock.NewBedrockEmbeddingsProxy(region) + return bedrock.NewBedrockEmbeddingsProxy(credentials) } return nil, fmt.Errorf("invalid provider: %s", p) } diff --git a/go/ai-proxy/test/bedrock/bedrock_test.go b/go/ai-proxy/test/bedrock/bedrock_test.go index 5a7c3330ce..c5eac66ec9 100644 --- a/go/ai-proxy/test/bedrock/bedrock_test.go +++ b/go/ai-proxy/test/bedrock/bedrock_test.go @@ -13,7 +13,6 @@ import ( "github.com/spf13/pflag" "k8s.io/klog/v2" - "github.com/pluralsh/console/go/ai-proxy/api/bedrock" "github.com/pluralsh/console/go/ai-proxy/test/helpers" ) @@ -50,7 +49,7 @@ func TestBedrockProxy(t *testing.T) { { Name: "chat request should return correct openai response", Method: "POST", - Endpoint: bedrock.EndpointChat, + Endpoint: "/openai/v1/chat/completions", Request: openai.ChatCompletionRequest{ Model: "anthropic.claude-v2", Messages: []openai.Message{{ @@ -102,7 +101,7 @@ func TestBedrockProxy_Streaming(t *testing.T) { streamTest := helpers.TestStruct[openai.ChatCompletionRequest, any]{ Name: "chat request with streaming should return SSE headers", Method: "POST", - Endpoint: bedrock.EndpointChat, + Endpoint: "/openai/v1/chat/completions", Request: openai.ChatCompletionRequest{ Model: "testmodel", Stream: true, @@ -153,7 +152,7 @@ func TestBedrockEmbeddingsProxy(t *testing.T) { { Name: "embeddings request should return correct openai response", Method: "POST", - Endpoint: bedrock.EndpointEmbeddings, + Endpoint: "/v1/embeddings", Request: openai.EmbedRequest{ Model: "amazon.titan-embed-text-v2:0", Input: "Hello from Titan embeddings test.", diff --git a/go/ai-proxy/test/helpers/common.go b/go/ai-proxy/test/helpers/common.go index 58f961f8b7..3344fbcb44 100644 --- a/go/ai-proxy/test/helpers/common.go +++ b/go/ai-proxy/test/helpers/common.go @@ -15,7 +15,6 @@ import ( "k8s.io/klog/v2" "github.com/pluralsh/console/go/ai-proxy/api" - "github.com/pluralsh/console/go/ai-proxy/api/bedrock" "github.com/pluralsh/console/go/ai-proxy/api/ollama" "github.com/pluralsh/console/go/ai-proxy/api/openai" "github.com/pluralsh/console/go/ai-proxy/args" @@ -23,41 +22,34 @@ import ( ) func SetupServer() (*httptest.Server, error) { + router := mux.NewRouter() p, err := proxy.NewOllamaTranslationProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials()) if err != nil { - fmt.Println("Failed") - return nil, err - } - - op, err := proxy.NewOpenAIProxy(api.ProviderOpenAI, args.ProviderHost(), args.ProviderCredentials()) - if err != nil { - return nil, err - } - - eop, err := proxy.NewOpenAIEmbeddingsProxy(api.ProviderOpenAI, args.ProviderHost(), args.ProviderCredentials()) - if err != nil { - klog.ErrorS(err, "Could not create proxy") - os.Exit(1) - } + if args.Provider() == api.ProviderBedrock { - bp, err := proxy.NewBedrockProxy(api.ProviderBedrock, args.ProviderCredentials()) - if err != nil { - return nil, err + } else { + klog.ErrorS(err, "Could not create proxy") + os.Exit(1) + } + } else { + router.HandleFunc(ollama.EndpointChat, p.Proxy()) } - ebp, err := proxy.NewBedrockEmbeddingsProxy(api.ProviderBedrock, args.ProviderCredentials()) - if err != nil { - klog.ErrorS(err, "Could not create proxy") - os.Exit(1) + if args.Provider() == api.ProviderOpenAI || args.Provider() == api.ProviderBedrock { + op, err := proxy.NewOpenAIProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials()) + if err != nil { + klog.ErrorS(err, "Could not create proxy") + os.Exit(1) + } + ep, err := proxy.NewOpenAIEmbeddingsProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials()) + if err != nil { + klog.ErrorS(err, "Could not create embedding proxy") + os.Exit(1) + } + router.HandleFunc(openai.EndpointChat, op.Proxy()) + router.HandleFunc(openai.EndpointEmbeddings, ep.Proxy()) } - router := mux.NewRouter() - router.HandleFunc(ollama.EndpointChat, p.Proxy()) - router.HandleFunc(openai.EndpointChat, op.Proxy()) - router.HandleFunc(openai.EndpointEmbeddings, eop.Proxy()) - router.HandleFunc(bedrock.EndpointChat, bp.Proxy()) - router.HandleFunc(bedrock.EndpointEmbeddings, ebp.Proxy()) - return httptest.NewServer(router), nil } func SetupProviderServer(handlers map[string]http.HandlerFunc) (*httptest.Server, error) { diff --git a/go/ai-proxy/test/openai_standard/openai_standard_test.go b/go/ai-proxy/test/openai_standard/openai_standard_test.go index 101e46da23..c13f6bac4f 100644 --- a/go/ai-proxy/test/openai_standard/openai_standard_test.go +++ b/go/ai-proxy/test/openai_standard/openai_standard_test.go @@ -30,8 +30,7 @@ var ( ) const ( - endpointChat = "/openai/chat/completions" - endpointEmbeddings = "/embeddings/bedrock" + endpointChat = "/openai/chat/completions" ) func TestMain(m *testing.M) { @@ -156,7 +155,7 @@ func TestOpenAIEmbeddingsProxy(t *testing.T) { { Name: "chat request should return correct openai response", Method: "POST", - Endpoint: endpointEmbeddings, + Endpoint: "/v1/embeddings", Request: openai.EmbedRequest{ Model: "openai-embedding-model", Input: "Hello from embeddings test.", From e9a7b4bd06767190a2cf2723de2896ed61c862c9 Mon Sep 17 00:00:00 2001 From: John Blackwell Date: Mon, 24 Feb 2025 16:26:37 -0500 Subject: [PATCH 4/5] helper function --- go/ai-proxy/args/args.go | 4 ++++ go/ai-proxy/main.go | 2 +- go/ai-proxy/test/helpers/common.go | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/go/ai-proxy/args/args.go b/go/ai-proxy/args/args.go index 46281c063c..8b24e3de37 100644 --- a/go/ai-proxy/args/args.go +++ b/go/ai-proxy/args/args.go @@ -112,3 +112,7 @@ func Address() string { return fmt.Sprintf("%s:%d", *argAddress, *argPort) } + +func OpenAICompatible() bool { + return Provider() == api.ProviderOpenAI || Provider() == api.ProviderBedrock +} diff --git a/go/ai-proxy/main.go b/go/ai-proxy/main.go index 402c653a66..04288e5cd6 100644 --- a/go/ai-proxy/main.go +++ b/go/ai-proxy/main.go @@ -32,7 +32,7 @@ func main() { router.HandleFunc(ollama.EndpointChat, p.Proxy()) } - if args.Provider() == api.ProviderOpenAI || args.Provider() == api.ProviderBedrock { + if args.OpenAICompatible() { op, err := proxy.NewOpenAIProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials()) if err != nil { klog.ErrorS(err, "Could not create proxy") diff --git a/go/ai-proxy/test/helpers/common.go b/go/ai-proxy/test/helpers/common.go index 3344fbcb44..2245602140 100644 --- a/go/ai-proxy/test/helpers/common.go +++ b/go/ai-proxy/test/helpers/common.go @@ -35,7 +35,7 @@ func SetupServer() (*httptest.Server, error) { router.HandleFunc(ollama.EndpointChat, p.Proxy()) } - if args.Provider() == api.ProviderOpenAI || args.Provider() == api.ProviderBedrock { + if args.OpenAICompatible() { op, err := proxy.NewOpenAIProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials()) if err != nil { klog.ErrorS(err, "Could not create proxy") From 4972a6ef5e0e73d7985892bd8a3f4ff169198305 Mon Sep 17 00:00:00 2001 From: John Blackwell Date: Mon, 24 Feb 2025 16:42:43 -0500 Subject: [PATCH 5/5] cleanup --- go/ai-proxy/api/openai/openai.go | 2 +- go/ai-proxy/main.go | 12 ++++-------- go/ai-proxy/test/bedrock/bedrock_test.go | 2 +- .../test/openai_standard/openai_standard_test.go | 2 +- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/go/ai-proxy/api/openai/openai.go b/go/ai-proxy/api/openai/openai.go index bd83808527..1e8bbfcaab 100644 --- a/go/ai-proxy/api/openai/openai.go +++ b/go/ai-proxy/api/openai/openai.go @@ -15,7 +15,7 @@ type Endpoint string const ( EndpointChat = "/openai/v1/chat/completions" EndpointChatCompletions = "/v1/chat/completions" - EndpointEmbeddings = "/v1/embeddings" + EndpointEmbeddings = "/openai/v1/embeddings" ) type ChatCompletionRequest struct { diff --git a/go/ai-proxy/main.go b/go/ai-proxy/main.go index 04288e5cd6..9d5f86ac13 100644 --- a/go/ai-proxy/main.go +++ b/go/ai-proxy/main.go @@ -21,15 +21,11 @@ func main() { router := mux.NewRouter() p, err := proxy.NewOllamaTranslationProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials()) - if err != nil { - if args.Provider() == api.ProviderBedrock { - - } else { - klog.ErrorS(err, "Could not create proxy") - os.Exit(1) - } - } else { + if err == nil { router.HandleFunc(ollama.EndpointChat, p.Proxy()) + } else if args.Provider() != api.ProviderBedrock { + klog.ErrorS(err, "Could not create proxy") + os.Exit(1) } if args.OpenAICompatible() { diff --git a/go/ai-proxy/test/bedrock/bedrock_test.go b/go/ai-proxy/test/bedrock/bedrock_test.go index c5eac66ec9..9ac70c001c 100644 --- a/go/ai-proxy/test/bedrock/bedrock_test.go +++ b/go/ai-proxy/test/bedrock/bedrock_test.go @@ -152,7 +152,7 @@ func TestBedrockEmbeddingsProxy(t *testing.T) { { Name: "embeddings request should return correct openai response", Method: "POST", - Endpoint: "/v1/embeddings", + Endpoint: "/openai/v1/embeddings", Request: openai.EmbedRequest{ Model: "amazon.titan-embed-text-v2:0", Input: "Hello from Titan embeddings test.", diff --git a/go/ai-proxy/test/openai_standard/openai_standard_test.go b/go/ai-proxy/test/openai_standard/openai_standard_test.go index c13f6bac4f..66e2ef813f 100644 --- a/go/ai-proxy/test/openai_standard/openai_standard_test.go +++ b/go/ai-proxy/test/openai_standard/openai_standard_test.go @@ -155,7 +155,7 @@ func TestOpenAIEmbeddingsProxy(t *testing.T) { { Name: "chat request should return correct openai response", Method: "POST", - Endpoint: "/v1/embeddings", + Endpoint: "/openai/v1/embeddings", Request: openai.EmbedRequest{ Model: "openai-embedding-model", Input: "Hello from embeddings test.",