From c69a885f93954459d0d81767d541e61465acdc79 Mon Sep 17 00:00:00 2001 From: Craig Earl Date: Fri, 7 Feb 2025 14:44:07 +0000 Subject: [PATCH] ATO-1424: Add keyUse field and filter based on key use --- .../shared/helpers/EncryptionJwkCache.java | 2 +- .../shared/helpers/JwkCacheEntry.java | 20 +++++++++++---- .../helpers/EncryptionJwkCacheTest.java | 4 +-- .../shared/helpers/JwkCacheEntryTest.java | 25 +++++++++++++++++-- .../shared/services/JwksServiceTest.java | 1 + 5 files changed, 42 insertions(+), 10 deletions(-) diff --git a/orchestration-shared/src/main/java/uk/gov/di/orchestration/shared/helpers/EncryptionJwkCache.java b/orchestration-shared/src/main/java/uk/gov/di/orchestration/shared/helpers/EncryptionJwkCache.java index d11ca143d3..a91a09f852 100644 --- a/orchestration-shared/src/main/java/uk/gov/di/orchestration/shared/helpers/EncryptionJwkCache.java +++ b/orchestration-shared/src/main/java/uk/gov/di/orchestration/shared/helpers/EncryptionJwkCache.java @@ -27,7 +27,7 @@ public JwkCacheEntry getOrCreateEntry(URL url, int cacheExpiration) { "Cache entry does not exist for JWKS URL {}, creating new one with expiration of {} seconds", url, cacheExpiration); - jwkCacheEntry = JwkCacheEntry.withUrlAndExpiration(url, cacheExpiration); + jwkCacheEntry = JwkCacheEntry.forEncryptionKeys(url, cacheExpiration); cacheEntryByUrl.put(url.toString(), jwkCacheEntry); } else { LOG.info("Cache entry exists for JWKS URL {}", url); diff --git a/orchestration-shared/src/main/java/uk/gov/di/orchestration/shared/helpers/JwkCacheEntry.java b/orchestration-shared/src/main/java/uk/gov/di/orchestration/shared/helpers/JwkCacheEntry.java index c1793bd853..aa9b0b4809 100644 --- a/orchestration-shared/src/main/java/uk/gov/di/orchestration/shared/helpers/JwkCacheEntry.java +++ b/orchestration-shared/src/main/java/uk/gov/di/orchestration/shared/helpers/JwkCacheEntry.java @@ -2,6 +2,7 @@ import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.KeyUse; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import uk.gov.di.orchestration.shared.utils.JwksUtils; @@ -13,20 +14,26 @@ public class JwkCacheEntry { private static final Logger LOG = LogManager.getLogger(JwkCacheEntry.class); + private final KeyUse keyUse; private final URL jwksUrl; private final int expirationInSeconds; private JWK latestKey; private Date expireTime; - private JwkCacheEntry(URL jwksUrl, int expirationInSeconds) { + private JwkCacheEntry(URL jwksUrl, int expirationInSeconds, KeyUse keyUse) { this.jwksUrl = jwksUrl; this.expirationInSeconds = expirationInSeconds; this.expireTime = NowHelper.nowPlus(this.expirationInSeconds, ChronoUnit.SECONDS); + this.keyUse = keyUse; this.latestKey = getKeyFromUrl(); } - public static JwkCacheEntry withUrlAndExpiration(URL url, int expirationInSeconds) { - return new JwkCacheEntry(url, expirationInSeconds); + public static JwkCacheEntry forKeyUse(KeyUse keyUse, URL url, int expirationInSeconds) { + return new JwkCacheEntry(url, expirationInSeconds, keyUse); + } + + public static JwkCacheEntry forEncryptionKeys(URL url, int expirationInSeconds) { + return new JwkCacheEntry(url, expirationInSeconds, KeyUse.ENCRYPTION); } public JWK getKey() { @@ -41,8 +48,11 @@ public JWK getKey() { private JWK getKeyFromUrl() { try { List jwks = JwksUtils.retrieveJwksFromUrl(jwksUrl); - LOG.info("Found {} JWKs at {}", jwks.size(), jwksUrl); - return jwks.stream().findFirst().orElse(null); + LOG.info("Found {} {} JWKs at {}", jwks.size(), keyUse, jwksUrl); + return jwks.stream() + .filter(key -> keyUse.equals(key.getKeyUse())) + .findFirst() + .orElse(null); } catch (KeySourceException e) { throw new RuntimeException("Key sourcing failed", e); } diff --git a/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/helpers/EncryptionJwkCacheTest.java b/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/helpers/EncryptionJwkCacheTest.java index b68a2b2351..2a2b5e2f16 100644 --- a/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/helpers/EncryptionJwkCacheTest.java +++ b/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/helpers/EncryptionJwkCacheTest.java @@ -24,7 +24,7 @@ void shouldCreateNewJwkCacheEntryIfNotFound() throws Exception { ENCRYPTION_JWK_CACHE.getOrCreateEntry(testJwksUrl, testExpiry); mockJwkCacheEntry.verify( - () -> JwkCacheEntry.withUrlAndExpiration(testJwksUrl, testExpiry)); + () -> JwkCacheEntry.forEncryptionKeys(testJwksUrl, testExpiry)); } } @@ -37,7 +37,7 @@ void shouldUseExistingEntryIfPresent() throws Exception { ENCRYPTION_JWK_CACHE.getOrCreateEntry(testJwksUrl, testExpiry); ENCRYPTION_JWK_CACHE.getOrCreateEntry(testJwksUrl, testExpiry); mockJwkCacheEntry.verify( - () -> JwkCacheEntry.withUrlAndExpiration(testJwksUrl, testExpiry), times(1)); + () -> JwkCacheEntry.forEncryptionKeys(testJwksUrl, testExpiry), times(1)); } } } diff --git a/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/helpers/JwkCacheEntryTest.java b/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/helpers/JwkCacheEntryTest.java index 4f9a4103b0..2e3ca0f87d 100644 --- a/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/helpers/JwkCacheEntryTest.java +++ b/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/helpers/JwkCacheEntryTest.java @@ -1,6 +1,7 @@ package uk.gov.di.orchestration.shared.helpers; import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.KeyUse; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import uk.gov.di.orchestration.shared.utils.JwksUtils; @@ -15,6 +16,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; class JwkCacheEntryTest { private URL testJwksUrl; @@ -24,6 +26,8 @@ class JwkCacheEntryTest { @BeforeEach void setup() throws Exception { testJwksUrl = new URL("http://localhost/.well-known/jwks.json"); + when(TEST_KEY_1.getKeyUse()).thenReturn(KeyUse.ENCRYPTION); + when(TEST_KEY_2.getKeyUse()).thenReturn(KeyUse.ENCRYPTION); } @Test @@ -62,6 +66,19 @@ void shouldCacheFirstKeyIfMultipleKeysArePresent() { } } + @Test + void shouldIgnoreFirstKeyIfKeyHasDifferentUse() { + try (var mockJwksUtils = mockStatic(JwksUtils.class)) { + when(TEST_KEY_1.getKeyUse()).thenReturn(KeyUse.SIGNATURE); + mockJwksUtils + .when(() -> JwksUtils.retrieveJwksFromUrl(testJwksUrl)) + .thenReturn(List.of(TEST_KEY_1, TEST_KEY_2)); + + var cacheEntry = createCacheWithNoExpiration(); + assertEquals(TEST_KEY_2, cacheEntry.getKey()); + } + } + @Test void shouldRefreshCacheIfExpirationHasPassed() { try (var mockJwksUtils = mockStatic(JwksUtils.class); @@ -106,10 +123,14 @@ void shouldNotRefreshCacheIfExpirationHasNotPassedYet() { } private JwkCacheEntry createCacheWithNoExpiration() { - return createCacheWithExpiration(Integer.MAX_VALUE); + return createCacheWithExpiration(KeyUse.ENCRYPTION, Integer.MAX_VALUE); } private JwkCacheEntry createCacheWithExpiration(int expiration) { - return JwkCacheEntry.withUrlAndExpiration(testJwksUrl, expiration); + return JwkCacheEntry.forKeyUse(KeyUse.ENCRYPTION, testJwksUrl, expiration); + } + + private JwkCacheEntry createCacheWithExpiration(KeyUse keyUse, int expiration) { + return JwkCacheEntry.forKeyUse(keyUse, testJwksUrl, expiration); } } diff --git a/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/services/JwksServiceTest.java b/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/services/JwksServiceTest.java index 4851c0b175..8f9c672a8b 100644 --- a/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/services/JwksServiceTest.java +++ b/orchestration-shared/src/test/java/uk/gov/di/orchestration/shared/services/JwksServiceTest.java @@ -100,6 +100,7 @@ void shouldUseJwkCacheToGetKeyWhenFeatureFlagEnabled() throws Exception { try (var mockJwksUtils = mockStatic(JwksUtils.class)) { JWK testKey1 = mock(JWK.class); + when(testKey1.getKeyUse()).thenReturn(KeyUse.ENCRYPTION); mockJwksUtils .when(() -> JwksUtils.retrieveJwksFromUrl(testJwksUrl)) .thenReturn(List.of(testKey1));