diff --git a/mls-rs/src/group/proposal_cache.rs b/mls-rs/src/group/proposal_cache.rs index 4a346972..7932b946 100644 --- a/mls-rs/src/group/proposal_cache.rs +++ b/mls-rs/src/group/proposal_cache.rs @@ -24,8 +24,8 @@ use mls_rs_core::{error::IntoAnyError, psk::PreSharedKeyStorage}; #[cfg(feature = "by_ref_proposal")] #[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)] pub struct CachedProposal { - proposal: Proposal, - sender: Sender, + pub(crate) proposal: Proposal, + pub(crate) sender: Sender, } #[cfg(feature = "by_ref_proposal")] @@ -89,21 +89,21 @@ impl ProposalCache { sender: Sender, additional_proposals: Vec, ) -> ProposalBundle { - let mut proposals = ProposalBundle::default(); - - for (r, p) in &self.proposals { - proposals.add( - p.proposal.clone(), - p.sender, - ProposalSource::ByReference(r.clone()), - ); - } - - for p in additional_proposals.into_iter() { - proposals.add(p, sender, ProposalSource::ByValue); - } - - proposals + self.proposals + .iter() + .map(|(r, p)| { + ( + p.proposal.clone(), + p.sender, + ProposalSource::ByReference(r.clone()), + ) + }) + .chain( + additional_proposals + .into_iter() + .map(|p| (p, sender, ProposalSource::ByValue)), + ) + .collect() } pub fn resolve_for_commit( @@ -249,8 +249,13 @@ impl GroupState { .await?; #[cfg(feature = "by_ref_proposal")] - let rejected_proposals = - rejected_proposals(all_proposals, &applier_output.applied_proposals); + let rejected_proposals = rejected_proposals( + match direction { + CommitDirection::Send => all_proposals, + CommitDirection::Receive => self.proposals.proposals.iter().collect(), + }, + &applier_output.applied_proposals, + ); let mut group_context = self.context.clone(); group_context.epoch += 1; @@ -508,13 +513,14 @@ pub(crate) mod test_utils { context.extensions = group_extensions.clone(); - let state = GroupState::new( + let mut state = GroupState::new( context, public_tree.clone(), Vec::new().into(), ConfirmationTag::empty(cipher_suite_provider).await, ); + state.proposals.proposals = self.proposals.clone(); let proposals = self.resolve_for_commit(sender, proposal_list)?; state @@ -4194,7 +4200,10 @@ mod tests { .unwrap(); let [p] = &state.rejected_proposals[..] else { - panic!("Expected single rejected proposal"); + panic!( + "Expected single rejected proposal but got {:?}", + state.rejected_proposals + ); }; assert_eq!(p.proposal_ref(), Some(&proposal_ref)); diff --git a/mls-rs/src/group/proposal_filter/bundle.rs b/mls-rs/src/group/proposal_filter/bundle.rs index a1783625..5f091cbb 100644 --- a/mls-rs/src/group/proposal_filter/bundle.rs +++ b/mls-rs/src/group/proposal_filter/bundle.rs @@ -17,7 +17,7 @@ use crate::{ }; #[cfg(feature = "by_ref_proposal")] -use crate::group::{LeafIndex, ProposalRef, UpdateProposal}; +use crate::group::{proposal_cache::CachedProposal, LeafIndex, ProposalRef, UpdateProposal}; #[cfg(feature = "psk")] use crate::group::PreSharedKeyProposal; @@ -434,6 +434,47 @@ impl ProposalBundle { } } +impl FromIterator<(Proposal, Sender, ProposalSource)> for ProposalBundle { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut bundle = ProposalBundle::default(); + for (proposal, sender, source) in iter { + bundle.add(proposal, sender, source); + } + bundle + } +} + +#[cfg(feature = "by_ref_proposal")] +impl<'a> FromIterator<(&'a ProposalRef, &'a CachedProposal)> for ProposalBundle { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + iter.into_iter() + .map(|(r, p)| { + ( + p.proposal.clone(), + p.sender, + ProposalSource::ByReference(r.clone()), + ) + }) + .collect() + } +} + +#[cfg(feature = "by_ref_proposal")] +impl<'a> FromIterator<&'a (ProposalRef, CachedProposal)> for ProposalBundle { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + iter.into_iter().map(|pair| (&pair.0, &pair.1)).collect() + } +} + #[cfg_attr( all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(clone, opaque)