Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auth: support caching of values retrieves from Azure CLI to avoid unnecessary repeated invocations #753

Merged
merged 9 commits into from
Dec 12, 2023
55 changes: 23 additions & 32 deletions sdk/auth/azure_cli_authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package auth

import (
"context"
"errors"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -60,11 +59,6 @@ func (a *AzureCliAuthorizer) Token(_ context.Context, _ *http.Request) (*oauth2.

azArgs := []string{"account", "get-access-token"}

// verify that the Azure CLI supports MSAL - ADAL is no longer supported
err := azurecli.CheckAzVersion(azurecli.MsalVersion, nil)
if err != nil {
return nil, fmt.Errorf("checking the version of the Azure CLI: %+v", err)
}
scope, err := environments.Scope(a.conf.Api)
if err != nil {
return nil, fmt.Errorf("determining scope for %q: %+v", a.conf.Api.Name(), err)
Expand All @@ -78,7 +72,7 @@ func (a *AzureCliAuthorizer) Token(_ context.Context, _ *http.Request) (*oauth2.
}

var token azureCliToken
if err := azurecli.JSONUnmarshalAzCmd(&token, azArgs...); err != nil {
if err = azurecli.JSONUnmarshalAzCmd(false, &token, azArgs...); err != nil {
return nil, err
}

Expand Down Expand Up @@ -114,11 +108,6 @@ func (a *AzureCliAuthorizer) AuxiliaryTokens(_ context.Context, _ *http.Request)

azArgs := []string{"account", "get-access-token"}

// verify that the Azure CLI supports MSAL - ADAL is no longer supported
err := azurecli.CheckAzVersion(AzureCliMsalVersion, nil)
if err != nil {
return nil, fmt.Errorf("checking the version of the Azure CLI: %+v", err)
}
scope, err := environments.Scope(a.conf.Api)
if err != nil {
return nil, fmt.Errorf("determining scope for %q: %+v", a.conf.Api.Name(), err)
Expand All @@ -130,7 +119,7 @@ func (a *AzureCliAuthorizer) AuxiliaryTokens(_ context.Context, _ *http.Request)
argsWithTenant := append(azArgs, "--tenant", tenantId)

var token azureCliToken
if err := azurecli.JSONUnmarshalAzCmd(&token, argsWithTenant...); err != nil {
if err = azurecli.JSONUnmarshalAzCmd(false, &token, argsWithTenant...); err != nil {
return nil, err
}

Expand All @@ -143,12 +132,6 @@ func (a *AzureCliAuthorizer) AuxiliaryTokens(_ context.Context, _ *http.Request)
return tokens, nil
}

const (
AzureCliMinimumVersion = "2.0.81"
AzureCliMsalVersion = "2.30.0"
AzureCliNextMajorVersion = "3.0.0"
)

// azureCliConfig configures an AzureCliAuthorizer.
type azureCliConfig struct {
Api environments.Api
Expand All @@ -165,27 +148,35 @@ type azureCliConfig struct {

// newAzureCliConfig validates the supplied tenant ID and returns a new azureCliConfig.
func newAzureCliConfig(api environments.Api, tenantId string, auxiliaryTenantIds []string) (*azureCliConfig, error) {
var err error

// check az-cli version
nextMajor := azurecli.NextMajorVersion
if err = azurecli.CheckAzVersion(azurecli.MinimumVersion, &nextMajor); err != nil {
// check az-cli version, ensure that MSAL is supported
if err := azurecli.CheckAzVersion(); err != nil {
return nil, err
}

// check tenant ID
tenantId, err = azurecli.CheckTenantID(tenantId)
if err != nil {
return nil, err
// obtain default tenant ID if no tenant ID was provided
if strings.TrimSpace(tenantId) == "" {
if defaultTenantId, err := azurecli.GetDefaultTenantID(); err != nil {
return nil, fmt.Errorf("tenant ID was not specified and the default tenant ID could not be determined: %v", err)
} else if defaultTenantId == nil {
return nil, fmt.Errorf("tenant ID was not specified and the default tenant ID could not be determined")
} else {
tenantId = *defaultTenantId
}
}
if tenantId == "" {
return nil, errors.New("invalid tenantId or unable to determine tenantId")

// validate tenant ID
if valid, err := azurecli.ValidateTenantID(tenantId); err != nil {
return nil, err
} else if !valid {
return nil, fmt.Errorf("invalid tenant ID was provided")
}

// get the default subscription ID
subscriptionId, err := azurecli.GetDefaultSubscriptionID()
if err != nil {
var subscriptionId string
if defaultSubscriptionId, err := azurecli.GetDefaultSubscriptionID(); err != nil {
return nil, err
} else if defaultSubscriptionId != nil {
subscriptionId = *defaultSubscriptionId
}

return &azureCliConfig{
Expand Down
149 changes: 89 additions & 60 deletions sdk/internal/azurecli/azcli.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"log"
"os/exec"
"regexp"
"strings"
Expand All @@ -15,41 +16,29 @@ import (
)

// CheckAzVersion tries to determine the version of Azure CLI in the path and checks for a compatible version
func CheckAzVersion(minVersion string, nextMajorVersion *string) error {
var cliVersion *struct {
AzureCli *string `json:"azure-cli,omitempty"`
AzureCliCore *string `json:"azure-cli-core,omitempty"`
AzureCliTelemetry *string `json:"azure-cli-telemetry,omitempty"`
Extensions *interface{} `json:"extensions,omitempty"`
}
err := JSONUnmarshalAzCmd(&cliVersion, "version")
func CheckAzVersion() error {
currentVersion, err := getAzVersion()
if err != nil {
return fmt.Errorf("could not parse Azure CLI version: %v", err)
}

if cliVersion.AzureCli == nil {
return fmt.Errorf("could not detect Azure CLI version. Please ensure you have installed Azure CLI version %s or newer", minVersion)
return err
}

actual, err := version.NewVersion(*cliVersion.AzureCli)
actual, err := version.NewVersion(*currentVersion)
if err != nil {
return fmt.Errorf("could not parse detected Azure CLI version %q: %+v", *cliVersion.AzureCli, err)
return fmt.Errorf("could not parse detected Azure CLI version %q: %+v", *currentVersion, err)
}

supported, err := version.NewVersion(minVersion)
supported, err := version.NewVersion(MinimumVersion)
if err != nil {
return fmt.Errorf("could not parse supported Azure CLI version: %+v", err)
}

if nextMajorVersion != nil {
nextMajor, err := version.NewVersion(*nextMajorVersion)
if err != nil {
return fmt.Errorf("could not parse next major Azure CLI version: %+v", err)
}
nextMajor, err := version.NewVersion(NextMajorVersion)
if err != nil {
return fmt.Errorf("could not parse next major Azure CLI version: %+v", err)
}

if nextMajor.LessThanOrEqual(actual) {
return fmt.Errorf("unsupported Azure CLI version %q detected, please install a version newer than %s but older than %s", actual, supported, nextMajor)
}
if nextMajor.LessThanOrEqual(actual) {
return fmt.Errorf("unsupported Azure CLI version %q detected, please install a version newer than %s but older than %s", actual, supported, nextMajor)
}

if actual.LessThan(supported) {
Expand All @@ -59,68 +48,108 @@ func CheckAzVersion(minVersion string, nextMajorVersion *string) error {
return nil
}

// ValidateTenantID validates the supplied tenant ID, and tries to determine the default tenant if a valid one is not supplied.
func ValidateTenantID(tenantId string) (bool, error) {
validTenantId, err := regexp.MatchString("^[a-zA-Z0-9._-]+$", tenantId)
if err != nil {
return false, fmt.Errorf("could not parse tenant ID %q: %s", tenantId, err)
}

return validTenantId, nil
}

// GetDefaultTenantID tries to determine the default tenant
func GetDefaultTenantID() (*string, error) {
var account struct {
TenantID string `json:"tenantId"`
}
if err := JSONUnmarshalAzCmd(true, &account, "account", "show"); err != nil {
return nil, fmt.Errorf("obtaining tenant ID: %s", err)
}

return &account.TenantID, nil
}

// GetDefaultSubscriptionID tries to determine the default subscription
func GetDefaultSubscriptionID() (string, error) {
func GetDefaultSubscriptionID() (*string, error) {
var account struct {
SubscriptionID string `json:"id"`
}
err := JSONUnmarshalAzCmd(&account, "account", "show")
err := JSONUnmarshalAzCmd(true, &account, "account", "show")
if err != nil {
return "", fmt.Errorf("obtaining subscription ID: %s", err)
return nil, fmt.Errorf("obtaining subscription ID: %s", err)
}

return account.SubscriptionID, nil
return &account.SubscriptionID, nil
}

// CheckTenantID validates the supplied tenant ID, and tries to determine the default tenant if a valid one is not supplied.
func CheckTenantID(tenantId string) (string, error) {
validTenantId, err := regexp.MatchString("^[a-zA-Z0-9._-]+$", tenantId)
// getAzVersion tries to determine the version of Azure CLI in the path.
func getAzVersion() (*string, error) {
var cliVersion *struct {
AzureCli *string `json:"azure-cli,omitempty"`
AzureCliCore *string `json:"azure-cli-core,omitempty"`
AzureCliTelemetry *string `json:"azure-cli-telemetry,omitempty"`
Extensions *interface{} `json:"extensions,omitempty"`
}
err := JSONUnmarshalAzCmd(true, &cliVersion, "version")
if err != nil {
return "", fmt.Errorf("could not parse tenant ID %q: %s", tenantId, err)
return nil, fmt.Errorf("could not parse Azure CLI version: %v", err)
}

if !validTenantId {
var account struct {
ID string `json:"id"`
TenantID string `json:"tenantId"`
}
err := JSONUnmarshalAzCmd(&account, "account", "show")
if err != nil {
return "", fmt.Errorf("obtaining tenant ID: %s", err)
}
tenantId = account.TenantID
if cliVersion.AzureCli == nil {
return nil, fmt.Errorf("could not detect Azure CLI version")
}

return tenantId, nil
return cliVersion.AzureCli, nil
}

// JSONUnmarshalAzCmd executes an Azure CLI command and unmarshalls the JSON output.
func JSONUnmarshalAzCmd(i interface{}, arg ...string) error {
// JSONUnmarshalAzCmd executes an Azure CLI command and unmarshalls the JSON output, optionally retrieving from and
// populating the command result cache, to avoid unnecessary repeated invocations of Azure CLI.
func JSONUnmarshalAzCmd(cacheable bool, i interface{}, arg ...string) error {
var stderr bytes.Buffer
var stdout bytes.Buffer

arg = append(arg, "-o=json")
cmd := exec.Command("az", arg...)
cmd.Stderr = &stderr
cmd.Stdout = &stdout

if err := cmd.Start(); err != nil {
err := fmt.Errorf("launching Azure CLI: %+v", err)
if stdErrStr := stderr.String(); stdErrStr != "" {
err = fmt.Errorf("%s: %s", err, strings.TrimSpace(stdErrStr))
argstring := strings.Join(arg, " ")

var result []byte
if cacheable {
if cachedResult, ok := cache.Get(argstring); ok {
result = cachedResult
}
return err
}

if err := cmd.Wait(); err != nil {
err := fmt.Errorf("running Azure CLI: %+v", err)
if stdErrStr := stderr.String(); stdErrStr != "" {
err = fmt.Errorf("%s: %s", err, strings.TrimSpace(stdErrStr))
if result == nil {
log.Printf("[DEBUG] az-cli invocation: az %s", argstring)

cmd := exec.Command("az", arg...)
cmd.Stderr = &stderr
cmd.Stdout = &stdout

if err := cmd.Start(); err != nil {
err := fmt.Errorf("launching Azure CLI: %+v", err)
if stdErrStr := stderr.String(); stdErrStr != "" {
err = fmt.Errorf("%s: %s", err, strings.TrimSpace(stdErrStr))
}
return err
}

if err := cmd.Wait(); err != nil {
err := fmt.Errorf("running Azure CLI: %+v", err)
if stdErrStr := stderr.String(); stdErrStr != "" {
err = fmt.Errorf("%s: %s", err, strings.TrimSpace(stdErrStr))
}
return err
}

result = stdout.Bytes()

if cacheable {
cache.Set(argstring, result)
}
return err
}

if err := json.Unmarshal(stdout.Bytes(), &i); err != nil {
if err := json.Unmarshal(result, &i); err != nil {
return fmt.Errorf("unmarshaling the output of Azure CLI: %v", err)
}

Expand Down
24 changes: 24 additions & 0 deletions sdk/internal/azurecli/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package azurecli

var cache *cachedCliData

type cachedCliData struct {
data map[string][]byte
}

func (c *cachedCliData) Set(index string, data []byte) {
c.data[index] = data
}

func (c *cachedCliData) Get(index string) ([]byte, bool) {
if data, ok := c.data[index]; ok {
return data, true
}
return nil, false
}

func init() {
cache = &cachedCliData{
data: make(map[string][]byte),
}
}
9 changes: 7 additions & 2 deletions sdk/internal/azurecli/versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
package azurecli

const (
MinimumVersion = "2.0.81"
MsalVersion = "2.30.0"
// MsalVersion is the first known version of Azure CLI to support MSAL / v2 tokens
MsalVersion = "2.30.0"

// MinimumVersion is the oldest supported version of Azure CLI by this package
MinimumVersion = "2.0.81"

// NextMajorVersion is the next (possibly upcoming) major version that is not yet supported by this package
NextMajorVersion = "3.0.0"
)