Skip to content

Commit

Permalink
Modify state storage to handle multiple members in the same group
Browse files Browse the repository at this point in the history
  • Loading branch information
Marta Mularczyk committed Mar 19, 2024
1 parent 7a58743 commit 37940ad
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 67 deletions.
4 changes: 2 additions & 2 deletions mls-rs-uniffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,9 @@ async fn signing_identity_to_identifier(
impl Group {
/// Write the current state of the group to storage defined by
/// [`ClientConfig::group_state_storage`]
pub async fn write_to_storage(&self) -> Result<(), Error> {
pub async fn write_to_storage(&self) -> Result<Vec<u8>, Error> {
let mut group = self.inner().await;
group.write_to_storage().await.map_err(Into::into)
Ok(group.write_to_storage().await?.to_vec())
}

/// Perform a commit of received proposals (or an empty commit).
Expand Down
4 changes: 2 additions & 2 deletions mls-rs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,11 +606,11 @@ where
/// this client was configured to use.
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[inline(never)]
pub async fn load_group(&self, group_id: &[u8]) -> Result<Group<C>, MlsError> {
pub async fn load_group(&self, group_state_id: &[u8]) -> Result<Group<C>, MlsError> {
let snapshot = self
.config
.group_state_storage()
.state(group_id)
.state(group_state_id)
.await
.map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
.ok_or(MlsError::GroupNotFound)?;
Expand Down
7 changes: 7 additions & 0 deletions mls-rs/src/group/external_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ impl<C: ClientConfig> ExternalCommitBuilder<C> {
)
.await?;

#[cfg(feature = "prior_epoch")]
{
let repo = &mut group.state_repo;
let id = &group.state.context.group_id;
repo.set_group_state_id(id, group.private_tree.self_index)?;
}

group.apply_pending_commit().await?;

Ok((group, commit_output.commit_message))
Expand Down
60 changes: 50 additions & 10 deletions mls-rs/src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,13 @@ where

let state_repo = GroupStateRepository::new(
#[cfg(feature = "prior_epoch")]
context.group_id.clone(),
&context.group_id,
#[cfg(feature = "prior_epoch")]
private_tree.self_index,
config.group_state_storage(),
config.key_package_repo(),
None,
)
.await?;
)?;

let key_schedule_result = KeySchedule::from_random_epoch_secret(
&cipher_suite_provider,
Expand Down Expand Up @@ -604,12 +605,13 @@ where

let state_repo = GroupStateRepository::new(
#[cfg(feature = "prior_epoch")]
group_info.group_context.group_id.clone(),
&group_info.group_context.group_id,
#[cfg(feature = "prior_epoch")]
private_tree.self_index,
config.group_state_storage(),
config.key_package_repo(),
used_key_package_ref,
)
.await?;
)?;

let group = Group {
config,
Expand Down Expand Up @@ -1838,8 +1840,8 @@ pub(crate) mod test_utils;
mod tests {
use crate::{
client::test_utils::{
test_client_with_key_pkg, TestClientBuilder, TEST_CIPHER_SUITE,
TEST_CUSTOM_PROPOSAL_TYPE, TEST_PROTOCOL_VERSION,
test_client_with_key_pkg, test_client_with_key_pkg_custom, TestClientBuilder,
TEST_CIPHER_SUITE, TEST_CUSTOM_PROPOSAL_TYPE, TEST_PROTOCOL_VERSION,
},
client_builder::{test_utils::TestClientConfig, ClientBuilder, MlsConfig},
crypto::test_utils::TestCryptoProvider,
Expand Down Expand Up @@ -3910,7 +3912,7 @@ mod tests {
.unwrap()
.0;

bob.write_to_storage().await.unwrap();
let state_id = bob.write_to_storage().await.unwrap().to_vec();

// Bob reloads his group data, but with parameters that will cause his generated leaves to
// not support the mandatory extension.
Expand All @@ -3919,7 +3921,7 @@ mod tests {
.key_package_repo(bob.config.key_package_repo())
.group_state_storage(bob.config.group_state_storage())
.build()
.load_group(alice.group_id())
.load_group(&state_id)
.await
.unwrap();

Expand Down Expand Up @@ -4225,4 +4227,42 @@ mod tests {

assert_eq!(update.committer, *group.private_tree.self_index);
}

#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn can_have_two_members_in_one_group() {
let mut group1 = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
.await
.group;

let (client2, _) = test_client_with_key_pkg_custom(
TEST_PROTOCOL_VERSION,
TEST_CIPHER_SUITE,
"arnold",
|c| {
c.0.group_state_storage = group1.config.group_state_storage();
c.0.key_package_repo = group1.config.key_package_repo();
},
)
.await;

let commit = group1
.commit_builder()
.add_member(client2.generate_key_package_message().await.unwrap())
.unwrap()
.build()
.await
.unwrap();

let (mut group2, _) = client2
.join_group(None, &commit.welcome_messages[0])
.await
.unwrap();

group1.apply_pending_commit().await.unwrap();

group1
.process_incoming_message(group2.commit(vec![]).await.unwrap().commit_message)
.await
.unwrap();
}
}
33 changes: 23 additions & 10 deletions mls-rs/src/group/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use alloc::vec::Vec;

use crate::{
client::MlsError,
client_config::ClientConfig,
Expand Down Expand Up @@ -30,17 +32,14 @@ use mls_rs_core::identity::IdentityProvider;
#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
use std::collections::HashMap;

#[cfg(all(feature = "by_ref_proposal", not(feature = "std")))]
use alloc::vec::Vec;

use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository};

#[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct Snapshot {
version: u16,
pub(crate) state: RawGroupState,
private_tree: TreeKemPrivate,
pub(crate) private_tree: TreeKemPrivate,
epoch_secrets: EpochSecrets,
key_schedule: KeySchedule,
#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
Expand All @@ -51,6 +50,12 @@ pub(crate) struct Snapshot {
signer: SignatureSecretKey,
}

impl Snapshot {
pub(crate) fn group_state_id(&self) -> Result<Vec<u8>, mls_rs_codec::Error> {
(&self.state.context.group_id, self.private_tree.self_index).mls_encode_to_vec()
}
}

#[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct RawGroupState {
Expand Down Expand Up @@ -150,10 +155,17 @@ where
{
/// Write the current state of the group to the
/// [`GroupStorageProvider`](crate::GroupStateStorage)
/// that is currently in use by the group.
/// that is currently in use by the group. Returns an identifier
/// of the stored state that can be later used to load the group
/// using [`load_group`](crate::Client::load_group).
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
pub async fn write_to_storage(&mut self) -> Result<(), MlsError> {
self.state_repo.write_to_storage(self.snapshot()).await
pub async fn write_to_storage(&mut self) -> Result<Vec<u8>, MlsError> {
let snapshot = self.snapshot();
let state_id = snapshot.group_state_id()?;

self.state_repo.write_to_storage(snapshot).await?;

Ok(state_id)
}

pub(crate) fn snapshot(&self) -> Snapshot {
Expand Down Expand Up @@ -182,12 +194,13 @@ where

let state_repo = GroupStateRepository::new(
#[cfg(feature = "prior_epoch")]
snapshot.state.context.group_id.clone(),
&snapshot.state.context.group_id,
#[cfg(feature = "prior_epoch")]
snapshot.private_tree.self_index,
config.group_state_storage(),
config.key_package_repo(),
None,
)
.await?;
)?;

Ok(Group {
config,
Expand Down
Loading

0 comments on commit 37940ad

Please sign in to comment.