Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use linked cancellation token source during proactive refresh. #4471

Merged
merged 6 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok

SilentRequestHelper.ProcessFetchInBackground(
cachedAccessTokenItem,
() => GetAccessTokenAsync(cancellationToken, logger), logger);
() =>
{
using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
pmaytak marked this conversation as resolved.
Show resolved Hide resolved
return GetAccessTokenAsync(tokenSource.Token, logger);
}, logger);
}
}
catch (MsalServiceException e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok

SilentRequestHelper.ProcessFetchInBackground(
cachedAccessTokenItem,
() => GetAccessTokenAsync(cancellationToken, logger), logger);
() =>
{
using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
return GetAccessTokenAsync(tokenSource.Token, logger);
}, logger);
}
}
catch (MsalServiceException e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ protected override async Task<AuthenticationResult> ExecuteAsync(CancellationTok

SilentRequestHelper.ProcessFetchInBackground(
cachedAccessToken,
() => RefreshRtOrFetchNewAccessTokenAsync(cancellationToken), logger);
() =>
{
using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
return RefreshRtOrFetchNewAccessTokenAsync(tokenSource.Token);
}, logger);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ public async Task<AuthenticationResult> ExecuteAsync(CancellationToken cancellat

SilentRequestHelper.ProcessFetchInBackground(
cachedAccessTokenItem,
() => RefreshRtOrFailAsync(cancellationToken), logger);
() =>
{
using var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
return RefreshRtOrFailAsync(tokenSource.Token);
}, logger);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ namespace Microsoft.Identity.Client.Internal
internal static class SilentRequestHelper
{
internal const string MamEnrollmentIdKey = "microsoft_enrollment_id";
internal const string ProactiveRefreshServiceError = "Proactive token refresh failed with MsalServiceException.";
internal const string ProactiveRefreshGeneralError = "Proactive token refresh failed with exception.";
internal const string ProactiveRefreshCancellationError = "Proactive token refresh was canceled.";

internal static async Task<MsalTokenResponse> RefreshAccessTokenAsync(MsalRefreshTokenCacheItem msalRefreshTokenItem, RequestBase request, AuthenticationRequestParameters authenticationRequestParameters, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -89,20 +92,23 @@ internal static void ProcessFetchInBackground(
}
catch (MsalServiceException ex)
{
string logMsg = $"Background fetch failed with MsalServiceException. Is exception retryable? { ex.IsRetryable}";
string logMsg = $"{ProactiveRefreshServiceError} Is exception retryable? {ex.IsRetryable}";
if (ex.StatusCode == 400)
{
logger.ErrorPiiWithPrefix(ex, logMsg);
}
else
{
logger.WarningPiiWithPrefix(ex, logMsg);
logger.ErrorPiiWithPrefix(ex, logMsg);
}
}
catch (OperationCanceledException ex)
{
logger.WarningPiiWithPrefix(ex, ProactiveRefreshCancellationError);
}
catch (Exception ex)
{
string logMsg = $"Background fetch failed with exception.";
logger.WarningPiiWithPrefix(ex, logMsg);
logger.ErrorPiiWithPrefix(ex, ProactiveRefreshGeneralError);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,61 @@ public async Task ManagedIdentityIsProactivelyRefreshedAsync()
}
}

[TestMethod]
public async Task ProactiveRefresh_CancelsSuccessfully_Async()
{
bool wasErrorLogged = false;

using (new EnvVariableContext())
using (var httpManager = new MockHttpManager(isManagedIdentity: true))
{
SetEnvironmentVariables(ManagedIdentitySource.AppService, AppServiceEndpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
.WithLogging(LocalLogCallback)
.WithHttpManager(httpManager);

// Disabling shared cache options to avoid cross test pollution.
miBuilder.Config.AccessorOptions = null;

var mi = miBuilder.BuildConcrete();

httpManager.AddManagedIdentityMockHandler(
AppServiceEndpoint,
Resource,
MockHelpers.GetMsiSuccessfulResponse(),
ManagedIdentitySource.AppService);

AuthenticationResult result = await mi.AcquireTokenForManagedIdentity(Resource)
.ExecuteAsync()
.ConfigureAwait(false);

TestCommon.UpdateATWithRefreshOn(mi.AppTokenCacheInternal.Accessor);

var cts = new CancellationTokenSource();
var cancellationToken = cts.Token;
cts.Cancel();
cts.Dispose();

// Act
result = await mi.AcquireTokenForManagedIdentity(Resource)
.ExecuteAsync(cancellationToken)
.ConfigureAwait(false);

// Assert
Assert.IsTrue(TestCommon.YieldTillSatisfied(() => wasErrorLogged));

void LocalLogCallback(LogLevel level, string message, bool containsPii)
{
if (level == LogLevel.Warning &&
message.Contains(SilentRequestHelper.ProactiveRefreshCancellationError))
{
wasErrorLogged = true;
}
}
}
}

[TestMethod]
public async Task ParallelRequests_CallTokenEndpointOnceAsync()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ await app.AcquireTokenSilent(
void LocalLogCallback(LogLevel level, string message, bool containsPii)
{
if (level == LogLevel.Error &&
message.Contains(BackgroundFetch_Failed))
message.Contains(SilentRequestHelper.ProactiveRefreshServiceError))
{
wasErrorLogged = true;
}
Expand Down Expand Up @@ -262,13 +262,51 @@ public void JitterIsAddedToRefreshOn()
Assert.IsTrue(refreshOnWithJitterList.Distinct().Count() >= 8, "Jitter is random, so we can only have 1-2 identical values");
}

[TestMethod]
public async Task ATS_ProactiveRefresh_CancelsSuccessfully_Async()
{
bool wasErrorLogged = false;

// Arrange
using MockHttpAndServiceBundle harness = base.CreateTestHarness();
harness.HttpManager.AddInstanceDiscoveryMockHandler();

PublicClientApplication app = SetupPca(harness, LocalLogCallback);
TestCommon.UpdateATWithRefreshOn(app.UserTokenCacheInternal.Accessor);

var account = new Account(TestConstants.s_userIdentifier, TestConstants.DisplayableId, null);

var cts = new CancellationTokenSource();
var cancellationToken = cts.Token;
cts.Cancel();
cts.Dispose();

// Act
await app.AcquireTokenSilent(
TestConstants.s_scope.ToArray(),
account)
.ExecuteAsync(cancellationToken)
.ConfigureAwait(false);

Assert.IsTrue(TestCommon.YieldTillSatisfied(() => wasErrorLogged));

void LocalLogCallback(LogLevel level, string message, bool containsPii)
{
if (level == LogLevel.Warning &&
message.Contains(SilentRequestHelper.ProactiveRefreshCancellationError))
{
wasErrorLogged = true;
}
}
}

#endregion

#region Client Creds
#region Client Credentials

[TestMethod]
[Description("AT in cache, needs refresh. AAD responds well to Refresh.")]
public async Task ClientCreds_NonExpired_NeedsRefresh_ValidResponse_Async()
public async Task ClientCredentials_NonExpired_NeedsRefresh_ValidResponse_Async()
{
// Arrange
using (MockHttpAndServiceBundle harness = base.CreateTestHarness())
Expand Down Expand Up @@ -309,7 +347,7 @@ public async Task ClientCreds_NonExpired_NeedsRefresh_ValidResponse_Async()

[TestMethod]
[Description("AT in cache, needs refresh. AAD responds well to Refresh.")]
public async Task ClientCreds_OnBehalfOf_NonExpired_NeedsRefresh_ValidResponse_Async()
public async Task ClientCredentials_OnBehalfOf_NonExpired_NeedsRefresh_ValidResponse_Async()
{
// Arrange
using (MockHttpAndServiceBundle harness = base.CreateTestHarness())
Expand Down Expand Up @@ -362,7 +400,7 @@ private static ConfidentialClientApplication SetupCca(MockHttpAndServiceBundle h

[TestMethod]
[Description("AT in cache, needs refresh. AAD is unavailable when refreshing.")]
public async Task ClientCreds_NonExpired_NeedsRefresh_AADUnavailableResponse_Async()
public async Task ClientCredentials_NonExpired_NeedsRefresh_AadUnavailableResponse_Async()
{
// Arrange
using (MockHttpAndServiceBundle harness = base.CreateTestHarness())
Expand All @@ -386,7 +424,7 @@ public async Task ClientCreds_NonExpired_NeedsRefresh_AADUnavailableResponse_Asy
.ConfigureAwait(false);

// Assert
Assert.IsNotNull(result, "ClientCreds should still succeeds even though AAD is unavailable");
Assert.IsNotNull(result, "ClientCredentials should still succeeds even though AAD is unavailable");
TestCommon.YieldTillSatisfied(() => harness.HttpManager.QueueSize == 0);
Assert.AreEqual(0, harness.HttpManager.QueueSize);
cacheAccess.WaitTo_AssertAcessCounts(1, 0); // the refresh failed, no new data is written to the cache
Expand All @@ -404,7 +442,7 @@ public async Task ClientCreds_NonExpired_NeedsRefresh_AADUnavailableResponse_Asy
}

[TestMethod]
public async Task ClientCreds_NonExpired_NeedsRefresh_AADInvalidResponse_Async()
public async Task ClientCredentials_NonExpired_NeedsRefresh_AadInvalidResponse_Async()
{
bool wasErrorLogged = false;
// Arrange
Expand Down Expand Up @@ -434,7 +472,7 @@ await app.AcquireTokenForClient(TestConstants.s_scope)
void LocalLogCallback(LogLevel level, string message, bool containsPii)
{
if (level == LogLevel.Error &&
message.Contains(BackgroundFetch_Failed))
message.Contains(SilentRequestHelper.ProactiveRefreshServiceError))
{
wasErrorLogged = true;
}
Expand All @@ -443,7 +481,7 @@ void LocalLogCallback(LogLevel level, string message, bool containsPii)

[TestMethod]
[Description("AT expired. AAD fails but is available when refreshing.")]
public async Task ClientCreds_Expired_NeedsRefresh_AADInvalidResponse_Async()
public async Task ClientCredentials_Expired_NeedsRefresh_AADInvalidResponse_Async()
{
// Arrange
using (MockHttpAndServiceBundle harness = base.CreateTestHarness())
Expand All @@ -470,8 +508,39 @@ public async Task ClientCreds_Expired_NeedsRefresh_AADInvalidResponse_Async()
}
}

#endregion
[TestMethod]
public async Task ClientCredentials_ProactiveRefresh_CancelsSuccessfully_Async()
{
bool wasErrorLogged = false;

// Arrange
using MockHttpAndServiceBundle harness = CreateTestHarness();
harness.HttpManager.AddInstanceDiscoveryMockHandler();

ConfidentialClientApplication app = SetupCca(harness, LocalLogCallback);
TestCommon.UpdateATWithRefreshOn(app.AppTokenCacheInternal.Accessor);

var cts = new CancellationTokenSource();
var cancellationToken = cts.Token;
cts.Cancel();
cts.Dispose();

// Act
await app.AcquireTokenForClient(TestConstants.s_scope)
.ExecuteAsync(cancellationToken)
.ConfigureAwait(false);

private const string BackgroundFetch_Failed = "Background fetch failed";
Assert.IsTrue(TestCommon.YieldTillSatisfied(() => wasErrorLogged));

void LocalLogCallback(LogLevel level, string message, bool containsPii)
{
if (level == LogLevel.Warning &&
message.Contains(SilentRequestHelper.ProactiveRefreshCancellationError))
{
wasErrorLogged = true;
}
}
}
#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Cache.Items;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Client.Http;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.OAuth2;
using Microsoft.Identity.Test.Common;
using Microsoft.Identity.Test.Common.Core.Helpers;
Expand Down Expand Up @@ -79,6 +81,42 @@ public async Task LongRunningObo_RunsSuccessfully_TestAsync()
}
}

[TestMethod]
public async Task ProactiveRefresh_CancelsSuccessfully_Async()
{
bool wasErrorLogged = false;

using var httpManager = new MockHttpManager();
httpManager.AddInstanceDiscoveryMockHandler();
AddMockHandlerAadSuccess(httpManager);

var cca = BuildCCA(httpManager, LocalLogCallback);

string oboCacheKey = "obo-cache-key";
var result = await cca.InitiateLongRunningProcessInWebApi(TestConstants.s_scope, TestConstants.DefaultAccessToken, ref oboCacheKey)
.ExecuteAsync().ConfigureAwait(false);

TestCommon.UpdateATWithRefreshOn(cca.UserTokenCacheInternal.Accessor);

var cts = new CancellationTokenSource();
var cancellationToken = cts.Token;
cts.Cancel();
cts.Dispose();

result = await cca.AcquireTokenInLongRunningProcess(TestConstants.s_scope, oboCacheKey).ExecuteAsync(cancellationToken).ConfigureAwait(false);

Assert.IsTrue(TestCommon.YieldTillSatisfied(() => wasErrorLogged));

void LocalLogCallback(LogLevel level, string message, bool containsPii)
{
if (level == LogLevel.Warning &&
message.Contains(SilentRequestHelper.ProactiveRefreshCancellationError))
{
wasErrorLogged = true;
}
}
}

[TestMethod]
public async Task InitiateLongRunningObo_WithExistingKeyAndToken_TestAsync()
{
Expand Down Expand Up @@ -913,14 +951,18 @@ private MockHttpMessageHandler AddMockHandlerAadSuccess(
return handler;
}

private ConfidentialClientApplication BuildCCA(IHttpManager httpManager)
private ConfidentialClientApplication BuildCCA(IHttpManager httpManager, LogCallback logCallback = null)
{
return ConfidentialClientApplicationBuilder
.Create(TestConstants.ClientId)
.WithClientSecret(TestConstants.ClientSecret)
.WithAuthority(TestConstants.AuthorityCommonTenant)
.WithHttpManager(httpManager)
.BuildConcrete();
var builder = ConfidentialClientApplicationBuilder
.Create(TestConstants.ClientId)
.WithClientSecret(TestConstants.ClientSecret)
.WithAuthority(TestConstants.AuthorityCommonTenant)
.WithHttpManager(httpManager);
if (logCallback != null)
{
builder.WithLogging(logCallback);
}
return builder.BuildConcrete();
}
}
}
Loading