From bdfa5fedf94d683e6baf3a38d4d90b5f683f3b16 Mon Sep 17 00:00:00 2001 From: mulmarta Date: Thu, 2 May 2024 11:58:33 +0200 Subject: [PATCH] Support receiving own proposals without an error --- mls-rs-uniffi/src/lib.rs | 3 + mls-rs/src/external_client/group.rs | 16 ++-- mls-rs/src/group/commit.rs | 41 ++-------- mls-rs/src/group/external_commit.rs | 4 +- mls-rs/src/group/message_hash.rs | 36 +++++++++ mls-rs/src/group/message_processor.rs | 2 + mls-rs/src/group/message_verifier.rs | 36 +-------- mls-rs/src/group/mod.rs | 83 +++++++++++++++++++-- mls-rs/src/group/proposal_cache.rs | 33 ++++++++ mls-rs/src/group/snapshot.rs | 10 ++- mls-rs/test_harness_integration/src/main.rs | 13 +--- 11 files changed, 176 insertions(+), 101 deletions(-) create mode 100644 mls-rs/src/group/message_hash.rs diff --git a/mls-rs-uniffi/src/lib.rs b/mls-rs-uniffi/src/lib.rs index 4dc45e1d..193f0102 100644 --- a/mls-rs-uniffi/src/lib.rs +++ b/mls-rs-uniffi/src/lib.rs @@ -271,6 +271,8 @@ pub enum ReceivedMessage { proposal: Arc, }, + /// A proposal previously sent by this member was received. + OwnProposal, /// Validated GroupInfo object. GroupInfo, /// Validated welcome message. @@ -771,6 +773,7 @@ impl Group { let proposal = Arc::new(proposal_message.proposal.into()); Ok(ReceivedMessage::ReceivedProposal { sender, proposal }) } + group::ReceivedMessage::OwnProposal => Ok(ReceivedMessage::OwnProposal), // TODO: group::ReceivedMessage::GroupInfo does not have any // public methods (unless the "ffi" Cargo feature is set). // So perhaps we don't need it? diff --git a/mls-rs/src/external_client/group.rs b/mls-rs/src/external_client/group.rs index 89399480..95050d75 100644 --- a/mls-rs/src/external_client/group.rs +++ b/mls-rs/src/external_client/group.rs @@ -450,11 +450,8 @@ impl ExternalGroup { ) .await?; - self.state.proposals.insert( - ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?, - proposal, - sender, - ); + let proposal_ref = + ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?; let plaintext = PublicMessage { content: auth_content.content, @@ -462,10 +459,14 @@ impl ExternalGroup { membership_tag: None, }; - Ok(MlsMessage::new( + let message = MlsMessage::new( self.group_context().version(), MlsMessagePayload::Plain(plaintext), - )) + ); + + self.state.proposals.insert(proposal_ref, proposal, sender); + + Ok(message) } /// Delete all sent and received proposals cached for commit. @@ -587,7 +588,6 @@ where &self.cipher_suite_provider, message, None, - None, &self.state, ) .await?; diff --git a/mls-rs/src/group/commit.rs b/mls-rs/src/group/commit.rs index c201057c..093dd979 100644 --- a/mls-rs/src/group/commit.rs +++ b/mls-rs/src/group/commit.rs @@ -4,12 +4,9 @@ use alloc::vec; use alloc::vec::Vec; -use core::fmt::{self, Debug}; +use core::fmt::Debug; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; -use mls_rs_core::{ - crypto::{CipherSuiteProvider, SignatureSecretKey}, - error::IntoAnyError, -}; +use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError}; use crate::{ cipher_suite::CipherSuite, @@ -43,6 +40,7 @@ use super::{ confirmation_tag::ConfirmationTag, framing::{Content, MlsMessage, MlsMessagePayload, Sender}, key_schedule::{KeySchedule, WelcomeSecret}, + message_hash::MessageHash, message_processor::{path_update_required, MessageProcessor}, message_signature::AuthenticatedContent, mls_rules::CommitDirection, @@ -71,36 +69,7 @@ pub(super) struct CommitGeneration { pub content: AuthenticatedContent, pub pending_private_tree: TreeKemPrivate, pub pending_commit_secret: PathSecret, - pub commit_message_hash: CommitHash, -} - -#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub(crate) struct CommitHash( - #[mls_codec(with = "mls_rs_codec::byte_vec")] - #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] - Vec, -); - -impl Debug for CommitHash { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - mls_rs_core::debug::pretty_bytes(&self.0) - .named("CommitHash") - .fmt(f) - } -} - -impl CommitHash { - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub(crate) async fn compute( - cs: &CS, - commit: &MlsMessage, - ) -> Result { - cs.hash(&commit.mls_encode_to_vec()?) - .await - .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) - .map(Self) - } + pub commit_message_hash: MessageHash, } #[cfg_attr( @@ -760,7 +729,7 @@ where content: auth_content, pending_private_tree: provisional_private_tree, pending_commit_secret: commit_secret, - commit_message_hash: CommitHash::compute(&self.cipher_suite_provider, &commit_message) + commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message) .await?, }; diff --git a/mls-rs/src/group/external_commit.rs b/mls-rs/src/group/external_commit.rs index 34b10427..a39767b0 100644 --- a/mls-rs/src/group/external_commit.rs +++ b/mls-rs/src/group/external_commit.rs @@ -233,9 +233,7 @@ impl ExternalCommitBuilder { }; let auth_content = AuthenticatedContent::from(plaintext.clone()); - - verify_plaintext_authentication(&cipher_suite, plaintext, None, None, &group.state) - .await?; + verify_plaintext_authentication(&cipher_suite, plaintext, None, &group.state).await?; group .process_event_or_content(EventOrContent::Content(auth_content), true, None) diff --git a/mls-rs/src/group/message_hash.rs b/mls-rs/src/group/message_hash.rs new file mode 100644 index 00000000..2834a4a8 --- /dev/null +++ b/mls-rs/src/group/message_hash.rs @@ -0,0 +1,36 @@ +use core::fmt::Debug; + +use core::fmt; +use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; +use mls_rs_core::crypto::CipherSuiteProvider; + +use crate::{client::MlsError, error::IntoAnyError, MlsMessage}; + +#[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub(crate) struct MessageHash( + #[mls_codec(with = "mls_rs_codec::byte_vec")] + #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] + Vec, +); + +impl Debug for MessageHash { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + mls_rs_core::debug::pretty_bytes(&self.0) + .named("CommitHash") + .fmt(f) + } +} + +impl MessageHash { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub(crate) async fn compute( + cs: &CS, + message: &MlsMessage, + ) -> Result { + cs.hash(&message.mls_encode_to_vec()?) + .await + .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) + .map(Self) + } +} diff --git a/mls-rs/src/group/message_processor.rs b/mls-rs/src/group/message_processor.rs index 8084a583..07449bb7 100644 --- a/mls-rs/src/group/message_processor.rs +++ b/mls-rs/src/group/message_processor.rs @@ -187,6 +187,8 @@ pub enum ReceivedMessage { Commit(CommitMessageDescription), /// A proposal was received. Proposal(ProposalMessageDescription), + /// A proposal previously sent by this member was received. + OwnProposal, /// Validated GroupInfo object GroupInfo(GroupInfo), /// Validated welcome message diff --git a/mls-rs/src/group/message_verifier.rs b/mls-rs/src/group/message_verifier.rs index 7a2bc59b..2462822f 100644 --- a/mls-rs/src/group/message_verifier.rs +++ b/mls-rs/src/group/message_verifier.rs @@ -38,7 +38,6 @@ pub(crate) async fn verify_plaintext_authentication( cipher_suite_provider: &P, plaintext: PublicMessage, key_schedule: Option<&KeySchedule>, - self_index: Option, state: &GroupState, ) -> Result { let tag = plaintext.membership_tag.clone(); @@ -52,7 +51,7 @@ pub(crate) async fn verify_plaintext_authentication( // Verify the membership tag if needed match &auth_content.content.sender { - Sender::Member(index) => { + Sender::Member(_) => { if let Some(key_schedule) = key_schedule { let expected_tag = &key_schedule .get_membership_tag(&auth_content, context, cipher_suite_provider) @@ -64,10 +63,6 @@ pub(crate) async fn verify_plaintext_authentication( return Err(MlsError::InvalidMembershipTag); } } - - if self_index == Some(LeafIndex(*index)) { - return Err(MlsError::CantProcessMessageFromSelf); - } } _ => { tag.is_none() @@ -333,7 +328,6 @@ mod tests { &env.bob.group.cipher_suite_provider, message, Some(&env.bob.group.key_schedule), - None, &env.bob.group.state, ) .await @@ -381,7 +375,6 @@ mod tests { &env.bob.group.cipher_suite_provider, message, Some(&env.bob.group.key_schedule), - None, &env.bob.group.state, ) .await; @@ -399,7 +392,6 @@ mod tests { &env.bob.group.cipher_suite_provider, message, Some(&env.bob.group.key_schedule), - None, &env.bob.group.state, ) .await; @@ -417,7 +409,6 @@ mod tests { &env.bob.group.cipher_suite_provider, message, Some(&env.bob.group.key_schedule), - None, &env.bob.group.state, ) .await; @@ -485,7 +476,6 @@ mod tests { &test_group.group.cipher_suite_provider, message, Some(&test_group.group.key_schedule), - None, &test_group.group.state, ) .await @@ -506,7 +496,6 @@ mod tests { &test_group.group.cipher_suite_provider, message, Some(&test_group.group.key_schedule), - None, &test_group.group.state, ) .await; @@ -532,7 +521,6 @@ mod tests { &test_group.group.cipher_suite_provider, message, Some(&test_group.group.key_schedule), - None, &test_group.group.state, ) .await; @@ -556,7 +544,6 @@ mod tests { &test_group.group.cipher_suite_provider, message, Some(&test_group.group.key_schedule), - None, &test_group.group.state, ) .await; @@ -601,7 +588,6 @@ mod tests { &test_group.group.cipher_suite_provider, message, Some(&test_group.group.key_schedule), - None, &test_group.group.state, ) .await @@ -625,7 +611,6 @@ mod tests { &test_group.group.cipher_suite_provider, message, Some(&test_group.group.key_schedule), - None, &test_group.group.state, ) .await; @@ -652,29 +637,10 @@ mod tests { &test_group.group.cipher_suite_provider, message, Some(&test_group.group.key_schedule), - None, &test_group.group.state, ) .await; assert_matches!(res, Err(MlsError::MembershipTagForNonMember)); } - - #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] - async fn plaintext_from_self_fails_verification() { - let mut env = TestEnv::new().await; - - let message = make_signed_plaintext(&mut env.alice.group).await; - - let res = verify_plaintext_authentication( - &env.alice.group.cipher_suite_provider, - message, - Some(&env.alice.group.key_schedule), - Some(LeafIndex::new(env.alice.group.current_member_index())), - &env.alice.group.state, - ) - .await; - - assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf)) - } } diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 183baf8b..06cf51a3 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -39,6 +39,7 @@ use crate::crypto::{HpkePublicKey, HpkeSecretKey}; use crate::extension::ExternalPubExt; +use self::message_hash::MessageHash; #[cfg(feature = "private_message")] use self::mls_rules::{EncryptionOptions, MlsRules}; @@ -117,6 +118,7 @@ pub(crate) mod framing; mod group_info; pub(crate) mod key_schedule; mod membership_tag; +pub(crate) mod message_hash; pub(crate) mod message_processor; pub(crate) mod message_signature; pub(crate) mod message_verifier; @@ -703,14 +705,25 @@ where ) .await?; + let sender = auth_content.content.sender; + let proposal_ref = ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?; + let message = self.format_for_wire(auth_content).await?; + self.state .proposals - .insert(proposal_ref, proposal, auth_content.content.sender); + .insert_own( + proposal_ref, + proposal, + sender, + &message, + &self.cipher_suite_provider, + ) + .await?; - self.format_for_wire(auth_content).await + Ok(message) } /// Unique identifier for this group. @@ -1290,7 +1303,7 @@ where message: MlsMessage, ) -> Result { if let Some(pending) = &self.pending_commit { - let message_hash = CommitHash::compute(&self.cipher_suite_provider, &message).await?; + let message_hash = MessageHash::compute(&self.cipher_suite_provider, &message).await?; if message_hash == pending.commit_message_hash { let message_description = self.apply_pending_commit().await?; @@ -1299,6 +1312,18 @@ where } } + if message.wire_format() == WireFormat::PrivateMessage { + let cached_own_proposal = self + .state + .proposals + .contains_own(&self.cipher_suite_provider, &message) + .await?; + + if cached_own_proposal { + return Ok(ReceivedMessage::OwnProposal); + } + } + MessageProcessor::process_incoming_message( self, message, @@ -1627,7 +1652,6 @@ where &self.cipher_suite_provider, message, Some(&self.key_schedule), - Some(self.private_tree.self_index), &self.state, ) .await?; @@ -1847,8 +1871,8 @@ pub(crate) mod test_utils; mod tests { use crate::{ client::test_utils::{ - test_client_with_key_pkg, TestClientBuilder, TEST_CIPHER_SUITE, - TEST_CUSTOM_PROPOSAL_TYPE, TEST_PROTOCOL_VERSION, + test_client_with_key_pkg, test_client_with_key_pkg_custom, TestClientBuilder, + TEST_CIPHER_SUITE, TEST_CUSTOM_PROPOSAL_TYPE, TEST_PROTOCOL_VERSION, }, client_builder::{test_utils::TestClientConfig, ClientBuilder, MlsConfig}, crypto::test_utils::TestCryptoProvider, @@ -4255,4 +4279,51 @@ mod tests { let res = groups[1].group.apply_pending_commit().await; assert_matches!(res, Err(MlsError::PendingCommitNotFound)); } + + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn can_process_own_plaintext_proposal() { + can_process_own_roposal(false).await; + } + + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn can_process_own_ciphertext_proposal() { + can_process_own_roposal(true).await; + } + + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + async fn can_process_own_roposal(encrypt_proposal: bool) { + let (alice, _) = test_client_with_key_pkg_custom( + TEST_PROTOCOL_VERSION, + TEST_CIPHER_SUITE, + "alice", + |c| c.0.mls_rules.encryption_options.encrypt_control_messages = encrypt_proposal, + ) + .await; + + let mut alice = TestGroup { + group: alice.create_group(Default::default()).await.unwrap(), + }; + + let mut bob = alice.join("bob").await.0.group; + let mut alice = alice.group; + + let upd = alice.propose_update(vec![]).await.unwrap(); + alice.process_incoming_message(upd.clone()).await.unwrap(); + + bob.process_incoming_message(upd).await.unwrap(); + let commit = bob.commit(vec![]).await.unwrap().commit_message; + let update = alice.process_incoming_message(commit).await.unwrap(); + + let ReceivedMessage::Commit(update) = update else { + panic!("expected commit") + }; + + // Check that proposal was applied i.e. alice's index 0 is updated + assert!(update + .state_update + .roster_update + .updated() + .iter() + .any(|member| member.index() == 0)); + } } diff --git a/mls-rs/src/group/proposal_cache.rs b/mls-rs/src/group/proposal_cache.rs index 17acf79d..665bafa0 100644 --- a/mls-rs/src/group/proposal_cache.rs +++ b/mls-rs/src/group/proposal_cache.rs @@ -5,6 +5,7 @@ use alloc::vec::Vec; use super::{ + message_hash::MessageHash, message_processor::ProvisionalState, mls_rules::{CommitDirection, CommitSource, MlsRules}, GroupState, ProposalOrRef, @@ -16,6 +17,7 @@ use crate::{ Proposal, Sender, }, time::MlsTime, + MlsMessage, }; #[cfg(feature = "by_ref_proposal")] @@ -54,6 +56,7 @@ pub(crate) struct ProposalCache { pub(crate) proposals: HashMap, #[cfg(not(feature = "std"))] pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>, + pub(crate) own_proposals: Vec, } #[cfg(feature = "by_ref_proposal")] @@ -77,6 +80,7 @@ impl ProposalCache { protocol_version, group_id, proposals: Default::default(), + own_proposals: Default::default(), } } @@ -85,11 +89,13 @@ impl ProposalCache { group_id: Vec, #[cfg(feature = "std")] proposals: HashMap, #[cfg(not(feature = "std"))] proposals: Vec<(ProposalRef, CachedProposal)>, + own_proposals: Vec, ) -> Self { Self { protocol_version, group_id, proposals, + own_proposals, } } @@ -115,6 +121,22 @@ impl ProposalCache { self.proposals.push((proposal_ref, cached_proposal)); } + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn insert_own( + &mut self, + proposal_ref: ProposalRef, + proposal: Proposal, + sender: Sender, + message: &MlsMessage, + cs: &CS, + ) -> Result<(), MlsError> { + self.insert(proposal_ref, proposal, sender); + let message_hash = MessageHash::compute(cs, message).await?; + self.own_proposals.push(message_hash); + + Ok(()) + } + pub fn prepare_commit( &self, sender: Sender, @@ -169,6 +191,17 @@ impl ProposalCache { Ok(proposals) } + + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn contains_own( + &self, + cs: &CS, + message: &MlsMessage, + ) -> Result { + let message_hash = MessageHash::compute(cs, message).await?; + + Ok(self.own_proposals.iter().any(|op| op == &message_hash)) + } } #[cfg(not(feature = "by_ref_proposal"))] diff --git a/mls-rs/src/group/snapshot.rs b/mls-rs/src/group/snapshot.rs index dca64f8f..f5a6e0d1 100644 --- a/mls-rs/src/group/snapshot.rs +++ b/mls-rs/src/group/snapshot.rs @@ -30,10 +30,12 @@ use mls_rs_core::identity::IdentityProvider; #[cfg(all(feature = "std", feature = "by_ref_proposal"))] use std::collections::HashMap; -#[cfg(all(feature = "by_ref_proposal", not(feature = "std")))] use alloc::vec::Vec; -use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository}; +use super::{ + cipher_suite_provider, epoch::EpochSecrets, message_hash::MessageHash, + state_repo::GroupStateRepository, +}; #[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -59,6 +61,7 @@ pub(crate) struct RawGroupState { pub(crate) proposals: HashMap, #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))] pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>, + pub(crate) own_proposals: Vec, pub(crate) public_tree: TreeKemPublic, pub(crate) interim_transcript_hash: InterimTranscriptHash, pub(crate) pending_reinit: Option, @@ -81,6 +84,7 @@ impl RawGroupState { context: state.context.clone(), #[cfg(feature = "by_ref_proposal")] proposals: state.proposals.proposals.clone(), + own_proposals: state.proposals.own_proposals.clone(), public_tree, interim_transcript_hash: state.interim_transcript_hash.clone(), pending_reinit: state.pending_reinit.clone(), @@ -101,6 +105,7 @@ impl RawGroupState { context.protocol_version, context.group_id.clone(), self.proposals, + self.own_proposals.clone(), ); let mut public_tree = self.public_tree; @@ -238,6 +243,7 @@ pub(crate) mod test_utils { context: get_test_group_context(epoch_id, cipher_suite).await, #[cfg(feature = "by_ref_proposal")] proposals: Default::default(), + own_proposals: Default::default(), public_tree: Default::default(), interim_transcript_hash: InterimTranscriptHash::from(vec![]), pending_reinit: None, diff --git a/mls-rs/test_harness_integration/src/main.rs b/mls-rs/test_harness_integration/src/main.rs index 5b1a472f..61d242e6 100644 --- a/mls-rs/test_harness_integration/src/main.rs +++ b/mls-rs/test_harness_integration/src/main.rs @@ -17,7 +17,6 @@ use mls_rs::{ BaseInMemoryConfig, ClientBuilder, WithCryptoProvider, WithIdentityProvider, WithMlsRules, }, crypto::SignatureSecretKey, - error::MlsError, external_client::ExternalClient, group::{ExportedTree, Member, ReceivedMessage, Roster, StateUpdate}, identity::{ @@ -692,11 +691,7 @@ impl MlsClientImpl { for proposal_bytes in &request.by_reference { let proposal = MlsMessage::from_bytes(proposal_bytes).map_err(abort)?; - - match group.process_incoming_message(proposal) { - Ok(_) | Err(MlsError::CantProcessMessageFromSelf) => Ok(()), - Err(e) => Err(abort(e)), - }?; + group.process_incoming_message(proposal).map_err(abort)?; } { @@ -788,11 +783,7 @@ impl MlsClientImpl { for proposal in &request.proposal { let proposal = MlsMessage::from_bytes(proposal).map_err(abort)?; - - match group.process_incoming_message(proposal) { - Ok(_) | Err(MlsError::CantProcessMessageFromSelf) => Ok(()), - Err(e) => Err(abort(e)), - }?; + group.process_incoming_message(proposal).map_err(abort)?; } let commit = MlsMessage::from_bytes(&request.commit).map_err(abort)?;