Skip to content

Commit

Permalink
Added validation on Cryptor id.
Browse files Browse the repository at this point in the history
  • Loading branch information
marcin-cebo committed Nov 3, 2023
1 parent bbe06c1 commit e4985d2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 11 deletions.
25 changes: 17 additions & 8 deletions src/main/kotlin/com/pubnub/api/crypto/CryptoModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import java.io.InputStream
import java.io.SequenceInputStream
import java.lang.Integer.min

private const val SIZE_OF_CRYPTOR_ID = 4

class CryptoModule internal constructor(
internal val primaryCryptor: Cryptor,
internal val cryptorsForDecryptionOnly: List<Cryptor> = listOf(),
Expand Down Expand Up @@ -49,13 +51,15 @@ class CryptoModule internal constructor(
}

fun encrypt(data: ByteArray): ByteArray {
val cryptorId = primaryCryptor.id()
validateData(data)
validateCryptorIdSize(cryptorId)
val (metadata, encryptedData) = primaryCryptor.encrypt(data)

return if (primaryCryptor.id().contentEquals(LEGACY_CRYPTOR_ID)) {
return if (cryptorId.contentEquals(LEGACY_CRYPTOR_ID)) {
encryptedData
} else {
val cryptorHeader = headerParser.createCryptorHeader(primaryCryptor.id(), metadata)
val cryptorHeader = headerParser.createCryptorHeader(cryptorId, metadata)
cryptorHeader + encryptedData
}
}
Expand Down Expand Up @@ -110,8 +114,17 @@ class CryptoModule internal constructor(
}
}

private fun validateCryptorIdSize(cryptorId: ByteArray) {
if (cryptorId.size != SIZE_OF_CRYPTOR_ID) {
throw PubNubException(
errorMessage = "CryptorId should be exactly 4 bytes long",
pubnubError = PubNubError.UNKNOWN_CRYPTOR
)
}
}

private fun getDecryptedDataForLegacyCryptor(encryptedData: ByteArray): ByteArray {
return getLegacyCryptor()?.decrypt(EncryptedData(data = encryptedData)) ?: throw PubNubException(
return getCryptorById(LEGACY_CRYPTOR_ID)?.decrypt(EncryptedData(data = encryptedData)) ?: throw PubNubException(
errorMessage = "LegacyCryptor not available",
pubnubError = PubNubError.UNKNOWN_CRYPTOR
)
Expand All @@ -129,12 +142,8 @@ class CryptoModule internal constructor(
return decryptedData
}

private fun getLegacyCryptor(): Cryptor? {
val idOfLegacyCryptor = ByteArray(4) { 0.toByte() }
return getCryptorById(idOfLegacyCryptor)
}

private fun getCryptorById(cryptorId: ByteArray): Cryptor? {
validateCryptorIdSize(cryptorId)
return cryptorsForDecryptionOnly.firstOrNull { it.id().contentEquals(cryptorId) }
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/com/pubnub/api/crypto/cryptor/Cryptor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.pubnub.api.crypto.data.EncryptedStreamData
import java.io.InputStream

interface Cryptor {
fun id(): ByteArray // Assuming 4 bytes,
fun id(): ByteArray // Should return a ByteArray of exactly 4 bytes.
fun encrypt(data: ByteArray): EncryptedData
fun decrypt(encryptedData: EncryptedData): ByteArray
fun encryptStream(stream: InputStream): EncryptedStreamData
Expand Down
20 changes: 18 additions & 2 deletions src/test/kotlin/com/pubnub/api/crypto/CryptoModuleTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,21 @@ class CryptoModuleTest {
assertArrayEquals(msgToEncrypt, decryptedMsg)
}

@Test
fun `can decrypt encrypted message using cryptoModule with custom cryptor`() {
// given
val customCryptor = myCustomCryptor()
val cryptoModule = CryptoModule.createNewCryptoModule(defaultCryptor = customCryptor)
val msgToEncrypt = "Hello world".toByteArray()

// when
val encryptedMsg = cryptoModule.encrypt(msgToEncrypt)
val decryptedMsg = cryptoModule.decrypt(encryptedMsg)

// then
assertArrayEquals(msgToEncrypt, decryptedMsg)
}

@ParameterizedTest
@MethodSource("legacyAndAesCbcCryptors")
fun `should throw exception when encrypting empty data`(cryptoModule: CryptoModule) {
Expand Down Expand Up @@ -283,6 +298,7 @@ class CryptoModuleTest {

private fun myCustomCryptor() = object : Cryptor {
override fun id(): ByteArray {
// Should return a ByteArray of exactly 4 bytes.
return byteArrayOf('C'.code.toByte(), 'U'.code.toByte(), 'S'.code.toByte(), 'T'.code.toByte())
}

Expand All @@ -295,11 +311,11 @@ class CryptoModuleTest {
}

override fun encryptStream(stream: InputStream): EncryptedStreamData {
throw NotImplementedError()
return EncryptedStreamData(metadata = null, stream = stream)
}

override fun decryptStream(encryptedData: EncryptedStreamData): InputStream {
throw NotImplementedError()
return encryptedData.stream
}
}

Expand Down

0 comments on commit e4985d2

Please sign in to comment.