Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support openai embeddings endpoints in ai proxy #1934

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions go/ai-proxy/api/bedrock/bedrock.go
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
package bedrock

type Endpoint string

const EndpointChat = "/chat/bedrock"
1 change: 1 addition & 0 deletions go/ai-proxy/api/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Endpoint string
const (
EndpointChat = "/openai/v1/chat/completions"
EndpointChatCompletions = "/v1/chat/completions"
EndpointEmbeddings = "/openai/v1/embeddings"
)

type ChatCompletionRequest struct {
Expand Down
2 changes: 2 additions & 0 deletions go/ai-proxy/api/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ func ToProvider(s string) (Provider, error) {
return ProviderAnthropic, nil
case ProviderVertex.String():
return ProviderVertex, nil
case ProviderBedrock.String():
return ProviderBedrock, nil
}

return "", fmt.Errorf("invalid provider: %s", s)
Expand Down
8 changes: 8 additions & 0 deletions go/ai-proxy/args/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ func Provider() api.Provider {
}

func ProviderHost() string {
if Provider() == api.ProviderBedrock {
return ""
}

if len(*argProviderHost) == 0 {
panic(fmt.Errorf("provider host is required"))
}
Expand Down Expand Up @@ -108,3 +112,7 @@ func Address() string {

return fmt.Sprintf("%s:%d", *argAddress, *argPort)
}

func OpenAICompatible() bool {
return Provider() == api.ProviderOpenAI || Provider() == api.ProviderBedrock
}
1 change: 1 addition & 0 deletions go/ai-proxy/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/sashabaranov/go-openai v1.37.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.11.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go/ai-proxy/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/samber/lo v1.47.0 h1:z7RynLwP5nbyRscyvcD043DWYoOcYRv3mV8lBeqOCLc=
github.com/samber/lo v1.47.0/go.mod h1:RmDH9Ct32Qy3gduHQuKJ3gW1fMHAnE/fAzQuf6He5cU=
github.com/sashabaranov/go-openai v1.37.0 h1:hQQowgYm4OXJ1Z/wTrE+XZaO20BYsL0R3uRPSpfNZkY=
github.com/sashabaranov/go-openai v1.37.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down
34 changes: 17 additions & 17 deletions go/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"k8s.io/klog/v2"

"github.com/pluralsh/console/go/ai-proxy/api"
"github.com/pluralsh/console/go/ai-proxy/api/bedrock"
"github.com/pluralsh/console/go/ai-proxy/api/ollama"
"github.com/pluralsh/console/go/ai-proxy/api/openai"
"github.com/pluralsh/console/go/ai-proxy/args"
Expand All @@ -20,29 +19,30 @@ import (
func main() {
klog.V(log.LogLevelMinimal).InfoS("Starting AI Proxy", "provider", args.Provider(), "version", environment.Version, "commit", environment.Commit)

router := mux.NewRouter()
p, err := proxy.NewOllamaTranslationProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials())
if err != nil {
if err == nil {
router.HandleFunc(ollama.EndpointChat, p.Proxy())
} else if args.Provider() != api.ProviderBedrock {
klog.ErrorS(err, "Could not create proxy")
os.Exit(1)
}

op, err := proxy.NewOpenAIProxy(api.ProviderOpenAI, args.ProviderHost(), args.ProviderCredentials())
if err != nil {
klog.ErrorS(err, "Could not create proxy")
os.Exit(1)
if args.OpenAICompatible() {
op, err := proxy.NewOpenAIProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials())
if err != nil {
klog.ErrorS(err, "Could not create proxy")
os.Exit(1)
}
ep, err := proxy.NewOpenAIEmbeddingsProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials())
if err != nil {
klog.ErrorS(err, "Could not create embedding proxy")
os.Exit(1)
}
router.HandleFunc(openai.EndpointChat, op.Proxy())
router.HandleFunc(openai.EndpointEmbeddings, ep.Proxy())
}

bp, err := proxy.NewBedrockProxy(api.ProviderBedrock, args.ProviderCredentials())
if err != nil {
klog.ErrorS(err, "Could not create proxy")
os.Exit(1)
}

router := mux.NewRouter()
router.HandleFunc(ollama.EndpointChat, p.Proxy())
router.HandleFunc(openai.EndpointChat, op.Proxy())
router.HandleFunc(bedrock.EndpointChat, bp.Proxy())

klog.V(log.LogLevelMinimal).InfoS("Listening and serving HTTP", "address", args.Address())
if err := http.ListenAndServe(args.Address(), router); err != nil {
klog.ErrorS(err, "Could not run the router")
Expand Down
137 changes: 137 additions & 0 deletions go/ai-proxy/proxy/bedrock/embeddings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package bedrock

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/ollama/ollama/openai"
"k8s.io/klog/v2"

"github.com/pluralsh/console/go/ai-proxy/api"
)

const (
Titan = "titan"
Cohere = "cohere"
)

type titanEmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}

type cohereEmbeddingResponse struct {
Embedding []float32 `json:"embeddings"`
}

type BedrockEmbeddingsProxy struct {
bedrockClient *bedrockruntime.Client
}

func NewBedrockEmbeddingsProxy(region string) (api.OpenAIProxy, error) {
ctx := context.Background()

var loadOptions []func(options *config.LoadOptions) error
if region != "" {
loadOptions = append(loadOptions, config.WithRegion(region))
}

sdkConfig, err := config.LoadDefaultConfig(ctx, loadOptions...)
if err != nil {
klog.ErrorS(err, "Couldn't load default configuration. Have you set up your AWS account?")
return nil, err
}
bedrockClient := bedrockruntime.NewFromConfig(sdkConfig)
return &BedrockEmbeddingsProxy{
bedrockClient: bedrockClient,
}, nil
}

func (b *BedrockEmbeddingsProxy) Proxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var openAIReq openai.EmbedRequest
if err := json.NewDecoder(r.Body).Decode(&openAIReq); err != nil {
http.Error(w, "failed to parse openai request", http.StatusBadRequest)
return
}
b.handleEmbeddingBedrock(w, &openAIReq)
}
}

func (b *BedrockEmbeddingsProxy) handleEmbeddingBedrock(
w http.ResponseWriter,
req *openai.EmbedRequest,
) {
input := map[string]interface{}{}

switch {
case strings.Contains(strings.ToLower(req.Model), Titan):
input["inputText"] = req.Input
case strings.Contains(strings.ToLower(req.Model), Cohere):
input["texts"] = []string{req.Input.(string)}
default:
klog.Errorf("model doesn't support embedding at this time %s", req.Model)
return
}

payloadBytes, err := json.Marshal(input)
if err != nil {
klog.ErrorS(err, "failed to convert to bedrock request")
return
}

output, err := b.bedrockClient.InvokeModel(context.Background(), &bedrockruntime.InvokeModelInput{
ModelId: aws.String(req.Model),
Body: payloadBytes,
})
if err != nil {
klog.ErrorS(err, "request to bedrock failed")
return
}

response, err := convertBedrockEmbeddingToOpenAI(output, req.Model)
if err != nil {
klog.ErrorS(err, "failed to convert bedrock response to openai format")
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
klog.Errorf("Error encoding response: %v", err)
return
}
}

func convertBedrockEmbeddingToOpenAI(output *bedrockruntime.InvokeModelOutput, model string) (*openai.EmbeddingList, error) {
switch {
case strings.Contains(model, Titan):
var embed titanEmbeddingResponse
if err := json.Unmarshal(output.Body, &embed); err != nil {
return nil, fmt.Errorf("failed to unmarshal Titan embedding response: %v", err)
}
var embedding openai.Embedding
embedding.Embedding = embed.Embedding
return &openai.EmbeddingList{
Model: model,
Data: []openai.Embedding{embedding},
}, nil
case strings.Contains(model, Cohere):
var embed cohereEmbeddingResponse
if err := json.Unmarshal(output.Body, &embed); err != nil {
return nil, fmt.Errorf("failed to unmarshal Cohere embedding response: %v", err)
}
var embedding openai.Embedding
embedding.Embedding = embed.Embedding
return &openai.EmbeddingList{
Model: model,
Data: []openai.Embedding{embedding},
}, nil
default:
return nil, fmt.Errorf("model doesn't support embedding at this time %s", model)
}
}
50 changes: 50 additions & 0 deletions go/ai-proxy/proxy/openai/embeddings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package openai

import (
"fmt"
"net/http/httputil"
"net/url"

"k8s.io/klog/v2"

"github.com/pluralsh/console/go/ai-proxy/api"
"github.com/pluralsh/console/go/ai-proxy/api/openai"
"github.com/pluralsh/console/go/ai-proxy/internal/log"
)

func NewOpenAIEmbeddingsProxy(host, token string) (api.OpenAIProxy, error) {
parsedURL, err := url.Parse(host)
if err != nil {
return nil, err
}

reverse := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.Out.Header.Set("Authorization", "Bearer "+token)

r.SetXForwarded()

targetURL, err := url.Parse(openai.EndpointEmbeddings)
if err != nil {
klog.ErrorS(err, "failed to parse target url")
return
}

r.Out.URL.Scheme = parsedURL.Scheme
r.Out.URL.Host = parsedURL.Host
r.Out.Host = parsedURL.Host
r.Out.URL.Path = targetURL.Path

klog.V(log.LogLevelDebug).InfoS(
"proxying request",
"from", fmt.Sprintf("%s %s", r.In.Method, r.In.URL.Path),
"to", r.Out.URL.String(),
)
},
}

return &OpenAIProxy{
proxy: reverse,
token: token,
}, nil
}
12 changes: 8 additions & 4 deletions go/ai-proxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@ func NewOllamaTranslationProxy(p api.Provider, host string, credentials string)
return nil, fmt.Errorf("invalid provider: %s", p)
}

func NewOpenAIProxy(p api.Provider, host, token string) (api.OpenAIProxy, error) {
func NewOpenAIProxy(p api.Provider, host, credentials string) (api.OpenAIProxy, error) {
switch p {
case api.ProviderOpenAI:
return openai.NewOpenAIProxy(host, token)
return openai.NewOpenAIProxy(host, credentials)
case api.ProviderBedrock:
return bedrock.NewBedrockProxy(credentials)
}
return nil, fmt.Errorf("invalid provider: %s", p)
}

func NewBedrockProxy(p api.Provider, region string) (api.OpenAIProxy, error) {
func NewOpenAIEmbeddingsProxy(p api.Provider, host, credentials string) (api.OpenAIProxy, error) {
switch p {
case api.ProviderOpenAI:
return openai.NewOpenAIEmbeddingsProxy(host, credentials)
case api.ProviderBedrock:
return bedrock.NewBedrockProxy(region)
return bedrock.NewBedrockEmbeddingsProxy(credentials)
}
return nil, fmt.Errorf("invalid provider: %s", p)
}
Loading
Loading