Skip to content

Commit

Permalink
Merge pull request #51 from bdonlan/bare-region
Browse files Browse the repository at this point in the history
Fix bare aliases not using default region
  • Loading branch information
SalusaSecondus authored Apr 5, 2018
2 parents 5af4b07 + 39a711e commit 7d5ab6e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ public KmsMasterKey getMasterKey(final String provider, final String keyId) thro
}

String regionName = parseRegionfromKeyArn(keyId);

if (regionName == null && defaultRegion_ != null) {
regionName = defaultRegion_;
}

AWSKMS kms = regionalClientSupplier_.getClient(regionName);
if (kms == null) {
throw new AwsCryptoException("Can't use keys from region " + regionName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import static com.amazonaws.encryptionsdk.multi.MultipleProviderFactory.buildMultiProvider;
import static com.amazonaws.regions.Region.getRegion;
import static com.amazonaws.regions.Regions.DEFAULT_REGION;
import static com.amazonaws.regions.Regions.fromName;
import static java.util.Collections.singletonList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
Expand All @@ -31,11 +33,56 @@
import com.amazonaws.encryptionsdk.kms.KmsMasterKey;
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider;
import com.amazonaws.encryptionsdk.kms.KmsMasterKeyProvider.RegionalClientSupplier;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.kms.model.CreateAliasRequest;
import com.amazonaws.services.kms.model.DecryptRequest;
import com.amazonaws.services.kms.model.EncryptRequest;
import com.amazonaws.services.kms.model.GenerateDataKeyRequest;

public class KMSProviderBuilderMockTests {
@Test
public void testBareAliasMapping() {
MockKMSClient client = spy(new MockKMSClient());

RegionalClientSupplier supplier = mock(RegionalClientSupplier.class);
when(supplier.getClient(notNull())).thenReturn(client);

String key1 = client.createKey().getKeyMetadata().getKeyId();
client.createAlias(new CreateAliasRequest()
.withAliasName("foo")
.withTargetKeyId(key1)
);

KmsMasterKeyProvider mkp0 = KmsMasterKeyProvider.builder()
.withKeysForEncryption("alias/foo")
.withCustomClientFactory(supplier)
.withDefaultRegion("us-west-2")
.build();

new AwsCrypto().encryptData(mkp0, new byte[0]);
}

@Test
public void testBareAliasMapping_withLegacyCtor() {
MockKMSClient client = spy(new MockKMSClient());

RegionalClientSupplier supplier = mock(RegionalClientSupplier.class);
when(supplier.getClient(any())).thenReturn(client);

String key1 = client.createKey().getKeyMetadata().getKeyId();
client.createAlias(new CreateAliasRequest()
.withAliasName("foo")
.withTargetKeyId(key1)
);

KmsMasterKeyProvider mkp0 = new KmsMasterKeyProvider(
client, Region.getRegion(Regions.DEFAULT_REGION), Arrays.asList("alias/foo")
);

new AwsCrypto().encryptData(mkp0, new byte[0]);
}

@Test
public void testGrantTokenPassthrough_usingMKsetCall() throws Exception {
MockKMSClient client = spy(new MockKMSClient());
Expand Down
27 changes: 16 additions & 11 deletions src/test/java/com/amazonaws/services/kms/MockKMSClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,20 @@ public class MockKMSClient extends AWSKMSClient {
private static final SecureRandom rnd = new SecureRandom();
private static final String ACCOUNT_ID = "01234567890";
private final Map<DecryptMapKey, DecryptResult> results_ = new HashMap<>();
private final Map<String, String> idToArnMap = new HashMap<>();
private final Set<String> activeKeys = new HashSet<>();
private final Map<String, String> keyAliases = new HashMap<>();
private Region region_ = Region.getRegion(Regions.DEFAULT_REGION);

@Override
public CreateAliasResult createAlias(CreateAliasRequest arg0) throws AmazonServiceException, AmazonClientException {
throw new java.lang.UnsupportedOperationException();
assertExists(arg0.getTargetKeyId());

keyAliases.put(
"alias/" + arg0.getAliasName(),
keyAliases.get(arg0.getTargetKeyId())
);

return new CreateAliasResult();
}

@Override
Expand All @@ -111,8 +118,9 @@ public CreateKeyResult createKey() throws AmazonServiceException, AmazonClientEx
public CreateKeyResult createKey(CreateKeyRequest req) throws AmazonServiceException, AmazonClientException {
String keyId = UUID.randomUUID().toString();
String arn = "arn:aws:kms:" + region_.getName() + ":" + ACCOUNT_ID + ":key/" + keyId;
idToArnMap.put(keyId, arn);
activeKeys.add(arn);
keyAliases.put(keyId, arn);
keyAliases.put(arn, arn);
CreateKeyResult result = new CreateKeyResult();
result.setKeyMetadata(new KeyMetadata().withAWSAccountId(ACCOUNT_ID).withCreationDate(new Date())
.withDescription(req.getDescription()).withEnabled(true).withKeyId(keyId)
Expand Down Expand Up @@ -183,7 +191,7 @@ private EncryptResult encrypt0(EncryptRequest req) throws AmazonServiceException
final byte[] cipherText = new byte[512];
rnd.nextBytes(cipherText);
DecryptResult dec = new DecryptResult();
dec.withKeyId(req.getKeyId()).withPlaintext(req.getPlaintext().asReadOnlyBuffer());
dec.withKeyId(retrieveArn(req.getKeyId())).withPlaintext(req.getPlaintext().asReadOnlyBuffer());
ByteBuffer ctBuff = ByteBuffer.wrap(cipherText);

results_.put(new DecryptMapKey(ctBuff, req.getEncryptionContext()), dec);
Expand Down Expand Up @@ -336,20 +344,17 @@ public void deleteKey(final String keyId) {
}

private String retrieveArn(final String keyId) {
String arn = keyId;
if (keyId.contains("arn:") == false) {
arn = idToArnMap.get(keyId);
}
String arn = keyAliases.get(keyId);
assertExists(arn);
return arn;
}

private void assertExists(String keyId) {
if (idToArnMap.containsKey(keyId)) {
keyId = idToArnMap.get(keyId);
if (keyAliases.containsKey(keyId)) {
keyId = keyAliases.get(keyId);
}
if (keyId == null || !activeKeys.contains(keyId)) {
throw new NotFoundException("Key doesn't exist");
throw new NotFoundException("Key doesn't exist: " + keyId);
}
}

Expand Down

0 comments on commit 7d5ab6e

Please sign in to comment.