From 39a711eede6eff5545866de6f1a38b7912402266 Mon Sep 17 00:00:00 2001 From: Bryan Donlan Date: Thu, 5 Apr 2018 09:46:46 -0700 Subject: [PATCH] Fix bare aliases not using default region The default region was not actually being consulted when presented with a regionless key ID (such as a bare UUID or an "alias/foo" value). Fixes #50. --- .../kms/KmsMasterKeyProvider.java | 5 ++ .../kms/KMSProviderBuilderMockTests.java | 47 +++++++++++++++++++ .../amazonaws/services/kms/MockKMSClient.java | 27 ++++++----- 3 files changed, 68 insertions(+), 11 deletions(-) diff --git a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java index 8989e4a83..c85286881 100644 --- a/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java +++ b/src/main/java/com/amazonaws/encryptionsdk/kms/KmsMasterKeyProvider.java @@ -450,6 +450,11 @@ public KmsMasterKey getMasterKey(final String provider, final String keyId) thro } String regionName = parseRegionfromKeyArn(keyId); + + if (regionName == null && defaultRegion_ != null) { + regionName = defaultRegion_; + } + AWSKMS kms = regionalClientSupplier_.getClient(regionName); if (kms == null) { throw new AwsCryptoException("Can't use keys from region " + regionName); diff --git a/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java b/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java index 62f512d20..093faf8fb 100644 --- a/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java +++ b/src/test/java/com/amazonaws/services/kms/KMSProviderBuilderMockTests.java @@ -2,12 +2,14 @@ import static com.amazonaws.encryptionsdk.multi.MultipleProviderFactory.buildMultiProvider; import static com.amazonaws.regions.Region.getRegion; +import static com.amazonaws.regions.Regions.DEFAULT_REGION; import static com.amazonaws.regions.Regions.fromName; import static java.util.Collections.singletonList; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.notNull; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -31,11 +33,56 @@ import com.amazonaws.encryptionsdk.kms.KmsMasterKey; import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider; import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier; +import com.amazonaws.regions.Region; +import com.amazonaws.regions.Regions; +import com.amazonaws.services.kms.model.CreateAliasRequest; import com.amazonaws.services.kms.model.DecryptRequest; import com.amazonaws.services.kms.model.EncryptRequest; import com.amazonaws.services.kms.model.GenerateDataKeyRequest; public class KMSProviderBuilderMockTests { + @Test + public void testBareAliasMapping() { + MockKMSClient client = spy(new MockKMSClient()); + + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(notNull())).thenReturn(client); + + String key1 = client.createKey().getKeyMetadata().getKeyId(); + client.createAlias(new CreateAliasRequest() + .withAliasName("foo") + .withTargetKeyId(key1) + ); + + KmsMasterKeyProvider mkp0 = KmsMasterKeyProvider.builder() + .withKeysForEncryption("alias/foo") + .withCustomClientFactory(supplier) + .withDefaultRegion("us-west-2") + .build(); + + new AwsCrypto().encryptData(mkp0, new byte[0]); + } + + @Test + public void testBareAliasMapping_withLegacyCtor() { + MockKMSClient client = spy(new MockKMSClient()); + + RegionalClientSupplier supplier = mock(RegionalClientSupplier.class); + when(supplier.getClient(any())).thenReturn(client); + + String key1 = client.createKey().getKeyMetadata().getKeyId(); + client.createAlias(new CreateAliasRequest() + .withAliasName("foo") + .withTargetKeyId(key1) + ); + + KmsMasterKeyProvider mkp0 = new KmsMasterKeyProvider( + client, Region.getRegion(Regions.DEFAULT_REGION), Arrays.asList("alias/foo") + ); + + new AwsCrypto().encryptData(mkp0, new byte[0]); + } + @Test public void testGrantTokenPassthrough_usingMKsetCall() throws Exception { MockKMSClient client = spy(new MockKMSClient()); diff --git a/src/test/java/com/amazonaws/services/kms/MockKMSClient.java b/src/test/java/com/amazonaws/services/kms/MockKMSClient.java index 7dbdf350b..f072c4744 100644 --- a/src/test/java/com/amazonaws/services/kms/MockKMSClient.java +++ b/src/test/java/com/amazonaws/services/kms/MockKMSClient.java @@ -88,13 +88,20 @@ public class MockKMSClient extends AWSKMSClient { private static final SecureRandom rnd = new SecureRandom(); private static final String ACCOUNT_ID = "01234567890"; private final Map results_ = new HashMap<>(); - private final Map idToArnMap = new HashMap<>(); private final Set activeKeys = new HashSet<>(); + private final Map keyAliases = new HashMap<>(); private Region region_ = Region.getRegion(Regions.DEFAULT_REGION); @Override public CreateAliasResult createAlias(CreateAliasRequest arg0) throws AmazonServiceException, AmazonClientException { - throw new java.lang.UnsupportedOperationException(); + assertExists(arg0.getTargetKeyId()); + + keyAliases.put( + "alias/" + arg0.getAliasName(), + keyAliases.get(arg0.getTargetKeyId()) + ); + + return new CreateAliasResult(); } @Override @@ -111,8 +118,9 @@ public CreateKeyResult createKey() throws AmazonServiceException, AmazonClientEx public CreateKeyResult createKey(CreateKeyRequest req) throws AmazonServiceException, AmazonClientException { String keyId = UUID.randomUUID().toString(); String arn = "arn:aws:kms:" + region_.getName() + ":" + ACCOUNT_ID + ":key/" + keyId; - idToArnMap.put(keyId, arn); activeKeys.add(arn); + keyAliases.put(keyId, arn); + keyAliases.put(arn, arn); CreateKeyResult result = new CreateKeyResult(); result.setKeyMetadata(new KeyMetadata().withAWSAccountId(ACCOUNT_ID).withCreationDate(new Date()) .withDescription(req.getDescription()).withEnabled(true).withKeyId(keyId) @@ -183,7 +191,7 @@ private EncryptResult encrypt0(EncryptRequest req) throws AmazonServiceException final byte[] cipherText = new byte[512]; rnd.nextBytes(cipherText); DecryptResult dec = new DecryptResult(); - dec.withKeyId(req.getKeyId()).withPlaintext(req.getPlaintext().asReadOnlyBuffer()); + dec.withKeyId(retrieveArn(req.getKeyId())).withPlaintext(req.getPlaintext().asReadOnlyBuffer()); ByteBuffer ctBuff = ByteBuffer.wrap(cipherText); results_.put(new DecryptMapKey(ctBuff, req.getEncryptionContext()), dec); @@ -336,20 +344,17 @@ public void deleteKey(final String keyId) { } private String retrieveArn(final String keyId) { - String arn = keyId; - if (keyId.contains("arn:") == false) { - arn = idToArnMap.get(keyId); - } + String arn = keyAliases.get(keyId); assertExists(arn); return arn; } private void assertExists(String keyId) { - if (idToArnMap.containsKey(keyId)) { - keyId = idToArnMap.get(keyId); + if (keyAliases.containsKey(keyId)) { + keyId = keyAliases.get(keyId); } if (keyId == null || !activeKeys.contains(keyId)) { - throw new NotFoundException("Key doesn't exist"); + throw new NotFoundException("Key doesn't exist: " + keyId); } }