diff --git a/README.md b/README.md index 07d9e29e..cc8f38fb 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ Priority order: | identityProviders.*.negativeCacheExpirationMs | 10000 | No |How long to retain JWKS response in the cache in case of failed response. | identityProviders.*.issuerPattern | - | No |Regexp to match the claim "iss" to identity provider. | identityProviders.*.disableJwtVerification | false | No |The flag disables JWT verification. *Note*. `userInfoEndpoint` must be unset if the flag is set to `true`. +| identityProviders.*.audience | - | No |If the setting is set it will be validated against the claim `aud` in JWT | vertx.* | - | No |Vertx settings. Refer to [vertx.io](https://vertx.io/docs/apidocs/io/vertx/core/VertxOptions.html) to learn more. | server.* | - | No |Vertx HTTP server settings for incoming requests. | client.* | - | No |Vertx HTTP client settings for outbound requests. diff --git a/sample/aidial.settings.json b/sample/aidial.settings.json index 69d2f8dd..ff0a5465 100644 --- a/sample/aidial.settings.json +++ b/sample/aidial.settings.json @@ -13,12 +13,14 @@ "jwksUrl": "https://login.microsoftonline.com/path/discovery/keys", "rolePath": "groups", "projectPath": "aud", + "audience": "dial", "issuerPattern": "^https:\\/\\/some\\.windows\\.net.+$" }, "keycloak": { "jwksUrl": "https://host.com/realms/your/protocol/openid-connect/certs", "rolePath": "resource_access.your.roles", "projectPath": "azp", + "audience": "dial", "issuerPattern": "^https:\\/\\/some-keycloak.com.+$" }, "google": { @@ -26,6 +28,7 @@ "projectPath": "aud", "userInfoEndpoint": "https://openidconnect.googleapis.com/v1/userinfo", "loggingKey": "email", + "audience": "dial", "loggingSalt": "salt" }, "cognito": { @@ -33,12 +36,14 @@ "issuerPattern": "^https:\\/\\/cognito-idp\\.eu-north-1\\.amazonaws\\.com.+$", "rolePath": "roles", "projectPath": "aud", + "audience": "dial", "jwksUrl": "https://cognito-idp.eu-north-1.amazonaws.com/eu-north-1_PWSAjo4OY/.well-known/jwks.json", "loggingSalt": "loggingSalt" }, "gitlab": { "rolePath": "groups", "projectPath": "aud", + "audience": "dial", "userInfoEndpoint": "https://gitlab.com/oauth/userinfo", "loggingKey": "email", "loggingSalt": "salt" @@ -48,6 +53,7 @@ "issuerPattern": "^https:\\/\\/chatbot-ui-staging\\.eu\\.auth0\\.com.+$", "rolePath": "dial_roles", "projectPath": "aud", + "audience": "dial", "jwksUrl": "https://.auth0.com/.well-known/jwks.json", "loggingSalt": "loggingSalt" }, @@ -56,6 +62,7 @@ "issuerPattern": "^https:\\/\\/\\.okta\\.com.*$", "rolePath": "Groups", "projectPath": "aud", + "audience": "dial", "jwksUrl": "https://.okta.com/oauth2/default/v1/keys", "loggingSalt": "loggingSalt" }, diff --git a/server/src/main/java/com/epam/aidial/core/server/security/IdentityProvider.java b/server/src/main/java/com/epam/aidial/core/server/security/IdentityProvider.java index 8a79ddfd..9a94a22a 100644 --- a/server/src/main/java/com/epam/aidial/core/server/security/IdentityProvider.java +++ b/server/src/main/java/com/epam/aidial/core/server/security/IdentityProvider.java @@ -7,6 +7,7 @@ import com.auth0.jwt.algorithms.Algorithm; import com.auth0.jwt.interfaces.Claim; import com.auth0.jwt.interfaces.DecodedJWT; +import com.auth0.jwt.interfaces.Verification; import io.vertx.core.Future; import io.vertx.core.Promise; import io.vertx.core.Vertx; @@ -84,6 +85,8 @@ public class IdentityProvider { private final GetUserRoleFn getUserRoleFn; + private final String audience; + public IdentityProvider(JsonObject settings, Vertx vertx, HttpClient client, Function jwkProviderSupplier, GetUserRoleFunctionFactory factory) { if (settings == null) { @@ -153,6 +156,8 @@ public IdentityProvider(JsonObject settings, Vertx vertx, HttpClient client, } obfuscateUserEmail = settings.getBoolean("obfuscateUserEmail", true); + audience = settings.getString("audience", null); + long period = Math.min(negativeCacheExpirationMs, positiveCacheExpirationMs); vertx.setPeriodic(0, period, event -> evictExpiredJwks()); } @@ -235,7 +240,11 @@ private DecodedJWT verifyJwt(DecodedJWT jwt, JwkResult jwkResult) { } Jwk jwk = jwkResult.jwk(); try { - return JWT.require(Algorithm.RSA256((RSAPublicKey) jwk.getPublicKey(), null)).build().verify(jwt); + Verification verification = JWT.require(Algorithm.RSA256((RSAPublicKey) jwk.getPublicKey(), null)); + if (audience != null) { + verification.withAudience(audience); + } + return verification.build().verify(jwt); } catch (JwkException e) { throw new RuntimeException(e); } diff --git a/server/src/test/java/com/epam/aidial/core/server/security/IdentityProviderTest.java b/server/src/test/java/com/epam/aidial/core/server/security/IdentityProviderTest.java index b040a8b6..1e1e82c6 100644 --- a/server/src/test/java/com/epam/aidial/core/server/security/IdentityProviderTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/security/IdentityProviderTest.java @@ -822,6 +822,57 @@ public void testExtractClaims_29() throws JwkException { }); } + @Test + public void testExtractClaims_30() throws JwkException { + settings.put("audience", "dial"); + IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory); + Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate()); + + String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("aud", "dial").withClaim("roles", List.of("manager")).sign(algorithm); + Jwk jwk = mock(Jwk.class); + when(jwk.getPublicKey()).thenReturn(keyPair.getPublic()); + when(jwkProvider.get(eq("kid1"))).thenReturn(jwk); + when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + + Future result = identityProvider.extractClaimsFromJwt(JWT.decode(token)); + + assertNotNull(result); + result.onComplete(res -> { + assertTrue(res.succeeded()); + ExtractedClaims claims = res.result(); + assertNotNull(claims); + assertEquals(List.of("manager"), claims.userRoles()); + }); + } + + @Test + public void testExtractClaims_31() throws JwkException { + settings.put("audience", "dial"); + IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory); + Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate()); + + String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("aud", "wrong_aud").withClaim("roles", List.of("manager")).sign(algorithm); + Jwk jwk = mock(Jwk.class); + when(jwk.getPublicKey()).thenReturn(keyPair.getPublic()); + when(jwkProvider.get(eq("kid1"))).thenReturn(jwk); + when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> { + Callable callable = invocation.getArgument(0); + return Future.succeededFuture(callable.call()); + }); + + Future result = identityProvider.extractClaimsFromJwt(JWT.decode(token)); + + assertNotNull(result); + result.onComplete(res -> { + assertFalse(res.succeeded()); + ExtractedClaims claims = res.result(); + assertNull(claims); + }); + } + @Test public void testExtractClaims_FromUserInfo_01() { settings.remove("jwksUrl");