Skip to content

Commit

Permalink
Update the client and server stacks to automatically restore the auth…
Browse files Browse the repository at this point in the history
…entication properties and attach them to the authentication context
  • Loading branch information
kevinchalet committed Jan 5, 2024
1 parent 29d7197 commit 3753229
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Globalization;
using System.Security.Claims;
using System.Text.Encodings.Web;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using static OpenIddict.Client.AspNetCore.OpenIddictClientAspNetCoreConstants;
Expand Down Expand Up @@ -169,14 +168,20 @@ protected override async Task<AuthenticateResult> HandleAuthenticateAsync()

else
{
// Restore or create a new authentication properties collection and populate it.
var properties = CreateProperties(context.StateTokenPrincipal);
properties.ExpiresUtc = context.StateTokenPrincipal?.GetExpirationDate();
properties.IssuedUtc = context.StateTokenPrincipal?.GetCreationDate();
var properties = new AuthenticationProperties
{
ExpiresUtc = context.StateTokenPrincipal?.GetExpirationDate(),
IssuedUtc = context.StateTokenPrincipal?.GetCreationDate(),

// Restore the target link URI that was stored in the state
// token when the challenge operation started, if available.
RedirectUri = context.StateTokenPrincipal?.GetClaim(Claims.TargetLinkUri)
};

// Restore the target link URI that was stored in the state
// token when the challenge operation started, if available.
properties.RedirectUri = context.StateTokenPrincipal?.GetClaim(Claims.TargetLinkUri);
foreach (var property in context.Properties)
{
properties.Items[property.Key] = property.Value;
}

List<AuthenticationToken>? tokens = null;

Expand Down Expand Up @@ -334,29 +339,6 @@ protected override async Task<AuthenticateResult> HandleAuthenticateAsync()
return AuthenticateResult.Success(new AuthenticationTicket(
context.MergedPrincipal ?? new ClaimsPrincipal(new ClaimsIdentity()), properties,
OpenIddictClientAspNetCoreDefaults.AuthenticationScheme));

static AuthenticationProperties CreateProperties(ClaimsPrincipal? principal)
{
// Note: the principal may be null if no value was extracted from the corresponding token.
if (principal is not null)
{
var value = principal.GetClaim(Claims.Private.HostProperties);
if (!string.IsNullOrEmpty(value))
{
var dictionary = new Dictionary<string, string?>(comparer: StringComparer.Ordinal);
using var document = JsonDocument.Parse(value);

foreach (var property in document.RootElement.EnumerateObject())
{
dictionary[property.Name] = property.Value.GetString();
}

return new AuthenticationProperties(dictionary);
}
}

return new AuthenticationProperties();
}
}
}

Expand Down
44 changes: 13 additions & 31 deletions src/OpenIddict.Client.Owin/OpenIddictClientOwinHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Security.Claims;
using System.Text.Json;
using Microsoft.Owin.Security.Infrastructure;
using static OpenIddict.Client.Owin.OpenIddictClientOwinConstants;
using Properties = OpenIddict.Client.Owin.OpenIddictClientOwinConstants.Properties;
Expand Down Expand Up @@ -168,14 +167,20 @@ public override async Task<bool> InvokeAsync()

else
{
// Restore or create a new authentication properties collection and populate it.
var properties = CreateProperties(context.StateTokenPrincipal);
properties.ExpiresUtc = context.StateTokenPrincipal?.GetExpirationDate();
properties.IssuedUtc = context.StateTokenPrincipal?.GetCreationDate();
var properties = new AuthenticationProperties
{
ExpiresUtc = context.StateTokenPrincipal?.GetExpirationDate(),
IssuedUtc = context.StateTokenPrincipal?.GetCreationDate(),

// Restore the target link URI that was stored in the state
// token when the challenge operation started, if available.
RedirectUri = context.StateTokenPrincipal?.GetClaim(Claims.TargetLinkUri)
};

// Restore the target link URI that was stored in the state
// token when the challenge operation started, if available.
properties.RedirectUri = context.StateTokenPrincipal?.GetClaim(Claims.TargetLinkUri);
foreach (var property in context.Properties)
{
properties.Dictionary[property.Key] = property.Value;
}

// Attach the tokens to allow any OWIN component (e.g a controller)
// to retrieve them (e.g to make an API request to another application).
Expand Down Expand Up @@ -236,29 +241,6 @@ public override async Task<bool> InvokeAsync()
}

return new AuthenticationTicket(context.MergedPrincipal?.Identity as ClaimsIdentity, properties);

static AuthenticationProperties CreateProperties(ClaimsPrincipal? principal)
{
// Note: the principal may be null if no value was extracted from the corresponding token.
if (principal is not null)
{
var value = principal.GetClaim(Claims.Private.HostProperties);
if (!string.IsNullOrEmpty(value))
{
var dictionary = new Dictionary<string, string?>(comparer: StringComparer.Ordinal);
using var document = JsonDocument.Parse(value);

foreach (var property in document.RootElement.EnumerateObject())
{
dictionary[property.Name] = property.Value.GetString();
}

return new AuthenticationProperties(dictionary);
}
}

return new AuthenticationProperties();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public static partial class OpenIddictClientSystemIntegrationHandlers

RestoreStateTokenFromMarshalledAuthentication.Descriptor,
RestoreStateTokenPrincipalFromMarshalledAuthentication.Descriptor,
RestoreHostAuthenticationPropertiesFromMarshalledAuthentication.Descriptor,
RestoreClientRegistrationFromMarshalledContext.Descriptor,

RedirectProtocolActivation.Descriptor,
Expand Down Expand Up @@ -699,6 +700,53 @@ OpenIddictClientEndpointType.Unknown when _marshal.TryGetResult(context.Nonce, o
}
}

/// <summary>
/// Contains the logic responsible for restoring the host authentication
/// properties from the marshalled authentication context, if applicable.
/// </summary>
public sealed class RestoreHostAuthenticationPropertiesFromMarshalledAuthentication : IOpenIddictClientHandler<ProcessAuthenticationContext>
{
private readonly OpenIddictClientSystemIntegrationMarshal _marshal;

public RestoreHostAuthenticationPropertiesFromMarshalledAuthentication(OpenIddictClientSystemIntegrationMarshal marshal)
=> _marshal = marshal ?? throw new ArgumentNullException(nameof(marshal));

/// <summary>
/// Gets the default descriptor definition assigned to this handler.
/// </summary>
public static OpenIddictClientHandlerDescriptor Descriptor { get; }
= OpenIddictClientHandlerDescriptor.CreateBuilder<ProcessAuthenticationContext>()
.AddFilter<RequireAuthenticationNonce>()
.UseSingletonHandler<RestoreHostAuthenticationPropertiesFromMarshalledAuthentication>()
.SetOrder(ResolveHostAuthenticationPropertiesFromStateToken.Descriptor.Order + 500)
.SetType(OpenIddictClientHandlerType.BuiltIn)
.Build();

/// <inheritdoc/>
public ValueTask HandleAsync(ProcessAuthenticationContext context)
{
if (context is null)
{
throw new ArgumentNullException(nameof(context));
}

Debug.Assert(!string.IsNullOrEmpty(context.Nonce), SR.GetResourceString(SR.ID4019));

// When the authentication context is marshalled, restore the
// host authentication properties from the other instance.
if (context.EndpointType is OpenIddictClientEndpointType.Unknown &&
_marshal.TryGetResult(context.Nonce, out var notification))
{
foreach (var property in notification.Properties)
{
context.Properties[property.Key] = property.Value;
}
}

return default;
}
}

/// <summary>
/// Contains the logic responsible for restoring the client registration and
/// configuration from the marshalled authentication context, if applicable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,9 @@ public ValueTask HandleAsync(ProcessAuthenticationContext context)
return default;
}

// Resolve the shop name from the authentication properties stored in the state token principal.
if (context.StateTokenPrincipal.FindFirst(Claims.Private.HostProperties)?.Value is not string value ||
JsonSerializer.Deserialize<JsonElement>(value) is not { ValueKind: JsonValueKind.Object } properties ||
!properties.TryGetProperty(Shopify.Properties.ShopName, out JsonElement name))
// Resolve the shop name from the authentication properties.
if (!context.Properties.TryGetValue(Shopify.Properties.ShopName, out string? name) ||
string.IsNullOrEmpty(name))
{
throw new InvalidOperationException(SR.GetResourceString(SR.ID0412));
}
Expand Down Expand Up @@ -352,13 +351,9 @@ public ValueTask HandleAsync(ProcessAuthenticationContext context)
//
// For more information, see
// https://shopify.dev/docs/apps/auth/oauth/getting-started#step-5-get-an-access-token.
ProviderTypes.Shopify when context.GrantType is GrantTypes.AuthorizationCode =>
context.StateTokenPrincipal is ClaimsPrincipal principal &&
principal.FindFirst(Claims.Private.HostProperties)?.Value is string value &&
JsonSerializer.Deserialize<JsonElement>(value) is { ValueKind: JsonValueKind.Object } properties &&
properties.TryGetProperty(Shopify.Properties.ShopName, out JsonElement name) ?
new Uri($"https://{name}.myshopify.com/admin/oauth/access_token", UriKind.Absolute) :
throw new InvalidOperationException(SR.GetResourceString(SR.ID0412)),
ProviderTypes.Shopify when context.GrantType is GrantTypes.AuthorizationCode &&
context.Properties[Shopify.Properties.ShopName] is var name =>
new Uri($"https://{name}.myshopify.com/admin/oauth/access_token", UriKind.Absolute),

// Trovo uses a different token endpoint for the refresh token grant.
//
Expand Down Expand Up @@ -1233,9 +1228,8 @@ public ValueTask HandleAsync(ProcessChallengeContext context)
//
// For more information, see
// https://shopify.dev/docs/apps/auth/oauth/getting-started#step-3-ask-for-permission.
ProviderTypes.Shopify => context.Properties.TryGetValue(Shopify.Properties.ShopName, out string? name) ?
new Uri($"https://{name}.myshopify.com/admin/oauth/authorize", UriKind.Absolute) :
throw new InvalidOperationException(SR.GetResourceString(SR.ID0412)),
ProviderTypes.Shopify when context.Properties[Shopify.Properties.ShopName] is var name =>
new Uri($"https://{name}.myshopify.com/admin/oauth/authorize", UriKind.Absolute),

// Stripe uses a different authorization endpoint for express accounts.
//
Expand Down
45 changes: 44 additions & 1 deletion src/OpenIddict.Client/OpenIddictClientHandlers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Runtime.InteropServices;
using System.Security.Claims;
using System.Text;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
using Microsoft.IdentityModel.Tokens;
Expand All @@ -36,6 +37,7 @@ public static partial class OpenIddictClientHandlers
ResolveValidatedStateToken.Descriptor,
ValidateRequiredStateToken.Descriptor,
ValidateStateToken.Descriptor,
ResolveHostAuthenticationPropertiesFromStateToken.Descriptor,
ResolveNonceFromStateToken.Descriptor,
RedeemStateTokenEntry.Descriptor,
ValidateStateTokenEndpointType.Descriptor,
Expand Down Expand Up @@ -658,6 +660,47 @@ public async ValueTask HandleAsync(ProcessAuthenticationContext context)
}
}

/// <summary>
/// Contains the logic responsible for resolving the host authentication properties from the state token principal.
/// </summary>
public sealed class ResolveHostAuthenticationPropertiesFromStateToken : IOpenIddictClientHandler<ProcessAuthenticationContext>
{
/// <summary>
/// Gets the default descriptor definition assigned to this handler.
/// </summary>
public static OpenIddictClientHandlerDescriptor Descriptor { get; }
= OpenIddictClientHandlerDescriptor.CreateBuilder<ProcessAuthenticationContext>()
.AddFilter<RequireStateTokenPrincipal>()
.AddFilter<RequireStateTokenValidated>()
.UseSingletonHandler<ResolveHostAuthenticationPropertiesFromStateToken>()
.SetOrder(ValidateStateToken.Descriptor.Order + 1_000)
.Build();

/// <inheritdoc/>
public ValueTask HandleAsync(ProcessAuthenticationContext context)
{
if (context is null)
{
throw new ArgumentNullException(nameof(context));
}

Debug.Assert(context.StateTokenPrincipal is { Identity: ClaimsIdentity }, SR.GetResourceString(SR.ID4006));

var properties = context.StateTokenPrincipal.GetClaim(Claims.Private.HostProperties);
if (!string.IsNullOrEmpty(properties))
{
using var document = JsonDocument.Parse(properties);

foreach (var property in document.RootElement.EnumerateObject())
{
context.Properties[property.Name] = property.Value.GetString();
}
}

return default;
}
}

/// <summary>
/// Contains the logic responsible for resolving the nonce identifying
/// the authentication operation from the state token principal.
Expand All @@ -672,7 +715,7 @@ public sealed class ResolveNonceFromStateToken : IOpenIddictClientHandler<Proces
.AddFilter<RequireStateTokenPrincipal>()
.AddFilter<RequireStateTokenValidated>()
.UseSingletonHandler<ResolveNonceFromStateToken>()
.SetOrder(ValidateStateToken.Descriptor.Order + 1_000)
.SetOrder(ResolveHostAuthenticationPropertiesFromStateToken.Descriptor.Order + 1_000)
.Build();

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using System.ComponentModel;
using System.Security.Claims;
using System.Text.Encodings.Web;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using static OpenIddict.Server.AspNetCore.OpenIddictServerAspNetCoreConstants;
Expand Down Expand Up @@ -200,10 +199,16 @@ OpenIddictServerEndpointType.Token when context.Request.IsRefreshTokenGrantType(
_ => null
};

// Restore or create a new authentication properties collection and populate it.
var properties = CreateProperties(principal);
properties.ExpiresUtc = principal?.GetExpirationDate();
properties.IssuedUtc = principal?.GetCreationDate();
var properties = new AuthenticationProperties
{
ExpiresUtc = principal?.GetExpirationDate(),
IssuedUtc = principal?.GetCreationDate()
};

foreach (var property in context.Properties)
{
properties.Items[property.Key] = property.Value;
}

List<AuthenticationToken>? tokens = null;

Expand Down Expand Up @@ -324,29 +329,6 @@ OpenIddictServerEndpointType.Token when context.Request.IsRefreshTokenGrantType(
principal ?? new ClaimsPrincipal(new ClaimsIdentity()), properties,
OpenIddictServerAspNetCoreDefaults.AuthenticationScheme));
}

static AuthenticationProperties CreateProperties(ClaimsPrincipal? principal)
{
// Note: the principal may be null if no value was extracted from the corresponding token.
if (principal is not null)
{
var value = principal.GetClaim(Claims.Private.HostProperties);
if (!string.IsNullOrEmpty(value))
{
var dictionary = new Dictionary<string, string?>(comparer: StringComparer.Ordinal);
using var document = JsonDocument.Parse(value);

foreach (var property in document.RootElement.EnumerateObject())
{
dictionary[property.Name] = property.Value.GetString();
}

return new AuthenticationProperties(dictionary);
}
}

return new AuthenticationProperties();
}
}

/// <inheritdoc/>
Expand Down
Loading

0 comments on commit 3753229

Please sign in to comment.