From 97d7df423add9fd427f4cdbc7c5ea79bd69e6436 Mon Sep 17 00:00:00 2001 From: Marta Mularczyk Date: Wed, 31 Jan 2024 16:41:11 +0100 Subject: [PATCH] Fixup --- mls-rs/src/group/external_commit.rs | 2 +- mls-rs/src/group/secret_tree.rs | 4 +-- mls-rs/src/tree_kem/kem.rs | 5 ++-- mls-rs/src/tree_kem/math.rs | 42 +++++++++++++++++++++------ mls-rs/src/tree_kem/node.rs | 8 ----- mls-rs/src/tree_kem/parent_hash.rs | 22 +++++++------- mls-rs/src/tree_kem/tree_hash.rs | 19 ++++-------- mls-rs/src/tree_kem/tree_validator.rs | 17 ++++------- 8 files changed, 62 insertions(+), 57 deletions(-) diff --git a/mls-rs/src/group/external_commit.rs b/mls-rs/src/group/external_commit.rs index 88488a7b..32bd09b3 100644 --- a/mls-rs/src/group/external_commit.rs +++ b/mls-rs/src/group/external_commit.rs @@ -190,7 +190,7 @@ impl ExternalCommitBuilder { resumption_secret: PreSharedKey::new(vec![]), sender_data_secret: SenderDataSecret::from(vec![]), #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] - secret_tree: SecretTree::empty(0), + secret_tree: SecretTree::empty(), }; let (mut group, _) = Group::join_with( diff --git a/mls-rs/src/group/secret_tree.rs b/mls-rs/src/group/secret_tree.rs index a60ae35c..4ea5d676 100644 --- a/mls-rs/src/group/secret_tree.rs +++ b/mls-rs/src/group/secret_tree.rs @@ -110,10 +110,10 @@ pub struct SecretTree { } impl SecretTree { - pub(crate) fn empty(zero_leaf_count: T) -> SecretTree { + pub(crate) fn empty() -> SecretTree { SecretTree { known_secrets: Default::default(), - leaf_count: zero_leaf_count, + leaf_count: T::zero(), } } } diff --git a/mls-rs/src/tree_kem/kem.rs b/mls-rs/src/tree_kem/kem.rs index b1116f81..cedeb0e3 100644 --- a/mls-rs/src/tree_kem/kem.rs +++ b/mls-rs/src/tree_kem/kem.rs @@ -355,9 +355,10 @@ impl<'a> TreeKem<'a> { let ctxts = ctxts.try_collect().await?; - let (path_index, _) = copath_index + let path_index = copath_index .parent_sibling(&self.tree_kem_public.total_leaf_count()) - .ok_or(MlsError::ExpectedNode)?; + .ok_or(MlsError::ExpectedNode)? + .parent; Ok(UpdatePathNode { public_key: self diff --git a/mls-rs/src/tree_kem/math.rs b/mls-rs/src/tree_kem/math.rs index 11324c88..86963d52 100644 --- a/mls-rs/src/tree_kem/math.rs +++ b/mls-rs/src/tree_kem/math.rs @@ -13,10 +13,12 @@ pub trait TreeIndex: Sized + Send + Sync + Eq + Clone + Debug { fn left_unchecked(&self) -> Self; fn right_unchecked(&self) -> Self; - fn parent_sibling(&self, leaf_count: &Self) -> Option<(Self, Self)>; + fn parent_sibling(&self, leaf_count: &Self) -> Option>; fn is_leaf(&self) -> bool; fn is_in_tree(&self, root: &Self) -> bool; + fn zero() -> Self; + fn left(&self) -> Option { (!self.is_leaf()).then(|| self.left_unchecked()) } @@ -35,9 +37,9 @@ pub trait TreeIndex: Sized + Send + Sync + Eq + Clone + Debug { let mut path = Vec::new(); let mut parent = self.clone(); - while let Some((p, s)) = parent.parent_sibling(leaf_count) { - path.push(CopathNode::new(p.clone(), s.clone())); - parent = p; + while let Some(ps) = parent.parent_sibling(leaf_count) { + path.push(CopathNode::new(ps.parent.clone(), ps.sibling)); + parent = ps.parent; } path @@ -56,6 +58,18 @@ impl CopathNode { } } +#[derive(Clone, PartialEq, Eq, core::fmt::Debug)] +pub struct ParentSibling { + pub parent: T, + pub sibling: T, +} + +impl ParentSibling { + pub fn new(parent: T, sibling: T) -> ParentSibling { + ParentSibling { parent, sibling } + } +} + macro_rules! impl_tree_stdint { ($t:ty) => { impl TreeIndex for $t { @@ -73,7 +87,7 @@ macro_rules! impl_tree_stdint { *self ^ (0x03 << (level(*self) - 1)) } - fn parent_sibling(&self, leaf_count: &Self) -> Option<(Self, Self)> { + fn parent_sibling(&self, leaf_count: &Self) -> Option> { if self == &leaf_count.root() { return None; } @@ -87,7 +101,7 @@ macro_rules! impl_tree_stdint { p.left_unchecked() }; - Some((p, s)) + Some(ParentSibling::new(p, s)) } fn is_leaf(&self) -> bool { @@ -97,6 +111,10 @@ macro_rules! impl_tree_stdint { fn is_in_tree(&self, root: &Self) -> bool { *self <= 2 * root } + + fn zero() -> Self { + 0 + } } fn level(x: $t) -> u32 { @@ -111,7 +129,6 @@ impl_tree_stdint!(u32); mod test_utils { use super::*; impl_tree_stdint!(u64); - //impl_tree_stdint!(u16); } pub fn leaf_lca_level(x: u32, y: u32) -> u32 { @@ -220,7 +237,11 @@ mod tests { let right = (0..n_nodes).map(|x| x.right()).collect::>(); let (parent, sibling) = (0..n_nodes) - .map(|x| x.parent_sibling(&n_leaves).unzip()) + .map(|x| { + x.parent_sibling(&n_leaves) + .map(|ps| (ps.parent, ps.sibling)) + .unzip() + }) .unzip(); test_cases.push(TestCase { @@ -253,7 +274,10 @@ mod tests { assert_eq!(x.left(), case.left[x as usize]); assert_eq!(x.right(), case.right[x as usize]); - let (p, s) = x.parent_sibling(&case.n_leaves).unzip(); + let (p, s) = x + .parent_sibling(&case.n_leaves) + .map(|ps| (ps.parent, ps.sibling)) + .unzip(); assert_eq!(p, case.parent[x as usize]); assert_eq!(s, case.sibling[x as usize]); diff --git a/mls-rs/src/tree_kem/node.rs b/mls-rs/src/tree_kem/node.rs index 68a77423..a8aa9ba7 100644 --- a/mls-rs/src/tree_kem/node.rs +++ b/mls-rs/src/tree_kem/node.rs @@ -221,14 +221,6 @@ impl NodeVec { self.iter().step_by(2).map(|n| n.as_leaf().ok()) } - /*pub fn direct_path(&self, index: LeafIndex) -> Vec { - NodeIndex::from(index) - .direct_copath(&self.total_leaf_count()) - .into_iter() - .map(|n| n.path) - .collect() - }*/ - pub fn direct_copath(&self, index: LeafIndex) -> Vec> { NodeIndex::from(index).direct_copath(&self.total_leaf_count()) } diff --git a/mls-rs/src/tree_kem/parent_hash.rs b/mls-rs/src/tree_kem/parent_hash.rs index 5da7c7bf..f40b383f 100644 --- a/mls-rs/src/tree_kem/parent_hash.rs +++ b/mls-rs/src/tree_kem/parent_hash.rs @@ -184,20 +184,19 @@ impl TreeKemPublic { for (leaf_index, _) in self.nodes.non_empty_leaves() { let mut n = NodeIndex::from(leaf_index); - while let Some((mut p, mut s)) = n.parent_sibling(&num_leaves) { + while let Some(mut ps) = n.parent_sibling(&num_leaves) { // Find the first non-blank ancestor p of n and p's co-path child s. - while self.nodes.is_blank(p)? { + while self.nodes.is_blank(ps.parent)? { // If we reached the root, we're done with this chain. - let Some((p_parent, p_sibling)) = p.parent_sibling(&num_leaves) else { + let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else { return Ok(()); }; - p = p_parent; - s = p_sibling; + ps = ps_parent; } // Check is n's parent_hash field matches the parent hash of p with co-path child s. - let p_parent = self.nodes.borrow_as_parent(p)?; + let p_parent = self.nodes.borrow_as_parent(ps.parent)?; let n_node = self .nodes @@ -209,17 +208,18 @@ impl TreeKemPublic { cipher_suite_provider, &p_parent.public_key, &p_parent.parent_hash, - &original_hashes[s as usize], + &original_hashes[ps.sibling as usize], ) .await?; if n_node.get_parent_hash() == Some(calculated) { // Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree // under c is equal to the resolution of c with n removed". - let Some((_, c)) = s.parent_sibling(&num_leaves) else { + let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else { return Err(MlsError::ParentHashMismatch); }; + let c = cp.sibling; let c_resolution = self.nodes.get_resolution_index(c)?.into_iter(); #[cfg(feature = "std")] @@ -228,7 +228,7 @@ impl TreeKemPublic { let mut c_resolution = c_resolution.collect::>(); let p_unmerged_in_c_subtree = self - .unmerged_in_subtree(p, c)? + .unmerged_in_subtree(ps.parent, c)? .iter() .copied() .map(|x| *x * 2); @@ -240,10 +240,10 @@ impl TreeKemPublic { if c_resolution.remove(&n) && c_resolution == p_unmerged_in_c_subtree - && nodes_to_validate.remove(&p) + && nodes_to_validate.remove(&ps.parent) { // If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue. - n = p; + n = ps.parent; } else { // If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain"). return Err(MlsError::ParentHashMismatch); diff --git a/mls-rs/src/tree_kem/tree_hash.rs b/mls-rs/src/tree_kem/tree_hash.rs index ebce7bdd..46d102ab 100644 --- a/mls-rs/src/tree_kem/tree_hash.rs +++ b/mls-rs/src/tree_kem/tree_hash.rs @@ -176,10 +176,11 @@ impl TreeKemPublic { let bfs_iter = BfsIterTopDown::new(num_leaves).skip(1); for n in bfs_iter { - let Some((p, _)) = (n as u32).parent_sibling(&(num_leaves as u32)) else { + let Some(ps) = (n as u32).parent_sibling(&(num_leaves as u32)) else { break; }; + let p = ps.parent; filtered_sets[n] = filtered_sets[p as usize].clone(); if self.different_unmerged(*filtered_sets[p as usize].last().unwrap(), p)? { @@ -270,17 +271,11 @@ async fn tree_hash( hashes[2 * **l as usize] = TreeHash(hash_for_leaf(*l, leaf, cipher_suite_provider).await?); - if 2 * **l != num_leaves.root() { - let Some((p, _)) = (2 * **l).parent_sibling(&num_leaves) else { - break; - }; - - node_queue.push_back(p); + if let Some(ps) = (2 * **l).parent_sibling(&num_leaves) { + node_queue.push_back(ps.parent); } } - let root = num_leaves.root(); - while let Some(n) = node_queue.pop_front() { let hash = TreeHash( hash_for_parent( @@ -295,10 +290,8 @@ async fn tree_hash( hashes[n as usize] = hash; - if n != root { - if let Some((p, _)) = n.parent_sibling(&num_leaves) { - node_queue.push_back(p); - } + if let Some(ps) = n.parent_sibling(&num_leaves) { + node_queue.push_back(ps.parent); } } diff --git a/mls-rs/src/tree_kem/tree_validator.rs b/mls-rs/src/tree_kem/tree_validator.rs index 04f8637b..dcfee6e8 100644 --- a/mls-rs/src/tree_kem/tree_validator.rs +++ b/mls-rs/src/tree_kem/tree_validator.rs @@ -128,27 +128,22 @@ fn validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError> { // such that for each i=1..k, either L is in the unmerged leaves of P[i], or P[i] is blank. We will // then check that L is unmerged at each P[1], ..., P[k] and no other node. let leaf_count = tree.total_leaf_count(); - let root = leaf_count.root(); for (index, _) in tree.nodes.non_empty_leaves() { let mut n = NodeIndex::from(index); - while n != root { - let Some((parent, _)) = n.parent_sibling(&leaf_count) else { - break; - }; - - if tree.nodes.is_blank(parent)? { - n = parent; + while let Some(ps) = n.parent_sibling(&leaf_count) { + if tree.nodes.is_blank(ps.parent)? { + n = ps.parent; continue; } - let parent_node = tree.nodes.borrow_as_parent(parent)?; + let parent_node = tree.nodes.borrow_as_parent(ps.parent)?; if parent_node.unmerged_leaves.contains(&index) { - unmerged_sets[parent as usize].retain(|i| i != &index); + unmerged_sets[ps.parent as usize].retain(|i| i != &index); - n = parent; + n = ps.parent; } else { break; }