Skip to content

Commit

Permalink
[Core] PREVIEW: Support managed identity on Azure Arc-enabled Windows…
Browse files Browse the repository at this point in the history
… server (#29187)
  • Loading branch information
jiasli authored Nov 6, 2024
1 parent 5705899 commit d68ba44
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 8 deletions.
59 changes: 52 additions & 7 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def login(self,
return deepcopy(consolidated)

def login_with_managed_identity(self, identity_id=None, allow_no_subscriptions=None):
if _on_azure_arc_windows():
return self.login_with_managed_identity_azure_arc_windows(
identity_id=identity_id, allow_no_subscriptions=allow_no_subscriptions)

import jwt
from azure.mgmt.core.tools import is_valid_resource_id
from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper
Expand Down Expand Up @@ -282,6 +286,33 @@ def login_with_managed_identity(self, identity_id=None, allow_no_subscriptions=N
self._set_subscriptions(consolidated)
return deepcopy(consolidated)

def login_with_managed_identity_azure_arc_windows(self, identity_id=None, allow_no_subscriptions=None):
import jwt
identity_type = MsiAccountTypes.system_assigned
from .auth.msal_credentials import ManagedIdentityCredential

cred = ManagedIdentityCredential()
token = cred.get_token(*self._arm_scope).token
logger.info('Managed identity: token was retrieved. Now trying to initialize local accounts...')
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
tenant = decode['tid']

subscription_finder = SubscriptionFinder(self.cli_ctx)
subscriptions = subscription_finder.find_using_specific_tenant(tenant, cred)
base_name = ('{}-{}'.format(identity_type, identity_id) if identity_id else identity_type)
user = _USER_ASSIGNED_IDENTITY if identity_id else _SYSTEM_ASSIGNED_IDENTITY
if not subscriptions:
if allow_no_subscriptions:
subscriptions = self._build_tenant_level_accounts([tenant])
else:
raise CLIError('No access was configured for the managed identity, hence no subscriptions were found. '
"If this is expected, use '--allow-no-subscriptions' to have tenant level access.")

consolidated = self._normalize_properties(user, subscriptions, is_service_principal=True,
user_assigned_identity_id=base_name)
self._set_subscriptions(consolidated)
return deepcopy(consolidated)

def login_in_cloud_shell(self):
import jwt
from .auth.msal_credentials import CloudShellCredential
Expand Down Expand Up @@ -354,13 +385,18 @@ def get_login_credentials(self, resource=None, client_id=None, subscription_id=N
# Cloud Shell
from .auth.msal_credentials import CloudShellCredential
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
cs_cred = CloudShellCredential()
# The cloud shell credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(cs_cred, resource=resource)
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(CloudShellCredential(), resource=resource)

elif managed_identity_type:
# managed identity
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource)
if _on_azure_arc_windows():
from .auth.msal_credentials import ManagedIdentityCredential
from azure.cli.core.auth.credential_adaptor import CredentialAdaptor
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
cred = CredentialAdaptor(ManagedIdentityCredential(), resource=resource)
else:
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource)

else:
# user and service principal
Expand Down Expand Up @@ -415,9 +451,13 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
# managed identity
if tenant:
raise CLIError("Tenant shouldn't be specified for managed identity account")
from .auth.util import scopes_to_resource
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
scopes_to_resource(scopes))
if _on_azure_arc_windows():
from .auth.msal_credentials import ManagedIdentityCredential
cred = ManagedIdentityCredential()
else:
from .auth.util import scopes_to_resource
cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
scopes_to_resource(scopes))

else:
cred = self._create_credential(account, tenant)
Expand Down Expand Up @@ -918,3 +958,8 @@ def _create_identity_instance(cli_ctx, *args, **kwargs):
return Identity(*args, encrypt=encrypt, use_msal_http_cache=use_msal_http_cache,
enable_broker_on_windows=enable_broker_on_windows,
instance_discovery=instance_discovery, **kwargs)


def _on_azure_arc_windows():
# This indicates an Azure Arc-enabled Windows server
return "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ
21 changes: 20 additions & 1 deletion src/azure-cli-core/azure/cli/core/auth/msal_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from knack.log import get_logger
from knack.util import CLIError
from msal import PublicClientApplication, ConfidentialClientApplication
from msal import (PublicClientApplication, ConfidentialClientApplication,
ManagedIdentityClient, SystemAssignedManagedIdentity)

from .constants import AZURE_CLI_CLIENT_ID
from .util import check_result, build_sdk_access_token
Expand Down Expand Up @@ -131,3 +132,21 @@ def get_token(self, *scopes, **kwargs):
result = self._msal_app.acquire_token_interactive(list(scopes), prompt="none", **kwargs)
check_result(result, scopes=scopes)
return build_sdk_access_token(result)


class ManagedIdentityCredential: # pylint: disable=too-few-public-methods
"""Managed identity credential implementing get_token interface.
Currently, only Azure Arc's system-assigned managed identity is supported.
"""

def __init__(self):
import requests
self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session())

def get_token(self, *scopes, **kwargs):
logger.debug("ManagedIdentityCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)

from .util import scopes_to_resource
result = self._msal_client.acquire_token_for_client(resource=scopes_to_resource(scopes))
check_result(result)
return build_sdk_access_token(result)

0 comments on commit d68ba44

Please sign in to comment.