Skip to content

Commit

Permalink
ATO-1424: Add keyUse field and filter based on key use
Browse files Browse the repository at this point in the history
  • Loading branch information
cearl1 committed Feb 7, 2025
1 parent 000d7a7 commit c69a885
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public JwkCacheEntry getOrCreateEntry(URL url, int cacheExpiration) {
"Cache entry does not exist for JWKS URL {}, creating new one with expiration of {} seconds",
url,
cacheExpiration);
jwkCacheEntry = JwkCacheEntry.withUrlAndExpiration(url, cacheExpiration);
jwkCacheEntry = JwkCacheEntry.forEncryptionKeys(url, cacheExpiration);
cacheEntryByUrl.put(url.toString(), jwkCacheEntry);
} else {
LOG.info("Cache entry exists for JWKS URL {}", url);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.KeyUse;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import uk.gov.di.orchestration.shared.utils.JwksUtils;
Expand All @@ -13,20 +14,26 @@

public class JwkCacheEntry {
private static final Logger LOG = LogManager.getLogger(JwkCacheEntry.class);
private final KeyUse keyUse;
private final URL jwksUrl;
private final int expirationInSeconds;
private JWK latestKey;
private Date expireTime;

private JwkCacheEntry(URL jwksUrl, int expirationInSeconds) {
private JwkCacheEntry(URL jwksUrl, int expirationInSeconds, KeyUse keyUse) {
this.jwksUrl = jwksUrl;
this.expirationInSeconds = expirationInSeconds;
this.expireTime = NowHelper.nowPlus(this.expirationInSeconds, ChronoUnit.SECONDS);
this.keyUse = keyUse;
this.latestKey = getKeyFromUrl();
}

public static JwkCacheEntry withUrlAndExpiration(URL url, int expirationInSeconds) {
return new JwkCacheEntry(url, expirationInSeconds);
public static JwkCacheEntry forKeyUse(KeyUse keyUse, URL url, int expirationInSeconds) {
return new JwkCacheEntry(url, expirationInSeconds, keyUse);
}

public static JwkCacheEntry forEncryptionKeys(URL url, int expirationInSeconds) {
return new JwkCacheEntry(url, expirationInSeconds, KeyUse.ENCRYPTION);
}

public JWK getKey() {
Expand All @@ -41,8 +48,11 @@ public JWK getKey() {
private JWK getKeyFromUrl() {
try {
List<JWK> jwks = JwksUtils.retrieveJwksFromUrl(jwksUrl);
LOG.info("Found {} JWKs at {}", jwks.size(), jwksUrl);
return jwks.stream().findFirst().orElse(null);
LOG.info("Found {} {} JWKs at {}", jwks.size(), keyUse, jwksUrl);
return jwks.stream()
.filter(key -> keyUse.equals(key.getKeyUse()))
.findFirst()
.orElse(null);
} catch (KeySourceException e) {
throw new RuntimeException("Key sourcing failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void shouldCreateNewJwkCacheEntryIfNotFound() throws Exception {

ENCRYPTION_JWK_CACHE.getOrCreateEntry(testJwksUrl, testExpiry);
mockJwkCacheEntry.verify(
() -> JwkCacheEntry.withUrlAndExpiration(testJwksUrl, testExpiry));
() -> JwkCacheEntry.forEncryptionKeys(testJwksUrl, testExpiry));
}
}

Expand All @@ -37,7 +37,7 @@ void shouldUseExistingEntryIfPresent() throws Exception {
ENCRYPTION_JWK_CACHE.getOrCreateEntry(testJwksUrl, testExpiry);
ENCRYPTION_JWK_CACHE.getOrCreateEntry(testJwksUrl, testExpiry);
mockJwkCacheEntry.verify(
() -> JwkCacheEntry.withUrlAndExpiration(testJwksUrl, testExpiry), times(1));
() -> JwkCacheEntry.forEncryptionKeys(testJwksUrl, testExpiry), times(1));
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package uk.gov.di.orchestration.shared.helpers;

import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.KeyUse;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import uk.gov.di.orchestration.shared.utils.JwksUtils;
Expand All @@ -15,6 +16,7 @@
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;

class JwkCacheEntryTest {
private URL testJwksUrl;
Expand All @@ -24,6 +26,8 @@ class JwkCacheEntryTest {
@BeforeEach
void setup() throws Exception {
testJwksUrl = new URL("http://localhost/.well-known/jwks.json");
when(TEST_KEY_1.getKeyUse()).thenReturn(KeyUse.ENCRYPTION);
when(TEST_KEY_2.getKeyUse()).thenReturn(KeyUse.ENCRYPTION);
}

@Test
Expand Down Expand Up @@ -62,6 +66,19 @@ void shouldCacheFirstKeyIfMultipleKeysArePresent() {
}
}

@Test
void shouldIgnoreFirstKeyIfKeyHasDifferentUse() {
try (var mockJwksUtils = mockStatic(JwksUtils.class)) {
when(TEST_KEY_1.getKeyUse()).thenReturn(KeyUse.SIGNATURE);
mockJwksUtils
.when(() -> JwksUtils.retrieveJwksFromUrl(testJwksUrl))
.thenReturn(List.of(TEST_KEY_1, TEST_KEY_2));

var cacheEntry = createCacheWithNoExpiration();
assertEquals(TEST_KEY_2, cacheEntry.getKey());
}
}

@Test
void shouldRefreshCacheIfExpirationHasPassed() {
try (var mockJwksUtils = mockStatic(JwksUtils.class);
Expand Down Expand Up @@ -106,10 +123,14 @@ void shouldNotRefreshCacheIfExpirationHasNotPassedYet() {
}

private JwkCacheEntry createCacheWithNoExpiration() {
return createCacheWithExpiration(Integer.MAX_VALUE);
return createCacheWithExpiration(KeyUse.ENCRYPTION, Integer.MAX_VALUE);
}

private JwkCacheEntry createCacheWithExpiration(int expiration) {
return JwkCacheEntry.withUrlAndExpiration(testJwksUrl, expiration);
return JwkCacheEntry.forKeyUse(KeyUse.ENCRYPTION, testJwksUrl, expiration);
}

private JwkCacheEntry createCacheWithExpiration(KeyUse keyUse, int expiration) {
return JwkCacheEntry.forKeyUse(keyUse, testJwksUrl, expiration);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ void shouldUseJwkCacheToGetKeyWhenFeatureFlagEnabled() throws Exception {

try (var mockJwksUtils = mockStatic(JwksUtils.class)) {
JWK testKey1 = mock(JWK.class);
when(testKey1.getKeyUse()).thenReturn(KeyUse.ENCRYPTION);
mockJwksUtils
.when(() -> JwksUtils.retrieveJwksFromUrl(testJwksUrl))
.thenReturn(List.of(testKey1));
Expand Down

0 comments on commit c69a885

Please sign in to comment.