From fd8000ea8ec18f28f463218541e9b52091bbc6b7 Mon Sep 17 00:00:00 2001 From: Marta Mularczyk Date: Wed, 31 Jan 2024 10:35:19 +0100 Subject: [PATCH] Make secret tree work with arbitrary indices, not only usize --- mls-rs/src/group/ciphertext_processor.rs | 9 +- mls-rs/src/group/epoch.rs | 4 +- mls-rs/src/group/external_commit.rs | 2 +- .../group/interop_test_vectors/tree_kem.rs | 4 +- mls-rs/src/group/mod.rs | 8 +- mls-rs/src/group/secret_tree.rs | 248 ++++++++---------- mls-rs/src/tree_kem/kem.rs | 55 ++-- mls-rs/src/tree_kem/math.rs | 218 ++++++++------- mls-rs/src/tree_kem/mod.rs | 40 ++- mls-rs/src/tree_kem/node.rs | 57 ++-- mls-rs/src/tree_kem/parent_hash.rs | 27 +- mls-rs/src/tree_kem/private.rs | 11 +- mls-rs/src/tree_kem/tree_hash.rs | 28 +- mls-rs/src/tree_kem/tree_utils.rs | 14 +- mls-rs/src/tree_kem/tree_validator.rs | 8 +- 15 files changed, 364 insertions(+), 369 deletions(-) diff --git a/mls-rs/src/group/ciphertext_processor.rs b/mls-rs/src/group/ciphertext_processor.rs index 70d735e0..111de4ea 100644 --- a/mls-rs/src/group/ciphertext_processor.rs +++ b/mls-rs/src/group/ciphertext_processor.rs @@ -16,7 +16,10 @@ use super::{ secret_tree::{KeyType, MessageKeyData}, GroupContext, }; -use crate::{client::MlsError, tree_kem::node::LeafIndex}; +use crate::{ + client::MlsError, + tree_kem::node::{LeafIndex, NodeIndex}, +}; use mls_rs_codec::MlsEncode; use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError}; use zeroize::Zeroizing; @@ -67,7 +70,7 @@ where &mut self, key_type: KeyType, ) -> Result { - let self_index = self.group_state.self_index(); + let self_index = NodeIndex::from(self.group_state.self_index()); self.group_state .epoch_secrets_mut() @@ -83,6 +86,8 @@ where key_type: KeyType, generation: u32, ) -> Result { + let sender = NodeIndex::from(sender); + self.group_state .epoch_secrets_mut() .secret_tree diff --git a/mls-rs/src/group/epoch.rs b/mls-rs/src/group/epoch.rs index 7360bf62..7f27cb13 100644 --- a/mls-rs/src/group/epoch.rs +++ b/mls-rs/src/group/epoch.rs @@ -4,6 +4,8 @@ #[cfg(feature = "psk")] use crate::psk::PreSharedKey; +#[cfg(any(feature = "secret_tree_access", feature = "private_message"))] +use crate::tree_kem::node::NodeIndex; #[cfg(feature = "prior_epoch")] use crate::{crypto::SignaturePublicKey, group::GroupContext, tree_kem::node::LeafIndex}; use alloc::vec::Vec; @@ -66,7 +68,7 @@ pub(crate) struct EpochSecrets { #[mls_codec(with = "mls_rs_codec::byte_vec")] pub(crate) sender_data_secret: SenderDataSecret, #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] - pub(crate) secret_tree: SecretTree, + pub(crate) secret_tree: SecretTree, } #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)] diff --git a/mls-rs/src/group/external_commit.rs b/mls-rs/src/group/external_commit.rs index 32bd09b3..88488a7b 100644 --- a/mls-rs/src/group/external_commit.rs +++ b/mls-rs/src/group/external_commit.rs @@ -190,7 +190,7 @@ impl ExternalCommitBuilder { resumption_secret: PreSharedKey::new(vec![]), sender_data_secret: SenderDataSecret::from(vec![]), #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] - secret_tree: SecretTree::empty(), + secret_tree: SecretTree::empty(0), }; let (mut group, _) = Group::join_with( diff --git a/mls-rs/src/group/interop_test_vectors/tree_kem.rs b/mls-rs/src/group/interop_test_vectors/tree_kem.rs index 9ae0aee8..0a04312c 100644 --- a/mls-rs/src/group/interop_test_vectors/tree_kem.rs +++ b/mls-rs/src/group/interop_test_vectors/tree_kem.rs @@ -102,11 +102,13 @@ async fn tree_kem() { let mut tree_private = TreeKemPrivate::new(LeafIndex(leaf.index)); // Set and validate HPKE keys on direct path - let path = tree.nodes.direct_path(tree_private.self_index).unwrap(); + let path = tree.nodes.direct_copath(tree_private.self_index); tree_private.secret_keys = Vec::new(); for dp in path { + let dp = dp.path; + let secret = leaf .path_secrets .iter() diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 1e4c8cbe..add67186 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -719,14 +719,14 @@ where let path = provisional_state .public_tree .nodes - .direct_path(self_index)?; + .direct_copath(self_index); provisional_private_tree .secret_keys .resize(path.len() + 1, None); for (i, n) in path.iter().enumerate() { - if provisional_state.public_tree.nodes.is_blank(*n)? { + if provisional_state.public_tree.nodes.is_blank(n.path)? { provisional_private_tree.secret_keys[i + 1] = None; } } @@ -1512,7 +1512,7 @@ where pub fn next_encryption_key(&mut self) -> Result { self.epoch_secrets.secret_tree.next_message_key( &self.cipher_suite_provider, - self.private_tree.self_index, + crate::tree_kem::node::NodeIndex::from(self.private_tree.self_index), KeyType::Application, ) } @@ -1525,7 +1525,7 @@ where ) -> Result { self.epoch_secrets.secret_tree.message_key_generation( &self.cipher_suite_provider, - LeafIndex(sender), + crate::tree_kem::node::NodeIndex::from(sender), KeyType::Application, generation, ) diff --git a/mls-rs/src/group/secret_tree.rs b/mls-rs/src/group/secret_tree.rs index 0f4e186f..a60ae35c 100644 --- a/mls-rs/src/group/secret_tree.rs +++ b/mls-rs/src/group/secret_tree.rs @@ -4,25 +4,35 @@ use crate::client::MlsError; use crate::tree_kem::math as tree_math; -use crate::tree_kem::node::{LeafIndex, NodeIndex}; use crate::CipherSuiteProvider; -use alloc::vec; use alloc::vec::Vec; -use core::ops::{Deref, DerefMut}; +use core::hash::Hash; +use core::{ + fmt::Debug, + ops::{Deref, DerefMut}, +}; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::error::IntoAnyError; +use tree_math::TreeIndex; use zeroize::Zeroizing; -#[cfg(all(feature = "std", feature = "out_of_order"))] +#[cfg(feature = "std")] use std::collections::HashMap; -#[cfg(all(not(feature = "std"), feature = "out_of_order"))] +#[cfg(not(feature = "std"))] use alloc::collections::BTreeMap; use super::key_schedule::kdf_expand_with_label; pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024; +pub trait MlsCodec: + Clone + Debug + PartialEq + Eq + Default + MlsEncode + MlsDecode + MlsSize + Hash + Ord +{ +} + +impl MlsCodec for u32 {} + #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] #[repr(u8)] enum SecretTreeNode { @@ -31,14 +41,6 @@ enum SecretTreeNode { } impl SecretTreeNode { - fn into_ratchet(self) -> Option { - if let SecretTreeNode::Ratchet(ratchets) = self { - Some(ratchets) - } else { - None - } - } - fn into_secret(self) -> Option { if let SecretTreeNode::Secret(secret) = self { Some(secret) @@ -83,54 +85,35 @@ impl From>> for TreeSecret { } } -#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)] -struct TreeSecretsVec(Vec>); - -impl Deref for TreeSecretsVec { - type Target = Vec>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for TreeSecretsVec { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } +#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)] +struct TreeSecretsVec { + #[cfg(feature = "std")] + inner: HashMap, + #[cfg(not(feature = "std"))] + inner: BTreeMap, } -impl TreeSecretsVec { - fn replace_node( - &mut self, - index: NodeIndex, - value: Option, - ) -> Result<(), MlsError> { - self.get_mut(index as usize) - .ok_or(MlsError::InvalidNodeIndex(index)) - .map(|n| *n = value) - } - - fn get_secret(&self, index: NodeIndex) -> Option { - self.get(index as usize).and_then(|n| n.clone()) +impl TreeSecretsVec { + fn set_node(&mut self, index: T, value: SecretTreeNode) { + self.inner.insert(index, value); } - fn total_leaf_count(&self) -> u32 { - ((self.len() / 2 + 1) as u32).next_power_of_two() + fn take_node(&mut self, index: &T) -> Option { + self.inner.remove(index) } } #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)] -pub struct SecretTree { - known_secrets: TreeSecretsVec, - leaf_count: u32, +pub struct SecretTree { + known_secrets: TreeSecretsVec, + leaf_count: T, } -impl SecretTree { - pub(crate) fn empty() -> SecretTree { +impl SecretTree { + pub(crate) fn empty(zero_leaf_count: T) -> SecretTree { SecretTree { - known_secrets: TreeSecretsVec(vec![]), - leaf_count: 0, + known_secrets: Default::default(), + leaf_count: zero_leaf_count, } } } @@ -176,12 +159,12 @@ impl SecretRatchets { } } -impl SecretTree { - pub fn new(leaf_count: u32, encryption_secret: Zeroizing>) -> SecretTree { - let mut known_secrets = TreeSecretsVec(vec![None; (leaf_count * 2 - 1) as usize]); +impl SecretTree { + pub fn new(leaf_count: T, encryption_secret: Zeroizing>) -> SecretTree { + let mut known_secrets = TreeSecretsVec::default(); - known_secrets[tree_math::root(leaf_count) as usize] = - Some(SecretTreeNode::Secret(TreeSecret::from(encryption_secret))); + let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret)); + known_secrets.set_node(leaf_count.root(), root_secret); Self { known_secrets, @@ -193,11 +176,13 @@ impl SecretTree { async fn consume_node( &mut self, cipher_suite_provider: &P, - index: NodeIndex, + index: &T, ) -> Result<(), MlsError> { - if let Some(secret) = self.read_node(index)?.and_then(|n| n.into_secret()) { - let left_index = tree_math::left(index)?; - let right_index = tree_math::right(index)?; + let node = self.known_secrets.take_node(index); + + if let Some(secret) = node.and_then(|n| n.into_secret()) { + let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?; + let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?; let left_secret = kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None) @@ -207,61 +192,45 @@ impl SecretTree { kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None) .await?; - self.write_node(left_index, Some(SecretTreeNode::Secret(left_secret.into())))?; + self.known_secrets + .set_node(left_index, SecretTreeNode::Secret(left_secret.into())); - self.write_node( - right_index, - Some(SecretTreeNode::Secret(right_secret.into())), - )?; - - self.write_node(index, None) - } else { - Ok(()) // If the node is empty we can just skip it + self.known_secrets + .set_node(right_index, SecretTreeNode::Secret(right_secret.into())); } - } - fn read_node(&self, index: NodeIndex) -> Result, MlsError> { - Ok(self.known_secrets.get_secret(index)) - } - - fn write_node( - &mut self, - index: NodeIndex, - value: Option, - ) -> Result<(), MlsError> { - self.known_secrets.replace_node(index, value) + Ok(()) } - // Start at the root node and work your way down consuming any intermediates needed #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - async fn leaf_secret_ratchets( + async fn take_leaf_ratchet( &mut self, cipher_suite: &P, - leaf_index: LeafIndex, + leaf_index: &T, ) -> Result { - if let Some(ratchet) = self - .read_node(leaf_index.into())? - .and_then(|n| n.into_ratchet()) - { - return Ok(ratchet); - } - - let path = leaf_index.direct_path(self.known_secrets.total_leaf_count())?; - - for i in path.into_iter().rev() { - self.consume_node(cipher_suite, i).await?; - } - - let secret = self - .read_node(leaf_index.into())? - .and_then(|n| n.into_secret()) - .ok_or(MlsError::InvalidLeafConsumption)?; + let node_index = leaf_index; + + let node = match self.known_secrets.take_node(node_index) { + Some(node) => node, + None => { + // Start at the root node and work your way down consuming any intermediates needed + for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() { + self.consume_node(cipher_suite, &i.path).await?; + } - self.write_node(leaf_index.into(), None)?; + self.known_secrets + .take_node(node_index) + .ok_or(MlsError::InvalidLeafConsumption)? + } + }; - Ok(SecretRatchets { - application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application).await?, - handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?, + Ok(match node { + SecretTreeNode::Ratchet(ratchet) => ratchet, + SecretTreeNode::Secret(secret) => SecretRatchets { + application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application) + .await?, + handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?, + }, }) } @@ -269,12 +238,14 @@ impl SecretTree { pub async fn next_message_key( &mut self, cipher_suite: &P, - leaf_index: LeafIndex, + leaf_index: T, key_type: KeyType, ) -> Result { - let mut ratchet = self.leaf_secret_ratchets(cipher_suite, leaf_index).await?; + let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?; let res = ratchet.next_message_key(cipher_suite, key_type).await?; - self.write_node(leaf_index.into(), Some(SecretTreeNode::Ratchet(ratchet)))?; + + self.known_secrets + .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet)); Ok(res) } @@ -283,17 +254,18 @@ impl SecretTree { pub async fn message_key_generation( &mut self, cipher_suite: &P, - leaf_index: LeafIndex, + leaf_index: T, key_type: KeyType, generation: u32, ) -> Result { - let mut ratchet = self.leaf_secret_ratchets(cipher_suite, leaf_index).await?; + let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?; let res = ratchet .message_key_generation(cipher_suite, generation, key_type) .await?; - self.write_node(leaf_index.into(), Some(SecretTreeNode::Ratchet(ratchet)))?; + self.known_secrets + .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet)); Ok(res) } @@ -538,22 +510,27 @@ pub(crate) mod test_utils { use mls_rs_core::crypto::CipherSuiteProvider; use zeroize::Zeroizing; - use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem}; + use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex}; + + use super::{KeyType, MlsCodec, SecretKeyRatchet, SecretTree}; - use super::{KeyType, SecretKeyRatchet, SecretTree}; + impl MlsCodec for u64 {} - pub(crate) fn get_test_tree(secret: Vec, leaf_count: u32) -> SecretTree { + pub(crate) fn get_test_tree( + secret: Vec, + leaf_count: T, + ) -> SecretTree { SecretTree::new(leaf_count, Zeroizing::new(secret)) } - impl SecretTree { + impl SecretTree { pub(crate) fn get_root_secret(&self) -> Vec { - self.read_node(tree_kem::math::root(self.leaf_count)) - .unwrap() + self.known_secrets + .clone() + .take_node(&self.leaf_count.root()) .unwrap() .into_secret() .unwrap() - .0 .to_vec() } } @@ -609,12 +586,15 @@ pub(crate) mod test_utils { #[cfg(test)] mod tests { + use alloc::vec; + use crate::{ cipher_suite::CipherSuite, client::test_utils::TEST_CIPHER_SUITE, crypto::test_utils::{ test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider, }, + tree_kem::node::NodeIndex, }; #[cfg(not(mls_build_async))] @@ -629,18 +609,27 @@ mod tests { #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_secret_tree() { + test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await; + test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await; + } + + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + async fn test_secret_tree_custom( + leaf_count: T, + leaves_to_check: Vec, + all_deleted: bool, + ) { for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { let cs_provider = test_cipher_suite_provider(cipher_suite); let test_secret = vec![0u8; cs_provider.kdf_extract_size()]; - - let mut test_tree = get_test_tree(test_secret.clone(), 16); + let mut test_tree = get_test_tree(test_secret, leaf_count.clone()); let mut secrets = Vec::::new(); - for i in 0..16 { + for i in &leaves_to_check { let secret = test_tree - .leaf_secret_ratchets(&test_cipher_suite_provider(cipher_suite), LeafIndex(i)) + .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i) .await .unwrap(); @@ -648,12 +637,7 @@ mod tests { } // Verify the tree is now completely empty - let full = test_tree - .known_secrets - .iter() - .filter(|n| n.is_some()) - .count(); - assert_eq!(full, 0); + assert!(!all_deleted || test_tree.known_secrets.inner.is_empty()); // Verify that all the secrets are unique let count = secrets.len(); @@ -844,7 +828,7 @@ mod tests { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn get_ratchet_data( - secret_tree: &mut SecretTree, + secret_tree: &mut SecretTree, cipher_suite: CipherSuite, ) -> Vec { let provider = test_cipher_suite_provider(cipher_suite); @@ -852,7 +836,7 @@ mod tests { for index in 0..16 { let mut ratchets = secret_tree - .leaf_secret_ratchets(&provider, LeafIndex(index)) + .take_leaf_ratchet(&provider, &(index * 2)) .await .unwrap(); @@ -943,7 +927,7 @@ mod interop_tests { use crate::{ crypto::test_utils::try_test_cipher_suite_provider, - group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType, LeafIndex}, + group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType}, }; use super::SecretTree; @@ -970,7 +954,7 @@ mod interop_tests { let key = tree .message_key_generation( &cs, - LeafIndex(index as u32), + (index as u32) * 2, KeyType::Application, leaf.generation, ) @@ -983,7 +967,7 @@ mod interop_tests { let key = tree .message_key_generation( &cs, - LeafIndex(index as u32), + (index as u32) * 2, KeyType::Handshake, leaf.generation, ) @@ -1045,7 +1029,7 @@ mod interop_tests { .map(|leaf| { gens.into_iter() .map(|gen| { - let index = LeafIndex(leaf); + let index = leaf * 2u32; let handshake_key = tree .message_key_generation(&cs, index, KeyType::Handshake, gen) diff --git a/mls-rs/src/tree_kem/kem.rs b/mls-rs/src/tree_kem/kem.rs index 3a42205c..b1116f81 100644 --- a/mls-rs/src/tree_kem/kem.rs +++ b/mls-rs/src/tree_kem/kem.rs @@ -12,6 +12,7 @@ use alloc::vec; use alloc::vec::Vec; use itertools::Itertools; use mls_rs_codec::MlsEncode; +use tree_math::{CopathNode, TreeIndex}; #[cfg(all(not(mls_build_async), feature = "rayon"))] use {crate::iter::ParallelIteratorExt, rayon::prelude::*}; @@ -72,7 +73,7 @@ impl<'a> TreeKem<'a> { P: CipherSuiteProvider + Send + Sync, { let self_index = self.private_key.self_index; - let path = self.tree_kem_public.nodes.direct_path_copath(self_index)?; + let path = self.tree_kem_public.nodes.direct_copath(self_index); let filtered = self.tree_kem_public.nodes.filtered(self_index)?; self.private_key.secret_keys.resize(path.len() + 1, None); @@ -80,7 +81,7 @@ impl<'a> TreeKem<'a> { let mut secret_generator = PathSecretGenerator::new(cipher_suite_provider); let mut path_secrets = vec![]; - for (i, ((dp, _), f)) in path.iter().zip(&filtered).enumerate() { + for (i, (node, f)) in path.iter().zip(&filtered).enumerate() { if !f { let secret = secret_generator.next_secret().await?; @@ -88,7 +89,7 @@ impl<'a> TreeKem<'a> { secret.to_hpke_key_pair(cipher_suite_provider).await?; self.private_key.secret_keys[i + 1] = Some(secret_key); - self.tree_kem_public.update_node(public_key, *dp)?; + self.tree_kem_public.update_node(public_key, node.path)?; path_secrets.push(Some(secret)); } else { self.private_key.secret_keys[i + 1] = None; @@ -174,7 +175,7 @@ impl<'a> TreeKem<'a> { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn encrypt_path_secrets( &self, - path: Vec<(u32, u32)>, + path: Vec>, path_secrets: &[Option], context_bytes: &[u8], cipher_suite: &P, @@ -189,13 +190,13 @@ impl<'a> TreeKem<'a> { let mut node_updates = Vec::new(); - for ((_, copath_index), path_secret) in path.into_iter().zip(path_secrets.iter()) { + for (index, path_secret) in path.into_iter().zip(path_secrets.iter()) { if let Some(path_secret) = path_secret { node_updates.push( self.encrypt_copath_node_resolution( cipher_suite, path_secret, - copath_index, + index.copath, context_bytes, &excluding, ) @@ -210,7 +211,7 @@ impl<'a> TreeKem<'a> { #[cfg(all(not(mls_build_async), feature = "rayon"))] fn encrypt_path_secrets( &self, - path: Vec<(u32, u32)>, + path: Vec>, path_secrets: &[Option], context_bytes: &[u8], cipher_suite: &P, @@ -225,12 +226,12 @@ impl<'a> TreeKem<'a> { path.into_par_iter() .zip(path_secrets.par_iter()) - .filter_map(|((_, copath_index), path_secret)| { + .filter_map(|(node, path_secret)| { path_secret.as_ref().map(|path_secret| { self.encrypt_copath_node_resolution( cipher_suite, path_secret, - copath_index, + node.copath, context_bytes, &excluding, ) @@ -256,10 +257,13 @@ impl<'a> TreeKem<'a> { let lca_index = tree_math::leaf_lca_level(self_index.into(), sender_index.into()) as usize - 2; - let mut path = self.tree_kem_public.nodes.direct_path(self_index)?; - path.insert(0, self_index.into()); + let mut path = self.tree_kem_public.nodes.direct_copath(self_index); + let leaf = CopathNode::new(self_index.into(), 0); + path.insert(0, leaf); let resolved_pos = self.find_resolved_pos(&path, lca_index)?; - let ct_pos = self.find_ciphertext_pos(path[lca_index], path[resolved_pos], added_leaves)?; + + let ct_pos = + self.find_ciphertext_pos(path[lca_index].path, path[resolved_pos].path, added_leaves)?; let lca_node = update_path.nodes[lca_index] .as_ref() @@ -277,7 +281,7 @@ impl<'a> TreeKem<'a> { let public = self .tree_kem_public .nodes - .borrow_node(path[resolved_pos])? + .borrow_node(path[resolved_pos].path)? .as_ref() .ok_or(MlsError::UpdateErrorNoSecretKey)? .public_key(); @@ -351,7 +355,9 @@ impl<'a> TreeKem<'a> { let ctxts = ctxts.try_collect().await?; - let path_index = tree_math::parent(copath_index); + let (path_index, _) = copath_index + .parent_sibling(&self.tree_kem_public.total_leaf_count()) + .ok_or(MlsError::ExpectedNode)?; Ok(UpdatePathNode { public_key: self @@ -367,10 +373,10 @@ impl<'a> TreeKem<'a> { #[inline] fn find_resolved_pos( &self, - path: &[NodeIndex], + path: &[CopathNode], mut lca_index: usize, ) -> Result { - while self.tree_kem_public.nodes.is_blank(path[lca_index])? { + while self.tree_kem_public.nodes.is_blank(path[lca_index].path)? { lca_index -= 1; } @@ -425,6 +431,7 @@ mod tests { use alloc::{format, vec, vec::Vec}; use mls_rs_codec::MlsEncode; use mls_rs_core::crypto::CipherSuiteProvider; + use tree_math::TreeIndex; // Verify that the tree is in the correct state after generating an update path fn verify_tree_update_path( @@ -435,12 +442,13 @@ mod tests { extensions: Option, ) { // Make sure the update path is based on the direct path of the sender - let direct_path = tree.nodes.direct_path(index).unwrap(); - for (i, &dpi) in direct_path.iter().enumerate() { + let direct_path = tree.nodes.direct_copath(index); + + for (i, n) in direct_path.iter().enumerate() { assert_eq!( *tree .nodes - .borrow_node(dpi) + .borrow_node(n.path) .unwrap() .as_ref() .unwrap() @@ -466,7 +474,7 @@ mod tests { } // Verify that we have a public keys up to the root - let root = tree_math::root(tree.total_leaf_count()); + let root = tree.total_leaf_count().root(); assert!(tree.nodes.borrow_node(root).unwrap().is_some()); } @@ -484,17 +492,16 @@ mod tests { // Make sure we have private values along the direct path, and the public keys match let path_iter = public_tree .nodes - .direct_path(index) - .unwrap() + .direct_copath(index) .into_iter() .enumerate(); - for (i, dp) in path_iter { + for (i, n) in path_iter { let secret_key = private_tree.secret_keys[i + 1].as_ref().unwrap(); let public_key = public_tree .nodes - .borrow_node(dp) + .borrow_node(n.path) .unwrap() .as_ref() .unwrap() diff --git a/mls-rs/src/tree_kem/math.rs b/mls-rs/src/tree_kem/math.rs index 2567fa36..11324c88 100644 --- a/mls-rs/src/tree_kem/math.rs +++ b/mls-rs/src/tree_kem/math.rs @@ -3,121 +3,115 @@ // SPDX-License-Identifier: (Apache-2.0 OR MIT) use alloc::vec::Vec; - -use crate::client::MlsError; +use core::fmt::Debug; use super::node::LeafIndex; -pub fn level(x: u32) -> u32 { - x.trailing_ones() -} +pub trait TreeIndex: Sized + Send + Sync + Eq + Clone + Debug { + fn root(&self) -> Self; -pub fn root(n: u32) -> u32 { - n - 1 -} + fn left_unchecked(&self) -> Self; + fn right_unchecked(&self) -> Self; -#[cfg(any(feature = "secret_tree_access", feature = "private_message", test))] -pub fn left(x: u32) -> Result { - if x & 1 == 0 { - Err(MlsError::LeafNodeNoChildren) - } else { - Ok(left_unchecked(x)) - } -} + fn parent_sibling(&self, leaf_count: &Self) -> Option<(Self, Self)>; + fn is_leaf(&self) -> bool; + fn is_in_tree(&self, root: &Self) -> bool; -/// Panicks if `x` is even. -pub fn left_unchecked(x: u32) -> u32 { - x ^ (0x01 << (level(x) - 1)) -} + fn left(&self) -> Option { + (!self.is_leaf()).then(|| self.left_unchecked()) + } -#[cfg(any(feature = "secret_tree_access", feature = "private_message", test))] -pub fn right(x: u32) -> Result { - if x & 1 == 0 { - Err(MlsError::LeafNodeNoChildren) - } else { - Ok(right_unchecked(x)) + fn right(&self) -> Option { + (!self.is_leaf()).then(|| self.right_unchecked()) } -} -/// Panicks if `x` is even. -pub fn right_unchecked(x: u32) -> u32 { - x ^ (0x03 << (level(x) - 1)) -} + fn direct_copath(&self, leaf_count: &Self) -> Vec> { + let root = leaf_count.root(); -pub fn parent(x: u32) -> u32 { - let lvl = level(x); - (x & !(1 << (lvl + 1))) | (1 << lvl) -} + if !self.is_in_tree(&root) { + return Vec::new(); + } -pub fn sibling(x: u32) -> u32 { - let p = parent(x); + let mut path = Vec::new(); + let mut parent = self.clone(); - if x < p { - right_unchecked(p) - } else { - left_unchecked(p) - } -} + while let Some((p, s)) = parent.parent_sibling(leaf_count) { + path.push(CopathNode::new(p.clone(), s.clone())); + parent = p; + } -pub fn direct_path(x: u32, n: u32) -> Result, MlsError> { - if x > 2 * n - 1 { - return Err(MlsError::InvalidTreeIndex); + path } +} - let mut d = Vec::new(); - let mut m = 1 << (level(x) + 1); +#[derive(Clone, PartialEq, Eq, core::fmt::Debug)] +pub struct CopathNode { + pub path: T, + pub copath: T, +} - while m <= n { - d.push((x & !m) | (m - 1)); - m <<= 1; +impl CopathNode { + pub fn new(path: T, copath: T) -> CopathNode { + CopathNode { path, copath } } - - Ok(d) } -pub fn copath(mut x: u32, n: u32) -> Result, MlsError> { - if x > 2 * n - 1 { - return Err(MlsError::InvalidTreeIndex); - } +macro_rules! impl_tree_stdint { + ($t:ty) => { + impl TreeIndex for $t { + fn root(&self) -> $t { + *self - 1 + } - let mut d = Vec::new(); + /// Panicks if `x` is even. + fn left_unchecked(&self) -> Self { + *self ^ (0x01 << (level(*self) - 1)) + } - while x != root(n) { - let p = parent(x); + /// Panicks if `x` is even. + fn right_unchecked(&self) -> Self { + *self ^ (0x03 << (level(*self) - 1)) + } - d.push(if x < p { - right_unchecked(p) - } else { - left_unchecked(p) - }); + fn parent_sibling(&self, leaf_count: &Self) -> Option<(Self, Self)> { + if self == &leaf_count.root() { + return None; + } - x = p; - } + let lvl = level(*self); + let p = (self & !(1 << (lvl + 1))) | (1 << lvl); - Ok(d) -} + let s = if *self < p { + p.right_unchecked() + } else { + p.left_unchecked() + }; -pub fn path_copath(mut x: u32, n: u32) -> Result, MlsError> { - if x > 2 * n - 1 { - return Err(MlsError::InvalidTreeIndex); - } + Some((p, s)) + } - let mut d = Vec::new(); + fn is_leaf(&self) -> bool { + self & 1 == 0 + } - while x != root(n) { - let p = parent(x); + fn is_in_tree(&self, root: &Self) -> bool { + *self <= 2 * root + } + } - let s = if x < p { - right_unchecked(p) - } else { - left_unchecked(p) - }; + fn level(x: $t) -> u32 { + x.trailing_ones() + } + }; +} - d.push((p, s)); - x = p; - } +impl_tree_stdint!(u32); - Ok(d) +#[cfg(test)] +mod test_utils { + use super::*; + impl_tree_stdint!(u64); + //impl_tree_stdint!(u16); } pub fn leaf_lca_level(x: u32, y: u32) -> u32 { @@ -183,6 +177,7 @@ impl Iterator for BfsIterTopDown { #[cfg(test)] mod tests { use super::*; + use itertools::Itertools; use serde::{Deserialize, Serialize}; #[cfg(target_arch = "wasm32")] @@ -221,21 +216,17 @@ mod tests { for log_n_leaves in 0..8 { let n_leaves = 1 << log_n_leaves; let n_nodes = node_width(n_leaves); - let left = (0..n_nodes).map(|x| left(x).ok()).collect::>(); - let right = (0..n_nodes).map(|x| right(x).ok()).collect::>(); + let left = (0..n_nodes).map(|x| x.left()).collect::>(); + let right = (0..n_nodes).map(|x| x.right()).collect::>(); - let parent = (0..n_nodes) - .map(|x| (x != root(n_leaves)).then_some(parent(x))) - .collect::>(); - - let sibling = (0..n_nodes) - .map(|x| (x != root(n_leaves)).then_some(sibling(x))) - .collect::>(); + let (parent, sibling) = (0..n_nodes) + .map(|x| x.parent_sibling(&n_leaves).unzip()) + .unzip(); test_cases.push(TestCase { n_leaves, n_nodes, - root: root(n_leaves), + root: n_leaves.root(), left, right, parent, @@ -256,21 +247,16 @@ mod tests { for case in test_cases { assert_eq!(node_width(case.n_leaves), case.n_nodes); - assert_eq!(root(case.n_leaves), case.root); + assert_eq!(case.n_leaves.root(), case.root); for x in 0..case.n_nodes { - assert_eq!(left(x).ok(), case.left[x as usize]); - assert_eq!(right(x).ok(), case.right[x as usize]); - - assert_eq!( - (x != root(case.n_leaves)).then_some(sibling(x)), - case.sibling[x as usize] - ); - - assert_eq!( - (x != root(case.n_leaves)).then_some(parent(x)), - case.parent[x as usize] - ); + assert_eq!(x.left(), case.left[x as usize]); + assert_eq!(x.right(), case.right[x as usize]); + + let (p, s) = x.parent_sibling(&case.n_leaves).unzip(); + + assert_eq!(p, case.parent[x as usize]); + assert_eq!(s, case.sibling[x as usize]); } } } @@ -313,7 +299,13 @@ mod tests { .to_vec(); for (i, item) in expected.iter().enumerate() { - assert_eq!(item, &direct_path(i as u32, 16).unwrap()) + let path = (i as u32) + .direct_copath(&16) + .into_iter() + .map(|cp| cp.path) + .collect_vec(); + + assert_eq!(item, &path) } } @@ -355,7 +347,13 @@ mod tests { .to_vec(); for (i, item) in expected.iter().enumerate() { - assert_eq!(item, &copath(i as u32, 16).unwrap()) + let copath = (i as u32) + .direct_copath(&16) + .into_iter() + .map(|cp| cp.copath) + .collect_vec(); + + assert_eq!(item, &copath) } } } diff --git a/mls-rs/src/tree_kem/mod.rs b/mls-rs/src/tree_kem/mod.rs index 64b34822..e9b52888 100644 --- a/mls-rs/src/tree_kem/mod.rs +++ b/mls-rs/src/tree_kem/mod.rs @@ -286,11 +286,11 @@ impl TreeKemPublic { *existing_leaf = update_path.leaf_node.clone(); // Update the rest of the nodes on the direct path - let path = self.nodes.direct_path(sender)?; + let path = self.nodes.direct_copath(sender); - for (node, dp) in update_path.nodes.iter().zip(path) { + for (node, pn) in update_path.nodes.iter().zip(path) { node.as_ref() - .map(|n| self.update_node(n.public_key.clone(), dp)) + .map(|n| self.update_node(n.public_key.clone(), pn.path)) .transpose()?; } @@ -319,8 +319,8 @@ impl TreeKemPublic { fn update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError> { // For a given leaf index, find parent nodes and add the leaf to the unmerged leaf - self.nodes.direct_path(index)?.into_iter().for_each(|i| { - if let Ok(p) = self.nodes.borrow_as_parent_mut(i) { + self.nodes.direct_copath(index).into_iter().for_each(|i| { + if let Ok(p) = self.nodes.borrow_as_parent_mut(i.path) { p.unmerged_leaves.push(index) } }); @@ -856,13 +856,13 @@ pub(crate) mod test_utils { ) { let committer = LeafIndex(committer); - let path = self.tree.nodes.direct_path(committer).unwrap(); + let path = self.tree.nodes.direct_copath(committer); let filtered = self.tree.nodes.filtered(committer).unwrap(); - for (i, f) in path.into_iter().zip(filtered) { + for (n, f) in path.into_iter().zip(filtered) { if !f { self.tree - .update_node(cs.kem_generate().await.unwrap().1, i) + .update_node(cs.kem_generate().await.unwrap().1, n.path) .unwrap(); } } @@ -1171,15 +1171,11 @@ mod tests { .unwrap(); // Add in parent nodes so we can detect them clearing after update - tree.nodes - .direct_path(LeafIndex(0)) - .unwrap() - .iter() - .for_each(|&i| { - tree.nodes - .borrow_or_fill_node_as_parent(i, &b"pub_key".to_vec().into()) - .unwrap(); - }); + tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| { + tree.nodes + .borrow_or_fill_node_as_parent(n.path, &b"pub_key".to_vec().into()) + .unwrap(); + }); let original_size = tree.occupied_leaf_count(); let original_leaf_index = LeafIndex(1); @@ -1209,13 +1205,9 @@ mod tests { ); // Verify that the direct path has been cleared - tree.nodes - .direct_path(LeafIndex(0)) - .unwrap() - .iter() - .for_each(|&i| { - assert!(tree.nodes[i as usize].is_none()); - }); + tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| { + assert!(tree.nodes[n.path as usize].is_none()); + }); } #[cfg(feature = "by_ref_proposal")] diff --git a/mls-rs/src/tree_kem/node.rs b/mls-rs/src/tree_kem/node.rs index 3330641f..68a77423 100644 --- a/mls-rs/src/tree_kem/node.rs +++ b/mls-rs/src/tree_kem/node.rs @@ -12,6 +12,7 @@ use alloc::vec::Vec; use core::hash::Hash; use core::ops::{Deref, DerefMut}; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; +use tree_math::{CopathNode, TreeIndex}; #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] pub(crate) struct Parent { @@ -52,12 +53,6 @@ impl From for NodeIndex { } } -impl LeafIndex { - pub(crate) fn direct_path(&self, leaf_count: u32) -> Result, MlsError> { - tree_math::direct_path(NodeIndex::from(self), leaf_count) - } -} - pub(crate) type NodeIndex = u32; #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] @@ -226,29 +221,27 @@ impl NodeVec { self.iter().step_by(2).map(|n| n.as_leaf().ok()) } - #[inline] - pub fn direct_path(&self, index: LeafIndex) -> Result, MlsError> { - // Direct path from leaf to root - index.direct_path(self.total_leaf_count()) - } + /*pub fn direct_path(&self, index: LeafIndex) -> Vec { + NodeIndex::from(index) + .direct_copath(&self.total_leaf_count()) + .into_iter() + .map(|n| n.path) + .collect() + }*/ - pub fn direct_path_copath( - &self, - index: LeafIndex, - ) -> Result, MlsError> { - tree_math::path_copath(NodeIndex::from(index), self.total_leaf_count()) + pub fn direct_copath(&self, index: LeafIndex) -> Vec> { + NodeIndex::from(index).direct_copath(&self.total_leaf_count()) } // Section 8.4 // The filtered direct path of a node is obtained from the node's direct path by removing // all nodes whose child on the nodes's copath has an empty resolution pub fn filtered(&self, index: LeafIndex) -> Result, MlsError> { - Ok( - tree_math::copath(NodeIndex::from(index), self.total_leaf_count())? - .into_iter() - .map(|cp| self.is_resolution_empty(cp)) - .collect(), - ) + Ok(NodeIndex::from(index) + .direct_copath(&self.total_leaf_count()) + .into_iter() + .map(|cp| self.is_resolution_empty(cp.copath)) + .collect()) } #[inline] @@ -272,8 +265,8 @@ impl NodeVec { } pub fn blank_direct_path(&mut self, leaf: LeafIndex) -> Result<(), MlsError> { - for i in self.direct_path(leaf)? { - if let Some(n) = self.get_mut(i as usize) { + for i in self.direct_copath(leaf) { + if let Some(n) = self.get_mut(i.path as usize) { *n = None } } @@ -351,9 +344,9 @@ impl NodeVec { if let Node::Parent(p) = node { resolution.extend(p.unmerged_leaves.iter().map(NodeIndex::from)); } - } else if index & 1 == 1 { - indexes.push(tree_math::right_unchecked(index)); - indexes.push(tree_math::left_unchecked(index)); + } else if !index.is_leaf() { + indexes.push(index.right_unchecked()); + indexes.push(index.left_unchecked()); } } @@ -379,9 +372,9 @@ impl NodeVec { if let Node::Parent(p) = node { indexes.extend(p.unmerged_leaves.iter().map(NodeIndex::from)); } - } else if index & 1 == 1 { - indexes.push(tree_math::right_unchecked(index)); - indexes.push(tree_math::left_unchecked(index)); + } else if !index.is_leaf() { + indexes.push(index.right_unchecked()); + indexes.push(index.left_unchecked()); } } @@ -490,8 +483,8 @@ mod tests { async fn test_direct_path() { let test_vec = get_test_node_vec().await; // Tree math is already tested in that module, just ensure equality - let expected = tree_math::direct_path(0, 4).unwrap(); - let actual = test_vec.direct_path(LeafIndex(0)).unwrap(); + let expected = 0.direct_copath(&4); + let actual = test_vec.direct_copath(LeafIndex(0)); assert_eq!(actual, expected); } diff --git a/mls-rs/src/tree_kem/parent_hash.rs b/mls-rs/src/tree_kem/parent_hash.rs index 2e8ad3e2..5da7c7bf 100644 --- a/mls-rs/src/tree_kem/parent_hash.rs +++ b/mls-rs/src/tree_kem/parent_hash.rs @@ -11,6 +11,7 @@ use alloc::vec::Vec; use core::ops::Deref; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::error::IntoAnyError; +use tree_math::TreeIndex; use super::leaf_node::LeafNodeSource; @@ -103,18 +104,18 @@ impl TreeKemPublic { ) -> Result { let mut hash = ParentHash::empty(); - for (dp, cp) in self.nodes.direct_path_copath(index)?.into_iter().rev() { - if self.nodes.is_resolution_empty(cp) { + for node in self.nodes.direct_copath(index).into_iter().rev() { + if self.nodes.is_resolution_empty(node.copath) { continue; } - let parent = self.nodes.borrow_as_parent_mut(dp)?; + let parent = self.nodes.borrow_as_parent_mut(node.path)?; let calculated = ParentHash::new( cipher_suite_provider, &parent.public_key, &hash, - &self.tree_hashes.current[cp as usize], + &self.tree_hashes.current[node.copath as usize], ) .await?; @@ -178,25 +179,21 @@ impl TreeKemPublic { let mut nodes_to_validate = nodes_to_validate.collect::>(); let num_leaves = self.total_leaf_count(); - let root = tree_math::root(num_leaves); // For each leaf l, validate all non-blank nodes on the chain from l up the tree. for (leaf_index, _) in self.nodes.non_empty_leaves() { let mut n = NodeIndex::from(leaf_index); - while n != root { + while let Some((mut p, mut s)) = n.parent_sibling(&num_leaves) { // Find the first non-blank ancestor p of n and p's co-path child s. - let mut p = tree_math::parent(n); - let mut s = tree_math::sibling(n); - while self.nodes.is_blank(p)? { // If we reached the root, we're done with this chain. - if p == root { + let Some((p_parent, p_sibling)) = p.parent_sibling(&num_leaves) else { return Ok(()); - } + }; - s = tree_math::sibling(p); - p = tree_math::parent(p); + p = p_parent; + s = p_sibling; } // Check is n's parent_hash field matches the parent hash of p with co-path child s. @@ -219,7 +216,9 @@ impl TreeKemPublic { if n_node.get_parent_hash() == Some(calculated) { // Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree // under c is equal to the resolution of c with n removed". - let c = tree_math::sibling(s); + let Some((_, c)) = s.parent_sibling(&num_leaves) else { + return Err(MlsError::ParentHashMismatch); + }; let c_resolution = self.nodes.get_resolution_index(c)?.into_iter(); diff --git a/mls-rs/src/tree_kem/private.rs b/mls-rs/src/tree_kem/private.rs index bd1cf0e4..fbd7b738 100644 --- a/mls-rs/src/tree_kem/private.rs +++ b/mls-rs/src/tree_kem/private.rs @@ -50,11 +50,11 @@ impl TreeKemPrivate { let mut node_secret_gen = PathSecretGenerator::starting_with(cipher_suite_provider, path_secret); - let path = public_tree.nodes.direct_path(self.self_index)?; + let path = public_tree.nodes.direct_copath(self.self_index); let filtered = &public_tree.nodes.filtered(self.self_index)?; self.secret_keys.resize(path.len() + 1, None); - for (i, (dp, f)) in path.iter().zip(filtered).enumerate().skip(lca_index) { + for (i, (n, f)) in path.iter().zip(filtered).enumerate().skip(lca_index) { if *f { continue; } @@ -63,7 +63,7 @@ impl TreeKemPrivate { let expected_pub_key = public_tree .nodes - .borrow_node(*dp)? + .borrow_node(n.path)? .as_ref() .map(|n| n.public_key()) .ok_or(MlsError::PubKeyMismatch)?; @@ -113,6 +113,7 @@ mod tests { leaf_node::test_utils::{ default_properties, get_basic_test_node, get_basic_test_node_sig_key, }, + math::TreeIndex, node::LeafIndex, }, }; @@ -249,7 +250,7 @@ mod tests { // Sabotage the public tree public_tree .nodes - .borrow_as_parent_mut(tree_math::root(public_tree.total_leaf_count())) + .borrow_as_parent_mut(public_tree.total_leaf_count().root()) .unwrap() .public_key = random_bytes(32).into(); @@ -273,7 +274,7 @@ mod tests { let mut private_key = TreeKemPrivate::new_self_leaf(self_index, secret.clone()); - private_key.secret_keys = (0..tree_math::direct_path(0, leaf_count).unwrap().len() + 1) + private_key.secret_keys = (0..0.direct_copath(&leaf_count).len() + 1) .map(|_| Some(secret.clone())) .collect(); diff --git a/mls-rs/src/tree_kem/tree_hash.rs b/mls-rs/src/tree_kem/tree_hash.rs index db0629c7..ebce7bdd 100644 --- a/mls-rs/src/tree_kem/tree_hash.rs +++ b/mls-rs/src/tree_kem/tree_hash.rs @@ -16,6 +16,7 @@ use alloc::vec::Vec; use itertools::Itertools; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::error::IntoAnyError; +use tree_math::TreeIndex; use core::ops::Deref; @@ -68,7 +69,7 @@ impl TreeKemPublic { P: CipherSuiteProvider, { self.initialize_hashes(cipher_suite_provider).await?; - let root = tree_math::root(self.total_leaf_count()); + let root = self.total_leaf_count().root(); Ok(self.tree_hashes.current[root as usize].to_vec()) } @@ -164,7 +165,7 @@ impl TreeKemPublic { cipher_suite: &P, ) -> Result, MlsError> { let num_leaves = self.nodes.total_leaf_count() as usize; - let root = tree_math::root(num_leaves as u32); + let root = (num_leaves as u32).root(); // The value `filtered_sets[n]` is a list of all ancestors `a` of `n` s.t. we have to compute // the tree hash of `n` with the unmerged leaves of `a` filtered out. @@ -175,7 +176,10 @@ impl TreeKemPublic { let bfs_iter = BfsIterTopDown::new(num_leaves).skip(1); for n in bfs_iter { - let p = tree_math::parent(n as u32); + let Some((p, _)) = (n as u32).parent_sibling(&(num_leaves as u32)) else { + break; + }; + filtered_sets[n] = filtered_sets[p as usize].clone(); if self.different_unmerged(*filtered_sets[p as usize].last().unwrap(), p)? { @@ -266,12 +270,16 @@ async fn tree_hash( hashes[2 * **l as usize] = TreeHash(hash_for_leaf(*l, leaf, cipher_suite_provider).await?); - if 2 * **l != tree_math::root(num_leaves) { - node_queue.push_back(tree_math::parent(2 * **l)); + if 2 * **l != num_leaves.root() { + let Some((p, _)) = (2 * **l).parent_sibling(&num_leaves) else { + break; + }; + + node_queue.push_back(p); } } - let root = tree_math::root(num_leaves); + let root = num_leaves.root(); while let Some(n) = node_queue.pop_front() { let hash = TreeHash( @@ -279,8 +287,8 @@ async fn tree_hash( nodes.borrow_as_parent(n).ok(), cipher_suite_provider, filtered_leaves, - &hashes[tree_math::left_unchecked(n) as usize], - &hashes[tree_math::right_unchecked(n) as usize], + &hashes[n.left_unchecked() as usize], + &hashes[n.right_unchecked() as usize], ) .await?, ); @@ -288,7 +296,9 @@ async fn tree_hash( hashes[n as usize] = hash; if n != root { - node_queue.push_back(tree_math::parent(n)); + if let Some((p, _)) = n.parent_sibling(&num_leaves) { + node_queue.push_back(p); + } } } diff --git a/mls-rs/src/tree_kem/tree_utils.rs b/mls-rs/src/tree_kem/tree_utils.rs index b32e5b41..e7cdeb12 100644 --- a/mls-rs/src/tree_kem/tree_utils.rs +++ b/mls-rs/src/tree_kem/tree_utils.rs @@ -8,10 +8,8 @@ use core::borrow::BorrowMut; use debug_tree::TreeBuilder; -use super::node::NodeIndex; -use super::tree_math::{left_unchecked, right_unchecked}; -use super::{math::root, node::NodeVec}; -use crate::client::MlsError; +use super::node::{NodeIndex, NodeVec}; +use crate::{client::MlsError, tree_kem::math::TreeIndex}; pub(crate) fn build_tree( tree: &mut TreeBuilder, @@ -30,7 +28,7 @@ pub(crate) fn build_tree( // Parent Leaf let mut parent_tag = format!("{blank_tag}Parent ({idx})"); - if root(nodes.total_leaf_count()) == idx { + if nodes.total_leaf_count().root() == idx { parent_tag = format!("{blank_tag}Root ({idx})"); } @@ -56,8 +54,8 @@ pub(crate) fn build_tree( let mut branch = tree.add_branch(&parent_tag); //This cannot panic, as we already checked that idx is not a leaf - build_tree(tree, nodes, left_unchecked(idx))?; - build_tree(tree, nodes, right_unchecked(idx))?; + build_tree(tree, nodes, idx.left_unchecked())?; + build_tree(tree, nodes, idx.right_unchecked())?; branch.release(); @@ -67,7 +65,7 @@ pub(crate) fn build_tree( pub(crate) fn build_ascii_tree(nodes: &NodeVec) -> String { let leaves_count: u32 = nodes.total_leaf_count(); let mut tree = TreeBuilder::new(); - build_tree(tree.borrow_mut(), nodes, root(leaves_count)).unwrap(); + build_tree(tree.borrow_mut(), nodes, leaves_count.root()).unwrap(); tree.string() } diff --git a/mls-rs/src/tree_kem/tree_validator.rs b/mls-rs/src/tree_kem/tree_validator.rs index 86ffdb85..04f8637b 100644 --- a/mls-rs/src/tree_kem/tree_validator.rs +++ b/mls-rs/src/tree_kem/tree_validator.rs @@ -7,6 +7,7 @@ use std::collections::HashSet; #[cfg(not(feature = "std"))] use alloc::{vec, vec::Vec}; +use tree_math::TreeIndex; use super::node::{Node, NodeIndex}; use crate::client::MlsError; @@ -126,13 +127,16 @@ fn validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError> { // For each leaf L, we search for the longest prefix P[1], P[2], ..., P[k] of the direct path of L // such that for each i=1..k, either L is in the unmerged leaves of P[i], or P[i] is blank. We will // then check that L is unmerged at each P[1], ..., P[k] and no other node. - let root = tree_math::root(tree.total_leaf_count()); + let leaf_count = tree.total_leaf_count(); + let root = leaf_count.root(); for (index, _) in tree.nodes.non_empty_leaves() { let mut n = NodeIndex::from(index); while n != root { - let parent = tree_math::parent(n); + let Some((parent, _)) = n.parent_sibling(&leaf_count) else { + break; + }; if tree.nodes.is_blank(parent)? { n = parent;