diff --git a/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md b/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md index 340928d8759d..cc1929a10aa2 100644 --- a/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` +- Added support for Continuous Access Evaluation (CAE). `enable_cae=True` is passed to all `get_token` requests. ### Breaking Changes @@ -12,6 +13,7 @@ ([#34744](https://github.com/Azure/azure-sdk-for-python/issues/34744)) ### Other Changes +- Updated minimum `azure-core` version to 1.31.0 - Key Vault API version `7.6-preview.1` is now the default ## 4.4.0 (2024-02-22) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index 1a872f36b6a8..e9b44fc68e55 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -16,33 +16,140 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from typing_extensions import ParamSpec + +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest +from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache -from .challenge_auth_policy import _enforce_tls, _update_challenge +from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge + + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """Policy for handling HTTP authentication challenges. :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity.aio` - :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenProvider """ - def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - super().__init__(credential, *scopes, **kwargs) - self._credential: AsyncTokenCredential = credential - self._token: Optional[AccessToken] = None + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request + super().__init__(credential, *scopes, enable_cae=True, **kwargs) + self._credential: AsyncTokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + async def send( + self, request: PipelineRequest[HttpRequest] + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HttpRequest, AsyncHttpResponse] + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + return await self.handle_challenge_flow(request, response) + return response + + async def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, AsyncHttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + if not claims_challenge: + return await self.handle_challenge_flow(request, response, consecutive_challenge=True) + await await_result(self.on_response, request, response) + return response + + async def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) @@ -51,14 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None: if self._need_new_token(): # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) - else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + await self._request_kv_token(scope, challenge) + + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -78,7 +181,19 @@ async def on_request(self, request: PipelineRequest) -> None: async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -104,11 +219,38 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope) + await self.authorize_request(request, scope, claims=challenge.claims) else: - await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + await self.authorize_request( + request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) return True def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = await self._credential.get_token(scope, enable_cae=True) + else: + self._token = await cast(AsyncTokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index f16297aa5026..eb4073d0e699 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -16,14 +16,21 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, cast, Optional, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenCredential, + TokenProvider, + TokenRequestOptions, + SupportsTokenInfo, +) from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache @@ -36,6 +43,20 @@ def _enforce_tls(request: PipelineRequest) -> None: ) +def _has_claims(challenge: str) -> bool: + """Check if a challenge header contains claims. + + :param challenge: The challenge header to check. + :type challenge: str + + :returns: True if the challenge contains claims; False otherwise. + :rtype: bool + """ + # Split the challenge into its scheme and parameters, then check if any parameter contains claims + split_challenge = challenge.strip().split(" ", 1) + return any("claims=" in item for item in split_challenge[1].split(",")) + + def _update_challenge(request: PipelineRequest, challenger: PipelineResponse) -> HttpChallenge: """Parse challenge from a challenge response, cache it, and return it. @@ -62,16 +83,89 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity` - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) - self._credential: TokenCredential = credential - self._token: Optional[AccessToken] = None + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request + super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) + self._credential: TokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + return self.handle_challenge_flow(request, response) + return response + + def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, HttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, HttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + if not claims_challenge: + return self.handle_challenge_flow(request, response, consecutive_challenge=True) + self.on_response(request, response) + return response + def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) @@ -80,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None: if self._need_new_token: # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) - else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + self._request_kv_token(scope, challenge) + + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -106,7 +196,19 @@ def on_request(self, request: PipelineRequest) -> None: def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -132,12 +234,37 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = self._credential.get_token(scope, enable_cae=True) + else: + self._token = cast(TokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py index df9055c7bda6..0320df5a868b 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 from typing import Dict, MutableMapping, Optional from urllib import parse @@ -18,7 +19,13 @@ class HttpChallenge(object): def __init__( self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None ) -> None: - """Parses an HTTP WWW-Authentication Bearer challenge from a server.""" + """Parses an HTTP WWW-Authentication Bearer challenge from a server. + + Example challenge with claims: + Bearer authorization="https://login.windows-ppe.net/", error="invalid_token", + error_description="User session has been revoked", + claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + """ self.source_authority = self._validate_request_uri(request_uri) self.source_uri = request_uri self._parameters: "Dict[str, str]" = {} @@ -29,16 +36,27 @@ def __init__( self.scheme = split_challenge[0] trimmed_challenge = split_challenge[1] + self.claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): + # Special case for claims, which can contain = symbols as padding. Assume at most one claim per challenge + if "claims=" in item: + encoded_claims = item[item.index("=") + 1 :].strip(" \"'") + padding_needed = -len(encoded_claims) % 4 + try: + decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode() + self.claims = decoded_claims + except Exception: # pylint:disable=broad-except + continue # process name=value pairs - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - self._parameters[key] = value + else: + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value # minimum set of parameters if not self._parameters: diff --git a/sdk/keyvault/azure-keyvault-administration/setup.py b/sdk/keyvault/azure-keyvault-administration/setup.py index 60bbdd2a58a6..30c88eb275bd 100644 --- a/sdk/keyvault/azure-keyvault-administration/setup.py +++ b/sdk/keyvault/azure-keyvault-administration/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.29.5", + "azure-core>=1.31.0", "isodate>=0.6.1", "typing-extensions>=4.0.1", ], diff --git a/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md b/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md index 086f14d6e349..40fbe4b1674a 100644 --- a/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` +- Added support for Continuous Access Evaluation (CAE). `enable_cae=True` is passed to all `get_token` requests. ### Breaking Changes @@ -12,6 +13,7 @@ ([#34744](https://github.com/Azure/azure-sdk-for-python/issues/34744)) ### Other Changes +- Updated minimum `azure-core` version to 1.31.0 - Key Vault API version `7.6-preview.1` is now the default ## 4.8.0 (2024-02-22) diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index 1a872f36b6a8..e9b44fc68e55 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -16,33 +16,140 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from typing_extensions import ParamSpec + +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest +from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache -from .challenge_auth_policy import _enforce_tls, _update_challenge +from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge + + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """Policy for handling HTTP authentication challenges. :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity.aio` - :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenProvider """ - def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - super().__init__(credential, *scopes, **kwargs) - self._credential: AsyncTokenCredential = credential - self._token: Optional[AccessToken] = None + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request + super().__init__(credential, *scopes, enable_cae=True, **kwargs) + self._credential: AsyncTokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + async def send( + self, request: PipelineRequest[HttpRequest] + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HttpRequest, AsyncHttpResponse] + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + return await self.handle_challenge_flow(request, response) + return response + + async def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, AsyncHttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + if not claims_challenge: + return await self.handle_challenge_flow(request, response, consecutive_challenge=True) + await await_result(self.on_response, request, response) + return response + + async def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) @@ -51,14 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None: if self._need_new_token(): # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) - else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + await self._request_kv_token(scope, challenge) + + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -78,7 +181,19 @@ async def on_request(self, request: PipelineRequest) -> None: async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -104,11 +219,38 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope) + await self.authorize_request(request, scope, claims=challenge.claims) else: - await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + await self.authorize_request( + request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) return True def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = await self._credential.get_token(scope, enable_cae=True) + else: + self._token = await cast(AsyncTokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index f16297aa5026..eb4073d0e699 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -16,14 +16,21 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, cast, Optional, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenCredential, + TokenProvider, + TokenRequestOptions, + SupportsTokenInfo, +) from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache @@ -36,6 +43,20 @@ def _enforce_tls(request: PipelineRequest) -> None: ) +def _has_claims(challenge: str) -> bool: + """Check if a challenge header contains claims. + + :param challenge: The challenge header to check. + :type challenge: str + + :returns: True if the challenge contains claims; False otherwise. + :rtype: bool + """ + # Split the challenge into its scheme and parameters, then check if any parameter contains claims + split_challenge = challenge.strip().split(" ", 1) + return any("claims=" in item for item in split_challenge[1].split(",")) + + def _update_challenge(request: PipelineRequest, challenger: PipelineResponse) -> HttpChallenge: """Parse challenge from a challenge response, cache it, and return it. @@ -62,16 +83,89 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity` - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) - self._credential: TokenCredential = credential - self._token: Optional[AccessToken] = None + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request + super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) + self._credential: TokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + return self.handle_challenge_flow(request, response) + return response + + def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, HttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, HttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + if not claims_challenge: + return self.handle_challenge_flow(request, response, consecutive_challenge=True) + self.on_response(request, response) + return response + def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) @@ -80,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None: if self._need_new_token: # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) - else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + self._request_kv_token(scope, challenge) + + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -106,7 +196,19 @@ def on_request(self, request: PipelineRequest) -> None: def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -132,12 +234,37 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = self._credential.get_token(scope, enable_cae=True) + else: + self._token = cast(TokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py index df9055c7bda6..0320df5a868b 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 from typing import Dict, MutableMapping, Optional from urllib import parse @@ -18,7 +19,13 @@ class HttpChallenge(object): def __init__( self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None ) -> None: - """Parses an HTTP WWW-Authentication Bearer challenge from a server.""" + """Parses an HTTP WWW-Authentication Bearer challenge from a server. + + Example challenge with claims: + Bearer authorization="https://login.windows-ppe.net/", error="invalid_token", + error_description="User session has been revoked", + claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + """ self.source_authority = self._validate_request_uri(request_uri) self.source_uri = request_uri self._parameters: "Dict[str, str]" = {} @@ -29,16 +36,27 @@ def __init__( self.scheme = split_challenge[0] trimmed_challenge = split_challenge[1] + self.claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): + # Special case for claims, which can contain = symbols as padding. Assume at most one claim per challenge + if "claims=" in item: + encoded_claims = item[item.index("=") + 1 :].strip(" \"'") + padding_needed = -len(encoded_claims) % 4 + try: + decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode() + self.claims = decoded_claims + except Exception: # pylint:disable=broad-except + continue # process name=value pairs - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - self._parameters[key] = value + else: + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value # minimum set of parameters if not self._parameters: diff --git a/sdk/keyvault/azure-keyvault-certificates/setup.py b/sdk/keyvault/azure-keyvault-certificates/setup.py index 00390ee336a0..347d79fb93a0 100644 --- a/sdk/keyvault/azure-keyvault-certificates/setup.py +++ b/sdk/keyvault/azure-keyvault-certificates/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.29.5", + "azure-core>=1.31.0", "isodate>=0.6.1", "typing-extensions>=4.0.1", ], diff --git a/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md b/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md index 4c7444151bf1..faf4085e14e1 100644 --- a/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` +- Added support for Continuous Access Evaluation (CAE). `enable_cae=True` is passed to all `get_token` requests. ### Breaking Changes @@ -12,6 +13,7 @@ ([#34744](https://github.com/Azure/azure-sdk-for-python/issues/34744)) ### Other Changes +- Updated minimum `azure-core` version to 1.31.0 - Key Vault API version `7.6-preview.1` is now the default ## 4.9.0 (2024-02-22) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index 1a872f36b6a8..e9b44fc68e55 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -16,33 +16,140 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from typing_extensions import ParamSpec + +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest +from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache -from .challenge_auth_policy import _enforce_tls, _update_challenge +from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge + + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """Policy for handling HTTP authentication challenges. :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity.aio` - :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenProvider """ - def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - super().__init__(credential, *scopes, **kwargs) - self._credential: AsyncTokenCredential = credential - self._token: Optional[AccessToken] = None + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request + super().__init__(credential, *scopes, enable_cae=True, **kwargs) + self._credential: AsyncTokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + async def send( + self, request: PipelineRequest[HttpRequest] + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HttpRequest, AsyncHttpResponse] + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + return await self.handle_challenge_flow(request, response) + return response + + async def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, AsyncHttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + if not claims_challenge: + return await self.handle_challenge_flow(request, response, consecutive_challenge=True) + await await_result(self.on_response, request, response) + return response + + async def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) @@ -51,14 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None: if self._need_new_token(): # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) - else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + await self._request_kv_token(scope, challenge) + + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -78,7 +181,19 @@ async def on_request(self, request: PipelineRequest) -> None: async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -104,11 +219,38 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope) + await self.authorize_request(request, scope, claims=challenge.claims) else: - await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + await self.authorize_request( + request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) return True def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = await self._credential.get_token(scope, enable_cae=True) + else: + self._token = await cast(AsyncTokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index f16297aa5026..eb4073d0e699 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -16,14 +16,21 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, cast, Optional, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenCredential, + TokenProvider, + TokenRequestOptions, + SupportsTokenInfo, +) from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache @@ -36,6 +43,20 @@ def _enforce_tls(request: PipelineRequest) -> None: ) +def _has_claims(challenge: str) -> bool: + """Check if a challenge header contains claims. + + :param challenge: The challenge header to check. + :type challenge: str + + :returns: True if the challenge contains claims; False otherwise. + :rtype: bool + """ + # Split the challenge into its scheme and parameters, then check if any parameter contains claims + split_challenge = challenge.strip().split(" ", 1) + return any("claims=" in item for item in split_challenge[1].split(",")) + + def _update_challenge(request: PipelineRequest, challenger: PipelineResponse) -> HttpChallenge: """Parse challenge from a challenge response, cache it, and return it. @@ -62,16 +83,89 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity` - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) - self._credential: TokenCredential = credential - self._token: Optional[AccessToken] = None + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request + super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) + self._credential: TokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + return self.handle_challenge_flow(request, response) + return response + + def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, HttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, HttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + if not claims_challenge: + return self.handle_challenge_flow(request, response, consecutive_challenge=True) + self.on_response(request, response) + return response + def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) @@ -80,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None: if self._need_new_token: # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) - else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + self._request_kv_token(scope, challenge) + + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -106,7 +196,19 @@ def on_request(self, request: PipelineRequest) -> None: def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -132,12 +234,37 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = self._credential.get_token(scope, enable_cae=True) + else: + self._token = cast(TokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py index df9055c7bda6..0320df5a868b 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 from typing import Dict, MutableMapping, Optional from urllib import parse @@ -18,7 +19,13 @@ class HttpChallenge(object): def __init__( self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None ) -> None: - """Parses an HTTP WWW-Authentication Bearer challenge from a server.""" + """Parses an HTTP WWW-Authentication Bearer challenge from a server. + + Example challenge with claims: + Bearer authorization="https://login.windows-ppe.net/", error="invalid_token", + error_description="User session has been revoked", + claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + """ self.source_authority = self._validate_request_uri(request_uri) self.source_uri = request_uri self._parameters: "Dict[str, str]" = {} @@ -29,16 +36,27 @@ def __init__( self.scheme = split_challenge[0] trimmed_challenge = split_challenge[1] + self.claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): + # Special case for claims, which can contain = symbols as padding. Assume at most one claim per challenge + if "claims=" in item: + encoded_claims = item[item.index("=") + 1 :].strip(" \"'") + padding_needed = -len(encoded_claims) % 4 + try: + decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode() + self.claims = decoded_claims + except Exception: # pylint:disable=broad-except + continue # process name=value pairs - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - self._parameters[key] = value + else: + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value # minimum set of parameters if not self._parameters: diff --git a/sdk/keyvault/azure-keyvault-keys/setup.py b/sdk/keyvault/azure-keyvault-keys/setup.py index cbb26cf86d49..7bbe28af42c5 100644 --- a/sdk/keyvault/azure-keyvault-keys/setup.py +++ b/sdk/keyvault/azure-keyvault-keys/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.29.5", + "azure-core>=1.31.0", "cryptography>=2.1.4", "isodate>=0.6.1", "typing-extensions>=4.0.1", diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index 8d9eff46b208..de26cc59f07d 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -6,7 +6,9 @@ Tests for the HTTP challenge authentication implementation. These tests aren't parallelizable, because the challenge cache is global to the process. """ +import base64 import functools +from itertools import product import os import time from unittest.mock import Mock, patch @@ -16,12 +18,11 @@ from devtools_testutils import recorded_by_proxy import pytest -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import Pipeline from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.rest import HttpRequest -from azure.identity import AzureCliCredential, AzurePowerShellCredential, ClientSecretCredential from azure.keyvault.keys import KeyClient from azure.keyvault.keys._shared import ChallengeAuthPolicy, HttpChallenge, HttpChallengeCache from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION @@ -33,6 +34,8 @@ only_default_version = get_decorator(api_versions=[DEFAULT_VERSION]) +TOKEN_TYPES = [AccessToken, AccessTokenInfo] + class TestChallengeAuth(KeyVaultTestCase, KeysTestCase): @pytest.mark.parametrize("api_version,is_hsm", only_default_version) @KeysClientPreparer() @@ -69,6 +72,7 @@ def test_multitenant_authentication(self, client, is_hsm, **kwargs): else: os.environ.pop("AZURE_TENANT_ID") + def empty_challenge_cache(fn): @functools.wraps(fn) def wrapper(**kwargs): @@ -85,6 +89,25 @@ def get_random_url(): return f"https://{uuid4()}.vault.azure.net/{uuid4()}".replace("-", "") +URL = f'authorization_uri="{get_random_url()}"' +CLIENT_ID = 'client_id="00000003-0000-0000-c000-000000000000"' +CAE_ERROR = 'error="insufficient_claims"' +CAE_DECODED_CLAIM = '{"access_token": {"foo": "bar"}}' +# Claim token is a string of the base64 encoding of the claim +CLAIM_TOKEN = base64.b64encode(CAE_DECODED_CLAIM.encode()).decode() +# Note that no resource or scope is necessarily provided in a CAE challenge +CLAIM_CHALLENGE = f'Bearer realm="", {URL}, {CLIENT_ID}, {CAE_ERROR}, claims="{CLAIM_TOKEN}"' +CAE_CHALLENGE_RESPONSE = Mock(status_code=401, headers={"WWW-Authenticate": CLAIM_CHALLENGE}) + +KV_CHALLENGE_TENANT = "tenant-id" +ENDPOINT = f"https://authority.net/{KV_CHALLENGE_TENANT}" +RESOURCE = "https://vault.azure.net" +KV_CHALLENGE_RESPONSE = Mock( + status_code=401, + headers={"WWW-Authenticate": f'Bearer authorization="{ENDPOINT}", resource={RESOURCE}'}, +) + + def add_url_port(url: str): """Like `get_random_url`, but includes a port number (comes after the domain, and before the path of the URL).""" @@ -135,7 +158,8 @@ def test_challenge_parsing(): @empty_challenge_cache -def test_scope(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_scope(token_type): """The policy's token requests should always be for an AADv2 scope""" expected_content = b"a duck" @@ -164,15 +188,21 @@ def send(request): def get_token(*scopes, **_): assert len(scopes) == 1 assert scopes[0] == expected_scope - return AccessToken(expected_token, 0) + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 endpoint = "https://authority.net/tenant" @@ -194,7 +224,8 @@ def get_token(*scopes, **_): @empty_challenge_cache -def test_tenant(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_tenant(token_type): """The policy's token requests should pass the parsed tenant ID from the challenge""" expected_content = b"a duck" @@ -220,17 +251,24 @@ def send(request): return Mock(status_code=200) raise ValueError("unexpected request") - def get_token(*_, **kwargs): - assert kwargs.get("tenant_id") == expected_tenant - return AccessToken(expected_token, 0) + def get_token(*_, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("tenant_id") == expected_tenant + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 tenant = "tenant-id" endpoint = f"https://authority.net/{tenant}" @@ -245,7 +283,8 @@ def get_token(*_, **kwargs): @empty_challenge_cache -def test_adfs(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_adfs(token_type): """The policy should handle AD FS challenges as a special case and omit the tenant ID from token requests""" expected_content = b"a duck" @@ -274,15 +313,21 @@ def send(request): def get_token(*_, **kwargs): # we shouldn't provide a tenant ID during AD FS authentication assert "tenant_id" not in kwargs - return AccessToken(expected_token, 0) + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) policy = ChallengeAuthPolicy(credential=credential) pipeline = Pipeline(policies=[policy], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 # Regression test: https://github.com/Azure/azure-sdk-for-python/issues/33621 policy._token = None @@ -301,7 +346,8 @@ def get_token(*_, **kwargs): test_with_challenge(challenge, tenant) -def test_policy_updates_cache(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_policy_updates_cache(token_type): """ It's possible for the challenge returned for a request to change, e.g. when a vault is moved to a new tenant. When the policy receives a 401, it should update the cached challenge for the requested URL, if one exists. @@ -340,23 +386,38 @@ def test_policy_updates_cache(): ), ) - credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken(first_token, time.time() + 3600))) + token = token_type(first_token, time.time() + 3600) + + def get_token(*_, **__): + return token + + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=transport) # policy should complete and cache the first challenge and access token for _ in range(2): pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 # The next request will receive a new challenge. The policy should handle it and update caches. - credential.get_token.return_value = AccessToken(second_token, time.time() + 3600) + token = token_type(second_token, time.time() + 3600) for _ in range(2): pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 2 - + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 -def test_token_expiration(): +@empty_challenge_cache +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_token_expiration(token_type): """policy should not use a cached token which has expired""" url = get_random_url() @@ -366,12 +427,15 @@ def test_token_expiration(): second_token = "**" resource = "https://vault.azure.net" - token = AccessToken(first_token, expires_on) + token = token_type(first_token, expires_on) def get_token(*_, **__): return token - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = validating_transport( requests=[ Request(), @@ -390,16 +454,23 @@ def get_token(*_, **__): for _ in range(2): pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 - token = AccessToken(second_token, time.time() + 3600) + token = token_type(second_token, time.time() + 3600) with patch("time.time", lambda: expires_on): pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 2 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 @empty_challenge_cache -def test_preserves_options_and_headers(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_preserves_options_and_headers(token_type): """After a challenge, the policy should send the original request with its options and headers preserved""" url = get_random_url() @@ -407,9 +478,12 @@ def test_preserves_options_and_headers(): resource = "https://vault.azure.net" def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"Authorization": "Bearer " + token})], @@ -451,8 +525,8 @@ def verify(request): @empty_challenge_cache -@pytest.mark.parametrize("verify_challenge_resource", [True, False]) -def test_verify_challenge_resource_matches(verify_challenge_resource): +@pytest.mark.parametrize("verify_challenge_resource,token_type", product([True, False], TOKEN_TYPES)) +def test_verify_challenge_resource_matches(verify_challenge_resource, token_type): """The auth policy should raise if the challenge resource doesn't match the request URL unless check is disabled""" url = get_random_url() @@ -461,9 +535,12 @@ def test_verify_challenge_resource_matches(verify_challenge_resource): resource = "https://myvault.azure.net" # Doesn't match a "".vault.azure.net" resource because of the "my" prefix def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = validating_transport( requests=[Request(), Request(required_headers={"Authorization": f"Bearer {token}"})], @@ -504,8 +581,8 @@ def get_token(*_, **__): @empty_challenge_cache -@pytest.mark.parametrize("verify_challenge_resource", [True, False]) -def test_verify_challenge_resource_valid(verify_challenge_resource): +@pytest.mark.parametrize("verify_challenge_resource,token_type", product([True, False], TOKEN_TYPES)) +def test_verify_challenge_resource_valid(verify_challenge_resource, token_type): """The auth policy should raise if the challenge resource isn't a valid URL unless check is disabled""" url = get_random_url() @@ -513,9 +590,12 @@ def test_verify_challenge_resource_valid(verify_challenge_resource): resource = "bad-resource" def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = validating_transport( requests=[Request(), Request(required_headers={"Authorization": f"Bearer {token}"})], @@ -536,3 +616,247 @@ def get_token(*_, **__): else: key = client.get_key("key-name") assert key.name == "key-name" + + +@empty_challenge_cache +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_cae(token_type): + """The policy should handle claims in a challenge response after having successfully authenticated prior.""" + + expected_content = b"a duck" + + def test_with_challenge(claims_challenge, expected_claim): + first_token = "first_token" + expected_token = "expected_token" + + class Requests: + count = 0 + + def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content; triggers a KV challenge response + assert not request.body + assert "Authorization" not in request.headers + assert request.headers["Content-Length"] == "0" + return KV_CHALLENGE_RESPONSE + elif Requests.count == 2: + # second request should be authorized according to challenge and have the expected content + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 3: + # third request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return claims_challenge + elif Requests.count == 4: + # fourth request should include the required claims and correctly use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 5: + # fifth request should be a regular request with the expected token + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return KV_CHALLENGE_RESPONSE + elif Requests.count == 6: + # sixth request should respond to the KV challenge WITHOUT including claims + # we return another challenge to confirm that the policy will return consecutive 401s to the user + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return KV_CHALLENGE_RESPONSE + raise ValueError("unexpected request") + + def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" + # Response to KV challenge + if Requests.count == 1: + assert options_bag.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + # Response to CAE challenge + elif Requests.count == 3: + assert options_bag.get("claims") == expected_claim + return AccessToken(expected_token, time.time() + 3600) + # Response to second KV challenge + elif Requests.count == 5: + assert options_bag.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + elif Requests.count == 6: + raise ValueError("unexpected token request") + + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) + pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + pipeline.run(request) # Send the request once to trigger a regular auth challenge + pipeline.run(request) # Send the request again to trigger a CAE challenge + pipeline.run(request) # Send the request once to trigger another regular auth challenge + + # token requests made for the CAE challenge and first two KV challenges, but not the final KV challenge + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 3 + else: + assert credential.get_token_info.call_count == 3 + + test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) + + +@empty_challenge_cache +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_cae_consecutive_challenges(token_type): + """The policy should correctly handle consecutive challenges in cases where the flow is valid or invalid.""" + + expected_content = b"a duck" + + def test_with_challenge(claims_challenge, expected_claim): + first_token = "first_token" + expected_token = "expected_token" + + class Requests: + count = 0 + + def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content; triggers a KV challenge response + assert not request.body + assert "Authorization" not in request.headers + assert request.headers["Content-Length"] == "0" + return KV_CHALLENGE_RESPONSE + elif Requests.count == 2: + # second request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return claims_challenge + elif Requests.count == 3: + # third request should include the required claims and correctly use content from the first challenge + # we return another CAE challenge to verify that the policy will return consecutive CAE 401s to the user + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return claims_challenge + raise ValueError("unexpected request") + + def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" + # Response to KV challenge + if Requests.count == 1: + assert options_bag.get("claims") == None + return token_type(first_token, time.time() + 3600) + # Response to first CAE challenge + elif Requests.count == 2: + assert options_bag.get("claims") == expected_claim + return token_type(expected_token, time.time() + 3600) + + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) + pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + pipeline.run(request) + + # token requests made for the KV challenge and first CAE challenge, but not the second CAE challenge + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 + + test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) + + +@empty_challenge_cache +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_cae_token_expiry(token_type): + """The policy should avoid sending claims more than once when a token expires.""" + + expected_content = b"a duck" + + def test_with_challenge(claims_challenge, expected_claim): + first_token = "first_token" + second_token = "second_token" + third_token = "third_token" + + class Requests: + count = 0 + + def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content; triggers a KV challenge response + assert not request.body + assert "Authorization" not in request.headers + assert request.headers["Content-Length"] == "0" + return KV_CHALLENGE_RESPONSE + elif Requests.count == 2: + # second request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return claims_challenge + elif Requests.count == 3: + # third request should include the required claims and correctly use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert second_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 4: + # fourth request should not include claims, but otherwise use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert third_token in request.headers["Authorization"] + return Mock(status_code=200) + raise ValueError("unexpected request") + + def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" + # Response to KV challenge + if Requests.count == 1: + assert options_bag.get("claims") == None + return token_type(first_token, time.time() + 3600) + # Response to first CAE challenge + elif Requests.count == 2: + assert options_bag.get("claims") == expected_claim + return token_type(second_token, 0) # Return a token that expires immediately to trigger a refresh + # Token refresh before making the final request + elif Requests.count == 3: + assert options_bag.get("claims") == None + return token_type(third_token, time.time() + 3600) + + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) + pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + pipeline.run(request) + pipeline.run(request) # Send the request again to trigger a token refresh upon expiry + + # token requests made for the KV and CAE challenges, as well as for the token refresh + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 3 + else: + assert credential.get_token_info.call_count == 3 + + test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index 980753fbdc0a..81ec711f6ad2 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -7,18 +7,18 @@ the challenge cache is global to the process. """ import asyncio +from itertools import product import os import time from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from azure.core.credentials import AccessToken +from azure.core.credentials import AccessToken, AccessTokenInfo from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.rest import HttpRequest -from azure.identity.aio import AzureCliCredential, AzurePowerShellCredential, ClientSecretCredential from azure.keyvault.keys._shared import AsyncChallengeAuthPolicy,HttpChallenge, HttpChallengeCache from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION from azure.keyvault.keys.aio import KeyClient @@ -28,7 +28,17 @@ from _shared.helpers import Request, mock_response from _shared.helpers_async import async_validating_transport from _shared.test_case_async import KeyVaultTestCase -from test_challenge_auth import empty_challenge_cache, get_random_url, add_url_port +from test_challenge_auth import ( + empty_challenge_cache, + get_random_url, + add_url_port, + CAE_CHALLENGE_RESPONSE, + CAE_DECODED_CLAIM, + KV_CHALLENGE_RESPONSE, + KV_CHALLENGE_TENANT, + RESOURCE, + TOKEN_TYPES, +) only_default_version = get_decorator(is_async=True, api_versions=[DEFAULT_VERSION]) @@ -85,7 +95,8 @@ async def test_enforces_tls(): @pytest.mark.asyncio @empty_challenge_cache -async def test_scope(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_scope(token_type): """The policy's token requests should always be for an AADv2 scope""" expected_content = b"a duck" @@ -114,9 +125,12 @@ async def send(request): async def get_token(*scopes, **_): assert len(scopes) == 1 assert scopes[0] == expected_scope - return AccessToken(expected_token, 0) + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = AsyncPipeline( policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send) ) @@ -124,7 +138,10 @@ async def get_token(*scopes, **_): request.set_bytes_body(expected_content) await pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 endpoint = "https://authority.net/tenant" @@ -147,7 +164,8 @@ async def get_token(*scopes, **_): @pytest.mark.asyncio @empty_challenge_cache -async def test_tenant(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_tenant(token_type): """The policy's token requests should pass the parsed tenant ID from the challenge""" expected_content = b"a duck" @@ -173,11 +191,15 @@ async def send(request): return Mock(status_code=200) raise ValueError("unexpected request") - async def get_token(*_, **kwargs): - assert kwargs.get("tenant_id") == expected_tenant - return AccessToken(expected_token, 0) + async def get_token(*_, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("tenant_id") == expected_tenant + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = AsyncPipeline( policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send) ) @@ -185,7 +207,10 @@ async def get_token(*_, **kwargs): request.set_bytes_body(expected_content) await pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 tenant = "tenant-id" endpoint = f"https://authority.net/{tenant}" @@ -201,7 +226,8 @@ async def get_token(*_, **kwargs): @pytest.mark.asyncio @empty_challenge_cache -async def test_adfs(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_adfs(token_type): """The policy should handle AD FS challenges as a special case and omit the tenant ID from token requests""" expected_content = b"a duck" @@ -230,15 +256,21 @@ async def send(request): async def get_token(*_, **kwargs): # we shouldn't provide a tenant ID during AD FS authentication assert "tenant_id" not in kwargs - return AccessToken(expected_token, 0) + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) policy = AsyncChallengeAuthPolicy(credential=credential) pipeline = AsyncPipeline(policies=[policy], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) await pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 # Regression test: https://github.com/Azure/azure-sdk-for-python/issues/33621 policy._token = None @@ -259,7 +291,8 @@ async def get_token(*_, **kwargs): @pytest.mark.asyncio @empty_challenge_cache -async def test_policy_updates_cache(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_policy_updates_cache(token_type): """ It's possible for the challenge returned for a request to change, e.g. when a vault is moved to a new tenant. When the policy receives a 401, it should update the cached challenge for the requested URL, if one exists. @@ -298,29 +331,39 @@ async def test_policy_updates_cache(): ), ) - token = AccessToken(first_token, time.time() + 3600) + token = token_type(first_token, time.time() + 3600) async def get_token(*_, **__): return token - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=transport) # policy should complete and cache the first challenge and access token for _ in range(2): await pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 # The next request will receive a new challenge. The policy should handle it and update caches. - token = AccessToken(second_token, time.time() + 3600) + token = token_type(second_token, time.time() + 3600) for _ in range(2): await pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 2 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 @pytest.mark.asyncio @empty_challenge_cache -async def test_token_expiration(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_token_expiration(token_type): """policy should not use a cached token which has expired""" url = get_random_url() @@ -330,12 +373,15 @@ async def test_token_expiration(): second_token = "**" resource = "https://vault.azure.net" - token = AccessToken(first_token, expires_on) + token = token_type(first_token, expires_on) async def get_token(*_, **__): return token - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = async_validating_transport( requests=[ Request(), @@ -354,17 +400,24 @@ async def get_token(*_, **__): for _ in range(2): await pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 - token = AccessToken(second_token, time.time() + 3600) + token = token_type(second_token, time.time() + 3600) with patch("time.time", lambda: expires_on): await pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 2 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 @pytest.mark.asyncio @empty_challenge_cache -async def test_preserves_options_and_headers(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_preserves_options_and_headers(token_type): """After a challenge, the policy should send the original request with its options and headers preserved""" url = get_random_url() @@ -372,9 +425,12 @@ async def test_preserves_options_and_headers(): resource = "https://vault.azure.net" async def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = async_validating_transport( requests=[Request()] * 2 + [Request(required_headers={"Authorization": "Bearer " + token})], @@ -416,8 +472,8 @@ def verify(request): @pytest.mark.asyncio @empty_challenge_cache -@pytest.mark.parametrize("verify_challenge_resource", [True, False]) -async def test_verify_challenge_resource_matches(verify_challenge_resource): +@pytest.mark.parametrize("verify_challenge_resource,token_type", product([True, False], TOKEN_TYPES)) +async def test_verify_challenge_resource_matches(verify_challenge_resource, token_type): """The auth policy should raise if the challenge resource doesn't match the request URL unless check is disabled""" url = get_random_url() @@ -426,9 +482,12 @@ async def test_verify_challenge_resource_matches(verify_challenge_resource): resource = "https://myvault.azure.net" # Doesn't match a "".vault.azure.net" resource because of the "my" prefix async def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = async_validating_transport( requests=[Request(), Request(required_headers={"Authorization": f"Bearer {token}"})], @@ -469,9 +528,8 @@ async def get_token(*_, **__): @pytest.mark.asyncio -@empty_challenge_cache -@pytest.mark.parametrize("verify_challenge_resource", [True, False]) -async def test_verify_challenge_resource_valid(verify_challenge_resource): +@pytest.mark.parametrize("verify_challenge_resource,token_type", product([True, False], TOKEN_TYPES)) +async def test_verify_challenge_resource_valid(verify_challenge_resource, token_type): """The auth policy should raise if the challenge resource isn't a valid URL unless check is disabled""" url = get_random_url() @@ -479,9 +537,12 @@ async def test_verify_challenge_resource_valid(verify_challenge_resource): resource = "bad-resource" async def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = async_validating_transport( requests=[Request(), Request(required_headers={"Authorization": f"Bearer {token}"})], @@ -502,3 +563,250 @@ async def get_token(*_, **__): else: key = await client.get_key("key-name") assert key.name == "key-name" + + +@pytest.mark.asyncio +@empty_challenge_cache +@pytest.mark.parametrize("token_type", [AccessToken, AccessTokenInfo]) +async def test_cae(token_type): + """The policy should handle claims in a challenge response after having successfully authenticated prior.""" + + expected_content = b"a duck" + + async def test_with_challenge(claims_challenge, expected_claim): + first_token = "first_token" + expected_token = "expected_token" + + class Requests: + count = 0 + + async def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content; triggers a KV challenge response + assert not request.body + assert "Authorization" not in request.headers + assert request.headers["Content-Length"] == "0" + return KV_CHALLENGE_RESPONSE + elif Requests.count == 2: + # second request should be authorized according to challenge and have the expected content + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 3: + # third request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return claims_challenge + elif Requests.count == 4: + # fourth request should include the required claims and correctly use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 5: + # fifth request should be a regular request with the expected token + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return KV_CHALLENGE_RESPONSE + elif Requests.count == 6: + # sixth request should respond to the KV challenge WITHOUT including claims + # we return another challenge to confirm that the policy will return consecutive 401s to the user + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return KV_CHALLENGE_RESPONSE + raise ValueError("unexpected request") + + async def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" + # Response to KV challenge + if Requests.count == 1: + assert options_bag.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + # Response to CAE challenge + elif Requests.count == 3: + assert options_bag.get("claims") == expected_claim + return AccessToken(expected_token, time.time() + 3600) + # Response to second KV challenge + elif Requests.count == 5: + assert options_bag.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + elif Requests.count == 6: + raise ValueError("unexpected token request") + + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) + pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + await pipeline.run(request) # Send the request once to trigger a regular auth challenge + await pipeline.run(request) # Send the request again to trigger a CAE challenge + await pipeline.run(request) # Send the request once to trigger another regular auth challenge + + # token requests made for the CAE challenge and first two KV challenges, but not the final KV challenge + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 3 + else: + assert credential.get_token_info.call_count == 3 + + await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) + + +@pytest.mark.asyncio +@empty_challenge_cache +@pytest.mark.parametrize("token_type", [AccessToken, AccessTokenInfo]) +async def test_cae_consecutive_challenges(token_type): + """The policy should correctly handle consecutive challenges in cases where the flow is valid or invalid.""" + + expected_content = b"a duck" + + async def test_with_challenge(claims_challenge, expected_claim): + first_token = "first_token" + expected_token = "expected_token" + + class Requests: + count = 0 + + async def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content; triggers a KV challenge response + assert not request.body + assert "Authorization" not in request.headers + assert request.headers["Content-Length"] == "0" + return KV_CHALLENGE_RESPONSE + elif Requests.count == 2: + # second request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return claims_challenge + elif Requests.count == 3: + # third request should include the required claims and correctly use content from the first challenge + # we return another CAE challenge to verify that the policy will return consecutive CAE 401s to the user + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return claims_challenge + raise ValueError("unexpected request") + + async def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" + # Response to KV challenge + if Requests.count == 1: + assert options_bag.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + # Response to first CAE challenge + elif Requests.count == 2: + assert options_bag.get("claims") == expected_claim + return AccessToken(expected_token, time.time() + 3600) + + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) + pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + await pipeline.run(request) + + # token requests made for the KV challenge and first CAE challenge, but not the second CAE challenge + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 + + await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) + + +@pytest.mark.asyncio +@empty_challenge_cache +@pytest.mark.parametrize("token_type", [AccessToken, AccessTokenInfo]) +async def test_cae_token_expiry(token_type): + """The policy should avoid sending claims more than once when a token expires.""" + + expected_content = b"a duck" + + async def test_with_challenge(claims_challenge, expected_claim): + first_token = "first_token" + second_token = "second_token" + third_token = "third_token" + + class Requests: + count = 0 + + async def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content; triggers a KV challenge response + assert not request.body + assert "Authorization" not in request.headers + assert request.headers["Content-Length"] == "0" + return KV_CHALLENGE_RESPONSE + elif Requests.count == 2: + # second request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return claims_challenge + elif Requests.count == 3: + # third request should include the required claims and correctly use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert second_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 4: + # fourth request should not include claims, but otherwise use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert third_token in request.headers["Authorization"] + return Mock(status_code=200) + raise ValueError("unexpected request") + + async def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" + # Response to KV challenge + if Requests.count == 1: + assert options_bag.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + # Response to first CAE challenge + elif Requests.count == 2: + assert options_bag.get("claims") == expected_claim + return AccessToken(second_token, 0) # Return a token that expires immediately to trigger a refresh + # Token refresh before making the final request + elif Requests.count == 3: + assert options_bag.get("claims") == None + return AccessToken(third_token, time.time() + 3600) + + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) + pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + await pipeline.run(request) + await pipeline.run(request) # Send the request again to trigger a token refresh upon expiry + + # token requests made for the KV and CAE challenges, as well as for the token refresh + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 3 + else: + assert credential.get_token_info.call_count == 3 + + await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) diff --git a/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md b/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md index a366b4af33a0..e02843b0c0cd 100644 --- a/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` +- Added support for Continuous Access Evaluation (CAE). `enable_cae=True` is passed to all `get_token` requests. ### Breaking Changes @@ -12,6 +13,7 @@ ([#34744](https://github.com/Azure/azure-sdk-for-python/issues/34744)) ### Other Changes +- Updated minimum `azure-core` version to 1.31.0 - Key Vault API version `7.6-preview.1` is now the default ## 4.8.0 (2024-02-22) diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index 1a872f36b6a8..e9b44fc68e55 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -16,33 +16,140 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from typing_extensions import ParamSpec + +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest +from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache -from .challenge_auth_policy import _enforce_tls, _update_challenge +from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge + + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """Policy for handling HTTP authentication challenges. :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity.aio` - :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenProvider """ - def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - super().__init__(credential, *scopes, **kwargs) - self._credential: AsyncTokenCredential = credential - self._token: Optional[AccessToken] = None + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request + super().__init__(credential, *scopes, enable_cae=True, **kwargs) + self._credential: AsyncTokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + async def send( + self, request: PipelineRequest[HttpRequest] + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HttpRequest, AsyncHttpResponse] + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + return await self.handle_challenge_flow(request, response) + return response + + async def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, AsyncHttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + if not claims_challenge: + return await self.handle_challenge_flow(request, response, consecutive_challenge=True) + await await_result(self.on_response, request, response) + return response + + async def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) @@ -51,14 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None: if self._need_new_token(): # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) - else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + await self._request_kv_token(scope, challenge) + + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -78,7 +181,19 @@ async def on_request(self, request: PipelineRequest) -> None: async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -104,11 +219,38 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope) + await self.authorize_request(request, scope, claims=challenge.claims) else: - await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + await self.authorize_request( + request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) return True def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = await self._credential.get_token(scope, enable_cae=True) + else: + self._token = await cast(AsyncTokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index f16297aa5026..eb4073d0e699 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -16,14 +16,21 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, cast, Optional, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenCredential, + TokenProvider, + TokenRequestOptions, + SupportsTokenInfo, +) from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache @@ -36,6 +43,20 @@ def _enforce_tls(request: PipelineRequest) -> None: ) +def _has_claims(challenge: str) -> bool: + """Check if a challenge header contains claims. + + :param challenge: The challenge header to check. + :type challenge: str + + :returns: True if the challenge contains claims; False otherwise. + :rtype: bool + """ + # Split the challenge into its scheme and parameters, then check if any parameter contains claims + split_challenge = challenge.strip().split(" ", 1) + return any("claims=" in item for item in split_challenge[1].split(",")) + + def _update_challenge(request: PipelineRequest, challenger: PipelineResponse) -> HttpChallenge: """Parse challenge from a challenge response, cache it, and return it. @@ -62,16 +83,89 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity` - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) - self._credential: TokenCredential = credential - self._token: Optional[AccessToken] = None + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request + super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) + self._credential: TokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + return self.handle_challenge_flow(request, response) + return response + + def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, HttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, HttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + if not claims_challenge: + return self.handle_challenge_flow(request, response, consecutive_challenge=True) + self.on_response(request, response) + return response + def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) @@ -80,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None: if self._need_new_token: # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) - else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + self._request_kv_token(scope, challenge) + + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -106,7 +196,19 @@ def on_request(self, request: PipelineRequest) -> None: def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -132,12 +234,37 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = self._credential.get_token(scope, enable_cae=True) + else: + self._token = cast(TokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py index df9055c7bda6..0320df5a868b 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 from typing import Dict, MutableMapping, Optional from urllib import parse @@ -18,7 +19,13 @@ class HttpChallenge(object): def __init__( self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None ) -> None: - """Parses an HTTP WWW-Authentication Bearer challenge from a server.""" + """Parses an HTTP WWW-Authentication Bearer challenge from a server. + + Example challenge with claims: + Bearer authorization="https://login.windows-ppe.net/", error="invalid_token", + error_description="User session has been revoked", + claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + """ self.source_authority = self._validate_request_uri(request_uri) self.source_uri = request_uri self._parameters: "Dict[str, str]" = {} @@ -29,16 +36,27 @@ def __init__( self.scheme = split_challenge[0] trimmed_challenge = split_challenge[1] + self.claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): + # Special case for claims, which can contain = symbols as padding. Assume at most one claim per challenge + if "claims=" in item: + encoded_claims = item[item.index("=") + 1 :].strip(" \"'") + padding_needed = -len(encoded_claims) % 4 + try: + decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode() + self.claims = decoded_claims + except Exception: # pylint:disable=broad-except + continue # process name=value pairs - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - self._parameters[key] = value + else: + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value # minimum set of parameters if not self._parameters: diff --git a/sdk/keyvault/azure-keyvault-secrets/setup.py b/sdk/keyvault/azure-keyvault-secrets/setup.py index 989d08c1e98a..54019bafe3b8 100644 --- a/sdk/keyvault/azure-keyvault-secrets/setup.py +++ b/sdk/keyvault/azure-keyvault-secrets/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.29.5", + "azure-core>=1.31.0", "isodate>=0.6.1", "typing-extensions>=4.0.1", ],