Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to AI Proxy for full Open AI protocol #1752

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions go/ai-proxy/api/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/pluralsh/console/go/ai-proxy/api/ollama"
"github.com/pluralsh/console/go/ai-proxy/api/openai"
"github.com/pluralsh/console/go/ai-proxy/api/openai_standard"
"github.com/pluralsh/console/go/ai-proxy/api/vertex"
)

Expand Down Expand Up @@ -51,9 +50,6 @@ var (
ollamaToVertex ProviderAPIMapping = map[string]string{
ollama.EndpointChat: vertex.EndpointChat,
}
openAIToOpenAI ProviderAPIMapping = map[string]string{
openai_standard.EndpointChat: openai.EndpointChat,
}
)

func ToProviderAPIPath(target Provider, path string) string {
Expand All @@ -66,13 +62,6 @@ func ToProviderAPIPath(target Provider, path string) string {
panic(fmt.Sprintf("path %s not registered for provider %s", path, target))
}

return targetPath
case ProviderOpenAIStandard:
targetPath, exists := openAIToOpenAI[path]
if !exists {
panic(fmt.Sprintf("path %s not registered for provider %s", path, target))
}

return targetPath
case ProviderVertex:
targetPath, exists := ollamaToVertex[path]
Expand Down
6 changes: 5 additions & 1 deletion go/ai-proxy/args/args.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,15 @@ func ProviderCredentials() string {
return *argProviderToken
}

if len(*argProviderToken) > 0 && Provider() == api.ProviderOpenAIStandard {
return *argProviderToken
}

if len(*argProviderServiceAccount) > 0 && Provider() == api.ProviderVertex {
return *argProviderServiceAccount
}

if Provider() == defaultProvider || Provider() == api.ProviderOpenAIStandard {
if Provider() == defaultProvider {
return ""
}

Expand Down
48 changes: 36 additions & 12 deletions go/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,52 @@ import (
"github.com/gorilla/mux"
"k8s.io/klog/v2"

"github.com/pluralsh/console/go/ai-proxy/api"
"github.com/pluralsh/console/go/ai-proxy/api/ollama"
"github.com/pluralsh/console/go/ai-proxy/api/openai_standard"
"github.com/pluralsh/console/go/ai-proxy/args"
"github.com/pluralsh/console/go/ai-proxy/internal/log"
"github.com/pluralsh/console/go/ai-proxy/proxy"
)

func main() {
//p, err := proxy.NewTranslationProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials())
//if err != nil {
// klog.ErrorS(err, "Could not create translation proxy")
// os.Exit(1)
//}

op, err := proxy.NewOpenAIProxy(args.Provider(), args.ProviderHost())
if err != nil {
klog.ErrorS(err, "Could not create openai proxy")
os.Exit(1)
provider := args.Provider()
host := args.ProviderHost()
creds := args.ProviderCredentials()

var translationProxy api.TranslationProxy
if provider != api.ProviderOpenAIStandard {
tp, err := proxy.NewTranslationProxy(provider, host, creds)
if err != nil {
klog.ErrorS(err, "Could not create translation proxy")
os.Exit(1)
}
translationProxy = tp
} else {
translationProxy = nil
}

var openaiProxy api.OpenAIProxy
if provider == api.ProviderOpenAIStandard {
op, err := proxy.NewOpenAIProxy(provider, host, creds)
if err != nil {
klog.ErrorS(err, "Could not create openai proxy")
os.Exit(1)
}
openaiProxy = op
} else {
openaiProxy = nil
}

router := mux.NewRouter()
//router.HandleFunc(ollama.EndpointChat, p.Proxy())
router.HandleFunc(openai_standard.EndpointChat, op.Proxy())

if translationProxy != nil {
router.HandleFunc(ollama.EndpointChat, translationProxy.Proxy())
}

if openaiProxy != nil {
router.HandleFunc(openai_standard.EndpointChat, openaiProxy.Proxy())
}

klog.V(log.LogLevelMinimal).InfoS("Listening and serving HTTP", "address", args.Address())
if err := http.ListenAndServe(args.Address(), router); err != nil {
Expand Down
102 changes: 102 additions & 0 deletions go/ai-proxy/proxy/openai/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package openai

import (
"bytes"
"compress/gzip"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/url"
"strings"

"github.com/andybalholm/brotli"
"k8s.io/klog/v2"

"github.com/pluralsh/console/go/ai-proxy/api"
"github.com/pluralsh/console/go/ai-proxy/api/openai"
"github.com/pluralsh/console/go/ai-proxy/internal/log"
)

const headerContentEncoding = "Content-Encoding"

type OpenAIProxy struct {
proxy *httputil.ReverseProxy
token string
}

func (o *OpenAIProxy) Proxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
o.proxy.ServeHTTP(w, r)
}
}

func NewOpenAIStandardProxy(host, token string) (api.OpenAIProxy, error) {
parsedURL, err := url.Parse(host)
if err != nil {
return nil, err
}

reverse := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.Out.Header.Set("Authorization", "Bearer "+token)

r.SetXForwarded()

targetURL, err := url.Parse(openai.EndpointChat)
if err != nil {
klog.ErrorS(err, "failed to parse target url")
return
}

r.Out.URL.Scheme = parsedURL.Scheme
r.Out.URL.Host = parsedURL.Host
r.Out.Host = parsedURL.Host
r.Out.URL.Path = targetURL.Path

klog.V(log.LogLevelDebug).InfoS(
"proxying request",
"from", fmt.Sprintf("%s %s", r.In.Method, r.In.URL.Path),
"to", r.Out.URL.String(),
)
},

ModifyResponse: func(resp *http.Response) error {
contentEncoding := resp.Header.Get(headerContentEncoding)
if contentEncoding == "" {
return nil
}

var reader io.Reader
switch strings.TrimSpace(contentEncoding) {
case "br":
resp.Header.Del(headerContentEncoding)
reader = brotli.NewReader(resp.Body)
case "gzip":
resp.Header.Del(headerContentEncoding)
gzr, err := gzip.NewReader(resp.Body)
if err != nil {
return err
}
reader = gzr
default:
return nil
}

decompressed, err := io.ReadAll(reader)
if err != nil {
return err
}

resp.Body = io.NopCloser(bytes.NewReader(decompressed))
resp.ContentLength = int64(len(decompressed))

return nil
},
}

return &OpenAIProxy{
proxy: reverse,
token: token,
}, nil
}
50 changes: 0 additions & 50 deletions go/ai-proxy/proxy/provider/openai_standard.go

This file was deleted.

7 changes: 3 additions & 4 deletions go/ai-proxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/pluralsh/console/go/ai-proxy/api"
"github.com/pluralsh/console/go/ai-proxy/proxy/openai"
"github.com/pluralsh/console/go/ai-proxy/proxy/provider"
)

Expand All @@ -13,8 +14,6 @@ func NewTranslationProxy(p api.Provider, host string, credentials string) (api.T
return provider.NewOllamaProxy(host)
case api.ProviderOpenAI:
return provider.NewOpenAIProxy(host, credentials)
//case api.ProviderOpenAIStandard:
// return provider.NewOpenAIStandardProxy(host)
case api.ProviderVertex:
return provider.NewVertexProxy(host, credentials)
case api.ProviderAnthropic:
Expand All @@ -24,10 +23,10 @@ func NewTranslationProxy(p api.Provider, host string, credentials string) (api.T
return nil, fmt.Errorf("invalid provider: %s", p)
}

func NewOpenAIProxy(p api.Provider, host string) (api.OpenAIProxy, error) {
func NewOpenAIProxy(p api.Provider, host, token string) (api.OpenAIProxy, error) {
switch p {
case api.ProviderOpenAIStandard:
return provider.NewOpenAIStandardProxy(host)
return openai.NewOpenAIStandardProxy(host, token)
}
return nil, fmt.Errorf("invalid provider: %s", p)
}
30 changes: 0 additions & 30 deletions go/ai-proxy/router/router.go

This file was deleted.

32 changes: 26 additions & 6 deletions go/ai-proxy/test/helpers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,40 @@ import (
"net/http/httptest"
"strings"

"github.com/gorilla/mux"

"github.com/pluralsh/console/go/ai-proxy/api"
"github.com/pluralsh/console/go/ai-proxy/api/ollama"
"github.com/pluralsh/console/go/ai-proxy/api/openai_standard"
"github.com/pluralsh/console/go/ai-proxy/args"
"github.com/pluralsh/console/go/ai-proxy/proxy"
"github.com/pluralsh/console/go/ai-proxy/router"
)

func SetupServer() (*httptest.Server, error) {
p, err := proxy.NewTranslationProxy(args.Provider(), args.ProviderHost(), args.ProviderCredentials())
if err != nil {
return nil, err
provider := args.Provider()
host := args.ProviderHost()
creds := args.ProviderCredentials()

router := mux.NewRouter()

if provider == api.ProviderOpenAIStandard {
op, err := proxy.NewOpenAIProxy(provider, host, creds)
if err != nil {
return nil, err
}
router.HandleFunc(openai_standard.EndpointChat, op.Proxy())

} else {
p, err := proxy.NewTranslationProxy(provider, host, creds)
if err != nil {
return nil, err
}

router.HandleFunc(ollama.EndpointChat, p.Proxy())
}

return httptest.NewServer(router.NewRouter(p)), nil
return httptest.NewServer(router), nil
}

func SetupProviderServer(handlers map[string]http.HandlerFunc) (*httptest.Server, error) {
server := httptest.NewUnstartedServer(http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
if handler, exists := handlers[request.URL.Path]; exists {
Expand Down
Loading