diff --git a/vector/src/androidTest/java/im/vector/app/features/ReportedDecryptionFailurePersistenceTest.kt b/vector/src/androidTest/java/im/vector/app/features/ReportedDecryptionFailurePersistenceTest.kt new file mode 100644 index 00000000000..35d4bd5a502 --- /dev/null +++ b/vector/src/androidTest/java/im/vector/app/features/ReportedDecryptionFailurePersistenceTest.kt @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package im.vector.app.features + +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.platform.app.InstrumentationRegistry +import im.vector.app.InstrumentedTest +import im.vector.app.features.analytics.ReportedDecryptionFailurePersistence +import kotlinx.coroutines.test.runTest +import org.amshove.kluent.shouldBeEqualTo +import org.junit.Test +import org.junit.runner.RunWith + +@RunWith(AndroidJUnit4::class) +class ReportedDecryptionFailurePersistenceTest : InstrumentedTest { + + private val context = InstrumentationRegistry.getInstrumentation().targetContext + + @Test + fun shouldPersistReportedUtds() = runTest { + val persistence = ReportedDecryptionFailurePersistence(context) + persistence.load() + + val eventIds = listOf("$0000", "$0001", "$0002", "$0003") + eventIds.forEach { + persistence.markAsReported(it) + } + + eventIds.forEach { + persistence.hasBeenReported(it) shouldBeEqualTo true + } + + persistence.hasBeenReported("$0004") shouldBeEqualTo false + + persistence.persist() + + // Load a new one + val persistence2 = ReportedDecryptionFailurePersistence(context) + persistence2.load() + + eventIds.forEach { + persistence2.hasBeenReported(it) shouldBeEqualTo true + } + } + + @Test + fun testSaturation() = runTest { + val persistence = ReportedDecryptionFailurePersistence(context) + + for (i in 1..6000) { + persistence.markAsReported("000$i") + } + + // This should have saturated the bloom filter, making the rate of false positives too high. + // A new bloom filter should have been created to avoid that and the recent reported events should still be in the new filter. + for (i in 5800..6000) { + persistence.hasBeenReported("000$i") shouldBeEqualTo true + } + + // Old ones should not be there though + for (i in 1..1000) { + persistence.hasBeenReported("000$i") shouldBeEqualTo false + } + } +} diff --git a/vector/src/main/java/im/vector/app/features/analytics/DecryptionFailureTracker.kt b/vector/src/main/java/im/vector/app/features/analytics/DecryptionFailureTracker.kt index fcbc67169eb..5ec7b6a63c6 100644 --- a/vector/src/main/java/im/vector/app/features/analytics/DecryptionFailureTracker.kt +++ b/vector/src/main/java/im/vector/app/features/analytics/DecryptionFailureTracker.kt @@ -63,6 +63,7 @@ private const val MAX_WAIT_MILLIS = 60_000 class DecryptionFailureTracker @Inject constructor( private val analyticsTracker: AnalyticsTracker, private val sessionDataSource: ActiveSessionDataSource, + private val decryptionFailurePersistence: ReportedDecryptionFailurePersistence, private val clock: Clock ) : Session.Listener, LiveEventListener { @@ -76,9 +77,6 @@ class DecryptionFailureTracker @Inject constructor( // Only accessed on a `post` call, ensuring sequential access private val trackedEventsMap = mutableMapOf() - // List of eventId that have been reported, to avoid double reporting - private val alreadyReported = mutableListOf() - // Mutex to ensure sequential access to internal state private val mutex = Mutex() @@ -98,10 +96,16 @@ class DecryptionFailureTracker @Inject constructor( this.scope = scope } observeActiveSession() + post { + decryptionFailurePersistence.load() + } } fun stop() { Timber.v("Stop DecryptionFailureTracker") + post { + decryptionFailurePersistence.persist() + } activeSessionSourceDisposable.cancel(CancellationException("Closing DecryptionFailureTracker")) activeSession?.removeListener(this) @@ -123,6 +127,7 @@ class DecryptionFailureTracker @Inject constructor( delay(CHECK_INTERVAL) post { checkFailures() + decryptionFailurePersistence.persist() currentTicker = null if (trackedEventsMap.isNotEmpty()) { // Reschedule @@ -136,7 +141,7 @@ class DecryptionFailureTracker @Inject constructor( .distinctUntilChanged() .onEach { Timber.v("Active session changed ${it.getOrNull()?.myUserId}") - it.orNull()?.let { session -> + it.getOrNull()?.let { session -> post { onSessionActive(session) } @@ -144,7 +149,7 @@ class DecryptionFailureTracker @Inject constructor( }.launchIn(scope) } - private fun onSessionActive(session: Session) { + private suspend fun onSessionActive(session: Session) { Timber.v("onSessionActive ${session.myUserId} previous: ${activeSession?.myUserId}") val sessionId = session.sessionId if (sessionId == activeSession?.sessionId) { @@ -201,7 +206,8 @@ class DecryptionFailureTracker @Inject constructor( // already tracked return } - if (alreadyReported.contains(eventId)) { + if (decryptionFailurePersistence.hasBeenReported(eventId)) { + Timber.v("Event $eventId already reported") // already reported return } @@ -236,7 +242,7 @@ class DecryptionFailureTracker @Inject constructor( } } - private fun handleEventDecrypted(eventId: String) { + private suspend fun handleEventDecrypted(eventId: String) { Timber.v("Handle event decrypted $eventId time: ${clock.epochMillis()}") // Only consider if it was tracked as a failure val trackedFailure = trackedEventsMap[eventId] ?: return @@ -269,7 +275,7 @@ class DecryptionFailureTracker @Inject constructor( } // This will mutate the trackedEventsMap, so don't call it while iterating on it. - private fun reportFailure(decryptionFailure: DecryptionFailure) { + private suspend fun reportFailure(decryptionFailure: DecryptionFailure) { Timber.v("Report failure for event ${decryptionFailure.failedEventId}") val error = decryptionFailure.toAnalyticsEvent() @@ -278,10 +284,10 @@ class DecryptionFailureTracker @Inject constructor( // now remove from tracked trackedEventsMap.remove(decryptionFailure.failedEventId) // mark as already reported - alreadyReported.add(decryptionFailure.failedEventId) + decryptionFailurePersistence.markAsReported(decryptionFailure.failedEventId) } - private fun checkFailures() { + private suspend fun checkFailures() { val now = clock.epochMillis() Timber.v("Check failures now $now") // report the definitely failed diff --git a/vector/src/main/java/im/vector/app/features/analytics/ReportedDecryptionFailurePersistence.kt b/vector/src/main/java/im/vector/app/features/analytics/ReportedDecryptionFailurePersistence.kt new file mode 100644 index 00000000000..731c6f33d2c --- /dev/null +++ b/vector/src/main/java/im/vector/app/features/analytics/ReportedDecryptionFailurePersistence.kt @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2024 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package im.vector.app.features.analytics + +import android.content.Context +import android.util.LruCache +import com.google.common.hash.BloomFilter +import com.google.common.hash.Funnels +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import timber.log.Timber +import java.io.File +import java.io.FileOutputStream +import javax.inject.Inject + +private const val REPORTED_UTD_FILE_NAME = "im.vector.analytics.reported_utd" +private const val EXPECTED_INSERTIONS = 5000 + +/** + * This class is used to keep track of the reported decryption failures to avoid double reporting. + * It uses a bloom filter to limit the memory/disk usage. + */ +class ReportedDecryptionFailurePersistence @Inject constructor( + private val context: Context, +) { + + // Keep a cache of recent reported failures in memory. + // They will be persisted to the a new bloom filter if the previous one is getting saturated. + // Should be around 30KB max in memory. + // Also allows to have 0% false positive rate for recent failures. + private val inMemoryReportedFailures: LruCache = LruCache(300) + + // Thread-safe and lock-free. + // The expected insertions is 5000, and expected false positive probability of 3% when close to max capability. + // The persisted size is expected to be around 5KB (100 times less than if it was raw strings). + private var bloomFilter: BloomFilter = BloomFilter.create(Funnels.stringFunnel(Charsets.UTF_8), EXPECTED_INSERTIONS) + + /** + * Mark an event as reported. + * @param eventId the event id to mark as reported. + */ + suspend fun markAsReported(eventId: String) { + // Add to in memory cache. + inMemoryReportedFailures.put(eventId, Unit) + bloomFilter.put(eventId) + + // check if the filter is getting saturated? and then replace + if (bloomFilter.approximateElementCount() > EXPECTED_INSERTIONS - 500) { + // The filter is getting saturated, and the false positive rate is increasing. + // It's time to replace the filter with a new one. And move the in-memory cache to the new filter. + bloomFilter = BloomFilter.create(Funnels.stringFunnel(Charsets.UTF_8), EXPECTED_INSERTIONS) + inMemoryReportedFailures.snapshot().keys.forEach { + bloomFilter.put(it) + } + persist() + } + Timber.v("## Bloom filter stats: expectedFpp: ${bloomFilter.expectedFpp()}, size: ${bloomFilter.approximateElementCount()}") + } + + /** + * Check if an event has been reported. + * @param eventId the event id to check. + * @return true if the event has been reported. + */ + fun hasBeenReported(eventId: String): Boolean { + // First check in memory cache. + if (inMemoryReportedFailures.get(eventId) != null) { + return true + } + return bloomFilter.mightContain(eventId) + } + + /** + * Load the reported failures from disk. + */ + suspend fun load() { + withContext(Dispatchers.IO) { + try { + val file = File(context.applicationContext.cacheDir, REPORTED_UTD_FILE_NAME) + if (file.exists()) { + file.inputStream().use { + bloomFilter = BloomFilter.readFrom(it, Funnels.stringFunnel(Charsets.UTF_8)) + } + } + } catch (e: Throwable) { + Timber.e(e, "## Failed to load reported failures") + } + } + } + + /** + * Persist the reported failures to disk. + */ + suspend fun persist() { + withContext(Dispatchers.IO) { + try { + val file = File(context.applicationContext.cacheDir, REPORTED_UTD_FILE_NAME) + if (!file.exists()) file.createNewFile() + FileOutputStream(file).buffered().use { + bloomFilter.writeTo(it) + } + Timber.v("## Successfully saved reported failures, size: ${file.length()}") + } catch (e: Throwable) { + Timber.e(e, "## Failed to save reported failures") + } + } + } +} diff --git a/vector/src/test/java/im/vector/app/features/analytics/DecryptionFailureTrackerTest.kt b/vector/src/test/java/im/vector/app/features/analytics/DecryptionFailureTrackerTest.kt index 2f11d4c2eb1..cbdd758a41a 100644 --- a/vector/src/test/java/im/vector/app/features/analytics/DecryptionFailureTrackerTest.kt +++ b/vector/src/test/java/im/vector/app/features/analytics/DecryptionFailureTrackerTest.kt @@ -23,6 +23,7 @@ import im.vector.app.test.fakes.FakeAnalyticsTracker import im.vector.app.test.fakes.FakeClock import im.vector.app.test.fakes.FakeSession import im.vector.app.test.shared.createTimberTestRule +import io.mockk.coEvery import io.mockk.every import io.mockk.just import io.mockk.mockk @@ -60,9 +61,24 @@ class DecryptionFailureTrackerTest { private val fakeClock = FakeClock() + val reportedEvents = mutableSetOf() + + private val fakePersistence = mockk { + + coEvery { load() } just runs + coEvery { persist() } just runs + coEvery { markAsReported(any()) } coAnswers { + reportedEvents.add(firstArg()) + } + every { hasBeenReported(any()) } answers { + reportedEvents.contains(firstArg()) + } + } + private val decryptionFailureTracker = DecryptionFailureTracker( fakeAnalyticsTracker, fakeActiveSessionDataSource.instance, + fakePersistence, fakeClock ) @@ -101,6 +117,7 @@ class DecryptionFailureTrackerTest { @Before fun setupTest() { + reportedEvents.clear() fakeMxOrgTestSession.fakeCryptoService.fakeCrossSigningService.givenIsCrossSigningVerifiedReturns(false) }