Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements concurrent Smt::compute_mutations #365

Merged
merged 15 commits into from
Feb 7, 2025
Merged
139 changes: 95 additions & 44 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ use rand_utils::rand_value;
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
pub struct BenchmarkCmd {
/// Size of the tree
#[clap(short = 's', long = "size")]
#[clap(short = 's', long = "size", default_value = "10000")]
size: usize,
/// Number of insertions
#[clap(short = 'i', long = "insertions", default_value = "10000")]
insertions: usize,
/// Number of updates
#[clap(short = 'u', long = "updates", default_value = "10000")]
updates: usize,
}

fn main() {
Expand All @@ -25,7 +31,10 @@ fn main() {
pub fn benchmark_smt() {
let args = BenchmarkCmd::parse();
let tree_size = args.size;
let insertions = args.insertions;
let updates = args.updates;

assert!(updates <= insertions + tree_size, "Cannot update more than insertions + size");
// prepare the `leaves` vector for tree creation
let mut entries = Vec::new();
for i in 0..tree_size {
Expand All @@ -35,35 +44,41 @@ pub fn benchmark_smt() {
}

let mut tree = construction(entries.clone(), tree_size).unwrap();
insertion(&mut tree).unwrap();
batched_insertion(&mut tree).unwrap();
batched_update(&mut tree, entries).unwrap();
insertion(&mut tree, insertions).unwrap();
batched_insertion(&mut tree, insertions).unwrap();
batched_update(&mut tree, entries, updates).unwrap();
proof_generation(&mut tree).unwrap();
}

/// Runs the construction benchmark for [`Smt`], returning the constructed tree.
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result<Smt, MerkleError> {
let cloned_entries = entries.clone();
println!("Running a construction benchmark:");
let now = Instant::now();
let tree = Smt::with_entries(entries)?;
let tree = Smt::with_entries(cloned_entries)?;
let elapsed = now.elapsed().as_secs_f32();

println!("Constructed a SMT with {size} key-value pairs in {elapsed:.1} seconds");
println!("Number of leaf nodes: {}\n", tree.leaves().count());

let now = Instant::now();
let tree_sequential = Smt::with_entries_sequential(entries)?;
let compute_elapsed_sequential = now.elapsed().as_secs_f32();

assert_eq!(tree.root(), tree_sequential.root());
println!("Constructed a SMT sequentially with {size} key-value pairs in {elapsed:.1} seconds");
let factor = compute_elapsed_sequential / elapsed;
println!("Parallel implementation is {factor}x times faster.");
Ok(tree)
}

/// Runs the insertion benchmark for the [`Smt`].
pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;

pub fn insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
println!("Running an insertion benchmark:");

let size = tree.num_leaves();
let mut insertion_times = Vec::new();

for i in 0..NUM_INSERTIONS {
for i in 0..insertions {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];

Expand All @@ -74,47 +89,62 @@ pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> {
}

println!(
"An average insertion time measured by {NUM_INSERTIONS} inserts into an SMT with {size} leaves is {:.0} μs\n",
"An average insertion time measured by {insertions} inserts into an SMT with {size} leaves is {:.0} μs\n",
// calculate the average
insertion_times.iter().sum::<u128>() as f64 / (NUM_INSERTIONS as f64),
insertion_times.iter().sum::<u128>() as f64 / (insertions as f64),
);

Ok(())
}

pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;

pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
println!("Running a batched insertion benchmark:");

let size = tree.num_leaves();

let new_pairs: Vec<(RpoDigest, Word)> = (0..NUM_INSERTIONS)
let new_pairs: Vec<(RpoDigest, Word)> = (0..insertions)
.map(|i| {
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
(key, value)
})
.collect();

let cloned_new_pairs = new_pairs.clone();
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let mutations = tree.compute_mutations(cloned_new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
let mutations_sequential = tree.compute_mutations_sequential(new_pairs);
let compute_elapsed_sequential = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

assert_eq!(mutations.root(), mutations_sequential.root());
assert_eq!(mutations.node_mutations(), mutations_sequential.node_mutations());
assert_eq!(mutations.new_pairs(), mutations_sequential.new_pairs());

println!(
"An average insert-batch computation time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
"An average insert-batch computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
compute_elapsed * 1000_f64 / insertions as f64, // time in μs
);

println!(
"An average insert-batch sequential computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed_sequential,
compute_elapsed_sequential * 1000_f64 / insertions as f64, // time in μs
);
let parallel_factor = compute_elapsed_sequential / compute_elapsed;
println!("Parallel implementation is {parallel_factor}x times faster.");

let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

println!(
"An average insert-batch application time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
"An average insert-batch application time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
apply_elapsed * 1000_f64 / insertions as f64, // time in μs
);

println!(
Expand All @@ -127,50 +157,71 @@ pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
Ok(())
}

pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result<(), MerkleError> {
const NUM_UPDATES: usize = 1_000;
pub fn batched_update(
tree: &mut Smt,
entries: Vec<(RpoDigest, Word)>,
updates: usize,
) -> Result<(), MerkleError> {
const REMOVAL_PROBABILITY: f64 = 0.2;

println!("Running a batched update benchmark:");

let size = tree.num_leaves();
let mut rng = thread_rng();

let new_pairs =
entries
.into_iter()
.choose_multiple(&mut rng, NUM_UPDATES)
.into_iter()
.map(|(key, _)| {
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
EMPTY_WORD
} else {
[ONE, ONE, ONE, Felt::new(rng.gen())]
};
let new_pairs: Vec<(RpoDigest, Word)> = entries
.into_iter()
.choose_multiple(&mut rng, updates)
.into_iter()
.map(|(key, _)| {
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
EMPTY_WORD
} else {
[ONE, ONE, ONE, Felt::new(rng.gen())]
};

(key, value)
});
(key, value)
})
.collect();

assert_eq!(new_pairs.len(), NUM_UPDATES);
assert_eq!(new_pairs.len(), updates);

let cloned_new_pairs = new_pairs.clone();
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let mutations = tree.compute_mutations(cloned_new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

let now = Instant::now();
let mutations_sequential = tree.compute_mutations_sequential(new_pairs);
let compute_elapsed_sequential = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

assert_eq!(mutations.root(), mutations_sequential.root());
assert_eq!(mutations.node_mutations(), mutations_sequential.node_mutations());
assert_eq!(mutations.new_pairs(), mutations_sequential.new_pairs());

let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms

println!(
"An average update-batch computation time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
"An average update-batch computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
compute_elapsed * 1000_f64 / updates as f64, // time in μs
);

println!(
"An average update-batch sequential computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed_sequential,
compute_elapsed_sequential * 1000_f64 / updates as f64, // time in μs
);

let factor = compute_elapsed_sequential / compute_elapsed;
println!("Parallel implementaton is {factor}x times faster.");

println!(
"An average update-batch application time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
"An average update-batch application time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
apply_elapsed * 1000_f64 / updates as f64, // time in μs
);

println!(
Expand Down
12 changes: 11 additions & 1 deletion src/merkle/smt/full/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,17 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
}

/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
/// Sequential implementation of [`Smt::compute_mutations()`].
pub fn compute_mutations_sequential(
&self,
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations_sequential(self, kv_pairs)
}

/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to
/// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the
/// updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
Expand Down
Loading
Loading