Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Devin/1736522054 allow arbitrary bedrock models #296

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ go 1.23.1

require (
github.com/atotto/clipboard v0.1.4
github.com/aws/aws-sdk-go-v2 v1.32.8
github.com/aws/aws-sdk-go-v2/config v1.28.10
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.2
github.com/jarcoal/httpmock v1.3.1
github.com/muesli/mango-cobra v1.2.0
github.com/muesli/roff v0.1.0
Expand All @@ -15,6 +18,18 @@ require (
)

require (
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.51 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.23 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.27 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.27 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.8 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.9 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.8 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.6 // indirect
github.com/aws/smithy-go v1.22.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
Expand Down
30 changes: 30 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,35 @@
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aws/aws-sdk-go-v2 v1.32.8 h1:cZV+NUS/eGxKXMtmyhtYPJ7Z4YLoI/V8bkTdRZfYhGo=
github.com/aws/aws-sdk-go-v2 v1.32.8/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc=
github.com/aws/aws-sdk-go-v2/config v1.28.10 h1:fKODZHfqQu06pCzR69KJ3GuttraRJkhlC8g80RZ0Dfg=
github.com/aws/aws-sdk-go-v2/config v1.28.10/go.mod h1:PvdxRYZ5Um9QMq9PQ0zHHNdtKK+he2NHtFCUFMXWXeg=
github.com/aws/aws-sdk-go-v2/credentials v1.17.51 h1:F/9Sm6Y6k4LqDesZDPJCLxQGXNNHd/ZtJiWd0lCZKRk=
github.com/aws/aws-sdk-go-v2/credentials v1.17.51/go.mod h1:TKbzCHm43AoPyA+iLGGcruXd4AFhF8tOmLex2R9jWNQ=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.23 h1:IBAoD/1d8A8/1aA8g4MBVtTRHhXRiNAgwdbo/xRM2DI=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.23/go.mod h1:vfENuCM7dofkgKpYzuzf1VT1UKkA/YL3qanfBn7HCaA=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.27 h1:jSJjSBzw8VDIbWv+mmvBSP8ezsztMYJGH+eKqi9AmNs=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.27/go.mod h1:/DAhLbFRgwhmvJdOfSm+WwikZrCuUJiA4WgJG0fTNSw=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.27 h1:l+X4K77Dui85pIj5foXDhPlnqcNRG2QUyvca300lXh8=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.27/go.mod h1:KvZXSFEXm6x84yE8qffKvT3x8J5clWnVFXphpohhzJ8=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.2 h1:8CcCVDj3hdUJoa1aOxdsKl6c73bC80x9ZylUWCytmgk=
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.23.2/go.mod h1:pCst69koE8+hbZ7EohPkOrOhyvqWqXxIVo8cp655yAg=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 h1:iXtILhvDxB6kPvEXgsDhGaZCSC6LQET5ZHSdJozeI0Y=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1/go.mod h1:9nu0fVANtYiAePIBh2/pFUSwtJ402hLnp854CNoDOeE=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.8 h1:cWno7lefSH6Pp+mSznagKCgfDGeZRin66UvYUqAkyeA=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.8/go.mod h1:tPD+VjU3ABTBoEJ3nctu5Nyg4P4yjqSH5bJGGkY4+XE=
github.com/aws/aws-sdk-go-v2/service/sso v1.24.9 h1:YqtxripbjWb2QLyzRK9pByfEDvgg95gpC2AyDq4hFE8=
github.com/aws/aws-sdk-go-v2/service/sso v1.24.9/go.mod h1:lV8iQpg6OLOfBnqbGMBKYjilBlf633qwHnBEiMSPoHY=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.8 h1:6dBT1Lz8fK11m22R+AqfRsFn8320K0T5DTGxxOQBSMw=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.8/go.mod h1:/kiBvRQXBc6xeJTYzhSdGvJ5vm1tjaDEjH+MSeRJnlY=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.6 h1:VwhTrsTuVn52an4mXx29PqRzs2Dvu921NpGk7y43tAM=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.6/go.mod h1:+8h7PZb3yY5ftmVLD7ocEoE98hdc8PoKS0H3wfx1dlc=
github.com/aws/smithy-go v1.22.1 h1:/HPHZQ0g7f4eUeK6HKglFz8uwVfZKgoI25rb/J+dnro=
github.com/aws/smithy-go v1.22.1/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
14 changes: 12 additions & 2 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,25 @@ var (

// OpenAIClient is a client for the OpenAI API.
type OpenAIClient struct {
HTTPClient *http.Client
httpClient *http.Client
config *viper.Viper
api *openai.Client
out io.Writer
chatSessionManager chat.SessionManager
}

// GetHTTPClient implements the Provider interface
func (c *OpenAIClient) GetHTTPClient() *http.Client {
return c.httpClient
}

// CreateClient creates a new OpenAI client with the given config and output writer.
func CreateClient(config *viper.Viper, out io.Writer) (*OpenAIClient, error) {
// Don't create OpenAI client if we're using Bedrock
if config.GetString("provider") == "bedrock" {
return nil, fmt.Errorf("cannot create OpenAI client when using Bedrock provider")
}

// Check, if api key was set
apiKey, exists := os.LookupEnv(envKeyOpenAIApi)
if !exists {
Expand Down Expand Up @@ -95,7 +105,7 @@ func CreateClient(config *viper.Viper, out io.Writer) (*OpenAIClient, error) {

// Create client
client := &OpenAIClient{
HTTPClient: httpClient,
httpClient: httpClient,
config: config,
api: openai.NewClientWithConfig(clientConfig),
out: out,
Expand Down
53 changes: 34 additions & 19 deletions pkg/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,33 @@ import (

"github.com/jarcoal/httpmock"
"github.com/sashabaranov/go-openai"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
"github.com/tbckr/sgpt/v2/internal/testlib"
"github.com/tbckr/sgpt/v2/pkg/chat"
)

func TestCreateClient(t *testing.T) {
func TestCreateProvider(t *testing.T) {
// Set the api key
err := os.Setenv("OPENAI_API_KEY", "test")
require.NoError(t, err)

var client *OpenAIClient
client, err = CreateClient(nil, nil)
config := viper.New()
config.Set("provider", "openai")

var provider Provider
provider, err = CreateProvider(config, nil)
require.NoError(t, err)
require.NotNil(t, client)
require.NotNil(t, provider)
}

func TestCreateClientMissingApiKey(t *testing.T) {
err := os.Unsetenv("OPENAI_API_KEY")
require.NoError(t, err)

config := viper.New()
var client *OpenAIClient
client, err = CreateClient(nil, nil)
client, err = CreateClient(config, nil)
require.Error(t, err)
require.ErrorIs(t, err, ErrMissingAPIKey)
require.Nil(t, client)
Expand All @@ -73,7 +78,7 @@ func TestSimplePrompt(t *testing.T) {
prompt := []string{"Say: Hello World!"}
expected := "Hello World!"

httpmock.ActivateNonDefault(client.HTTPClient)
httpmock.ActivateNonDefault(client.GetHTTPClient())
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(expected)

Expand Down Expand Up @@ -122,7 +127,7 @@ func TestStreamSimplePrompt(t *testing.T) {
prompt := []string{"Say: Hello World!"}
expected := "Hello World!"

httpmock.ActivateNonDefault(client.HTTPClient)
httpmock.ActivateNonDefault(client.GetHTTPClient())
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponseStream(expected)

Expand Down Expand Up @@ -160,7 +165,7 @@ func TestPromptSaveAsChat(t *testing.T) {
prompt := []string{"Say: Hello World!"}
expected := "Hello World!"

httpmock.ActivateNonDefault(client.HTTPClient)
httpmock.ActivateNonDefault(client.GetHTTPClient())
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(expected)

Expand All @@ -180,7 +185,9 @@ func TestPromptSaveAsChat(t *testing.T) {
require.Equal(t, expected, result)
require.NoError(t, writer.Close())

require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat"))
chatDir := filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")
require.DirExists(t, chatDir)
require.FileExists(t, filepath.Join(chatDir, "messages.json"))

var manager chat.SessionManager
manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config)
Expand Down Expand Up @@ -215,7 +222,7 @@ func TestPromptLoadChat(t *testing.T) {
prompt := []string{"Repeat last message"}
expected := "World!"

httpmock.ActivateNonDefault(client.HTTPClient)
httpmock.ActivateNonDefault(client.GetHTTPClient())
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(expected)

Expand Down Expand Up @@ -281,7 +288,7 @@ func TestPromptWithModifier(t *testing.T) {
response := `echo \"Hello World\"`
expected := `echo "Hello World"`

httpmock.ActivateNonDefault(client.HTTPClient)
httpmock.ActivateNonDefault(client.GetHTTPClient())
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(response)

Expand Down Expand Up @@ -309,7 +316,9 @@ func TestPromptWithModifier(t *testing.T) {
require.Equal(t, expected, result)
require.NoError(t, writer.Close())

require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat"))
chatDir := filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")
require.DirExists(t, chatDir)
require.FileExists(t, filepath.Join(chatDir, "messages.json"))

var manager chat.SessionManager
manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config)
Expand Down Expand Up @@ -348,7 +357,7 @@ func TestSimplePromptWithLocalImage(t *testing.T) {
expected := "The image shows a character that appears to be a stylized robot. It has"
inputImage := "testdata/marvin.jpg"

httpmock.ActivateNonDefault(client.HTTPClient)
httpmock.ActivateNonDefault(client.GetHTTPClient())
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(expected)

Expand Down Expand Up @@ -385,7 +394,7 @@ func TestSimplePromptWithLocalImageAndChat(t *testing.T) {
expected := "The image shows a character that appears to be a stylized robot. It has"
inputImage := "testdata/marvin.jpg"

httpmock.ActivateNonDefault(client.HTTPClient)
httpmock.ActivateNonDefault(client.GetHTTPClient())
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(expected)

Expand All @@ -405,7 +414,9 @@ func TestSimplePromptWithLocalImageAndChat(t *testing.T) {
require.Equal(t, expected, result)
require.NoError(t, writer.Close())

require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat"))
chatDir := filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")
require.DirExists(t, chatDir)
require.FileExists(t, filepath.Join(chatDir, "messages.json"))

var manager chat.SessionManager
manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config)
Expand Down Expand Up @@ -450,7 +461,7 @@ func TestSimplePromptWithURLImageAndChat(t *testing.T) {
expected := "The image shows a character that appears to be a stylized robot. It has"
inputImage := "https://upload.wikimedia.org/wikipedia/en/c/cb/Marvin_%28HHGG%29.jpg"

httpmock.ActivateNonDefault(client.HTTPClient)
httpmock.ActivateNonDefault(client.GetHTTPClient())
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(expected)

Expand All @@ -470,7 +481,9 @@ func TestSimplePromptWithURLImageAndChat(t *testing.T) {
require.Equal(t, expected, result)
require.NoError(t, writer.Close())

require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat"))
chatDir := filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")
require.DirExists(t, chatDir)
require.FileExists(t, filepath.Join(chatDir, "messages.json"))

var manager chat.SessionManager
manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config)
Expand Down Expand Up @@ -515,7 +528,7 @@ func TestSimplePromptWithMixedImagesAndChat(t *testing.T) {
inputImageFile := "testdata/marvin.jpg"
inputImageURL := "https://upload.wikimedia.org/wikipedia/en/c/cb/Marvin_%28HHGG%29.jpg"

httpmock.ActivateNonDefault(client.HTTPClient)
httpmock.ActivateNonDefault(client.GetHTTPClient())
t.Cleanup(httpmock.DeactivateAndReset)
testlib.RegisterExpectedChatResponse(expected)

Expand All @@ -535,7 +548,9 @@ func TestSimplePromptWithMixedImagesAndChat(t *testing.T) {
require.Equal(t, expected, result)
require.NoError(t, writer.Close())

require.FileExists(t, filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat"))
chatDir := filepath.Join(testCtx.Config.GetString("cacheDir"), "test_chat")
require.DirExists(t, chatDir)
require.FileExists(t, filepath.Join(chatDir, "messages.json"))

var manager chat.SessionManager
manager, err = chat.NewFilesystemChatSessionManager(testCtx.Config)
Expand Down
Loading
Loading