Skip to content

Commit

Permalink
Fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
Marta Mularczyk committed Jan 31, 2024
1 parent fd8000e commit 97d7df4
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 57 deletions.
2 changes: 1 addition & 1 deletion mls-rs/src/group/external_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
resumption_secret: PreSharedKey::new(vec![]),
sender_data_secret: SenderDataSecret::from(vec![]),
#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
secret_tree: SecretTree::empty(0),
secret_tree: SecretTree::empty(),
};

let (mut group, _) = Group::join_with(
Expand Down
4 changes: 2 additions & 2 deletions mls-rs/src/group/secret_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ pub struct SecretTree<T: MlsCodec> {
}

impl<T: TreeIndex + MlsCodec> SecretTree<T> {
pub(crate) fn empty(zero_leaf_count: T) -> SecretTree<T> {
pub(crate) fn empty() -> SecretTree<T> {
SecretTree {
known_secrets: Default::default(),
leaf_count: zero_leaf_count,
leaf_count: T::zero(),
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions mls-rs/src/tree_kem/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,10 @@ impl<'a> TreeKem<'a> {

let ctxts = ctxts.try_collect().await?;

let (path_index, _) = copath_index
let path_index = copath_index
.parent_sibling(&self.tree_kem_public.total_leaf_count())
.ok_or(MlsError::ExpectedNode)?;
.ok_or(MlsError::ExpectedNode)?
.parent;

Ok(UpdatePathNode {
public_key: self
Expand Down
42 changes: 33 additions & 9 deletions mls-rs/src/tree_kem/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ pub trait TreeIndex: Sized + Send + Sync + Eq + Clone + Debug {
fn left_unchecked(&self) -> Self;
fn right_unchecked(&self) -> Self;

fn parent_sibling(&self, leaf_count: &Self) -> Option<(Self, Self)>;
fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>>;
fn is_leaf(&self) -> bool;
fn is_in_tree(&self, root: &Self) -> bool;

fn zero() -> Self;

fn left(&self) -> Option<Self> {
(!self.is_leaf()).then(|| self.left_unchecked())
}
Expand All @@ -35,9 +37,9 @@ pub trait TreeIndex: Sized + Send + Sync + Eq + Clone + Debug {
let mut path = Vec::new();
let mut parent = self.clone();

while let Some((p, s)) = parent.parent_sibling(leaf_count) {
path.push(CopathNode::new(p.clone(), s.clone()));
parent = p;
while let Some(ps) = parent.parent_sibling(leaf_count) {
path.push(CopathNode::new(ps.parent.clone(), ps.sibling));
parent = ps.parent;
}

path
Expand All @@ -56,6 +58,18 @@ impl<T: Clone + PartialEq + Eq + core::fmt::Debug> CopathNode<T> {
}
}

#[derive(Clone, PartialEq, Eq, core::fmt::Debug)]
pub struct ParentSibling<T: Clone + PartialEq + Eq + core::fmt::Debug> {
pub parent: T,
pub sibling: T,
}

impl<T: Clone + PartialEq + Eq + core::fmt::Debug> ParentSibling<T> {
pub fn new(parent: T, sibling: T) -> ParentSibling<T> {
ParentSibling { parent, sibling }
}
}

macro_rules! impl_tree_stdint {
($t:ty) => {
impl TreeIndex for $t {
Expand All @@ -73,7 +87,7 @@ macro_rules! impl_tree_stdint {
*self ^ (0x03 << (level(*self) - 1))
}

fn parent_sibling(&self, leaf_count: &Self) -> Option<(Self, Self)> {
fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>> {
if self == &leaf_count.root() {
return None;
}
Expand All @@ -87,7 +101,7 @@ macro_rules! impl_tree_stdint {
p.left_unchecked()
};

Some((p, s))
Some(ParentSibling::new(p, s))
}

fn is_leaf(&self) -> bool {
Expand All @@ -97,6 +111,10 @@ macro_rules! impl_tree_stdint {
fn is_in_tree(&self, root: &Self) -> bool {
*self <= 2 * root
}

fn zero() -> Self {
0
}
}

fn level(x: $t) -> u32 {
Expand All @@ -111,7 +129,6 @@ impl_tree_stdint!(u32);
mod test_utils {
use super::*;
impl_tree_stdint!(u64);
//impl_tree_stdint!(u16);
}

pub fn leaf_lca_level(x: u32, y: u32) -> u32 {
Expand Down Expand Up @@ -220,7 +237,11 @@ mod tests {
let right = (0..n_nodes).map(|x| x.right()).collect::<Vec<_>>();

let (parent, sibling) = (0..n_nodes)
.map(|x| x.parent_sibling(&n_leaves).unzip())
.map(|x| {
x.parent_sibling(&n_leaves)
.map(|ps| (ps.parent, ps.sibling))
.unzip()
})
.unzip();

test_cases.push(TestCase {
Expand Down Expand Up @@ -253,7 +274,10 @@ mod tests {
assert_eq!(x.left(), case.left[x as usize]);
assert_eq!(x.right(), case.right[x as usize]);

let (p, s) = x.parent_sibling(&case.n_leaves).unzip();
let (p, s) = x
.parent_sibling(&case.n_leaves)
.map(|ps| (ps.parent, ps.sibling))
.unzip();

assert_eq!(p, case.parent[x as usize]);
assert_eq!(s, case.sibling[x as usize]);
Expand Down
8 changes: 0 additions & 8 deletions mls-rs/src/tree_kem/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,6 @@ impl NodeVec {
self.iter().step_by(2).map(|n| n.as_leaf().ok())
}

/*pub fn direct_path(&self, index: LeafIndex) -> Vec<NodeIndex> {
NodeIndex::from(index)
.direct_copath(&self.total_leaf_count())
.into_iter()
.map(|n| n.path)
.collect()
}*/

pub fn direct_copath(&self, index: LeafIndex) -> Vec<CopathNode<NodeIndex>> {
NodeIndex::from(index).direct_copath(&self.total_leaf_count())
}
Expand Down
22 changes: 11 additions & 11 deletions mls-rs/src/tree_kem/parent_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,20 +184,19 @@ impl TreeKemPublic {
for (leaf_index, _) in self.nodes.non_empty_leaves() {
let mut n = NodeIndex::from(leaf_index);

while let Some((mut p, mut s)) = n.parent_sibling(&num_leaves) {
while let Some(mut ps) = n.parent_sibling(&num_leaves) {
// Find the first non-blank ancestor p of n and p's co-path child s.
while self.nodes.is_blank(p)? {
while self.nodes.is_blank(ps.parent)? {
// If we reached the root, we're done with this chain.
let Some((p_parent, p_sibling)) = p.parent_sibling(&num_leaves) else {
let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else {
return Ok(());
};

p = p_parent;
s = p_sibling;
ps = ps_parent;
}

// Check is n's parent_hash field matches the parent hash of p with co-path child s.
let p_parent = self.nodes.borrow_as_parent(p)?;
let p_parent = self.nodes.borrow_as_parent(ps.parent)?;

let n_node = self
.nodes
Expand All @@ -209,17 +208,18 @@ impl TreeKemPublic {
cipher_suite_provider,
&p_parent.public_key,
&p_parent.parent_hash,
&original_hashes[s as usize],
&original_hashes[ps.sibling as usize],
)
.await?;

if n_node.get_parent_hash() == Some(calculated) {
// Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree
// under c is equal to the resolution of c with n removed".
let Some((_, c)) = s.parent_sibling(&num_leaves) else {
let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else {
return Err(MlsError::ParentHashMismatch);
};

let c = cp.sibling;
let c_resolution = self.nodes.get_resolution_index(c)?.into_iter();

#[cfg(feature = "std")]
Expand All @@ -228,7 +228,7 @@ impl TreeKemPublic {
let mut c_resolution = c_resolution.collect::<BTreeSet<_>>();

let p_unmerged_in_c_subtree = self
.unmerged_in_subtree(p, c)?
.unmerged_in_subtree(ps.parent, c)?
.iter()
.copied()
.map(|x| *x * 2);
Expand All @@ -240,10 +240,10 @@ impl TreeKemPublic {

if c_resolution.remove(&n)
&& c_resolution == p_unmerged_in_c_subtree
&& nodes_to_validate.remove(&p)
&& nodes_to_validate.remove(&ps.parent)
{
// If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue.
n = p;
n = ps.parent;
} else {
// If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain").
return Err(MlsError::ParentHashMismatch);
Expand Down
19 changes: 6 additions & 13 deletions mls-rs/src/tree_kem/tree_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,11 @@ impl TreeKemPublic {
let bfs_iter = BfsIterTopDown::new(num_leaves).skip(1);

for n in bfs_iter {
let Some((p, _)) = (n as u32).parent_sibling(&(num_leaves as u32)) else {
let Some(ps) = (n as u32).parent_sibling(&(num_leaves as u32)) else {
break;
};

let p = ps.parent;
filtered_sets[n] = filtered_sets[p as usize].clone();

if self.different_unmerged(*filtered_sets[p as usize].last().unwrap(), p)? {
Expand Down Expand Up @@ -270,17 +271,11 @@ async fn tree_hash<P: CipherSuiteProvider>(

hashes[2 * **l as usize] = TreeHash(hash_for_leaf(*l, leaf, cipher_suite_provider).await?);

if 2 * **l != num_leaves.root() {
let Some((p, _)) = (2 * **l).parent_sibling(&num_leaves) else {
break;
};

node_queue.push_back(p);
if let Some(ps) = (2 * **l).parent_sibling(&num_leaves) {
node_queue.push_back(ps.parent);
}
}

let root = num_leaves.root();

while let Some(n) = node_queue.pop_front() {
let hash = TreeHash(
hash_for_parent(
Expand All @@ -295,10 +290,8 @@ async fn tree_hash<P: CipherSuiteProvider>(

hashes[n as usize] = hash;

if n != root {
if let Some((p, _)) = n.parent_sibling(&num_leaves) {
node_queue.push_back(p);
}
if let Some(ps) = n.parent_sibling(&num_leaves) {
node_queue.push_back(ps.parent);
}
}

Expand Down
17 changes: 6 additions & 11 deletions mls-rs/src/tree_kem/tree_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,27 +128,22 @@ fn validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError> {
// such that for each i=1..k, either L is in the unmerged leaves of P[i], or P[i] is blank. We will
// then check that L is unmerged at each P[1], ..., P[k] and no other node.
let leaf_count = tree.total_leaf_count();
let root = leaf_count.root();

for (index, _) in tree.nodes.non_empty_leaves() {
let mut n = NodeIndex::from(index);

while n != root {
let Some((parent, _)) = n.parent_sibling(&leaf_count) else {
break;
};

if tree.nodes.is_blank(parent)? {
n = parent;
while let Some(ps) = n.parent_sibling(&leaf_count) {
if tree.nodes.is_blank(ps.parent)? {
n = ps.parent;
continue;
}

let parent_node = tree.nodes.borrow_as_parent(parent)?;
let parent_node = tree.nodes.borrow_as_parent(ps.parent)?;

if parent_node.unmerged_leaves.contains(&index) {
unmerged_sets[parent as usize].retain(|i| i != &index);
unmerged_sets[ps.parent as usize].retain(|i| i != &index);

n = parent;
n = ps.parent;
} else {
break;
}
Expand Down

0 comments on commit 97d7df4

Please sign in to comment.