Skip to content

Commit

Permalink
feat: implements concurrent Smt::compute_mutations (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
krushimir authored Feb 7, 2025
1 parent d569c71 commit 1b77fa8
Show file tree
Hide file tree
Showing 8 changed files with 771 additions and 461 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
## 0.13.1 (2024-12-26)

- Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355).
- Added parallel implementation of `Smt::compute_mutations` with better performance (#365).
- Implemented parallel leaf hashing in `Smt::process_sorted_pairs_to_leaves` (#365).

## 0.13.0 (2024-11-24)

Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ name = "store"
harness = false

[features]
concurrent = ["dep:rayon"]
concurrent = ["dep:rayon", "hashbrown?/rayon"]
default = ["std", "concurrent"]
executable = ["dep:clap", "dep:rand-utils", "std"]
smt_hashmaps = ["dep:hashbrown"]
Expand Down
77 changes: 42 additions & 35 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ use rand_utils::rand_value;
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
pub struct BenchmarkCmd {
/// Size of the tree
#[clap(short = 's', long = "size")]
#[clap(short = 's', long = "size", default_value = "1000000")]
size: usize,
/// Number of insertions
#[clap(short = 'i', long = "insertions", default_value = "1000")]
insertions: usize,
/// Number of updates
#[clap(short = 'u', long = "updates", default_value = "1000")]
updates: usize,
}

fn main() {
Expand All @@ -25,7 +31,10 @@ fn main() {
pub fn benchmark_smt() {
let args = BenchmarkCmd::parse();
let tree_size = args.size;
let insertions = args.insertions;
let updates = args.updates;

assert!(updates <= tree_size, "Cannot update more than `size`");
// prepare the `leaves` vector for tree creation
let mut entries = Vec::new();
for i in 0..tree_size {
Expand All @@ -35,9 +44,9 @@ pub fn benchmark_smt() {
}

let mut tree = construction(entries.clone(), tree_size).unwrap();
insertion(&mut tree).unwrap();
batched_insertion(&mut tree).unwrap();
batched_update(&mut tree, entries).unwrap();
insertion(&mut tree.clone(), insertions).unwrap();
batched_insertion(&mut tree.clone(), insertions).unwrap();
batched_update(&mut tree.clone(), entries, updates).unwrap();
proof_generation(&mut tree).unwrap();
}

Expand All @@ -47,23 +56,20 @@ pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result<Smt,
let now = Instant::now();
let tree = Smt::with_entries(entries)?;
let elapsed = now.elapsed().as_secs_f32();

println!("Constructed a SMT with {size} key-value pairs in {elapsed:.1} seconds");
println!("Constructed an SMT with {size} key-value pairs in {elapsed:.1} seconds");
println!("Number of leaf nodes: {}\n", tree.leaves().count());

Ok(tree)
}

/// Runs the insertion benchmark for the [`Smt`].
pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;

pub fn insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
println!("Running an insertion benchmark:");

let size = tree.num_leaves();
let mut insertion_times = Vec::new();

for i in 0..NUM_INSERTIONS {
for i in 0..insertions {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];

Expand All @@ -74,22 +80,20 @@ pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> {
}

println!(
"An average insertion time measured by {NUM_INSERTIONS} inserts into an SMT with {size} leaves is {:.0} μs\n",
"The average insertion time measured by {insertions} inserts into an SMT with {size} leaves is {:.0} μs\n",
// calculate the average
insertion_times.iter().sum::<u128>() as f64 / (NUM_INSERTIONS as f64),
insertion_times.iter().sum::<u128>() as f64 / (insertions as f64),
);

Ok(())
}

pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;

pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
println!("Running a batched insertion benchmark:");

let size = tree.num_leaves();

let new_pairs: Vec<(RpoDigest, Word)> = (0..NUM_INSERTIONS)
let new_pairs: Vec<(RpoDigest, Word)> = (0..insertions)
.map(|i| {
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
Expand All @@ -101,24 +105,24 @@ pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

println!(
"An average insert-batch computation time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
"The average insert-batch computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
compute_elapsed * 1000_f64 / insertions as f64, // time in μs
);

let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

println!(
"An average insert-batch application time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
"The average insert-batch application time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
apply_elapsed * 1000_f64 / insertions as f64, // time in μs
);

println!(
"An average batch insertion time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms",
"The average batch insertion time measured by a {insertions}-batch into an SMT with {size} leaves totals to {:.1} ms",
(compute_elapsed + apply_elapsed),
);

Expand All @@ -127,8 +131,11 @@ pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
Ok(())
}

pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result<(), MerkleError> {
const NUM_UPDATES: usize = 1_000;
pub fn batched_update(
tree: &mut Smt,
entries: Vec<(RpoDigest, Word)>,
updates: usize,
) -> Result<(), MerkleError> {
const REMOVAL_PROBABILITY: f64 = 0.2;

println!("Running a batched update benchmark:");
Expand All @@ -139,7 +146,7 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result
let new_pairs =
entries
.into_iter()
.choose_multiple(&mut rng, NUM_UPDATES)
.choose_multiple(&mut rng, updates)
.into_iter()
.map(|(key, _)| {
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
Expand All @@ -151,7 +158,7 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result
(key, value)
});

assert_eq!(new_pairs.len(), NUM_UPDATES);
assert_eq!(new_pairs.len(), updates);

let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
Expand All @@ -162,19 +169,19 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

println!(
"An average update-batch computation time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
"The average update-batch computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
compute_elapsed * 1000_f64 / updates as f64, // time in μs
);

println!(
"An average update-batch application time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
"The average update-batch application time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
apply_elapsed * 1000_f64 / updates as f64, // time in μs
);

println!(
"An average batch update time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms",
"The average batch update time measured by a {updates}-batch into an SMT with {size} leaves totals to {:.1} ms",
(compute_elapsed + apply_elapsed),
);

Expand Down Expand Up @@ -203,7 +210,7 @@ pub fn proof_generation(tree: &mut Smt) -> Result<(), MerkleError> {
}

println!(
"An average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs",
"The average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs",
// calculate the average
insertion_times.iter().sum::<u128>() as f64 / (NUM_PROOFS as f64),
);
Expand Down
4 changes: 2 additions & 2 deletions src/merkle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ pub use path::{MerklePath, RootPath, ValuePath};

mod smt;
#[cfg(feature = "internal")]
pub use smt::build_subtree_for_bench;
pub use smt::{build_subtree_for_bench, SubtreeLeaf};
pub use smt::{
InnerNode, LeafIndex, MutationSet, NodeMutation, SimpleSmt, Smt, SmtLeaf, SmtLeafError,
SmtProof, SmtProofError, SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
SmtProof, SmtProofError, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};

mod mmr;
Expand Down
Loading

0 comments on commit 1b77fa8

Please sign in to comment.