diff --git a/mls-rs-provider-sqlite/src/application.rs b/mls-rs-provider-sqlite/src/application.rs index 620cc634..10b330cf 100644 --- a/mls-rs-provider-sqlite/src/application.rs +++ b/mls-rs-provider-sqlite/src/application.rs @@ -11,6 +11,9 @@ use rusqlite::{params, Connection, OptionalExtension}; use crate::SqLiteDataStorageError; +const INSERT_SQL: &str = + "INSERT INTO kvs (key, value) VALUES (?,?) ON CONFLICT(key) DO UPDATE SET value=excluded.value"; + #[derive(Debug, Clone)] /// SQLite key-value storage for application specific data. pub struct SqLiteApplicationStorage { @@ -32,12 +35,27 @@ impl SqLiteApplicationStorage { // Upsert into the database connection - .execute( - "INSERT INTO kvs (key, value) VALUES (?,?) ON CONFLICT(key) DO UPDATE SET value=excluded.value", - params![key, value], - ) + .execute(INSERT_SQL, params![key, value]) .map(|_| ()) - .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) + .map_err(sql_engine_error) + } + + /// Execute multiple [`SqLiteApplicationStorage::insert`] operations in a transaction. + pub fn transact_insert(&self, items: Vec) -> Result<(), SqLiteDataStorageError> { + let mut connection = self.connection.lock().unwrap(); + + // Upsert into the database + let tx = connection.transaction().map_err(sql_engine_error)?; + + items.into_iter().try_for_each(|item| { + tx.execute(INSERT_SQL, params![item.key, item.value]) + .map_err(sql_engine_error) + .map(|_| ()) + })?; + + tx.commit().map_err(sql_engine_error)?; + + Ok(()) } /// Get a value from storage based on its `key`. @@ -49,7 +67,7 @@ impl SqLiteApplicationStorage { row.get(0) }) .optional() - .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) + .map_err(sql_engine_error) } /// Delete a value from storage based on its `key`. @@ -59,7 +77,7 @@ impl SqLiteApplicationStorage { connection .execute("DELETE FROM kvs WHERE key = ?", params![key]) .map(|_| ()) - .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) + .map_err(sql_engine_error) } /// Get all keys and values from storage for which key starts with `key_prefix`. @@ -70,15 +88,14 @@ impl SqLiteApplicationStorage { let mut stmt = connection .prepare("SELECT key, value FROM kvs WHERE key LIKE ? ESCAPE '$'") - .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?; + .map_err(sql_engine_error)?; let rows = stmt .query(params![key_prefix]) - .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))? + .map_err(sql_engine_error)? .mapped(|row| Ok(Item::new(row.get(0)?, row.get(1)?))); - rows.collect::>() - .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) + rows.collect::>().map_err(sql_engine_error) } /// Delete all values from storage for which key starts with `key_prefix`. @@ -93,7 +110,7 @@ impl SqLiteApplicationStorage { params![key_prefix], ) .map(|_| ()) - .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into())) + .map_err(sql_engine_error) } } @@ -101,6 +118,10 @@ fn sanitize(string: &str) -> String { string.replace('_', "$_").replace('%', "$%") } +fn sql_engine_error(e: rusqlite::Error) -> SqLiteDataStorageError { + SqLiteDataStorageError::SqlEngineError(e.into()) +} + #[derive(Clone, Default, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct Item { pub key: String, @@ -242,4 +263,20 @@ mod tests { let keys = items.into_iter().map(|i| i.key).collect::>(); assert_eq!(vec!["%$_ƕ❤_$%".to_string()], keys); } + + #[test] + fn batch_insert() { + let storage = test_storage(); + let items = vec![test_item(), test_item(), test_item()]; + + storage.transact_insert(items.clone()).unwrap(); + + for item in items { + assert_eq!(storage.get(&item.key).unwrap(), Some(item.value)); + } + } + + fn test_item() -> Item { + Item::new(hex::encode(gen_rand_bytes(5)), gen_rand_bytes(5)) + } }