Skip to content

Commit

Permalink
Use go-gh/auth package for IsEnterprise, IsTenancy, and NormalizeHost…
Browse files Browse the repository at this point in the history
…name
  • Loading branch information
jtmcg committed Oct 15, 2024
1 parent 44fdb33 commit 81591a0
Show file tree
Hide file tree
Showing 18 changed files with 45 additions and 212 deletions.
4 changes: 2 additions & 2 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
"regexp"
"strings"

"github.com/cli/cli/v2/internal/ghinstance"
ghAPI "github.com/cli/go-gh/v2/pkg/api"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)

const (
Expand Down Expand Up @@ -249,7 +249,7 @@ func generateScopesSuggestion(statusCode int, endpointNeedsScopes, tokenHasScope
return fmt.Sprintf(
"This API operation needs the %[1]q scope. To request it, run: gh auth refresh -h %[2]s -s %[1]s",
s,
ghinstance.NormalizeHostname(hostname),
ghauth.NormalizeHostname(hostname),
)
}

Expand Down
4 changes: 2 additions & 2 deletions api/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (
"strings"
"time"

"github.com/cli/cli/v2/internal/ghinstance"
"github.com/cli/cli/v2/utils"
ghAPI "github.com/cli/go-gh/v2/pkg/api"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
)

type tokenGetter interface {
Expand Down Expand Up @@ -98,7 +98,7 @@ func AddAuthTokenHeader(rt http.RoundTripper, cfg tokenGetter) http.RoundTripper
// Only set header if an initial request or redirect request to the same host as the initial request.
// If the host has changed during a redirect do not add the authentication token header.
if !redirectHostnameChange {
hostname := ghinstance.NormalizeHostname(getHost(req))
hostname := ghauth.NormalizeHostname(getHost(req))
if token, _ := cfg.ActiveToken(hostname); token != "" {
req.Header.Set(authorization, fmt.Sprintf("token %s", token))
}
Expand Down
4 changes: 3 additions & 1 deletion internal/authflow/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/cli/cli/v2/utils"
"github.com/cli/oauth"
"github.com/henvic/httpretty"

ghauth "github.com/cli/go-gh/v2/pkg/auth"
)

var (
Expand Down Expand Up @@ -105,7 +107,7 @@ func AuthFlow(oauthHost string, IO *iostreams.IOStreams, notice string, addition

func getCallbackURI(oauthHost string) string {
callbackURI := "http://127.0.0.1/callback"
if ghinstance.IsEnterprise(oauthHost) {
if ghauth.IsEnterprise(oauthHost) {
// the OAuth app on Enterprise hosts is still registered with a legacy callback URL
// see https://github.com/cli/cli/pull/222, https://github.com/cli/cli/pull/650
callbackURI = "http://localhost/"
Expand Down
10 changes: 5 additions & 5 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/cli/cli/v2/internal/gh"
"github.com/cli/cli/v2/internal/keyring"
o "github.com/cli/cli/v2/pkg/option"
ghAuth "github.com/cli/go-gh/v2/pkg/auth"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
ghConfig "github.com/cli/go-gh/v2/pkg/config"
)

Expand Down Expand Up @@ -206,7 +206,7 @@ func (c *AuthConfig) ActiveToken(hostname string) (string, string) {
if c.tokenOverride != nil {
return c.tokenOverride(hostname)
}
token, source := ghAuth.TokenFromEnvOrConfig(hostname)
token, source := ghauth.TokenFromEnvOrConfig(hostname)
if token == "" {
var err error
token, err = c.TokenFromKeyring(hostname)
Expand Down Expand Up @@ -240,7 +240,7 @@ func (c *AuthConfig) HasEnvToken() bool {
// It has to use a hostname that is not going to be found in the hosts so that it
// can guarantee that tokens will only be returned from a set env var.
// Discussed here, but maybe worth revisiting: https://github.com/cli/cli/pull/7169#discussion_r1136979033
token, _ := ghAuth.TokenFromEnvOrConfig(hostname)
token, _ := ghauth.TokenFromEnvOrConfig(hostname)
return token != ""
}

Expand Down Expand Up @@ -282,7 +282,7 @@ func (c *AuthConfig) Hosts() []string {
if c.hostsOverride != nil {
return c.hostsOverride()
}
return ghAuth.KnownHosts()
return ghauth.KnownHosts()
}

// SetHosts will override any hosts resolution and return the given
Expand All @@ -297,7 +297,7 @@ func (c *AuthConfig) DefaultHost() (string, string) {
if c.defaultHostOverride != nil {
return c.defaultHostOverride()
}
return ghAuth.DefaultHost()
return ghauth.DefaultHost()
}

// SetDefaultHost will override any host resolution and return the given
Expand Down
7 changes: 4 additions & 3 deletions internal/featuredetection/feature_detection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"net/http"

"github.com/cli/cli/v2/api"
"github.com/cli/cli/v2/internal/ghinstance"
"golang.org/x/sync/errgroup"

ghauth "github.com/cli/go-gh/v2/pkg/auth"
)

type Detector interface {
Expand Down Expand Up @@ -62,7 +63,7 @@ func NewDetector(httpClient *http.Client, host string) Detector {
}

func (d *detector) IssueFeatures() (IssueFeatures, error) {
if !ghinstance.IsEnterprise(d.host) {
if !ghauth.IsEnterprise(d.host) {
return allIssueFeatures, nil
}

Expand Down Expand Up @@ -163,7 +164,7 @@ func (d *detector) PullRequestFeatures() (PullRequestFeatures, error) {
}

func (d *detector) RepositoryFeatures() (RepositoryFeatures, error) {
if !ghinstance.IsEnterprise(d.host) {
if !ghauth.IsEnterprise(d.host) {
return allRepositoryFeatures, nil
}

Expand Down
42 changes: 8 additions & 34 deletions internal/ghinstance/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"errors"
"fmt"
"strings"

ghauth "github.com/cli/go-gh/v2/pkg/auth"
)

// DefaultHostname is the domain name of the default GitHub instance.
Expand All @@ -20,45 +22,17 @@ func Default() string {
return defaultHostname
}

// IsEnterprise reports whether a non-normalized host name looks like a GHE instance.
func IsEnterprise(h string) bool {
normalizedHostName := NormalizeHostname(h)
return normalizedHostName != defaultHostname && normalizedHostName != localhost
}

// IsTenancy reports whether a non-normalized host name looks like a tenancy instance.
func IsTenancy(h string) bool {
normalizedHostName := NormalizeHostname(h)
return strings.HasSuffix(normalizedHostName, "."+tenancyHost)
}

// TenantName extracts the tenant name from tenancy host name and
// reports whether it found the tenant name.
func TenantName(h string) (string, bool) {
normalizedHostName := NormalizeHostname(h)
normalizedHostName := ghauth.NormalizeHostname(h)
return cutSuffix(normalizedHostName, "."+tenancyHost)
}

func isGarage(h string) bool {
return strings.EqualFold(h, "garage.github.com")
}

// NormalizeHostname returns the canonical host name of a GitHub instance.
func NormalizeHostname(h string) string {
hostname := strings.ToLower(h)
if strings.HasSuffix(hostname, "."+defaultHostname) {
return defaultHostname
}
if strings.HasSuffix(hostname, "."+localhost) {
return localhost
}
if before, found := cutSuffix(hostname, "."+tenancyHost); found {
idx := strings.LastIndex(before, ".")
return fmt.Sprintf("%s.%s", before[idx+1:], tenancyHost)
}
return hostname
}

func HostnameValidator(hostname string) error {
if len(strings.TrimSpace(hostname)) < 1 {
return errors.New("a value is required")
Expand All @@ -77,10 +51,10 @@ func GraphQLEndpoint(hostname string) string {
// conditional can be removed as the flow will fall through to the bottom.
// However, we can't do that until we've investigated all places in which
// Tenancy is currently treated as Enterprise.
if IsTenancy(hostname) {
if ghauth.IsTenancy(hostname) {
return fmt.Sprintf("https://api.%s/graphql", hostname)
}
if IsEnterprise(hostname) {
if ghauth.IsEnterprise(hostname) {
return fmt.Sprintf("https://%s/api/graphql", hostname)
}
if strings.EqualFold(hostname, localhost) {
Expand All @@ -97,10 +71,10 @@ func RESTPrefix(hostname string) string {
// conditional can be removed as the flow will fall through to the bottom.
// However, we can't do that until we've investigated all places in which
// Tenancy is currently treated as Enterprise.
if IsTenancy(hostname) {
if ghauth.IsTenancy(hostname) {
return fmt.Sprintf("https://api.%s/", hostname)
}
if IsEnterprise(hostname) {
if ghauth.IsEnterprise(hostname) {
return fmt.Sprintf("https://%s/api/v3/", hostname)
}
if strings.EqualFold(hostname, localhost) {
Expand All @@ -121,7 +95,7 @@ func GistHost(hostname string) string {
if isGarage(hostname) {
return fmt.Sprintf("%s/gist/", hostname)
}
if IsEnterprise(hostname) {
if ghauth.IsEnterprise(hostname) {
return fmt.Sprintf("%s/gist/", hostname)
}
if strings.EqualFold(hostname, localhost) {
Expand Down
145 changes: 0 additions & 145 deletions internal/ghinstance/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,88 +6,6 @@ import (
"github.com/stretchr/testify/assert"
)

func TestIsEnterprise(t *testing.T) {
tests := []struct {
host string
want bool
}{
{
host: "github.com",
want: false,
},
{
host: "api.github.com",
want: false,
},
{
host: "github.localhost",
want: false,
},
{
host: "api.github.localhost",
want: false,
},
{
host: "garage.github.com",
want: false,
},
{
host: "ghe.io",
want: true,
},
{
host: "example.com",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := IsEnterprise(tt.host); got != tt.want {
t.Errorf("IsEnterprise() = %v, want %v", got, tt.want)
}
})
}
}

func TestIsTenancy(t *testing.T) {
tests := []struct {
host string
want bool
}{
{
host: "github.com",
want: false,
},
{
host: "github.localhost",
want: false,
},
{
host: "garage.github.com",
want: false,
},
{
host: "ghe.com",
want: false,
},
{
host: "tenant.ghe.com",
want: true,
},
{
host: "api.tenant.ghe.com",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := IsTenancy(tt.host); got != tt.want {
t.Errorf("IsTenancy() = %v, want %v", got, tt.want)
}
})
}
}

func TestTenantName(t *testing.T) {
tests := []struct {
host string
Expand Down Expand Up @@ -130,69 +48,6 @@ func TestTenantName(t *testing.T) {
}
}

func TestNormalizeHostname(t *testing.T) {
tests := []struct {
host string
want string
}{
{
host: "GitHub.com",
want: "github.com",
},
{
host: "api.github.com",
want: "github.com",
},
{
host: "ssh.github.com",
want: "github.com",
},
{
host: "upload.github.com",
want: "github.com",
},
{
host: "GitHub.localhost",
want: "github.localhost",
},
{
host: "api.github.localhost",
want: "github.localhost",
},
{
host: "garage.github.com",
want: "github.com",
},
{
host: "GHE.IO",
want: "ghe.io",
},
{
host: "git.my.org",
want: "git.my.org",
},
{
host: "ghe.com",
want: "ghe.com",
},
{
host: "tenant.ghe.com",
want: "tenant.ghe.com",
},
{
host: "api.tenant.ghe.com",
want: "tenant.ghe.com",
},
}
for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
if got := NormalizeHostname(tt.host); got != tt.want {
t.Errorf("NormalizeHostname() = %v, want %v", got, tt.want)
}
})
}
}

func TestHostnameValidator(t *testing.T) {
tests := []struct {
name string
Expand Down
4 changes: 2 additions & 2 deletions internal/ghrepo/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"strings"

"github.com/cli/cli/v2/internal/ghinstance"
ghAuth "github.com/cli/go-gh/v2/pkg/auth"
ghauth "github.com/cli/go-gh/v2/pkg/auth"
"github.com/cli/go-gh/v2/pkg/repository"
)

Expand Down Expand Up @@ -37,7 +37,7 @@ func FullName(r Interface) string {
}

func defaultHost() string {
host, _ := ghAuth.DefaultHost()
host, _ := ghauth.DefaultHost()
return host
}

Expand Down
Loading

0 comments on commit 81591a0

Please sign in to comment.