Skip to content

Commit

Permalink
fix: address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Binbin Li <[email protected]>
  • Loading branch information
binbin-li committed Jan 26, 2025
1 parent 39af35f commit 429f09b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 25 deletions.
14 changes: 4 additions & 10 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,10 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider
return nil, re.ErrorCodeEnvNotSet.WithDetail("AZURE_CLIENT_ID environment variable is empty").WithComponentType(re.AuthProvider)
}
}
if err != nil {
return nil, err
}

if len(conf.Endpoints) == 0 {
conf.Endpoints = []string{defaultACREndpoint}
} else {
if err := validateEndpoints(conf.Endpoints); err != nil {
return nil, re.ErrorCodeConfigInvalid.WithError(err)
}
endpoints, err := parseEndpoints(conf.Endpoints)

Check warning on line 112 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L112

Added line #L112 was not covered by tests
if err != nil {
return nil, re.ErrorCodeConfigInvalid.WithError(err)

Check warning on line 114 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L114

Added line #L114 was not covered by tests
}

// retrieve an AAD Access token
Expand All @@ -132,7 +126,7 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider
tenantID: tenant,
authClientFactory: &defaultAuthClientFactoryImpl{}, // Concrete implementation
getManagedIdentityToken: &defaultManagedIdentityTokenGetterImpl{}, // Concrete implementation
endpoints: conf.Endpoints,
endpoints: endpoints,

Check warning on line 129 in pkg/common/oras/authprovider/azure/azureidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureidentity.go#L129

Added line #L129 was not covered by tests
}, nil
}

Expand Down
31 changes: 17 additions & 14 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,9 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider
}
}

if len(conf.Endpoints) == 0 {
conf.Endpoints = []string{defaultACREndpoint}
} else {
if err := validateEndpoints(conf.Endpoints); err != nil {
return nil, re.ErrorCodeConfigInvalid.WithError(err)
}
endpoints, err := parseEndpoints(conf.Endpoints)
if err != nil {
return nil, re.ErrorCodeConfigInvalid.WithError(err)
}

Check warning on line 123 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L122-L123

Added lines #L122 - L123 were not covered by tests

// retrieve an AAD Access token
Expand All @@ -139,7 +136,7 @@ func (s *AzureWIProviderFactory) Create(authProviderConfig provider.AuthProvider
registryHostGetter: &defaultRegistryHostGetterImpl{}, // Concrete implementation
getAADAccessToken: &defaultAADAccessTokenGetterImpl{}, // Concrete implementation
reportMetrics: &defaultMetricsReporterImpl{},
endpoints: conf.Endpoints,
endpoints: endpoints,

Check warning on line 139 in pkg/common/oras/authprovider/azure/azureworkloadidentity.go

View check run for this annotation

Codecov / codecov/patch

pkg/common/oras/authprovider/azure/azureworkloadidentity.go#L139

Added line #L139 was not covered by tests
}, nil
}

Expand Down Expand Up @@ -222,7 +219,8 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
return authConfig, nil
}

// validateEndpoints checks if the endpoints are valid for auth provider.
// parseEndpoints checks if the endpoints are valid for auth provider. If no
// endpoints are provided, it defaults to the default ACR endpoint.
// A valid endpoint is either a fully qualified domain name or a wildcard domain
// name folloiwing RFC 1034.
// Valid examples:
Expand All @@ -233,23 +231,27 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
// - *
// - example.*
// - *example.com.*
func validateEndpoints(endpoints []string) error {
// - *.
func parseEndpoints(endpoints []string) ([]string, error) {
if len(endpoints) == 0 {
return []string{defaultACREndpoint}, nil
}
for _, endpoint := range endpoints {
switch strings.Count(endpoint, "*") {
case 0:
continue
case 1:
if !strings.HasPrefix(endpoint, "*.") {
return fmt.Errorf("invalid wildcard domain name: %s, it must start with '*.'", endpoint)
return nil, fmt.Errorf("invalid wildcard domain name: %s, it must start with '*.'", endpoint)
}
if len(endpoint) < 3 {
return fmt.Errorf("invalid wildcard domain name: %s, it must have at least one character after '*.'", endpoint)
return nil, fmt.Errorf("invalid wildcard domain name: %s, it must have at least one character after '*.'", endpoint)
}
default:
return fmt.Errorf("invalid wildcard domain name: %s, it must have at most one wildcard character", endpoint)
return nil, fmt.Errorf("invalid wildcard domain name: %s, it must have at most one wildcard character", endpoint)
}
}
return nil
return endpoints, nil
}

// validateHost checks if the host is matching endpoints supported by the auth
Expand All @@ -262,7 +264,8 @@ func validateHost(host string, endpoints []string) error {
return nil
}
case 1:
if strings.HasSuffix(host, strings.TrimPrefix(endpoint, "*")) {
index := strings.Index(host, ".")
if index > -1 && host[index:] == strings.TrimPrefix(endpoint, "*") {
return nil
}
default:
Expand Down
90 changes: 90 additions & 0 deletions pkg/common/oras/authprovider/azure/azureworkloadidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,93 @@ func TestAzureWIValidation_EnvironmentVariables_ExpectedResults(t *testing.T) {
t.Fatalf("create auth provider should have failed: expected err %s, but got err %s", expectedErr, err)
}
}

func TestValidateEndpoints(t *testing.T) {
tests := []struct {
name string
endpoint string
expectedErr bool
}{
{
name: "global wildcard",
endpoint: "*",
expectedErr: true,
},
{
name: "multiple wildcard",
endpoint: "*.example.*",
expectedErr: true,
},
{
name: "no subdomain",
endpoint: "*.",
expectedErr: true,
},
{
name: "full qualified domain",
endpoint: "example.com",
expectedErr: false,
},
{
name: "valid wildcard domain",
endpoint: "*.example.com",
expectedErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := parseEndpoints([]string{tt.endpoint})
if tt.expectedErr != (err != nil) {
t.Fatalf("expected error: %v, got error: %v", tt.expectedErr, err)
}
})
}
}

func TestValidateHost(t *testing.T) {
endpoints := []string{
"*.azurecr.io",
"example.azurecr.io",
}
tests := []struct {
name string
host string
expectedErr bool
}{
{
name: "empty host",
host: "",
expectedErr: true,
},
{
name: "valid host",
host: "example.azurecr.io",
expectedErr: false,
},
{
name: "no subdomain",
host: "azurecr.io",
expectedErr: true,
},
{
name: "multiple subdomains",
host: "example.test.azurecr.io",
expectedErr: true,
},
{
name: "matched host",
host: "test.azurecr.io",
expectedErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateHost(tt.host, endpoints)
if tt.expectedErr != (err != nil) {
t.Fatalf("expected error: %v, got error: %v", tt.expectedErr, err)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/common/oras/authprovider/azure/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ const (
dockerTokenLoginUsernameGUID = "00000000-0000-0000-0000-000000000000"
AADResource = "https://containerregistry.azure.net/.default"
defaultACRExpiryDuration time.Duration = 3 * time.Hour
defaultACREndpoint = ".*.azurecr.io"
defaultACREndpoint = "*.azurecr.io"
)

var logOpt = logger.Option{
Expand Down

0 comments on commit 429f09b

Please sign in to comment.