diff --git a/mls-rs-provider-sqlite/Cargo.toml b/mls-rs-provider-sqlite/Cargo.toml index 9cc12341..94e1819c 100644 --- a/mls-rs-provider-sqlite/Cargo.toml +++ b/mls-rs-provider-sqlite/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mls-rs-provider-sqlite" -version = "0.13.2" +version = "0.13.3" edition = "2021" description = "SQLite based state storage for mls-rs" homepage = "https://github.com/awslabs/mls-rs" diff --git a/mls-rs-provider-sqlite/src/key_package.rs b/mls-rs-provider-sqlite/src/key_package.rs index 82a028aa..a54253fb 100644 --- a/mls-rs-provider-sqlite/src/key_package.rs +++ b/mls-rs-provider-sqlite/src/key_package.rs @@ -75,10 +75,13 @@ impl SqLiteKeyPackageStorage { .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) } + /// Delete key packages that are expired based on the current system clock time. pub fn delete_expired(&self) -> Result<(), SqLiteDataStorageError> { self.delete_expired_by_time(MlsTime::now().seconds_since_epoch()) } + /// Delete key packages that are expired based on an application provided time in seconds since + /// unix epoch. pub fn delete_expired_by_time(&self, time: u64) -> Result<(), SqLiteDataStorageError> { let connection = self.connection.lock().unwrap(); @@ -91,6 +94,7 @@ impl SqLiteKeyPackageStorage { .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) } + /// Total number of key packages held in storage. pub fn count(&self) -> Result { let connection = self.connection.lock().unwrap(); @@ -100,6 +104,21 @@ impl SqLiteKeyPackageStorage { }) .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) } + + /// Total number of key packages that will still remain in storage at a specific application provided + /// time in seconds since unix epoch. This assumes that the application would also be calling + /// [SqLiteKeyPackageStorage::delete_expired] at a reasonable cadence to be accurate. + pub fn count_at_time(&self, time: u64) -> Result { + let connection = self.connection.lock().unwrap(); + + connection + .query_row( + "SELECT count(*) FROM key_package where expiration >= ?", + params![time], + |row| row.get(0), + ) + .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) + } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -240,4 +259,22 @@ mod tests { assert_eq!(storage.count().unwrap(), 10); } + + #[test] + fn key_count_at_time() { + let mut storage = test_storage(); + + let mut kp_1 = test_key_package(); + kp_1.1.expiration = 1; + storage.insert(&kp_1.0, kp_1.1).unwrap(); + + let mut kp_2 = test_key_package(); + kp_2.1.expiration = 2; + storage.insert(&kp_2.0, kp_2.1).unwrap(); + + assert_eq!(storage.count_at_time(3).unwrap(), 0); + assert_eq!(storage.count_at_time(2).unwrap(), 1); + assert_eq!(storage.count_at_time(1).unwrap(), 2); + assert_eq!(storage.count_at_time(0).unwrap(), 2); + } }