diff --git a/mls-rs-uniffi/src/lib.rs b/mls-rs-uniffi/src/lib.rs index 9babc107..e24fba10 100644 --- a/mls-rs-uniffi/src/lib.rs +++ b/mls-rs-uniffi/src/lib.rs @@ -586,9 +586,9 @@ async fn signing_identity_to_identifier( impl Group { /// Write the current state of the group to storage defined by /// [`ClientConfig::group_state_storage`] - pub async fn write_to_storage(&self) -> Result<(), Error> { + pub async fn write_to_storage(&self) -> Result, Error> { let mut group = self.inner().await; - group.write_to_storage().await.map_err(Into::into) + Ok(group.write_to_storage().await?.to_vec()) } /// Perform a commit of received proposals (or an empty commit). diff --git a/mls-rs/src/client.rs b/mls-rs/src/client.rs index 5cd2c6c5..73ebf74e 100644 --- a/mls-rs/src/client.rs +++ b/mls-rs/src/client.rs @@ -606,11 +606,11 @@ where /// this client was configured to use. #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] #[inline(never)] - pub async fn load_group(&self, group_id: &[u8]) -> Result, MlsError> { + pub async fn load_group(&self, group_state_id: &[u8]) -> Result, MlsError> { let snapshot = self .config .group_state_storage() - .state(group_id) + .state(group_state_id) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? .ok_or(MlsError::GroupNotFound)?; diff --git a/mls-rs/src/group/external_commit.rs b/mls-rs/src/group/external_commit.rs index 34b10427..7f545a4a 100644 --- a/mls-rs/src/group/external_commit.rs +++ b/mls-rs/src/group/external_commit.rs @@ -259,6 +259,13 @@ impl ExternalCommitBuilder { ) .await?; + #[cfg(feature = "prior_epoch")] + { + let repo = &mut group.state_repo; + let id = &group.state.context.group_id; + repo.set_group_state_id(id, group.private_tree.self_index)?; + } + group.apply_pending_commit().await?; Ok((group, commit_output.commit_message)) diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 97c124ee..a1e95ff3 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -341,12 +341,13 @@ where let state_repo = GroupStateRepository::new( #[cfg(feature = "prior_epoch")] - context.group_id.clone(), + &context.group_id, + #[cfg(feature = "prior_epoch")] + private_tree.self_index, config.group_state_storage(), config.key_package_repo(), None, - ) - .await?; + )?; let key_schedule_result = KeySchedule::from_random_epoch_secret( &cipher_suite_provider, @@ -604,12 +605,13 @@ where let state_repo = GroupStateRepository::new( #[cfg(feature = "prior_epoch")] - group_info.group_context.group_id.clone(), + &group_info.group_context.group_id, + #[cfg(feature = "prior_epoch")] + private_tree.self_index, config.group_state_storage(), config.key_package_repo(), used_key_package_ref, - ) - .await?; + )?; let group = Group { config, @@ -1838,8 +1840,8 @@ pub(crate) mod test_utils; mod tests { use crate::{ client::test_utils::{ - test_client_with_key_pkg, TestClientBuilder, TEST_CIPHER_SUITE, - TEST_CUSTOM_PROPOSAL_TYPE, TEST_PROTOCOL_VERSION, + test_client_with_key_pkg, test_client_with_key_pkg_custom, TestClientBuilder, + TEST_CIPHER_SUITE, TEST_CUSTOM_PROPOSAL_TYPE, TEST_PROTOCOL_VERSION, }, client_builder::{test_utils::TestClientConfig, ClientBuilder, MlsConfig}, crypto::test_utils::TestCryptoProvider, @@ -3910,7 +3912,7 @@ mod tests { .unwrap() .0; - bob.write_to_storage().await.unwrap(); + let state_id = bob.write_to_storage().await.unwrap().to_vec(); // Bob reloads his group data, but with parameters that will cause his generated leaves to // not support the mandatory extension. @@ -3919,7 +3921,7 @@ mod tests { .key_package_repo(bob.config.key_package_repo()) .group_state_storage(bob.config.group_state_storage()) .build() - .load_group(alice.group_id()) + .load_group(&state_id) .await .unwrap(); @@ -4225,4 +4227,42 @@ mod tests { assert_eq!(update.committer, *group.private_tree.self_index); } + + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn can_have_two_members_in_one_group() { + let mut group1 = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE) + .await + .group; + + let (client2, _) = test_client_with_key_pkg_custom( + TEST_PROTOCOL_VERSION, + TEST_CIPHER_SUITE, + "arnold", + |c| { + c.0.group_state_storage = group1.config.group_state_storage(); + c.0.key_package_repo = group1.config.key_package_repo(); + }, + ) + .await; + + let commit = group1 + .commit_builder() + .add_member(client2.generate_key_package_message().await.unwrap()) + .unwrap() + .build() + .await + .unwrap(); + + let (mut group2, _) = client2 + .join_group(None, &commit.welcome_messages[0]) + .await + .unwrap(); + + group1.apply_pending_commit().await.unwrap(); + + group1 + .process_incoming_message(group2.commit(vec![]).await.unwrap().commit_message) + .await + .unwrap(); + } } diff --git a/mls-rs/src/group/snapshot.rs b/mls-rs/src/group/snapshot.rs index 5d56c36b..645da978 100644 --- a/mls-rs/src/group/snapshot.rs +++ b/mls-rs/src/group/snapshot.rs @@ -40,7 +40,7 @@ use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRe pub(crate) struct Snapshot { version: u16, pub(crate) state: RawGroupState, - private_tree: TreeKemPrivate, + pub(crate) private_tree: TreeKemPrivate, epoch_secrets: EpochSecrets, key_schedule: KeySchedule, #[cfg(all(feature = "std", feature = "by_ref_proposal"))] @@ -51,6 +51,12 @@ pub(crate) struct Snapshot { signer: SignatureSecretKey, } +impl Snapshot { + pub(crate) fn group_state_id(&self) -> Result, mls_rs_codec::Error> { + (&self.state.context.group_id, self.private_tree.self_index).mls_encode_to_vec() + } +} + #[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct RawGroupState { @@ -150,10 +156,17 @@ where { /// Write the current state of the group to the /// [`GroupStorageProvider`](crate::GroupStateStorage) - /// that is currently in use by the group. + /// that is currently in use by the group. Returns an identifier + /// of the stored state that can be later used to load the group + /// using [`load_group`](crate::Client::load_group). #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub async fn write_to_storage(&mut self) -> Result<(), MlsError> { - self.state_repo.write_to_storage(self.snapshot()).await + pub async fn write_to_storage(&mut self) -> Result, MlsError> { + let snapshot = self.snapshot(); + let state_id = snapshot.group_state_id()?; + + self.state_repo.write_to_storage(snapshot).await?; + + Ok(state_id) } pub(crate) fn snapshot(&self) -> Snapshot { @@ -182,12 +195,13 @@ where let state_repo = GroupStateRepository::new( #[cfg(feature = "prior_epoch")] - snapshot.state.context.group_id.clone(), + &snapshot.state.context.group_id, + #[cfg(feature = "prior_epoch")] + snapshot.private_tree.self_index, config.group_state_storage(), config.key_package_repo(), None, - ) - .await?; + )?; Ok(Group { config, diff --git a/mls-rs/src/group/state_repo.rs b/mls-rs/src/group/state_repo.rs index f6c0c18d..c79c46ac 100644 --- a/mls-rs/src/group/state_repo.rs +++ b/mls-rs/src/group/state_repo.rs @@ -3,6 +3,7 @@ // SPDX-License-Identifier: (Apache-2.0 OR MIT) use crate::client::MlsError; +use crate::tree_kem::node::LeafIndex; use crate::{group::PriorEpoch, key_package::KeyPackageRef}; use alloc::collections::VecDeque; @@ -36,7 +37,7 @@ where { pending_commit: EpochStorageCommit, pending_key_package_removal: Option, - group_id: Vec, + group_state_id: Vec, storage: S, key_package_repo: K, } @@ -55,7 +56,7 @@ where ) .field( "group_id", - &mls_rs_core::debug::pretty_group_id(&self.group_id), + &mls_rs_core::debug::pretty_group_id(&self.group_state_id), ) .field("storage", &self.storage) .field("key_package_repo", &self.key_package_repo) @@ -68,21 +69,35 @@ where S: GroupStateStorage, K: KeyPackageStorage, { - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub async fn new( - group_id: Vec, + pub fn new( + group_id: &[u8], + leaf_index: LeafIndex, storage: S, key_package_repo: K, // Set to `None` if restoring from snapshot; set to `Some` when joining a group. key_package_to_remove: Option, ) -> Result, MlsError> { - Ok(GroupStateRepository { - group_id, + let mut repo = GroupStateRepository { + group_state_id: Default::default(), storage, pending_key_package_removal: key_package_to_remove, pending_commit: Default::default(), key_package_repo, - }) + }; + + repo.set_group_state_id(group_id, leaf_index)?; + + Ok(repo) + } + + pub(crate) fn set_group_state_id( + &mut self, + group_id: &[u8], + leaf_index: LeafIndex, + ) -> Result<(), MlsError> { + self.group_state_id = (group_id, leaf_index).mls_encode_to_vec()?; + + Ok(()) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] @@ -91,7 +106,7 @@ where Ok(Some(max)) } else { self.storage - .max_epoch_id(&self.group_id) + .max_epoch_id(&self.group_state_id) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error())) } @@ -157,7 +172,7 @@ where Some(i) => self.pending_commit.updates.get_mut(i).map(Ok), None => self .storage - .epoch(&self.group_id, epoch_id) + .epoch(&self.group_state_id, epoch_id) .await .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))? .and_then(|epoch| { @@ -175,7 +190,7 @@ where #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> { - if epoch.group_id() != self.group_id { + if (epoch.group_id(), epoch.self_index).mls_encode_to_vec()? != self.group_state_id { return Err(MlsError::GroupIdMismatch); } @@ -210,7 +225,7 @@ where let group_state = GroupState { data: group_snapshot.mls_encode_to_vec()?, - id: group_snapshot.state.context.group_id, + id: group_snapshot.group_state_id()?, }; self.storage @@ -257,19 +272,18 @@ mod tests { use super::*; - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - async fn test_group_state_repo( + fn test_group_state_repo( retention_limit: usize, ) -> GroupStateRepository { GroupStateRepository::new( - TEST_GROUP.to_vec(), + TEST_GROUP, + LeafIndex(0), InMemoryGroupStateStorage::new() .with_max_epoch_retention(retention_limit) .unwrap(), InMemoryKeyPackageStorage::default(), None, ) - .await .unwrap() } @@ -284,7 +298,7 @@ mod tests { #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_epoch_inserts() { - let mut test_repo = test_group_state_repo(1).await; + let mut test_repo = test_group_state_repo(1); let test_epoch = test_epoch(0); test_repo.insert(test_epoch.clone()).await.unwrap(); @@ -304,7 +318,7 @@ mod tests { let psk_id = ResumptionPsk { psk_epoch: 0, - psk_group_id: PskGroupId(test_repo.group_id.clone()), + psk_group_id: PskGroupId(test_repo.group_state_id.clone()), usage: ResumptionPSKUsage::Application, }; @@ -335,7 +349,7 @@ mod tests { assert_eq!(storage.len(), 1); - let stored = storage.get(TEST_GROUP).unwrap(); + let stored = storage.get(&test_repo.group_state_id).unwrap(); assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap()); @@ -352,7 +366,7 @@ mod tests { #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_updates() { - let mut test_repo = test_group_state_repo(2).await; + let mut test_repo = test_group_state_repo(2); let test_epoch_0 = test_epoch(0); test_repo.insert(test_epoch_0.clone()).await.unwrap(); @@ -381,7 +395,7 @@ mod tests { // Make sure you can access an epoch pending update let psk_id = ResumptionPsk { psk_epoch: 0, - psk_group_id: PskGroupId(test_repo.group_id.clone()), + psk_group_id: PskGroupId(test_repo.group_state_id.clone()), usage: ResumptionPSKUsage::Application, }; @@ -403,7 +417,7 @@ mod tests { assert_eq!(storage.len(), 1); - let stored = storage.get(TEST_GROUP).unwrap(); + let stored = storage.get(&test_repo.group_state_id).unwrap(); assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap()); @@ -417,7 +431,7 @@ mod tests { #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_insert_and_update() { - let mut test_repo = test_group_state_repo(2).await; + let mut test_repo = test_group_state_repo(2); let test_epoch_0 = test_epoch(0); test_repo.insert(test_epoch_0).await.unwrap(); @@ -453,7 +467,7 @@ mod tests { assert_eq!(storage.len(), 1); - let stored = storage.get(TEST_GROUP).unwrap(); + let stored = storage.get(&test_repo.group_state_id).unwrap(); assert_eq!(stored.epoch_data.len(), 2); @@ -475,7 +489,7 @@ mod tests { async fn test_many_epochs_in_storage() { let epochs = (0..10).map(test_epoch).collect::>(); - let mut test_repo = test_group_state_repo(10).await; + let mut test_repo = test_group_state_repo(10); for epoch in epochs.iter().cloned() { test_repo.insert(epoch).await.unwrap() @@ -495,7 +509,7 @@ mod tests { #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_stored_groups_list() { - let mut test_repo = test_group_state_repo(2).await; + let mut test_repo = test_group_state_repo(2); let test_epoch_0 = test_epoch(0); test_repo.insert(test_epoch_0.clone()).await.unwrap(); @@ -507,13 +521,13 @@ mod tests { assert_eq!( test_repo.storage.stored_groups(), - vec![test_epoch_0.context.group_id] + vec![test_repo.group_state_id] ) } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn reducing_retention_limit_takes_effect_on_epoch_access() { - let mut repo = test_group_state_repo(1).await; + let mut repo = test_group_state_repo(1); repo.insert(test_epoch(0)).await.unwrap(); repo.insert(test_epoch(1)).await.unwrap(); @@ -522,7 +536,7 @@ mod tests { let mut repo = GroupStateRepository { storage: repo.storage, - ..test_group_state_repo(1).await + ..test_group_state_repo(1) }; let res = repo.get_epoch_mut(0).await.unwrap(); @@ -532,7 +546,7 @@ mod tests { #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn in_memory_storage_obeys_retention_limit_after_saving() { - let mut repo = test_group_state_repo(1).await; + let mut repo = test_group_state_repo(1); repo.insert(test_epoch(0)).await.unwrap(); repo.write_to_storage(test_snapshot(0).await).await.unwrap(); @@ -544,7 +558,7 @@ mod tests { #[cfg(not(feature = "std"))] let lock = repo.storage.inner.lock(); - assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1); + assert_eq!(lock.get(&repo.group_state_id).unwrap().epoch_data.len(), 1); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] @@ -560,12 +574,12 @@ mod tests { key_package_repo.insert(id, data); let mut repo = GroupStateRepository::new( - TEST_GROUP.to_vec(), + TEST_GROUP, + LeafIndex(0), InMemoryGroupStateStorage::new(), key_package_repo, Some(key_package.reference.clone()), ) - .await .unwrap(); repo.key_package_repo.get(&key_package.reference).unwrap(); @@ -574,4 +588,31 @@ mod tests { assert!(repo.key_package_repo.get(&key_package.reference).is_none()); } + + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn can_have_two_members_in_one_group() { + let mut repo1 = test_group_state_repo(1); + let state1 = test_snapshot(1).await; + + let mut repo2 = repo1.clone(); + repo2.set_group_state_id(TEST_GROUP, LeafIndex(15)).unwrap(); + + let mut state2 = state1.clone(); + state2.state.interim_transcript_hash = b"different transcript hash".to_vec().into(); + state2.private_tree.self_index = LeafIndex(15); + + for (repo, state) in [&mut repo1, &mut repo2].into_iter().zip([&state1, &state2]) { + repo.write_to_storage(state.clone()).await.unwrap(); + } + + for (repo, state) in [&repo1, &repo2].iter().zip([&state1, &state2]) { + #[cfg(feature = "std")] + let storage = repo.storage.inner.lock().unwrap(); + #[cfg(not(feature = "std"))] + let storage = repo.storage.inner.lock(); + + let stored = storage.get(&state.group_state_id().unwrap()).unwrap(); + assert_eq!(&stored.state_data, &state.mls_encode_to_vec().unwrap()); + } + } } diff --git a/mls-rs/src/group/state_repo_light.rs b/mls-rs/src/group/state_repo_light.rs index ef823738..846f4f02 100644 --- a/mls-rs/src/group/state_repo_light.rs +++ b/mls-rs/src/group/state_repo_light.rs @@ -31,8 +31,7 @@ where S: GroupStateStorage, K: KeyPackageStorage, { - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - pub async fn new( + pub fn new( storage: S, key_package_repo: K, // Set to `None` if restoring from snapshot; set to `Some` when joining a group. @@ -49,7 +48,7 @@ where pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> { let group_state = GroupState { data: group_snapshot.mls_encode_to_vec()?, - id: group_snapshot.state.context.group_id, + id: group_snapshot.group_state_id()?, }; self.storage @@ -95,15 +94,15 @@ mod tests { InMemoryKeyPackageStorage::default(), None, ) - .await .unwrap(); - test_repo - .write_to_storage(test_snapshot(0).await) - .await - .unwrap(); + let snapshot = test_snapshot(0).await; + test_repo.write_to_storage(snapshot.clone()).await.unwrap(); - assert_eq!(test_repo.storage.stored_groups(), vec![TEST_GROUP]) + assert_eq!( + test_repo.storage.stored_groups(), + vec![snapshot.group_state_id().unwrap()] + ) } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] @@ -123,7 +122,6 @@ mod tests { key_package_repo, Some(key_package.reference.clone()), ) - .await .unwrap(); repo.key_package_repo.get(&key_package.reference).unwrap();