Skip to content

Commit

Permalink
Add transaction insert to SQLite storage (#152)
Browse files Browse the repository at this point in the history
Co-authored-by: mulmarta <mulmarta@amazon.com>
  • Loading branch information
mulmarta and mulmarta authored Apr 25, 2024
1 parent c38cdef commit 3662e44
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions mls-rs-provider-sqlite/src/application.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use rusqlite::{params, Connection, OptionalExtension};

use crate::SqLiteDataStorageError;

const INSERT_SQL: &str =
"INSERT INTO kvs (key, value) VALUES (?,?) ON CONFLICT(key) DO UPDATE SET value=excluded.value";

#[derive(Debug, Clone)]
/// SQLite key-value storage for application specific data.
pub struct SqLiteApplicationStorage {
Expand All @@ -32,12 +35,27 @@ impl SqLiteApplicationStorage {

// Upsert into the database
connection
.execute(
"INSERT INTO kvs (key, value) VALUES (?,?) ON CONFLICT(key) DO UPDATE SET value=excluded.value",
params![key, value],
)
.execute(INSERT_SQL, params![key, value])
.map(|_| ())
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
.map_err(sql_engine_error)
}

/// Execute multiple [`SqLiteApplicationStorage::insert`] operations in a transaction.
pub fn transact_insert(&self, items: Vec<Item>) -> Result<(), SqLiteDataStorageError> {
let mut connection = self.connection.lock().unwrap();

// Upsert into the database
let tx = connection.transaction().map_err(sql_engine_error)?;

items.into_iter().try_for_each(|item| {
tx.execute(INSERT_SQL, params![item.key, item.value])
.map_err(sql_engine_error)
.map(|_| ())
})?;

tx.commit().map_err(sql_engine_error)?;

Ok(())
}

/// Get a value from storage based on its `key`.
Expand All @@ -49,7 +67,7 @@ impl SqLiteApplicationStorage {
row.get(0)
})
.optional()
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
.map_err(sql_engine_error)
}

/// Delete a value from storage based on its `key`.
Expand All @@ -59,7 +77,7 @@ impl SqLiteApplicationStorage {
connection
.execute("DELETE FROM kvs WHERE key = ?", params![key])
.map(|_| ())
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
.map_err(sql_engine_error)
}

/// Get all keys and values from storage for which key starts with `key_prefix`.
Expand All @@ -70,15 +88,14 @@ impl SqLiteApplicationStorage {

let mut stmt = connection
.prepare("SELECT key, value FROM kvs WHERE key LIKE ? ESCAPE '$'")
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
.map_err(sql_engine_error)?;

let rows = stmt
.query(params![key_prefix])
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?
.map_err(sql_engine_error)?
.mapped(|row| Ok(Item::new(row.get(0)?, row.get(1)?)));

rows.collect::<Result<_, _>>()
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
rows.collect::<Result<_, _>>().map_err(sql_engine_error)
}

/// Delete all values from storage for which key starts with `key_prefix`.
Expand All @@ -93,14 +110,18 @@ impl SqLiteApplicationStorage {
params![key_prefix],
)
.map(|_| ())
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
.map_err(sql_engine_error)
}
}

fn sanitize(string: &str) -> String {
string.replace('_', "$_").replace('%', "$%")
}

fn sql_engine_error(e: rusqlite::Error) -> SqLiteDataStorageError {
SqLiteDataStorageError::SqlEngineError(e.into())
}

#[derive(Clone, Default, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct Item {
pub key: String,
Expand Down Expand Up @@ -242,4 +263,20 @@ mod tests {
let keys = items.into_iter().map(|i| i.key).collect::<Vec<_>>();
assert_eq!(vec!["%$_ƕ❤_$%".to_string()], keys);
}

#[test]
fn batch_insert() {
let storage = test_storage();
let items = vec![test_item(), test_item(), test_item()];

storage.transact_insert(items.clone()).unwrap();

for item in items {
assert_eq!(storage.get(&item.key).unwrap(), Some(item.value));
}
}

fn test_item() -> Item {
Item::new(hex::encode(gen_rand_bytes(5)), gen_rand_bytes(5))
}
}

0 comments on commit 3662e44

Please sign in to comment.