diff --git a/CHANGELOG.md b/CHANGELOG.md index caa5ec00..040f97d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ ## 0.13.1 (2024-12-26) - Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355). +- Added parallel implementation of `Smt::compute_mutations` with better performance (#365). +- Implemented parallel leaf hashing in `Smt::process_sorted_pairs_to_leaves` (#365). ## 0.13.0 (2024-11-24) diff --git a/Cargo.toml b/Cargo.toml index be393abf..febb4715 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ name = "store" harness = false [features] -concurrent = ["dep:rayon"] +concurrent = ["dep:rayon", "hashbrown?/rayon"] default = ["std", "concurrent"] executable = ["dep:clap", "dep:rand-utils", "std"] smt_hashmaps = ["dep:hashbrown"] diff --git a/src/main.rs b/src/main.rs index 5ee3834c..83daef67 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,8 +13,14 @@ use rand_utils::rand_value; #[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")] pub struct BenchmarkCmd { /// Size of the tree - #[clap(short = 's', long = "size")] + #[clap(short = 's', long = "size", default_value = "1000000")] size: usize, + /// Number of insertions + #[clap(short = 'i', long = "insertions", default_value = "1000")] + insertions: usize, + /// Number of updates + #[clap(short = 'u', long = "updates", default_value = "1000")] + updates: usize, } fn main() { @@ -25,7 +31,10 @@ fn main() { pub fn benchmark_smt() { let args = BenchmarkCmd::parse(); let tree_size = args.size; + let insertions = args.insertions; + let updates = args.updates; + assert!(updates <= tree_size, "Cannot update more than `size`"); // prepare the `leaves` vector for tree creation let mut entries = Vec::new(); for i in 0..tree_size { @@ -35,9 +44,9 @@ pub fn benchmark_smt() { } let mut tree = construction(entries.clone(), tree_size).unwrap(); - insertion(&mut tree).unwrap(); - batched_insertion(&mut tree).unwrap(); - batched_update(&mut tree, entries).unwrap(); + insertion(&mut tree.clone(), insertions).unwrap(); + batched_insertion(&mut tree.clone(), insertions).unwrap(); + batched_update(&mut tree.clone(), entries, updates).unwrap(); proof_generation(&mut tree).unwrap(); } @@ -47,23 +56,20 @@ pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result Result<(), MerkleError> { - const NUM_INSERTIONS: usize = 1_000; - +pub fn insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> { println!("Running an insertion benchmark:"); let size = tree.num_leaves(); let mut insertion_times = Vec::new(); - for i in 0..NUM_INSERTIONS { + for i in 0..insertions { let test_key = Rpo256::hash(&rand_value::().to_be_bytes()); let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)]; @@ -74,22 +80,20 @@ pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> { } println!( - "An average insertion time measured by {NUM_INSERTIONS} inserts into an SMT with {size} leaves is {:.0} μs\n", + "The average insertion time measured by {insertions} inserts into an SMT with {size} leaves is {:.0} μs\n", // calculate the average - insertion_times.iter().sum::() as f64 / (NUM_INSERTIONS as f64), + insertion_times.iter().sum::() as f64 / (insertions as f64), ); Ok(()) } -pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> { - const NUM_INSERTIONS: usize = 1_000; - +pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> { println!("Running a batched insertion benchmark:"); let size = tree.num_leaves(); - let new_pairs: Vec<(RpoDigest, Word)> = (0..NUM_INSERTIONS) + let new_pairs: Vec<(RpoDigest, Word)> = (0..insertions) .map(|i| { let key = Rpo256::hash(&rand_value::().to_be_bytes()); let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)]; @@ -101,24 +105,24 @@ pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> { let mutations = tree.compute_mutations(new_pairs); let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms - let now = Instant::now(); - tree.apply_mutations(mutations)?; - let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms - println!( - "An average insert-batch computation time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "The average insert-batch computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", compute_elapsed, - compute_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs + compute_elapsed * 1000_f64 / insertions as f64, // time in μs ); + let now = Instant::now(); + tree.apply_mutations(mutations)?; + let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms + println!( - "An average insert-batch application time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "The average insert-batch application time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", apply_elapsed, - apply_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs + apply_elapsed * 1000_f64 / insertions as f64, // time in μs ); println!( - "An average batch insertion time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms", + "The average batch insertion time measured by a {insertions}-batch into an SMT with {size} leaves totals to {:.1} ms", (compute_elapsed + apply_elapsed), ); @@ -127,8 +131,11 @@ pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> { Ok(()) } -pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result<(), MerkleError> { - const NUM_UPDATES: usize = 1_000; +pub fn batched_update( + tree: &mut Smt, + entries: Vec<(RpoDigest, Word)>, + updates: usize, +) -> Result<(), MerkleError> { const REMOVAL_PROBABILITY: f64 = 0.2; println!("Running a batched update benchmark:"); @@ -139,7 +146,7 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result let new_pairs = entries .into_iter() - .choose_multiple(&mut rng, NUM_UPDATES) + .choose_multiple(&mut rng, updates) .into_iter() .map(|(key, _)| { let value = if rng.gen_bool(REMOVAL_PROBABILITY) { @@ -151,7 +158,7 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result (key, value) }); - assert_eq!(new_pairs.len(), NUM_UPDATES); + assert_eq!(new_pairs.len(), updates); let now = Instant::now(); let mutations = tree.compute_mutations(new_pairs); @@ -162,19 +169,19 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms println!( - "An average update-batch computation time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "The average update-batch computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", compute_elapsed, - compute_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs + compute_elapsed * 1000_f64 / updates as f64, // time in μs ); println!( - "An average update-batch application time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", + "The average update-batch application time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs", apply_elapsed, - apply_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs + apply_elapsed * 1000_f64 / updates as f64, // time in μs ); println!( - "An average batch update time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms", + "The average batch update time measured by a {updates}-batch into an SMT with {size} leaves totals to {:.1} ms", (compute_elapsed + apply_elapsed), ); @@ -203,7 +210,7 @@ pub fn proof_generation(tree: &mut Smt) -> Result<(), MerkleError> { } println!( - "An average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs", + "The average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs", // calculate the average insertion_times.iter().sum::() as f64 / (NUM_PROOFS as f64), ); diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 95a404dc..54ced4dc 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -22,10 +22,10 @@ pub use path::{MerklePath, RootPath, ValuePath}; mod smt; #[cfg(feature = "internal")] -pub use smt::build_subtree_for_bench; +pub use smt::{build_subtree_for_bench, SubtreeLeaf}; pub use smt::{ InnerNode, LeafIndex, MutationSet, NodeMutation, SimpleSmt, Smt, SmtLeaf, SmtLeafError, - SmtProof, SmtProofError, SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, + SmtProof, SmtProofError, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, }; mod mmr; diff --git a/src/merkle/smt/full/concurrent/mod.rs b/src/merkle/smt/full/concurrent/mod.rs new file mode 100644 index 00000000..ea890d48 --- /dev/null +++ b/src/merkle/smt/full/concurrent/mod.rs @@ -0,0 +1,580 @@ +use alloc::{collections::BTreeSet, vec::Vec}; +use core::mem; + +use num::Integer; + +use super::{ + EmptySubtreeRoots, InnerNode, InnerNodes, LeafIndex, Leaves, MerkleError, MutationSet, + NodeIndex, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH, +}; +use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap}; + +#[cfg(test)] +mod tests; + +type MutatedSubtreeLeaves = Vec>; + +impl Smt { + /// Parallel implementation of [`Smt::with_entries()`]. + /// + /// This method constructs a new sparse Merkle tree concurrently by processing subtrees in + /// parallel, working from the bottom up. The process works as follows: + /// + /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf + /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. + /// + /// 2. The subtrees are then processed in parallel: + /// - For each subtree, compute the inner nodes from depth D down to depth D-8. + /// - Each subtree computation yields a new subtree root and its associated inner nodes. + /// + /// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration, + /// which processes the next 8 levels up. This continues until the final root of the tree is + /// computed at depth 0. + pub(crate) fn with_entries_concurrent( + entries: impl IntoIterator, + ) -> Result { + let mut seen_keys = BTreeSet::new(); + let entries: Vec<_> = entries + .into_iter() + .map(|(key, value)| { + if seen_keys.insert(key) { + Ok((key, value)) + } else { + Err(MerkleError::DuplicateValuesForIndex( + LeafIndex::::from(key).value(), + )) + } + }) + .collect::>()?; + if entries.is_empty() { + return Ok(Self::default()); + } + let (inner_nodes, leaves) = Self::build_subtrees(entries); + let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); + >::from_raw_parts(inner_nodes, leaves, root) + } + + /// Parallel implementation of [`Smt::compute_mutations()`]. + /// + /// This method computes mutations by recursively processing subtrees in parallel, working from + /// the bottom up. The process works as follows: + /// + /// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf + /// indices. Each subtree covers a range of 256 (2^8) possible leaf positions. + /// + /// 2. The subtrees containing modifications are then processed in parallel: + /// - For each modified subtree, compute node mutations from depth D up to depth D-8 + /// - Each subtree computation yields a new root at depth D-8 and its associated mutations + /// + /// 3. These subtree roots become the "leaves" for the next iteration, which processes the next + /// 8 levels up. This continues until reaching the tree's root at depth 0. + pub(crate) fn compute_mutations_concurrent( + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet + where + Self: Sized + Sync, + { + use rayon::prelude::*; + + // Collect and sort key-value pairs by their corresponding leaf index + let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect(); + sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); + + // Convert sorted pairs into mutated leaves and capture any new pairs + let (mut subtree_leaves, new_pairs) = + self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs); + let mut node_mutations = NodeMutations::default(); + + // Process each depth level in reverse, stepping by the subtree depth + for depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // Parallel processing of each subtree to generate mutations and roots + let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted() && !subtree.is_empty()); + self.build_subtree_mutations(subtree, SMT_DEPTH, depth) + }) + .unzip(); + + // Prepare leaves for the next depth level + subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + + // Aggregate all node mutations + node_mutations.extend(mutations_per_subtree.into_iter().flatten()); + + debug_assert!(!subtree_leaves.is_empty()); + } + + // Finalize the mutation set with updated roots and mutations + MutationSet { + old_root: self.root(), + new_root: subtree_leaves[0][0].hash, + node_mutations, + new_pairs, + } + } + + /// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing + /// subtrees. In other words, this function takes the key-value inputs to the tree, and produces + /// the inputs to feed into [`build_subtree()`]. + /// + /// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If + /// `pairs` is not correctly sorted, the returned computations will be incorrect. + /// + /// # Panics + /// With debug assertions on, this function panics if it detects that `pairs` is not correctly + /// sorted. Without debug assertions, the returned computations will be incorrect. + fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations { + Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf) + } + + /// Computes leaves from a set of key-value pairs and current leaf values. + /// Derived from `sorted_pairs_to_leaves` + fn sorted_pairs_to_mutated_subtree_leaves( + &self, + pairs: Vec<(RpoDigest, Word)>, + ) -> (MutatedSubtreeLeaves, UnorderedMap) { + // Map to track new key-value pairs for mutated leaves + let mut new_pairs = UnorderedMap::new(); + + let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| { + let mut leaf = self.get_leaf(&leaf_pairs[0].0); + + for (key, value) in leaf_pairs { + // Check if the value has changed + let old_value = + new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); + + // Skip if the value hasn't changed + if value == old_value { + continue; + } + + // Otherwise, update the leaf and track the new key-value pair + leaf = self.construct_prospective_leaf(leaf, &key, &value); + new_pairs.insert(key, value); + } + + leaf + }); + (accumulator.leaves, new_pairs) + } + + /// Computes the node mutations and the root of a subtree + fn build_subtree_mutations( + &self, + mut leaves: Vec, + tree_depth: u8, + bottom_depth: u8, + ) -> (NodeMutations, SubtreeLeaf) + where + Self: Sized, + { + debug_assert!(bottom_depth <= tree_depth); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); + debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); + + let subtree_root_depth = bottom_depth - SUBTREE_DEPTH; + let mut node_mutations: NodeMutations = Default::default(); + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + + for current_depth in (subtree_root_depth..bottom_depth).rev() { + debug_assert!(current_depth <= bottom_depth); + + let next_depth = current_depth + 1; + let mut iter = leaves.drain(..).peekable(); + + while let Some(first_leaf) = iter.next() { + // This constructs a valid index because next_depth will never exceed the depth of + // the tree. + let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent(); + let parent_node = self.get_inner_node(parent_index); + let combined_node = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node); + let combined_hash = combined_node.hash(); + + let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth); + + // Add the parent node even if it is empty for proper upward updates + next_leaves.push(SubtreeLeaf { + col: parent_index.value(), + hash: combined_hash, + }); + + node_mutations.insert( + parent_index, + if combined_hash != empty_hash { + NodeMutation::Addition(combined_node) + } else { + NodeMutation::Removal + }, + ); + } + drop(iter); + leaves = mem::take(&mut next_leaves); + } + + debug_assert_eq!(leaves.len(), 1); + let root_leaf = leaves.pop().unwrap(); + (node_mutations, root_leaf) + } + + /// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part: + /// - If `first_leaf` is a right child, the left child is copied from the `parent_node`. + /// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also + /// mutated or copied from the `parent_node`. + /// + /// Returns the `InnerNode` containing the hashes of the sibling pair. + fn fetch_sibling_pair( + iter: &mut core::iter::Peekable>, + first_leaf: SubtreeLeaf, + parent_node: InnerNode, + ) -> InnerNode { + let is_right_node = first_leaf.col.is_odd(); + + if is_right_node { + let left_leaf = SubtreeLeaf { + col: first_leaf.col - 1, + hash: parent_node.left, + }; + InnerNode { + left: left_leaf.hash, + right: first_leaf.hash, + } + } else { + let right_col = first_leaf.col + 1; + let right_leaf = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(), + _ => SubtreeLeaf { col: right_col, hash: parent_node.right }, + }; + InnerNode { + left: first_leaf.hash, + right: right_leaf.hash, + } + } + } + + /// Processes sorted key-value pairs to compute leaves for a subtree. + /// + /// This function groups key-value pairs by their corresponding column index and processes each + /// group to construct leaves. The actual construction of the leaf is delegated to the + /// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating + /// new leaves or mutating existing ones). + /// + /// # Parameters + /// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index + /// column (not simply by key). If the input is not sorted correctly, the function will + /// produce incorrect results and may panic in debug mode. + /// - `process_leaf`: A callback function used to process each group of key-value pairs + /// corresponding to the same column index. The callback takes a vector of key-value pairs for + /// a single column and returns the constructed leaf for that column. + /// + /// # Returns + /// A `PairComputations` containing: + /// - `nodes`: A mapping of column indices to the constructed leaves. + /// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each + /// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf. + /// + /// # Panics + /// This function will panic in debug mode if the input `pairs` are not sorted by column index. + fn process_sorted_pairs_to_leaves( + pairs: Vec<(RpoDigest, Word)>, + mut process_leaf: F, + ) -> PairComputations + where + F: FnMut(Vec<(RpoDigest, Word)>) -> SmtLeaf, + { + use rayon::prelude::*; + debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); + + let mut accumulator: PairComputations = Default::default(); + + // As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a + // single leaf. When we see a pair that's in a different leaf, we'll swap these pairs + // out and store them in our accumulated leaves. + let mut current_leaf_buffer: Vec<(RpoDigest, Word)> = Default::default(); + + let mut iter = pairs.into_iter().peekable(); + while let Some((key, value)) = iter.next() { + let col = Self::key_to_leaf_index(&key).index.value(); + let peeked_col = iter.peek().map(|(key, _v)| { + let index = Self::key_to_leaf_index(key); + let next_col = index.index.value(); + // We panic if `pairs` is not sorted by column. + debug_assert!(next_col >= col); + next_col + }); + current_leaf_buffer.push((key, value)); + + // If the next pair is the same column as this one, then we're done after adding this + // pair to the buffer. + if peeked_col == Some(col) { + continue; + } + + // Otherwise, the next pair is a different column, or there is no next pair. Either way + // it's time to swap out our buffer. + let leaf_pairs = mem::take(&mut current_leaf_buffer); + let leaf = process_leaf(leaf_pairs); + + accumulator.nodes.insert(col, leaf); + + debug_assert!(current_leaf_buffer.is_empty()); + } + + // Compute the leaves from the nodes concurrently + let mut accumulated_leaves: Vec = accumulator + .nodes + .clone() + .into_par_iter() + .map(|(col, leaf)| SubtreeLeaf { col, hash: Self::hash_leaf(&leaf) }) + .collect(); + + // Sort the leaves by column + accumulated_leaves.par_sort_by_key(|leaf| leaf.col); + + // TODO: determine is there is any notable performance difference between computing + // subtree boundaries after the fact as an iterator adapter (like this), versus computing + // subtree boundaries as we go. Either way this function is only used at the beginning of a + // parallel construction, so it should not be a critical path. + accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); + accumulator + } + + /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. + /// + /// `entries` need not be sorted. This function will sort them. + fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { + entries.sort_by_key(|item| { + let index = Self::key_to_leaf_index(&item.0); + index.value() + }); + Self::build_subtrees_from_sorted_entries(entries) + } + + /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. + /// + /// This function is mostly an implementation detail of + /// [`Smt::with_entries_concurrent()`]. + fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { + use rayon::prelude::*; + + let mut accumulated_nodes: InnerNodes = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: initial_leaves, + } = Self::sorted_pairs_to_leaves(entries); + + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + let (nodes, mut subtree_roots): (Vec>, Vec) = + leaf_subtrees + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted()); + debug_assert!(!subtree.is_empty()); + let (nodes, subtree_root) = + build_subtree(subtree, SMT_DEPTH, current_depth); + (nodes, subtree_root) + }) + .unzip(); + + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + debug_assert!(!leaf_subtrees.is_empty()); + } + (accumulated_nodes, initial_leaves) + } +} + +// SUBTREES +// ================================================================================================ + +/// A subtree is of depth 8. +const SUBTREE_DEPTH: u8 = 8; + +/// A depth-8 subtree contains 256 "columns" that can possibly be occupied. +const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32); + +/// Helper struct for organizing the data we care about when computing Merkle subtrees. +/// +/// Note that these represet "conceptual" leaves of some subtree, not necessarily +/// the leaf type for the sparse Merkle tree. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct SubtreeLeaf { + /// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known. + pub col: u64, + /// The hash of the node this `SubtreeLeaf` represents. + pub hash: RpoDigest, +} + +/// Helper struct to organize the return value of [`Smt::sorted_pairs_to_leaves()`]. +#[derive(Debug, Clone)] +pub(crate) struct PairComputations { + /// Literal leaves to be added to the sparse Merkle tree's internal mapping. + pub nodes: UnorderedMap, + /// "Conceptual" leaves that will be used for computations. + pub leaves: Vec>, +} + +// Derive requires `L` to impl Default, even though we don't actually need that. +impl Default for PairComputations { + fn default() -> Self { + Self { + nodes: Default::default(), + leaves: Default::default(), + } + } +} + +#[derive(Debug)] +pub(crate) struct SubtreeLeavesIter<'s> { + leaves: core::iter::Peekable>, +} +impl<'s> SubtreeLeavesIter<'s> { + fn from_leaves(leaves: &'s mut Vec) -> Self { + // TODO: determine if there is any notable performance difference between taking a Vec, + // which many need flattening first, vs storing a `Box>`. + // The latter may have self-referential properties that are impossible to express in purely + // safe Rust Rust. + Self { leaves: leaves.drain(..).peekable() } + } +} +impl Iterator for SubtreeLeavesIter<'_> { + type Item = Vec; + + /// Each `next()` collects an entire subtree. + fn next(&mut self) -> Option> { + let mut subtree: Vec = Default::default(); + + let mut last_subtree_col = 0; + + while let Some(leaf) = self.leaves.peek() { + last_subtree_col = u64::max(1, last_subtree_col); + let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE); + let next_subtree_col = if is_exact_multiple { + u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) + } else { + last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) + }; + + last_subtree_col = leaf.col; + if leaf.col < next_subtree_col { + subtree.push(self.leaves.next().unwrap()); + } else if subtree.is_empty() { + continue; + } else { + break; + } + } + + if subtree.is_empty() { + debug_assert!(self.leaves.peek().is_none()); + return None; + } + + Some(subtree) + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and +/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and +/// `leaves` must not contain more than one depth-8 subtree's worth of leaves. +/// +/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as +/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into +/// itself. +/// +/// # Panics +/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains +/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to +/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified +/// maximum depth (`DEPTH`), or if `leaves` is not sorted. +#[cfg(feature = "concurrent")] +pub(crate) fn build_subtree( + mut leaves: Vec, + tree_depth: u8, + bottom_depth: u8, +) -> (UnorderedMap, SubtreeLeaf) { + debug_assert!(bottom_depth <= tree_depth); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); + debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); + let subtree_root = bottom_depth - SUBTREE_DEPTH; + let mut inner_nodes: UnorderedMap = Default::default(); + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + for next_depth in (subtree_root..bottom_depth).rev() { + debug_assert!(next_depth <= bottom_depth); + // `next_depth` is the stuff we're making. + // `current_depth` is the stuff we have. + let current_depth = next_depth + 1; + let mut iter = leaves.drain(..).peekable(); + while let Some(first) = iter.next() { + // On non-continuous iterations, including the first iteration, `first_column` may + // be a left or right node. On subsequent continuous iterations, we will always call + // `iter.next()` twice. + // On non-continuous iterations (including the very first iteration), this column + // could be either on the left or the right. If the next iteration is not + // discontinuous with our right node, then the next iteration's + let is_right = first.col.is_odd(); + let (left, right) = if is_right { + // Discontinuous iteration: we have no left node, so it must be empty. + let left = SubtreeLeaf { + col: first.col - 1, + hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), + }; + let right = first; + (left, right) + } else { + let left = first; + let right_col = first.col + 1; + let right = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => { + // Our inputs must be sorted. + debug_assert!(left.col <= col); + // The next leaf in the iterator is our sibling. Use it and consume it! + iter.next().unwrap() + }, + // Otherwise, the leaves don't contain our sibling, so our sibling must be + // empty. + _ => SubtreeLeaf { + col: right_col, + hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), + }, + }; + (left, right) + }; + let index = NodeIndex::new_unchecked(current_depth, left.col).parent(); + let node = InnerNode { left: left.hash, right: right.hash }; + let hash = node.hash(); + let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth); + // If this hash is empty, then it doesn't become a new inner node, nor does it count + // as a leaf for the next depth. + if hash != equivalent_empty_hash { + inner_nodes.insert(index, node); + next_leaves.push(SubtreeLeaf { col: index.value(), hash }); + } + } + // Stop borrowing `leaves`, so we can swap it. + // The iterator is empty at this point anyway. + drop(iter); + // After each depth, consider the stuff we just made the new "leaves", and empty the + // other collection. + mem::swap(&mut leaves, &mut next_leaves); + } + debug_assert_eq!(leaves.len(), 1); + let root = leaves.pop().unwrap(); + (inner_nodes, root) +} + +#[cfg(feature = "internal")] +pub fn build_subtree_for_bench( + leaves: Vec, + tree_depth: u8, + bottom_depth: u8, +) -> (UnorderedMap, SubtreeLeaf) { + build_subtree(leaves, tree_depth, bottom_depth) +} diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/full/concurrent/tests.rs similarity index 75% rename from src/merkle/smt/tests.rs rename to src/merkle/smt/full/concurrent/tests.rs index 23794c67..c000a245 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/full/concurrent/tests.rs @@ -1,14 +1,16 @@ -use alloc::{collections::BTreeMap, vec::Vec}; +use alloc::{ + collections::{BTreeMap, BTreeSet}, + vec::Vec, +}; + +use rand::{prelude::IteratorRandom, thread_rng, Rng}; use super::{ - build_subtree, InnerNode, LeafIndex, NodeIndex, PairComputations, SmtLeaf, SparseMerkleTree, - SubtreeLeaf, SubtreeLeavesIter, COLS_PER_SUBTREE, SUBTREE_DEPTH, -}; -use crate::{ - hash::rpo::RpoDigest, - merkle::{Smt, SMT_DEPTH}, - Felt, Word, ONE, + build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, RpoDigest, + Smt, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE, + SMT_DEPTH, SUBTREE_DEPTH, }; +use crate::{merkle::smt::Felt, Word, EMPTY_WORD, ONE}; fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { SubtreeLeaf { @@ -32,9 +34,7 @@ fn test_sorted_pairs_to_leaves() { // Subtree 2. Another normal leaf. (RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]), ]; - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - let control_leaves: Vec = { let mut entries_iter = entries.iter().cloned(); let mut next_entry = || entries_iter.next().unwrap(); @@ -52,11 +52,9 @@ fn test_sorted_pairs_to_leaves() { assert_eq!(entries_iter.next(), None); control_leaves }; - let control_subtree_leaves: Vec> = { let mut control_leaves_iter = control_leaves.iter(); let mut next_leaf = || control_leaves_iter.next().unwrap(); - let control_subtree_leaves: Vec> = [ // Subtree 0. vec![next_leaf(), next_leaf(), next_leaf()], @@ -70,22 +68,18 @@ fn test_sorted_pairs_to_leaves() { assert_eq!(control_leaves_iter.next(), None); control_subtree_leaves }; - let subtrees: PairComputations = Smt::sorted_pairs_to_leaves(entries); // This will check that the hashes, columns, and subtree assignments all match. assert_eq!(subtrees.leaves, control_subtree_leaves); - // Flattening and re-separating out the leaves into subtrees should have the same result. let mut all_leaves: Vec = subtrees.leaves.clone().into_iter().flatten().collect(); let re_grouped: Vec> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); assert_eq!(subtrees.leaves, re_grouped); - // Then finally we might as well check the computed leaf nodes too. let control_leaves: BTreeMap = control .leaves() .map(|(index, value)| (index.index.value(), value.clone())) .collect(); - for (column, test_leaf) in subtrees.nodes { if test_leaf.is_empty() { continue; @@ -96,7 +90,6 @@ fn test_sorted_pairs_to_leaves() { assert_eq!(control_leaf, &test_leaf); } } - // Helper for the below tests. fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> { (0..pair_count) @@ -108,23 +101,41 @@ fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> { }) .collect() } - +fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(RpoDigest, Word)> { + const REMOVAL_PROBABILITY: f64 = 0.2; + let mut rng = thread_rng(); + // Assertion to ensure input keys are unique + assert!( + entries.iter().map(|(key, _)| key).collect::>().len() == entries.len(), + "Input entries contain duplicate keys!" + ); + let mut sorted_entries: Vec<(RpoDigest, Word)> = entries + .into_iter() + .choose_multiple(&mut rng, updates) + .into_iter() + .map(|(key, _)| { + let value = if rng.gen_bool(REMOVAL_PROBABILITY) { + EMPTY_WORD + } else { + [ONE, ONE, ONE, Felt::new(rng.gen())] + }; + (key, value) + }) + .collect(); + sorted_entries.sort_by_key(|(key, _)| Smt::key_to_leaf_index(key).value()); + sorted_entries +} #[test] fn test_single_subtree() { // A single subtree's worth of leaves. const PAIR_COUNT: u64 = COLS_PER_SUBTREE; - let entries = generate_entries(PAIR_COUNT); - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - // `entries` should already be sorted by nature of how we constructed it. let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; let leaves = leaves.into_iter().next().unwrap(); - let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH); assert!(!first_subtree.is_empty()); - // The inner nodes computed from that subtree should match the nodes in our control tree. for (index, node) in first_subtree.into_iter() { let control = control.get_inner_node(index); @@ -133,7 +144,6 @@ fn test_single_subtree() { "subtree-computed node at index {index:?} does not match control", ); } - // The root returned should also match the equivalent node in the control tree. let control_root_index = NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index"); @@ -144,7 +154,6 @@ fn test_single_subtree() { "Subtree-computed root at index {control_root_index:?} does not match control" ); } - // Test that not just can we compute a subtree correctly, but we can feed the results of one // subtree into computing another. In other words, test that `build_subtree()` is correctly // composable. @@ -152,30 +161,22 @@ fn test_single_subtree() { fn test_two_subtrees() { // Two subtrees' worth of leaves. const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2; - let entries = generate_entries(PAIR_COUNT); - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); // With two subtrees' worth of leaves, we should have exactly two subtrees. let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap(); assert_eq!(first.len() as u64, PAIR_COUNT / 2); assert_eq!(first.len(), second.len()); - let mut current_depth = SMT_DEPTH; let mut next_leaves: Vec = Default::default(); - let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth); next_leaves.push(first_root); - let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth); next_leaves.push(second_root); - // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle. let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len(); assert_eq!(total_computed as u64, PAIR_COUNT); - // Verify the computed nodes of both subtrees. let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes); for (index, test_node) in computed_nodes { @@ -185,13 +186,10 @@ fn test_two_subtrees() { "subtree-computed node at index {index:?} does not match control", ); } - current_depth -= SUBTREE_DEPTH; - let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth); assert_eq!(nodes.len(), SUBTREE_DEPTH as usize); assert_eq!(root_leaf.col, 0); - for (index, test_node) in nodes { let control_node = control.get_inner_node(index); assert_eq!( @@ -199,30 +197,23 @@ fn test_two_subtrees() { "subtree-computed node at index {index:?} does not match control", ); } - let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap(); let control_root = control.get_inner_node(index).hash(); assert_eq!(control_root, root_leaf.hash, "Root mismatch"); } - #[test] fn test_singlethreaded_subtrees() { const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - let entries = generate_entries(PAIR_COUNT); - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - let mut accumulated_nodes: BTreeMap = Default::default(); - let PairComputations { leaves: mut leaf_subtrees, nodes: test_leaves, } = Smt::sorted_pairs_to_leaves(entries); - for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { // There's no flat_map_unzip(), so this is the best we can do. - let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees + let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees .into_iter() .enumerate() .map(|(i, subtree)| { @@ -235,10 +226,8 @@ fn test_singlethreaded_subtrees() { !subtree.is_empty(), "subtree {i} at bottom-depth {current_depth} is empty!", ); - // Do actual things. let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth); - // Post-assertions. for (&index, test_node) in nodes.iter() { let control_node = control.get_inner_node(index); @@ -248,19 +237,14 @@ fn test_singlethreaded_subtrees() { current_depth, i, index, ); } - (nodes, subtree_root) }) .unzip(); - // Update state between each depth iteration. - leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); accumulated_nodes.extend(nodes.into_iter().flatten()); - assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); } - // Make sure the true leaves match, first checking length and then checking each individual // leaf. let control_leaves: BTreeMap<_, _> = control.leaves().collect(); @@ -272,7 +256,6 @@ fn test_singlethreaded_subtrees() { let &control_leaf = control_leaves.get(&index).unwrap(); assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control"); } - // Make sure the inner nodes match, checking length first and then each individual leaf. let control_nodes_len = control.inner_nodes().count(); let test_nodes_len = accumulated_nodes.len(); @@ -281,20 +264,16 @@ fn test_singlethreaded_subtrees() { let control_node = control.get_inner_node(index); assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); } - // After the last iteration of the above for loop, we should have the new root node actually // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from // `build_subtree()`. So let's check both! - let control_root = control.get_inner_node(NodeIndex::root()); - // That for loop should have left us with only one leaf subtree... let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap(); // which itself contains only one 'leaf'... let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap(); // which matches the expected root. assert_eq!(control.root(), root_leaf.hash); - // Likewise `accumulated_nodes` should contain a node at the root index... assert!(accumulated_nodes.contains_key(&NodeIndex::root())); // and it should match our actual root. @@ -303,28 +282,20 @@ fn test_singlethreaded_subtrees() { // And of course the root we got from each place should match. assert_eq!(control.root(), root_leaf.hash); } - /// The parallel version of `test_singlethreaded_subtree()`. #[test] -#[cfg(feature = "concurrent")] fn test_multithreaded_subtrees() { use rayon::prelude::*; - const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - let entries = generate_entries(PAIR_COUNT); - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - let mut accumulated_nodes: BTreeMap = Default::default(); - let PairComputations { leaves: mut leaf_subtrees, nodes: test_leaves, } = Smt::sorted_pairs_to_leaves(entries); - for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees + let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees .into_par_iter() .enumerate() .map(|(i, subtree)| { @@ -337,9 +308,7 @@ fn test_multithreaded_subtrees() { !subtree.is_empty(), "subtree {i} at bottom-depth {current_depth} is empty!", ); - let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth); - // Post-assertions. for (&index, test_node) in nodes.iter() { let control_node = control.get_inner_node(index); @@ -349,17 +318,13 @@ fn test_multithreaded_subtrees() { current_depth, i, index, ); } - (nodes, subtree_root) }) .unzip(); - leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); accumulated_nodes.extend(nodes.into_iter().flatten()); - assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); } - // Make sure the true leaves match, checking length first and then each individual leaf. let control_leaves: BTreeMap<_, _> = control.leaves().collect(); let control_leaves_len = control_leaves.len(); @@ -370,7 +335,6 @@ fn test_multithreaded_subtrees() { let &control_leaf = control_leaves.get(&index).unwrap(); assert_eq!(test_leaf, control_leaf); } - // Make sure the inner nodes match, checking length first and then each individual leaf. let control_nodes_len = control.inner_nodes().count(); let test_nodes_len = accumulated_nodes.len(); @@ -379,20 +343,16 @@ fn test_multithreaded_subtrees() { let control_node = control.get_inner_node(index); assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); } - // After the last iteration of the above for loop, we should have the new root node actually // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from // `build_subtree()`. So let's check both! - let control_root = control.get_inner_node(NodeIndex::root()); - // That for loop should have left us with only one leaf subtree... let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); // which itself contains only one 'leaf'... let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap(); // which matches the expected root. assert_eq!(control.root(), root_leaf.hash); - // Likewise `accumulated_nodes` should contain a node at the root index... assert!(accumulated_nodes.contains_key(&NodeIndex::root())); // and it should match our actual root. @@ -401,17 +361,86 @@ fn test_multithreaded_subtrees() { // And of course the root we got from each place should match. assert_eq!(control.root(), root_leaf.hash); } - #[test] -#[cfg(feature = "concurrent")] -fn test_with_entries_parallel() { +fn test_with_entries_concurrent() { const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; - let entries = generate_entries(PAIR_COUNT); - let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - let smt = Smt::with_entries(entries.clone()).unwrap(); assert_eq!(smt.root(), control.root()); assert_eq!(smt, control); } +/// Concurrent mutations +#[test] +fn test_singlethreaded_subtree_mutations() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + let entries = generate_entries(PAIR_COUNT); + let updates = generate_updates(entries.clone(), 1000); + let tree = Smt::with_entries_sequential(entries.clone()).unwrap(); + let control = tree.compute_mutations_sequential(updates.clone()); + let mut node_mutations = NodeMutations::default(); + let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates); + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // There's no flat_map_unzip(), so this is the best we can do. + let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves + .into_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + // Calculate the mutations for this subtree. + let (mutations_per_subtree, subtree_root) = + tree.build_subtree_mutations(subtree, SMT_DEPTH, current_depth); + // Check that the mutations match the control tree. + for (&index, mutation) in mutations_per_subtree.iter() { + let control_mutation = control.node_mutations().get(&index).unwrap(); + assert_eq!( + control_mutation, mutation, + "depth {} subtree {}: mutation does not match control at index {:?}", + current_depth, i, index, + ); + } + (mutations_per_subtree, subtree_root) + }) + .unzip(); + subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); + node_mutations.extend(mutations_per_subtree.into_iter().flatten()); + assert!(!subtree_leaves.is_empty(), "on depth {current_depth}"); + } + let [subtree]: [Vec<_>; 1] = subtree_leaves.try_into().unwrap(); + let [root_leaf]: [SubtreeLeaf; 1] = subtree.try_into().unwrap(); + // Check that the new root matches the control. + assert_eq!(control.new_root, root_leaf.hash); + // Check that the node mutations match the control. + assert_eq!(control.node_mutations().len(), node_mutations.len()); + for (&index, mutation) in control.node_mutations().iter() { + let test_mutation = node_mutations.get(&index).unwrap(); + assert_eq!(test_mutation, mutation); + } + // Check that the new pairs match the control + assert_eq!(control.new_pairs.len(), new_pairs.len()); + for (&key, &value) in control.new_pairs.iter() { + let test_value = new_pairs.get(&key).unwrap(); + assert_eq!(test_value, &value); + } +} +#[test] +fn test_compute_mutations_parallel() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + let entries = generate_entries(PAIR_COUNT); + let tree = Smt::with_entries(entries.clone()).unwrap(); + let updates = generate_updates(entries, 1000); + let control = tree.compute_mutations_sequential(updates.clone()); + let mutations = tree.compute_mutations(updates); + assert_eq!(mutations.root(), control.root()); + assert_eq!(mutations.old_root(), control.old_root()); + assert_eq!(mutations.node_mutations(), control.node_mutations()); + assert_eq!(mutations.new_pairs(), control.new_pairs()); +} diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 5cd641e4..70f69a58 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -1,4 +1,4 @@ -use alloc::{collections::BTreeSet, string::ToString, vec::Vec}; +use alloc::{string::ToString, vec::Vec}; use super::{ EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError, @@ -15,6 +15,12 @@ mod proof; pub use proof::SmtProof; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +// Concurrent implementation +#[cfg(feature = "concurrent")] +mod concurrent; +#[cfg(feature = "internal")] +pub use concurrent::{build_subtree_for_bench, SubtreeLeaf}; + #[cfg(test)] mod tests; @@ -81,23 +87,7 @@ impl Smt { ) -> Result { #[cfg(feature = "concurrent")] { - let mut seen_keys = BTreeSet::new(); - let entries: Vec<_> = entries - .into_iter() - .map(|(key, value)| { - if seen_keys.insert(key) { - Ok((key, value)) - } else { - Err(MerkleError::DuplicateValuesForIndex( - LeafIndex::::from(key).value(), - )) - } - }) - .collect::>()?; - if entries.is_empty() { - return Ok(Self::default()); - } - >::with_entries_par(entries) + Self::with_entries_concurrent(entries) } #[cfg(not(feature = "concurrent"))] { @@ -112,9 +102,12 @@ impl Smt { /// /// # Errors /// Returns an error if the provided entries contain multiple values for the same key. - pub fn with_entries_sequential( + #[cfg(any(not(feature = "concurrent"), test))] + fn with_entries_sequential( entries: impl IntoIterator, ) -> Result { + use alloc::collections::BTreeSet; + // create an empty tree let mut tree = Self::new(); @@ -252,7 +245,14 @@ impl Smt { &self, kv_pairs: impl IntoIterator, ) -> MutationSet { - >::compute_mutations(self, kv_pairs) + #[cfg(feature = "concurrent")] + { + self.compute_mutations_concurrent(kv_pairs) + } + #[cfg(not(feature = "concurrent"))] + { + >::compute_mutations(self, kv_pairs) + } } /// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree. diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index ffa60a2d..aae0ea96 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1,7 +1,6 @@ -use alloc::{collections::BTreeMap, vec::Vec}; -use core::{hash::Hash, mem}; +use alloc::vec::Vec; +use core::hash::Hash; -use num::Integer; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; @@ -11,6 +10,8 @@ use crate::{ }; mod full; +#[cfg(feature = "internal")] +pub use full::{build_subtree_for_bench, SubtreeLeaf}; pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH}; mod simple; @@ -75,17 +76,6 @@ pub(crate) trait SparseMerkleTree { // PROVIDED METHODS // --------------------------------------------------------------------------------------------- - /// Creates a new sparse Merkle tree from an existing set of key-value pairs, in parallel. - #[cfg(feature = "concurrent")] - fn with_entries_par(entries: Vec<(Self::Key, Self::Value)>) -> Result - where - Self: Sized, - { - let (inner_nodes, leaves) = Self::build_subtrees(entries); - let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); - Self::from_raw_parts(inner_nodes, leaves, root) - } - /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// path to the leaf, as well as the leaf itself. fn open(&self, key: &Self::Key) -> Self::Opening { @@ -178,6 +168,15 @@ pub(crate) trait SparseMerkleTree { fn compute_mutations( &self, kv_pairs: impl IntoIterator, + ) -> MutationSet { + self.compute_mutations_sequential(kv_pairs) + } + + /// Sequential version of [`SparseMerkleTree::compute_mutations()`]. + /// This is the default implementation. + fn compute_mutations_sequential( + &self, + kv_pairs: impl IntoIterator, ) -> MutationSet { use NodeMutation::*; @@ -457,118 +456,6 @@ pub(crate) trait SparseMerkleTree { /// /// The length `path` is guaranteed to be equal to `DEPTH` fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening; - - /// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing - /// subtrees. In other words, this function takes the key-value inputs to the tree, and produces - /// the inputs to feed into [`build_subtree()`]. - /// - /// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If - /// `pairs` is not correctly sorted, the returned computations will be incorrect. - /// - /// # Panics - /// With debug assertions on, this function panics if it detects that `pairs` is not correctly - /// sorted. Without debug assertions, the returned computations will be incorrect. - fn sorted_pairs_to_leaves( - pairs: Vec<(Self::Key, Self::Value)>, - ) -> PairComputations { - debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); - - let mut accumulator: PairComputations = Default::default(); - let mut accumulated_leaves: Vec = Vec::with_capacity(pairs.len() / 2); - - // As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a - // single leaf. When we see a pair that's in a different leaf, we'll swap these pairs - // out and store them in our accumulated leaves. - let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default(); - - let mut iter = pairs.into_iter().peekable(); - while let Some((key, value)) = iter.next() { - let col = Self::key_to_leaf_index(&key).index.value(); - let peeked_col = iter.peek().map(|(key, _v)| { - let index = Self::key_to_leaf_index(key); - let next_col = index.index.value(); - // We panic if `pairs` is not sorted by column. - debug_assert!(next_col >= col); - next_col - }); - current_leaf_buffer.push((key, value)); - - // If the next pair is the same column as this one, then we're done after adding this - // pair to the buffer. - if peeked_col == Some(col) { - continue; - } - - // Otherwise, the next pair is a different column, or there is no next pair. Either way - // it's time to swap out our buffer. - let leaf_pairs = mem::take(&mut current_leaf_buffer); - let leaf = Self::pairs_to_leaf(leaf_pairs); - let hash = Self::hash_leaf(&leaf); - - accumulator.nodes.insert(col, leaf); - accumulated_leaves.push(SubtreeLeaf { col, hash }); - - debug_assert!(current_leaf_buffer.is_empty()); - } - - // TODO: determine is there is any notable performance difference between computing - // subtree boundaries after the fact as an iterator adapter (like this), versus computing - // subtree boundaries as we go. Either way this function is only used at the beginning of a - // parallel construction, so it should not be a critical path. - accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); - accumulator - } - - /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. - /// - /// `entries` need not be sorted. This function will sort them. - #[cfg(feature = "concurrent")] - fn build_subtrees( - mut entries: Vec<(Self::Key, Self::Value)>, - ) -> (InnerNodes, Leaves) { - entries.sort_by_key(|item| { - let index = Self::key_to_leaf_index(&item.0); - index.value() - }); - Self::build_subtrees_from_sorted_entries(entries) - } - - /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. - /// - /// This function is mostly an implementation detail of - /// [`SparseMerkleTree::with_entries_par()`]. - #[cfg(feature = "concurrent")] - fn build_subtrees_from_sorted_entries( - entries: Vec<(Self::Key, Self::Value)>, - ) -> (InnerNodes, Leaves) { - use rayon::prelude::*; - - let mut accumulated_nodes: InnerNodes = Default::default(); - - let PairComputations { - leaves: mut leaf_subtrees, - nodes: initial_leaves, - } = Self::sorted_pairs_to_leaves(entries); - - for current_depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { - let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees - .into_par_iter() - .map(|subtree| { - debug_assert!(subtree.is_sorted()); - debug_assert!(!subtree.is_empty()); - - let (nodes, subtree_root) = build_subtree(subtree, DEPTH, current_depth); - (nodes, subtree_root) - }) - .unzip(); - - leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect(); - accumulated_nodes.extend(nodes.into_iter().flatten()); - - debug_assert!(!leaf_subtrees.is_empty()); - } - (accumulated_nodes, initial_leaves) - } } // INNER NODE @@ -820,198 +707,3 @@ impl De }) } } - -// SUBTREES -// ================================================================================================ - -/// A subtree is of depth 8. -const SUBTREE_DEPTH: u8 = 8; - -/// A depth-8 subtree contains 256 "columns" that can possibly be occupied. -const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32); - -/// Helper struct for organizing the data we care about when computing Merkle subtrees. -/// -/// Note that these represet "conceptual" leaves of some subtree, not necessarily -/// the leaf type for the sparse Merkle tree. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] -pub struct SubtreeLeaf { - /// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known. - pub col: u64, - /// The hash of the node this `SubtreeLeaf` represents. - pub hash: RpoDigest, -} - -/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`]. -#[derive(Debug, Clone)] -pub(crate) struct PairComputations { - /// Literal leaves to be added to the sparse Merkle tree's internal mapping. - pub nodes: UnorderedMap, - /// "Conceptual" leaves that will be used for computations. - pub leaves: Vec>, -} - -// Derive requires `L` to impl Default, even though we don't actually need that. -impl Default for PairComputations { - fn default() -> Self { - Self { - nodes: Default::default(), - leaves: Default::default(), - } - } -} - -#[derive(Debug)] -struct SubtreeLeavesIter<'s> { - leaves: core::iter::Peekable>, -} -impl<'s> SubtreeLeavesIter<'s> { - fn from_leaves(leaves: &'s mut Vec) -> Self { - // TODO: determine if there is any notable performance difference between taking a Vec, - // which many need flattening first, vs storing a `Box>`. - // The latter may have self-referential properties that are impossible to express in purely - // safe Rust Rust. - Self { leaves: leaves.drain(..).peekable() } - } -} -impl Iterator for SubtreeLeavesIter<'_> { - type Item = Vec; - - /// Each `next()` collects an entire subtree. - fn next(&mut self) -> Option> { - let mut subtree: Vec = Default::default(); - - let mut last_subtree_col = 0; - - while let Some(leaf) = self.leaves.peek() { - last_subtree_col = u64::max(1, last_subtree_col); - let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE); - let next_subtree_col = if is_exact_multiple { - u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) - } else { - last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) - }; - - last_subtree_col = leaf.col; - if leaf.col < next_subtree_col { - subtree.push(self.leaves.next().unwrap()); - } else if subtree.is_empty() { - continue; - } else { - break; - } - } - - if subtree.is_empty() { - debug_assert!(self.leaves.peek().is_none()); - return None; - } - - Some(subtree) - } -} - -// HELPER FUNCTIONS -// ================================================================================================ - -/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and -/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and -/// `leaves` must not contain more than one depth-8 subtree's worth of leaves. -/// -/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as -/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into -/// itself. -/// -/// # Panics -/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains -/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to -/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified -/// maximum depth (`DEPTH`), or if `leaves` is not sorted. -fn build_subtree( - mut leaves: Vec, - tree_depth: u8, - bottom_depth: u8, -) -> (BTreeMap, SubtreeLeaf) { - debug_assert!(bottom_depth <= tree_depth); - debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); - debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); - let subtree_root = bottom_depth - SUBTREE_DEPTH; - let mut inner_nodes: BTreeMap = Default::default(); - let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); - for next_depth in (subtree_root..bottom_depth).rev() { - debug_assert!(next_depth <= bottom_depth); - // `next_depth` is the stuff we're making. - // `current_depth` is the stuff we have. - let current_depth = next_depth + 1; - let mut iter = leaves.drain(..).peekable(); - while let Some(first) = iter.next() { - // On non-continuous iterations, including the first iteration, `first_column` may - // be a left or right node. On subsequent continuous iterations, we will always call - // `iter.next()` twice. - // On non-continuous iterations (including the very first iteration), this column - // could be either on the left or the right. If the next iteration is not - // discontinuous with our right node, then the next iteration's - let is_right = first.col.is_odd(); - let (left, right) = if is_right { - // Discontinuous iteration: we have no left node, so it must be empty. - let left = SubtreeLeaf { - col: first.col - 1, - hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), - }; - let right = first; - (left, right) - } else { - let left = first; - let right_col = first.col + 1; - let right = match iter.peek().copied() { - Some(SubtreeLeaf { col, .. }) if col == right_col => { - // Our inputs must be sorted. - debug_assert!(left.col <= col); - // The next leaf in the iterator is our sibling. Use it and consume it! - iter.next().unwrap() - }, - // Otherwise, the leaves don't contain our sibling, so our sibling must be - // empty. - _ => SubtreeLeaf { - col: right_col, - hash: *EmptySubtreeRoots::entry(tree_depth, current_depth), - }, - }; - (left, right) - }; - let index = NodeIndex::new_unchecked(current_depth, left.col).parent(); - let node = InnerNode { left: left.hash, right: right.hash }; - let hash = node.hash(); - let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth); - // If this hash is empty, then it doesn't become a new inner node, nor does it count - // as a leaf for the next depth. - if hash != equivalent_empty_hash { - inner_nodes.insert(index, node); - next_leaves.push(SubtreeLeaf { col: index.value(), hash }); - } - } - // Stop borrowing `leaves`, so we can swap it. - // The iterator is empty at this point anyway. - drop(iter); - // After each depth, consider the stuff we just made the new "leaves", and empty the - // other collection. - mem::swap(&mut leaves, &mut next_leaves); - } - debug_assert_eq!(leaves.len(), 1); - let root = leaves.pop().unwrap(); - (inner_nodes, root) -} - -#[cfg(feature = "internal")] -pub fn build_subtree_for_bench( - leaves: Vec, - tree_depth: u8, - bottom_depth: u8, -) -> (BTreeMap, SubtreeLeaf) { - build_subtree(leaves, tree_depth, bottom_depth) -} - -// TESTS -// ================================================================================================ -#[cfg(test)] -mod tests;