Skip to content

Commit

Permalink
Remove host entries without valid tokens during migration
Browse files Browse the repository at this point in the history
  • Loading branch information
williammartin committed Dec 6, 2023
1 parent 4f33d88 commit 06c36a7
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 13 deletions.
51 changes: 38 additions & 13 deletions internal/config/migration/multi_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/cli/go-gh/v2/pkg/config"
)

var noTokenError = errors.New("no token found")

type CowardlyRefusalError struct {
err error
}
Expand All @@ -21,6 +23,11 @@ func (e CowardlyRefusalError) Error() string {

var hostsKey = []string{"hosts"}

type tokenSource struct {
token string
inKeyring bool
}

// This migration exists to take a hosts section of the following structure:
//
// github.com:
Expand Down Expand Up @@ -95,12 +102,21 @@ func (m MultiAccount) Do(c *config.Config) error {

// Otherwise let's get to the business of migrating!
for _, hostname := range hostnames {
token, inKeyring, err := getToken(c, hostname)
tokenSource, err := getToken(c, hostname)
// If no token existed for this host we'll remove the entry from the hosts file
// by deleting it and moving on to the next one.
if errors.Is(err, noTokenError) {
// The only error that can be returned here is the key not existing, which
// we know can't be true.
_ = c.Remove(append(hostsKey, hostname))
continue
}
// For any other error we'll error out
if err != nil {
return CowardlyRefusalError{fmt.Errorf("couldn't find oauth token for %q: %w", hostname, err)}
}

username, err := getUsername(c, hostname, token, m.Transport)
username, err := getUsername(c, hostname, tokenSource.token, m.Transport)
if err != nil {
return CowardlyRefusalError{fmt.Errorf("couldn't get user name for %q: %w", hostname, err)}
}
Expand All @@ -109,26 +125,35 @@ func (m MultiAccount) Do(c *config.Config) error {
return CowardlyRefusalError{fmt.Errorf("couldn't not migrate config for %q: %w", hostname, err)}
}

if err := migrateToken(hostname, username, token, inKeyring); err != nil {
if err := migrateToken(hostname, username, tokenSource); err != nil {
return CowardlyRefusalError{fmt.Errorf("couldn't not migrate oauth token for %q: %w", hostname, err)}
}
}

return nil
}

func getToken(c *config.Config, hostname string) (string, bool, error) {
func getToken(c *config.Config, hostname string) (tokenSource, error) {
if token, _ := c.Get(append(hostsKey, hostname, "oauth_token")); token != "" {
return token, false, nil
return tokenSource{token: token, inKeyring: false}, nil
}
token, err := keyring.Get(keyringServiceName(hostname), "")
if err != nil {
return "", false, err

// If we have an error and it's not relating to there being no token
// then we'll return the error cause that's really unexpected.
if err != nil && !errors.Is(err, keyring.ErrNotFound) {
return tokenSource{}, err
}
if token == "" {
return "", false, errors.New("token not found in config or keyring")

// Otherwise we'll return a sentinel error
if err != nil || token == "" {
return tokenSource{}, noTokenError
}
return token, true, nil

return tokenSource{
token: token,
inKeyring: true,
}, nil
}

func getUsername(c *config.Config, hostname, token string, transport http.RoundTripper) (string, error) {
Expand Down Expand Up @@ -157,14 +182,14 @@ func getUsername(c *config.Config, hostname, token string, transport http.RoundT
return query.Viewer.Login, nil
}

func migrateToken(hostname, username, token string, inKeyring bool) error {
func migrateToken(hostname, username string, tokenSource tokenSource) error {
// If token is not currently stored in the keyring do not migrate it,
// as it is being stored in the config and is being handled when
// when migrating the config.
if !inKeyring {
if !tokenSource.inKeyring {
return nil
}
return keyring.Set(keyringServiceName(hostname), username, token)
return keyring.Set(keyringServiceName(hostname), username, tokenSource.token)
}

func migrateConfig(c *config.Config, hostname, username string) error {
Expand Down
43 changes: 43 additions & 0 deletions internal/config/migration/multi_account_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package migration_test

import (
"errors"
"fmt"
"testing"

Expand Down Expand Up @@ -252,6 +253,40 @@ hosts:
requireKeyWithValue(t, cfg, []string{"hosts", "github.com", "users", "monalisa", "git_protocol"}, "ssh")
}

func TestMigrationRemovesHostsWithInvalidTokens(t *testing.T) {
// Simulates config when user is logged in securely
// but no token entry is in the keyring.
keyring.MockInit()
cfg := config.ReadFromString(`
hosts:
github.com:
user: user1
git_protocol: ssh
`)

m := migration.MultiAccount{}
require.NoError(t, m.Do(cfg))

requireNoKey(t, cfg, []string{"hosts", "github.com"})
}

func TestMigrationErrorsWhenUnableToGetExpectedSecureToken(t *testing.T) {
// Simulates config when user is logged in securely
// but no token entry is in the keyring.
keyring.MockInitWithError(errors.New("keyring test error"))
cfg := config.ReadFromString(`
hosts:
github.com:
user: user1
git_protocol: ssh
`)

m := migration.MultiAccount{}
err := m.Do(cfg)

require.ErrorContains(t, err, `couldn't find oauth token for "github.com": keyring test error`)
}

func requireKeyWithValue(t *testing.T, cfg *config.Config, keys []string, value string) {
t.Helper()

Expand All @@ -260,3 +295,11 @@ func requireKeyWithValue(t *testing.T, cfg *config.Config, keys []string, value

require.Equal(t, value, actual)
}

func requireNoKey(t *testing.T, cfg *config.Config, keys []string) {
t.Helper()

_, err := cfg.Get(keys)
var keyNotFoundError *config.KeyNotFoundError
require.ErrorAs(t, err, &keyNotFoundError)
}
6 changes: 6 additions & 0 deletions internal/keyring/keyring.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
package keyring

import (
"errors"
"time"

"github.com/zalando/go-keyring"
)

var ErrNotFound = errors.New("secret not found in keyring")

type TimeoutError struct {
message string
}
Expand Down Expand Up @@ -46,6 +49,9 @@ func Get(service, user string) (string, error) {
}()
select {
case res := <-ch:
if errors.Is(res.err, keyring.ErrNotFound) {
return "", ErrNotFound
}
return res.val, res.err
case <-time.After(3 * time.Second):
return "", &TimeoutError{"timeout while trying to get secret from keyring"}
Expand Down

0 comments on commit 06c36a7

Please sign in to comment.