diff --git a/mls-rs-uniffi/src/config.rs b/mls-rs-uniffi/src/config.rs index 72e9e6b7..bd9f6001 100644 --- a/mls-rs-uniffi/src/config.rs +++ b/mls-rs-uniffi/src/config.rs @@ -64,6 +64,20 @@ pub type UniFFIConfig = client_builder::WithIdentityProvider< #[derive(Debug, Clone, uniffi::Record)] pub struct ClientConfig { pub group_state_storage: Arc, + /// Use the ratchet tree extension. If this is false, then you + /// must supply `ratchet_tree` out of band to clients. + pub use_ratchet_tree_extension: bool, +} + +impl Default for ClientConfig { + fn default() -> Self { + Self { + group_state_storage: Arc::new(GroupStateStorageAdapter::new( + InMemoryGroupStateStorage::new(), + )), + use_ratchet_tree_extension: true, + } + } } // TODO(mgeisler): turn into an associated function when UniFFI @@ -71,9 +85,5 @@ pub struct ClientConfig { /// Create a client config with an in-memory group state storage. #[uniffi::export] pub fn client_config_default() -> ClientConfig { - ClientConfig { - group_state_storage: Arc::new(GroupStateStorageAdapter::new( - InMemoryGroupStateStorage::new(), - )), - } + ClientConfig::default() } diff --git a/mls-rs-uniffi/src/lib.rs b/mls-rs-uniffi/src/lib.rs index 21ce968a..77947d01 100644 --- a/mls-rs-uniffi/src/lib.rs +++ b/mls-rs-uniffi/src/lib.rs @@ -30,6 +30,7 @@ use tokio::sync::Mutex; use mls_rs::error::{IntoAnyError, MlsError}; use mls_rs::group; use mls_rs::identity::basic; +use mls_rs::mls_rules; use mls_rs::{CipherSuiteProvider, CryptoProvider}; use mls_rs_core::identity; use mls_rs_core::identity::{BasicCredential, IdentityProvider}; @@ -312,12 +313,16 @@ impl Client { let basic_credential = BasicCredential::new(id); let signing_identity = identity::SigningIdentity::new(basic_credential.into_credential(), public_key.into()); - + let mls_rules = mls_rules::DefaultMlsRules::new().with_commit_options( + mls_rules::CommitOptions::default() + .with_ratchet_tree_extension(client_config.use_ratchet_tree_extension), + ); let client = mls_rs::Client::builder() .crypto_provider(crypto_provider) .identity_provider(basic::BasicIdentityProvider::new()) .signing_identity(signing_identity, secret_key.into(), cipher_suite.into()) .group_state_storage(client_config.group_state_storage.into()) + .mls_rules(mls_rules) .build(); Client { inner: client } @@ -360,9 +365,20 @@ impl Client { /// Join an existing group. /// + /// You must supply `ratchet_tree` if the client that created + /// `welcome_message` did not set `use_ratchet_tree_extension`. + /// /// See [`mls_rs::Client::join_group`] for details. - pub async fn join_group(&self, welcome_message: &Message) -> Result { - let (group, new_member_info) = self.inner.join_group(None, &welcome_message.inner).await?; + pub async fn join_group( + &self, + ratchet_tree: Option, + welcome_message: &Message, + ) -> Result { + let ratchet_tree = ratchet_tree.map(TryInto::try_into).transpose()?; + let (group, new_member_info) = self + .inner + .join_group(ratchet_tree, &welcome_message.inner) + .await?; let group = Arc::new(Group { inner: Arc::new(Mutex::new(group)), @@ -388,20 +404,28 @@ impl Client { } } -#[derive(Clone, Debug, uniffi::Record)] +#[derive(Clone, Debug, PartialEq, uniffi::Record)] pub struct RatchetTree { pub bytes: Vec, } -impl TryFrom> for RatchetTree { +impl TryFrom> for RatchetTree { type Error = Error; - fn try_from(exported_tree: mls_rs::group::ExportedTree<'static>) -> Result { + fn try_from(exported_tree: mls_rs::group::ExportedTree<'_>) -> Result { let bytes = exported_tree.to_bytes()?; Ok(Self { bytes }) } } +impl TryFrom for group::ExportedTree<'static> { + type Error = Error; + + fn try_from(ratchet_tree: RatchetTree) -> Result { + group::ExportedTree::from_bytes(&ratchet_tree.bytes).map_err(Into::into) + } +} + #[derive(Clone, Debug, uniffi::Record)] pub struct CommitOutput { /// Commit message to send to other group members. @@ -516,6 +540,15 @@ impl Group { group.write_to_storage().await.map_err(Into::into) } + /// Export the current epoch's ratchet tree in serialized format. + /// + /// This function is used to provide the current group tree to new + /// members when `use_ratchet_tree_extension` is set to false in + /// `ClientConfig`. + pub fn export_tree(&self) -> Result { + self.inner().export_tree().try_into() + } + /// Perform a commit of received proposals (or an empty commit). /// /// TODO: ensure `path_required` is always set in @@ -756,12 +789,14 @@ mod tests { let alice_config = ClientConfig { group_state_storage: Arc::new(CustomGroupStateStorage::new()), + ..Default::default() }; let alice_keypair = generate_signature_keypair(CipherSuite::Curve25519Aes128)?; let alice = Client::new(b"alice".to_vec(), alice_keypair, alice_config); let bob_config = ClientConfig { group_state_storage: Arc::new(CustomGroupStateStorage::new()), + ..Default::default() }; let bob_keypair = generate_signature_keypair(CipherSuite::Curve25519Aes128)?; let bob = Client::new(b"bob".to_vec(), bob_keypair, bob_config); @@ -771,7 +806,7 @@ mod tests { let commit = alice_group.add_members(vec![Arc::new(bob_key_package)])?; alice_group.process_incoming_message(commit.commit_message)?; - let bob_group = bob.join_group(&commit.welcome_messages[0])?.group; + let bob_group = bob.join_group(None, &commit.welcome_messages[0])?.group; let message = alice_group.encrypt_application_message(b"hello, bob")?; let received_message = bob_group.process_incoming_message(Arc::new(message))?; @@ -784,4 +819,40 @@ mod tests { Ok(()) } + + #[test] + #[cfg(not(mls_build_async))] + fn test_ratchet_tree_not_included() -> Result<(), Error> { + let alice_config = ClientConfig { + use_ratchet_tree_extension: true, + ..ClientConfig::default() + }; + + let alice_keypair = generate_signature_keypair(CipherSuite::Curve25519Aes128)?; + let alice = Client::new(b"alice".to_vec(), alice_keypair, alice_config); + let group = alice.create_group(None)?; + + assert_eq!(group.commit()?.ratchet_tree, None); + Ok(()) + } + + #[test] + #[cfg(not(mls_build_async))] + fn test_ratchet_tree_included() -> Result<(), Error> { + let alice_config = ClientConfig { + use_ratchet_tree_extension: false, + ..ClientConfig::default() + }; + + let alice_keypair = generate_signature_keypair(CipherSuite::Curve25519Aes128)?; + let alice = Client::new(b"alice".to_vec(), alice_keypair, alice_config); + let group = alice.create_group(None)?; + + let ratchet_tree: group::ExportedTree = + group.commit()?.ratchet_tree.unwrap().try_into().unwrap(); + group.inner().apply_pending_commit()?; + + assert_eq!(ratchet_tree, group.inner().export_tree()); + Ok(()) + } } diff --git a/mls-rs-uniffi/tests/ratchet_tree_sync.py b/mls-rs-uniffi/tests/ratchet_tree_sync.py new file mode 100644 index 00000000..677225df --- /dev/null +++ b/mls-rs-uniffi/tests/ratchet_tree_sync.py @@ -0,0 +1,13 @@ +from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client, \ + client_config_default + +client_config = client_config_default() +client_config.use_ratchet_tree_extension = False + +key = generate_signature_keypair(CipherSuite.CURVE25519_AES128) +alice = Client(b'alice', key, client_config) + +group = alice.create_group(None) +commit = group.commit() + +assert commit.ratchet_tree is not None diff --git a/mls-rs-uniffi/tests/scenarios.rs b/mls-rs-uniffi/tests/scenarios.rs index c60ce8b3..b438e5ad 100644 --- a/mls-rs-uniffi/tests/scenarios.rs +++ b/mls-rs-uniffi/tests/scenarios.rs @@ -49,3 +49,4 @@ generate_python_tests!(client_config_default_sync, None); // supported in the next UniFFI release // TODO(mgeisler): add back simple_scenario_async generate_python_tests!(simple_scenario_sync, None); +generate_python_tests!(ratchet_tree_sync, None); diff --git a/mls-rs-uniffi/tests/simple_scenario_sync.py b/mls-rs-uniffi/tests/simple_scenario_sync.py index e19ecd51..1b6bb332 100644 --- a/mls-rs-uniffi/tests/simple_scenario_sync.py +++ b/mls-rs-uniffi/tests/simple_scenario_sync.py @@ -60,7 +60,7 @@ def max_epoch_id(self, group_id: bytes): return last.id group_state_storage = PythonGroupStateStorage() -client_config = ClientConfig(group_state_storage) +client_config = ClientConfig(group_state_storage, use_ratchet_tree_extension=True) key = generate_signature_keypair(CipherSuite.CURVE25519_AES128) alice = Client(b'alice', key, client_config) @@ -73,7 +73,7 @@ def max_epoch_id(self, group_id: bytes): commit = alice.add_members([message]) alice.process_incoming_message(commit.commit_message) -bob = bob.join_group(commit.welcome_messages[0]).group +bob = bob.join_group(None, commit.welcome_messages[0]).group msg = alice.encrypt_application_message(b'hello, bob') output = bob.process_incoming_message(msg)