diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index 8bf1a619..91d82887 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -20,7 +20,7 @@ jobs: run: cargo fmt --all -- --check - name: Clippy on fuzz targets working-directory: mls-rs/fuzz - run: cargo clippy --all-targets --all-features --workspace -- -D warnings + run: cargo clippy --all-targets -- -D warnings - name: Run Fuzz Targets working-directory: mls-rs run: | diff --git a/.github/workflows/native_build.yml b/.github/workflows/native_build.yml index fb49af77..33c24c73 100644 --- a/.github/workflows/native_build.yml +++ b/.github/workflows/native_build.yml @@ -31,21 +31,21 @@ jobs: - name: Rust Fmt run: cargo fmt --all -- --check - name: Clippy Full RFC Compliance - run: cargo clippy --all-targets --all-features --workspace -- -D warnings + run: cargo clippy --all-targets --features external_client,grease --workspace -- -D warnings - name: Clippy Bare Bones run: cargo clippy --all-targets --no-default-features --features std,test_util --workspace -- -D warnings - name: Test Full RFC Compliance - run: cargo test --all-features --verbose --workspace + run: cargo test --features external_client,grease --verbose --workspace - name: Test Bare Bones run: cargo test --no-default-features --features std,test_util --verbose --workspace - name: Test Async Full RFC run: cargo test --lib --test '*' --verbose --features test_util -p mls-rs env: - RUSTFLAGS: '--cfg mls_build_async' + RUSTFLAGS: "--cfg mls_build_async" - name: Test Async Bare Bones run: cargo test --no-default-features --lib --test '*' --features std,test_util --verbose -p mls-rs env: - RUSTFLAGS: '--cfg mls_build_async' + RUSTFLAGS: "--cfg mls_build_async" - name: Examples working-directory: mls-rs run: cargo run --example basic_usage diff --git a/mls-rs-core/Cargo.toml b/mls-rs-core/Cargo.toml index 9976ec9f..a28de3a4 100644 --- a/mls-rs-core/Cargo.toml +++ b/mls-rs-core/Cargo.toml @@ -9,9 +9,8 @@ keywords = ["mls", "mls-rs"] license = "Apache-2.0 OR MIT" exclude = ["test_data"] - [features] -default = ["std", "rfc_compliant", "fast_serialize"] +default = ["std", "rfc_compliant", "fast_serialize", "rayon"] arbitrary = ["std", "dep:arbitrary"] fast_serialize = ["mls-rs-codec/preallocate"] std = ["mls-rs-codec/std", "zeroize/std", "safer-ffi-gen?/std", "dep:thiserror"] @@ -19,6 +18,7 @@ rfc_compliant = ["x509"] ffi = ["dep:safer-ffi", "dep:safer-ffi-gen"] x509 = [] test_suite = ["dep:serde", "dep:serde_json", "dep:hex", "dep:itertools"] +rayon = ["std", "dep:rayon"] [dependencies] mls-rs-codec = { version = "0.5.0", path = "../mls-rs-codec", default-features = false} @@ -28,6 +28,8 @@ thiserror = { version = "1.0.40", optional = true } safer-ffi = { version = "0.1.3", default-features = false, optional = true } safer-ffi-gen = { version = "0.9.2", default-features = false, optional = true } maybe-async = "0.2.7" +# TODO: https://github.com/GoogleChromeLabs/wasm-bindgen-rayon +rayon = { version = "1", optional = true } serde = { version = "1.0", default-features = false, features = ["alloc", "derive"], optional = true } serde_json = { version = "^1.0", optional = true } diff --git a/mls-rs-core/src/crypto.rs b/mls-rs-core/src/crypto.rs index 4594eeb9..ad93aef5 100644 --- a/mls-rs-core/src/crypto.rs +++ b/mls-rs-core/src/crypto.rs @@ -10,11 +10,14 @@ use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use zeroize::{ZeroizeOnDrop, Zeroizing}; mod cipher_suite; -pub use self::cipher_suite::*; +mod mm_hpke; #[cfg(feature = "test_suite")] pub mod test_suite; +pub use self::cipher_suite::*; +use self::mm_hpke::{mm_hpke_open, mm_hpke_seal}; + #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] /// Ciphertext produced by [`CipherSuiteProvider::hpke_seal`] @@ -242,8 +245,8 @@ pub trait CryptoProvider: Send + Sync { all(not(target_arch = "wasm32"), mls_build_async), maybe_async::must_be_async )] -pub trait CipherSuiteProvider: Send + Sync { - type Error: IntoAnyError; +pub trait CipherSuiteProvider: Send + Sync + Sized { + type Error: IntoAnyError + Send + Sync; type HpkeContextS: HpkeContextS + Send + Sync; type HpkeContextR: HpkeContextR + Send + Sync; @@ -380,7 +383,6 @@ pub trait CipherSuiteProvider: Send + Sync { kem_output: &[u8], local_secret: &HpkeSecretKey, local_public: &HpkePublicKey, - info: &[u8], ) -> Result; @@ -436,4 +438,26 @@ pub trait CipherSuiteProvider: Send + Sync { signature: &[u8], data: &[u8], ) -> Result<(), Self::Error>; + + async fn mm_hpke_seal( + &self, + info: &[u8], + aad: Option<&[u8]>, + pt: &[&[u8]], + remote_keys: &[Vec<&HpkePublicKey>], + ) -> Result>, Self::Error> { + mm_hpke_seal(self, info, aad, pt, remote_keys).await + } + + async fn mm_hpke_open( + &self, + ct: &[&[HpkeCiphertext]], + self_index: (usize, usize), + local_secret: &HpkeSecretKey, + local_public: &HpkePublicKey, + info: &[u8], + aad: Option<&[u8]>, + ) -> Result>, Self::Error> { + mm_hpke_open(self, ct, self_index, local_secret, local_public, info, aad).await + } } diff --git a/mls-rs-core/src/crypto/mm_hpke.rs b/mls-rs-core/src/crypto/mm_hpke.rs new file mode 100644 index 00000000..8a3b6659 --- /dev/null +++ b/mls-rs-core/src/crypto/mm_hpke.rs @@ -0,0 +1,72 @@ +use alloc::vec::Vec; + +#[cfg(all(not(mls_build_async), feature = "rayon"))] +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +use super::{CipherSuiteProvider, HpkeCiphertext, HpkePublicKey, HpkeSecretKey}; + +#[cfg(all(not(mls_build_async), feature = "rayon"))] +pub(crate) fn mm_hpke_seal( + cs: &P, + info: &[u8], + aad: Option<&[u8]>, + pt: &[&[u8]], + remote_keys: &[Vec<&HpkePublicKey>], +) -> Result>, P::Error> { + use rayon::iter::IndexedParallelIterator; + + pt.par_iter() + .zip(remote_keys.par_iter()) + .map(|(pt, remote_keys)| { + remote_keys + .par_iter() + .map(|remote_pub| cs.hpke_seal(remote_pub, info, aad, pt)) + .collect::>() + }) + .collect() +} + +#[cfg(any(mls_build_async, not(feature = "rayon")))] +#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] +pub(crate) async fn mm_hpke_seal( + cs: &P, + info: &[u8], + aad: Option<&[u8]>, + pt: &[&[u8]], + remote_keys: &[Vec<&HpkePublicKey>], +) -> Result>, P::Error> { + let mut ct = Vec::new(); + + for (pt, remote_keys) in pt.iter().zip(remote_keys.iter()) { + ct.push(Vec::new()); + + for remote_pub in remote_keys { + if let Some(ct) = ct.last_mut() { + ct.push(cs.hpke_seal(remote_pub, info, aad, pt).await?); + } + } + } + + Ok(ct) +} + +#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] +pub(crate) async fn mm_hpke_open( + cs: &P, + ct: &[&[HpkeCiphertext]], + self_index: (usize, usize), + local_secret: &HpkeSecretKey, + local_public: &HpkePublicKey, + info: &[u8], + aad: Option<&[u8]>, +) -> Result>, P::Error> { + let (i, j) = self_index; + + match ct.get(i).and_then(|ct| ct.get(j)) { + Some(ct) => Ok(Some( + cs.hpke_open(ct, local_secret, local_public, info, aad) + .await?, + )), + None => Ok(None), + } +} diff --git a/mls-rs-crypto-awslc/Cargo.toml b/mls-rs-crypto-awslc/Cargo.toml index 9befaf3f..afc43de2 100644 --- a/mls-rs-crypto-awslc/Cargo.toml +++ b/mls-rs-crypto-awslc/Cargo.toml @@ -8,12 +8,17 @@ repository = "https://github.com/awslabs/mls-rs" keywords = ["mls", "mls-rs", "aws-lc"] license = "Apache-2.0 OR MIT" +[features] +default = ["rayon"] +rayon = ["mls-rs-crypto-hpke/rayon"] +mmpke = ["mls-rs-crypto-hpke/mmpke"] + [dependencies] aws-lc-rs = "1.5.1" aws-lc-sys = { version = "0.12.0" } -mls-rs-core = { path = "../mls-rs-core", version = "0.17.0" } -mls-rs-crypto-hpke = { path = "../mls-rs-crypto-hpke", version = "0.8.0" } -mls-rs-crypto-traits = { path = "../mls-rs-crypto-traits", version = "0.9.0" } +mls-rs-core = { path = "../mls-rs-core", version = "0.17.0", default-features = false } +mls-rs-crypto-hpke = { path = "../mls-rs-crypto-hpke", version = "0.8.0", default-features = false, features = ["std"] } +mls-rs-crypto-traits = { path = "../mls-rs-crypto-traits", version = "0.9.0", default-features = false } mls-rs-identity-x509 = { path = "../mls-rs-identity-x509", version = "0.10.0" } thiserror = "1.0.40" zeroize = { version = "1", features = ["zeroize_derive"] } diff --git a/mls-rs-crypto-awslc/src/lib.rs b/mls-rs-crypto-awslc/src/lib.rs index 0fbcae68..52219728 100644 --- a/mls-rs-crypto-awslc/src/lib.rs +++ b/mls-rs-crypto-awslc/src/lib.rs @@ -319,6 +319,33 @@ impl CipherSuiteProvider for AwsLcCipherSuite { ) -> Result<(), Self::Error> { self.signing.verify(public_key, signature, data) } + + #[cfg(feature = "mmpke")] + async fn mm_hpke_seal( + &self, + info: &[u8], + aad: Option<&[u8]>, + pt: &[&[u8]], + remote_keys: &[Vec<&HpkePublicKey>], + ) -> Result>, Self::Error> { + Ok(self.hpke.mm_hpke_seal(info, aad, pt, remote_keys).await?) + } + + #[cfg(feature = "mmpke")] + async fn mm_hpke_open( + &self, + ct: &[&[HpkeCiphertext]], + self_index: (usize, usize), + local_secret: &HpkeSecretKey, + local_public: &HpkePublicKey, + info: &[u8], + aad: Option<&[u8]>, + ) -> Result>, Self::Error> { + Ok(self + .hpke + .mm_hpke_open(ct, self_index, local_secret, local_public, info, aad) + .await?) + } } pub fn sha256(data: &[u8]) -> [u8; 32] { diff --git a/mls-rs-crypto-hpke/Cargo.toml b/mls-rs-crypto-hpke/Cargo.toml index 45fe069c..21080235 100644 --- a/mls-rs-crypto-hpke/Cargo.toml +++ b/mls-rs-crypto-hpke/Cargo.toml @@ -10,9 +10,11 @@ categories = ["no-std", "cryptography"] license = "Apache-2.0 OR MIT" [features] -default = ["std"] +default = ["std", "rayon"] std = ["mls-rs-core/std", "mls-rs-crypto-traits/std", "dep:thiserror", "zeroize/std"] test_utils = ["mls-rs-core/test_suite"] +rayon = ["std", "dep:rayon", "mls-rs-core/rayon"] +mmpke = [] [dependencies] mls-rs-core = { path = "../mls-rs-core", default-features = false, version = "0.17.0" } @@ -21,6 +23,8 @@ thiserror = { version = "1.0.40", optional = true } zeroize = { version = "1", default-features = false, features = ["alloc", "zeroize_derive"] } cfg-if = "^1" maybe-async = "0.2.7" +# TODO: https://github.com/GoogleChromeLabs/wasm-bindgen-rayon +rayon = { version = "1", optional = true } [dev-dependencies] serde = { version = "1.0", features = ["derive"] } diff --git a/mls-rs-crypto-hpke/src/dhkem.rs b/mls-rs-crypto-hpke/src/dhkem.rs index 9113aede..b16e086d 100644 --- a/mls-rs-crypto-hpke/src/dhkem.rs +++ b/mls-rs-crypto-hpke/src/dhkem.rs @@ -14,6 +14,12 @@ use crate::kdf::HpkeKdf; use alloc::vec::Vec; +#[cfg(all(feature = "mmpke", feature = "rayon"))] +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +#[cfg(feature = "mmpke")] +use mls_rs_crypto_traits::MmKemOutput; + #[derive(Debug)] #[cfg_attr(feature = "std", derive(thiserror::Error))] pub enum DhKemError { @@ -103,21 +109,7 @@ impl KemType for DhKem { async fn encap(&self, remote_pk: &HpkePublicKey) -> Result { let (ephemeral_sk, ephemeral_pk) = self.generate().await?; - - let ecdh_ss = self - .dh - .dh(&ephemeral_sk, remote_pk) - .await - .map(Zeroizing::new) - .map_err(|e| DhKemError::DhError(e.into_any_error()))?; - - let kem_context = [ephemeral_pk.as_ref(), remote_pk.as_ref()].concat(); - - let shared_secret = self - .kdf - .labeled_extract_then_expand(&ecdh_ss, &kem_context, self.n_secret) - .await - .map_err(|e| DhKemError::KdfError(e.into_any_error()))?; + let shared_secret = self.encap(remote_pk, &ephemeral_sk, &ephemeral_pk).await?; Ok(KemResult::new(shared_secret, ephemeral_pk.into())) } @@ -150,9 +142,97 @@ impl KemType for DhKem { .public_key_validate(key) .map_err(|e| DhKemError::DhError(e.into_any_error())) } + + #[cfg(all(not(mls_build_async), feature = "rayon", feature = "mmpke"))] + async fn mm_encap( + &self, + remote_keys: &[Vec<&HpkePublicKey>], + ) -> Result { + let (ephemeral_sk, ephemeral_pk) = self.generate().await?; + + let kem_results = remote_keys + .par_iter() + .map(|rk| { + rk.par_iter() + .map(|rk| { + Ok(KemResult::new( + self.encap(rk, &ephemeral_sk, &ephemeral_pk)?, + Vec::new(), + )) + }) + .collect::>() + }) + .collect::>()?; + + Ok(MmKemOutput { + header: ephemeral_pk.into(), + kem_results, + }) + } + + #[cfg(all(any(mls_build_async, not(feature = "rayon")), feature = "mmpke"))] + async fn mm_encap<'a>( + &self, + remote_keys: &'a [Vec<&'a HpkePublicKey>], + ) -> Result { + let (ephemeral_sk, ephemeral_pk) = self.generate().await?; + + let mut kem_results = Vec::new(); + + for rk in remote_keys { + kem_results.push(Vec::new()); + + for rk in rk { + if let Some(kem_results) = kem_results.last_mut() { + kem_results.push(KemResult::new( + self.encap(rk, &ephemeral_sk, &ephemeral_pk).await?, + Vec::new(), + )); + } + } + } + + Ok(MmKemOutput { + header: ephemeral_pk.into(), + kem_results, + }) + } + + #[cfg(feature = "mmpke")] + async fn mm_decap( + &self, + header: &[u8], + _enc: &[u8], + secret_key: &HpkeSecretKey, + local_public: &HpkePublicKey, + ) -> Result, Self::Error> { + self.decap(header, secret_key, local_public).await + } } impl DhKem { + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + async fn encap( + &self, + remote_pk: &HpkePublicKey, + ephemeral_sk: &HpkeSecretKey, + ephemeral_pk: &HpkePublicKey, + ) -> Result, DhKemError> { + let ecdh_ss = self + .dh + .dh(ephemeral_sk, remote_pk) + .await + .map(Zeroizing::new) + .map_err(|e| DhKemError::DhError(e.into_any_error()))?; + + let kem_context = [ephemeral_pk.as_ref(), remote_pk.as_ref()].concat(); + + self.kdf + .labeled_extract_then_expand(&ecdh_ss, &kem_context, self.n_secret) + .await + .map_err(|e| DhKemError::KdfError(e.into_any_error())) + } + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn derive_with_rejection_sampling( &self, diff --git a/mls-rs-crypto-hpke/src/hpke.rs b/mls-rs-crypto-hpke/src/hpke.rs index 7b585be8..a48b7349 100644 --- a/mls-rs-crypto-hpke/src/hpke.rs +++ b/mls-rs-crypto-hpke/src/hpke.rs @@ -11,9 +11,15 @@ use mls_rs_core::{ HpkeCiphertext, HpkeContextR, HpkeContextS, HpkeModeId, HpkePublicKey, HpkeSecretKey, }, error::{AnyError, IntoAnyError}, + mls_rs_codec::{MlsDecode, MlsEncode}, }; -use mls_rs_crypto_traits::{AeadType, KdfType, KemType, AEAD_ID_EXPORT_ONLY}; +use mls_rs_crypto_traits::{AeadType, KdfType, KemResult, KemType, AEAD_ID_EXPORT_ONLY}; + +#[cfg(all(not(mls_build_async), feature = "rayon"))] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, +}; use zeroize::Zeroizing; @@ -56,6 +62,8 @@ pub enum HpkeError { /// Max sequence number exceeded, currently allowed up to MAX u64 #[cfg_attr(feature = "std", error("Sequence number overflow"))] SequenceNumberOverflow, + #[cfg_attr(feature = "std", error(transparent))] + MlsCodecError(AnyError), } impl IntoAnyError for HpkeError { @@ -324,4 +332,141 @@ where HpkeModeId::Base } } + + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn mm_hpke_seal( + &self, + info: &[u8], + aad: Option<&[u8]>, + pt: &[&[u8]], + remote_keys: &[Vec<&HpkePublicKey>], + ) -> Result>, HpkeError> { + let kem_results = self + .kem + .mm_encap(remote_keys) + .await + .map_err(|e| HpkeError::KemError(e.into_any_error()))?; + + let mode = self.base_mode(&None); + + let mut out = self + .mm_seal(kem_results.kem_results, info, aad, pt, mode) + .await?; + + if let Some(out) = out.first_mut().and_then(|out| out.first_mut()) { + out.kem_output = (&out.kem_output, kem_results.header) + .mls_encode_to_vec() + .map_err(|e| HpkeError::MlsCodecError(e.into_any_error()))?; + } + + Ok(out) + } + + #[cfg(all(not(mls_build_async), feature = "rayon"))] + pub fn mm_seal( + &self, + kem_results: Vec>, + info: &[u8], + aad: Option<&[u8]>, + pt: &[&[u8]], + mode: HpkeModeId, + ) -> Result>, HpkeError> { + kem_results + .into_par_iter() + .zip(pt.par_iter()) + .map(|(kem_res, pt)| { + kem_res + .into_par_iter() + .map(|kem_res| { + let ct = self + .key_schedule(mode, kem_res.shared_secret(), info, None) + .map(ContextS)? + .seal(aad, pt)?; + + Ok(HpkeCiphertext { + kem_output: kem_res.enc, + ciphertext: ct, + }) + }) + .collect::, HpkeError>>() + }) + .collect() + } + + #[cfg(not(all(not(mls_build_async), feature = "rayon")))] + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn mm_seal( + &self, + kem_results: Vec>, + info: &[u8], + aad: Option<&[u8]>, + pt: &[&[u8]], + mode: HpkeModeId, + ) -> Result>, HpkeError> { + let mut out = Vec::new(); + + for (kem_res, pt) in kem_results.into_iter().zip(pt.iter()) { + out.push(Vec::new()); + + for kem_res in kem_res { + if let Some(out) = out.last_mut() { + let ct = self + .key_schedule(mode, kem_res.shared_secret(), info, None) + .await + .map(ContextS)? + .seal(aad, pt) + .await?; + + out.push(HpkeCiphertext { + kem_output: kem_res.enc, + ciphertext: ct, + }); + } + } + } + + Ok(out) + } + + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + pub async fn mm_hpke_open( + &self, + ct: &[&[HpkeCiphertext]], + self_index: (usize, usize), + local_secret: &HpkeSecretKey, + local_public: &HpkePublicKey, + info: &[u8], + aad: Option<&[u8]>, + ) -> Result>, HpkeError> { + let Some(mut first_ct) = ct.first().and_then(|ct| ct.first()).cloned() else { + return Ok(None); + }; + + let (first_enc, header) = <(Vec, Vec)>::mls_decode(&mut &*first_ct.kem_output) + .map_err(|e| HpkeError::MlsCodecError(e.into_any_error()))?; + + first_ct.kem_output = first_enc; + + let ct = match self_index { + (0, 0) => Some(&first_ct), + (i, j) => ct.get(i).and_then(|ct| ct.get(j)), + }; + + let Some(ct) = ct else { return Ok(None) }; + + let shared_secret = self + .kem + .mm_decap(&header, &ct.kem_output, local_secret, local_public) + .await + .map_err(|e| HpkeError::KemError(e.into_any_error()))?; + + let pt = self + .key_schedule(self.base_mode(&None), &shared_secret, info, None) + .await + .map(ContextR)? + .open(aad, &ct.ciphertext) + .await?; + + Ok(Some(pt)) + } } diff --git a/mls-rs-crypto-openssl/Cargo.toml b/mls-rs-crypto-openssl/Cargo.toml index 44aac1e8..ec7dfff7 100644 --- a/mls-rs-crypto-openssl/Cargo.toml +++ b/mls-rs-crypto-openssl/Cargo.toml @@ -9,21 +9,25 @@ keywords = ["mls", "mls-rs", "openssl"] license = "Apache-2.0 OR MIT" [features] +default = ["x509", "rayon"] x509 = ["mls-rs-identity-x509"] -default = ["x509"] +rayon = ["mls-rs-crypto-hpke/rayon"] + +# This feature is NOT compliant with RFC 9420. +mmpke = ["mls-rs-crypto-hpke/mmpke"] [dependencies] openssl = { version = "0.10.40" } -mls-rs-core = { path = "../mls-rs-core", version = "0.17.0" } +mls-rs-core = { path = "../mls-rs-core", version = "0.17.0", default-features = false } mls-rs-identity-x509 = { path = "../mls-rs-identity-x509", optional = true, version = "0.10.0" } -mls-rs-crypto-hpke = { path = "../mls-rs-crypto-hpke", version = "0.8.0" } +mls-rs-crypto-hpke = { path = "../mls-rs-crypto-hpke", default-features = false, features = ["std"], version = "0.8.0" } mls-rs-crypto-traits = { path = "../mls-rs-crypto-traits", version = "0.9.0" } thiserror = "1.0.40" zeroize = { version = "1", features = ["zeroize_derive"] } maybe-async = "0.2.7" +hex = { version = "^0.4.3", features = ["serde"] } [dev-dependencies] -hex = { version = "^0.4.3", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "^1.0" } assert_matches = "1.5.0" diff --git a/mls-rs-crypto-openssl/src/lib.rs b/mls-rs-crypto-openssl/src/lib.rs index 34f375de..0e6c74a1 100644 --- a/mls-rs-crypto-openssl/src/lib.rs +++ b/mls-rs-crypto-openssl/src/lib.rs @@ -358,21 +358,99 @@ where ) -> Result { Ok(self.ec_signer.signature_key_derive_public(secret_key)?) } + + #[cfg(feature = "mmpke")] + async fn mm_hpke_seal( + &self, + info: &[u8], + aad: Option<&[u8]>, + pt: &[&[u8]], + remote_keys: &[Vec<&HpkePublicKey>], + ) -> Result>, Self::Error> { + Ok(self.hpke.mm_hpke_seal(info, aad, pt, remote_keys).await?) + } + + #[cfg(feature = "mmpke")] + async fn mm_hpke_open( + &self, + ct: &[&[HpkeCiphertext]], + self_index: (usize, usize), + local_secret: &HpkeSecretKey, + local_public: &HpkePublicKey, + info: &[u8], + aad: Option<&[u8]>, + ) -> Result>, Self::Error> { + Ok(self + .hpke + .mm_hpke_open(ct, self_index, local_secret, local_public, info, aad) + .await?) + } } -#[cfg(not(mls_build_async))] -#[test] -fn mls_core_tests() { - // Uncomment this to generate the tests instead. - // mls_rs_core::crypto::test_suite::generate_tests(&OpensslCryptoProvider::new()); - let provider = OpensslCryptoProvider::new(); +#[cfg(all(not(mls_build_async), test))] +mod tests { + use mls_rs_core::crypto::{CipherSuiteProvider, CryptoProvider}; + + use crate::OpensslCryptoProvider; - mls_rs_core::crypto::test_suite::verify_tests(&provider, true); + #[test] + fn mls_core_tests() { + // Uncomment this to generate the tests instead. + // mls_rs_core::crypto::test_suite::generate_tests(&OpensslCryptoProvider::new()); + let provider = OpensslCryptoProvider::new(); - for cs in OpensslCryptoProvider::all_supported_cipher_suites() { - let mut hpke = provider.cipher_suite_provider(cs).unwrap().hpke; + mls_rs_core::crypto::test_suite::verify_tests(&provider, true); - mls_rs_core::crypto::test_suite::verify_hpke_context_tests(&hpke, cs); - mls_rs_core::crypto::test_suite::verify_hpke_encap_tests(&mut hpke, cs); + for cs in OpensslCryptoProvider::all_supported_cipher_suites() { + let mut hpke = provider.cipher_suite_provider(cs).unwrap().hpke; + + mls_rs_core::crypto::test_suite::verify_hpke_context_tests(&hpke, cs); + mls_rs_core::crypto::test_suite::verify_hpke_encap_tests(&mut hpke, cs); + } + } + + #[test] + fn mm_hpke() { + let provider = OpensslCryptoProvider::new(); + + for cs in OpensslCryptoProvider::all_supported_cipher_suites() { + let cs = provider.cipher_suite_provider(cs).unwrap(); + + let (sk1, pk1) = cs.kem_generate().unwrap(); + let (sk2, pk2) = cs.kem_generate().unwrap(); + let (sk3, pk3) = cs.kem_generate().unwrap(); + + let ct = cs + .mm_hpke_seal( + &[], + None, + &[b"pt1", b"pt2"], + &[vec![&pk1, &pk2], vec![&pk3]], + ) + .unwrap(); + + let ct = ct.iter().map(|ct| ct.as_slice()).collect::>(); + + let pt = cs + .mm_hpke_open(&ct, (0, 0), &sk1, &pk1, &[], None) + .unwrap() + .unwrap(); + + assert_eq!(pt, b"pt1".to_vec()); + + let pt = cs + .mm_hpke_open(&ct, (0, 1), &sk2, &pk2, &[], None) + .unwrap() + .unwrap(); + + assert_eq!(pt, b"pt1".to_vec()); + + let pt = cs + .mm_hpke_open(&ct, (1, 0), &sk3, &pk3, &[], None) + .unwrap() + .unwrap(); + + assert_eq!(pt, b"pt2".to_vec()); + } } } diff --git a/mls-rs-crypto-traits/src/kem.rs b/mls-rs-crypto-traits/src/kem.rs index 63ffacf4..850198bf 100644 --- a/mls-rs-crypto-traits/src/kem.rs +++ b/mls-rs-crypto-traits/src/kem.rs @@ -38,6 +38,38 @@ pub trait KemType: Send + Sync { secret_key: &HpkeSecretKey, local_public: &HpkePublicKey, ) -> Result, Self::Error>; + + async fn mm_encap<'a>( + &self, + remote_keys: &'a [Vec<&'a HpkePublicKey>], + ) -> Result { + let mut kem_results = Vec::new(); + + for rk in remote_keys { + kem_results.push(Vec::new()); + + for rk in rk { + if let Some(kem_results) = kem_results.last_mut() { + kem_results.push(self.encap(rk).await?); + } + } + } + + Ok(MmKemOutput { + header: Vec::new(), + kem_results, + }) + } + + async fn mm_decap( + &self, + _header: &[u8], + enc: &[u8], + secret_key: &HpkeSecretKey, + local_public: &HpkePublicKey, + ) -> Result, Self::Error> { + self.decap(enc, secret_key, local_public).await + } } /// Struct to represent the output of the kem [encap](KemType::encap) function @@ -61,6 +93,12 @@ impl KemResult { } } +pub struct MmKemOutput { + pub header: Vec, + // Kyber : use modified KemResult with option where none indicates we dont have to encrypt + pub kem_results: Vec>, +} + /// Kem identifiers for HPKE #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[repr(u16)] diff --git a/mls-rs-crypto-traits/src/lib.rs b/mls-rs-crypto-traits/src/lib.rs index 4e2ef455..1aa79443 100644 --- a/mls-rs-crypto-traits/src/lib.rs +++ b/mls-rs-crypto-traits/src/lib.rs @@ -15,7 +15,7 @@ pub use aead::{AeadId, AeadType, AEAD_ID_EXPORT_ONLY, AES_TAG_LEN}; pub use dh::DhType; pub use ec::Curve; pub use kdf::{KdfId, KdfType}; -pub use kem::{KemId, KemResult, KemType}; +pub use kem::{KemId, KemResult, KemType, MmKemOutput}; #[cfg(feature = "mock")] pub mod mock; diff --git a/mls-rs/Cargo.toml b/mls-rs/Cargo.toml index 633dc7c8..4795669d 100644 --- a/mls-rs/Cargo.toml +++ b/mls-rs/Cargo.toml @@ -18,7 +18,7 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = ["std", "rayon", "rfc_compliant", "tree_index", "fast_serialize"] arbitrary = ["std", "dep:arbitrary", "mls-rs-core/arbitrary"] -rayon = ["std", "dep:rayon"] +rayon = ["std", "dep:rayon", "mls-rs-crypto-openssl?/rayon", "mls-rs-core/rayon"] external_client = ["std"] grease = ["std"] fast_serialize = ["mls-rs-core/fast_serialize"] @@ -38,6 +38,8 @@ std = ["mls-rs-core/std", "mls-rs-codec/std", "mls-rs-identity-x509?/std", "hex/ ffi = ["dep:safer-ffi", "dep:safer-ffi-gen", "mls-rs-core/ffi"] +mmpke = ["mls-rs-crypto-openssl?/mmpke"] + # SQLite support sqlite = ["std", "mls-rs-provider-sqlite/sqlite"] sqlite-bundled = ["sqlite", "mls-rs-provider-sqlite/sqlite-bundled"] diff --git a/mls-rs/benches/group_receive_commit.rs b/mls-rs/benches/group_receive_commit.rs index 9a8a765f..e4ee8e40 100644 --- a/mls-rs/benches/group_receive_commit.rs +++ b/mls-rs/benches/group_receive_commit.rs @@ -6,7 +6,7 @@ use criterion::{BatchSize, BenchmarkId, Criterion}; use mls_rs::{test_utils::benchmarks::load_group_states, CipherSuite}; fn bench(c: &mut Criterion) { - let cipher_suite = CipherSuite::CURVE25519_AES128; + let cipher_suite = CipherSuite::P384_AES256; let group_states = load_group_states(cipher_suite); let mut bench_group = c.benchmark_group("group_receive_commit"); diff --git a/mls-rs/src/group/mod.rs b/mls-rs/src/group/mod.rs index 7fa0bd7a..f5fb1a9c 100644 --- a/mls-rs/src/group/mod.rs +++ b/mls-rs/src/group/mod.rs @@ -25,7 +25,7 @@ use crate::protocol_version::ProtocolVersion; use crate::psk::secret::PskSecret; use crate::psk::PreSharedKeyID; use crate::signer::Signable; -use crate::tree_kem::hpke_encryption::HpkeEncryptable; +use crate::tree_kem::hpke_encryption::HpkeInfo; use crate::tree_kem::kem::TreeKem; use crate::tree_kem::node::LeafIndex; use crate::tree_kem::path_secret::PathSecret; @@ -159,7 +159,7 @@ pub(crate) mod secret_tree; #[cfg(any(feature = "secret_tree_access", feature = "private_message"))] pub use secret_tree::MessageKeyData as MessageKey; -#[cfg(all(test, feature = "rfc_compliant"))] +#[cfg(all(test, feature = "rfc_compliant", not(feature = "mmpke")))] mod interop_test_vectors; mod exported_tree; @@ -173,16 +173,8 @@ struct GroupSecrets { psks: Vec, } -impl HpkeEncryptable for GroupSecrets { +impl HpkeInfo for GroupSecrets { const ENCRYPT_LABEL: &'static str = "Welcome"; - - fn from_bytes(bytes: Vec) -> Result { - Self::mls_decode(&mut bytes.as_slice()).map_err(Into::into) - } - - fn get_bytes(&self) -> Result, MlsError> { - self.mls_encode_to_vec().map_err(Into::into) - } } #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] @@ -432,14 +424,18 @@ where // cipher suite and the HPKE private key corresponding to the GroupSecrets. If a // PreSharedKeyID is part of the GroupSecrets and the client is not in possession of // the corresponding PSK, return an error - let group_secrets = GroupSecrets::decrypt( - &cipher_suite_provider, - &key_package_generation.init_secret_key, - &key_package_generation.key_package.hpke_init_key, - &welcome.encrypted_group_info, - &encrypted_group_secrets.encrypted_group_secrets, - ) - .await?; + let group_secrets = cipher_suite_provider + .hpke_open( + &encrypted_group_secrets.encrypted_group_secrets, + &key_package_generation.init_secret_key, + &key_package_generation.key_package.hpke_init_key, + &GroupSecrets::hpke_info(&welcome.encrypted_group_info)?, + None, + ) + .await + .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; + + let group_secrets = GroupSecrets::mls_decode(&mut &*group_secrets)?; #[cfg(feature = "psk")] let psk_secret = if let Some(psk) = additional_psk { @@ -798,13 +794,15 @@ where psks, }; - let encrypted_group_secrets = group_secrets - .encrypt( - &self.cipher_suite_provider, - &key_package.hpke_init_key, - encrypted_group_info, - ) - .await?; + let info = GroupSecrets::hpke_info(encrypted_group_info)?; + let init_key = &key_package.hpke_init_key; + let pt = group_secrets.mls_encode_to_vec()?; + + let encrypted_group_secrets = self + .cipher_suite_provider + .hpke_seal(init_key, &info, None, &pt) + .await + .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; Ok(EncryptedGroupSecrets { new_member: key_package @@ -1988,6 +1986,7 @@ mod tests { ); assert_eq!(update.leaf_node.ungreased_extensions(), extension_list); + assert_eq!( update.leaf_node.ungreased_capabilities(), Capabilities { @@ -4182,4 +4181,34 @@ mod tests { allowed.then_some(proposals).ok_or(MlsError::InvalidSender) } } + + #[cfg(feature = "mmpke")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn mmpke() { + let mut group = get_test_groups_with_features(16, Default::default(), Default::default()) + .await + .remove(0); + + let commit = group.commit(vec![]).unwrap().commit_message; + let commit = commit.into_plaintext().unwrap().content.content; + + let Content::Commit(commit) = commit else { + panic!("expected commit") + }; + + let path = commit.path.unwrap().nodes; + + let mut ciphertexts = path + .into_iter() + .map(|n| n.encrypted_path_secret.into_iter()) + .flatten(); + + let first_ct = ciphertexts.next().unwrap(); + assert!(!first_ct.ciphertext.is_empty()); + assert!(!first_ct.kem_output.is_empty()); + + // All remining ciphertexts don't have the kem_output component as they should be using + // be using the first kem_output + assert!(ciphertexts.all(|ct| !ct.ciphertext.is_empty() && ct.kem_output.is_empty())); + } } diff --git a/mls-rs/src/tree_kem/hpke_encryption.rs b/mls-rs/src/tree_kem/hpke_encryption.rs index a29fd8f8..943e18aa 100644 --- a/mls-rs/src/tree_kem/hpke_encryption.rs +++ b/mls-rs/src/tree_kem/hpke_encryption.rs @@ -4,11 +4,6 @@ use alloc::vec::Vec; use mls_rs_codec::{MlsEncode, MlsSize}; -use mls_rs_core::{ - crypto::{CipherSuiteProvider, HpkeCiphertext, HpkePublicKey, HpkeSecretKey}, - error::IntoAnyError, -}; -use zeroize::Zeroizing; use crate::client::MlsError; @@ -29,53 +24,12 @@ impl<'a> EncryptContext<'a> { } } -#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] -#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))] -#[cfg_attr( - all(not(target_arch = "wasm32"), mls_build_async), - maybe_async::must_be_async -)] - -pub(crate) trait HpkeEncryptable: Sized { +pub(crate) trait HpkeInfo { const ENCRYPT_LABEL: &'static str; - async fn encrypt( - &self, - cipher_suite_provider: &P, - public_key: &HpkePublicKey, - context: &[u8], - ) -> Result { - let context = EncryptContext::new(Self::ENCRYPT_LABEL, context) - .mls_encode_to_vec() - .map(Zeroizing::new)?; - - let content = self.get_bytes().map(Zeroizing::new)?; - - cipher_suite_provider - .hpke_seal(public_key, &context, None, &content) - .await - .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) - } - - async fn decrypt( - cipher_suite_provider: &P, - secret_key: &HpkeSecretKey, - public_key: &HpkePublicKey, - context: &[u8], - ciphertext: &HpkeCiphertext, - ) -> Result { - let context = EncryptContext::new(Self::ENCRYPT_LABEL, context).mls_encode_to_vec()?; - - let plaintext = cipher_suite_provider - .hpke_open(ciphertext, secret_key, public_key, &context, None) - .await - .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; - - Self::from_bytes(plaintext.to_vec()) + fn hpke_info(context: &[u8]) -> Result, MlsError> { + Ok(EncryptContext::new(Self::ENCRYPT_LABEL, context).mls_encode_to_vec()?) } - - fn from_bytes(bytes: Vec) -> Result; - fn get_bytes(&self) -> Result, MlsError>; } #[cfg(test)] @@ -84,9 +38,9 @@ pub(crate) mod test_utils { use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::crypto::{CipherSuiteProvider, HpkeCiphertext}; - use crate::{client::MlsError, crypto::test_utils::try_test_cipher_suite_provider}; + use crate::crypto::test_utils::try_test_cipher_suite_provider; - use super::HpkeEncryptable; + use super::HpkeInfo; #[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct HpkeInteropTestCase { @@ -125,19 +79,10 @@ pub(crate) mod test_utils { } #[derive(Clone, Debug, MlsSize, MlsEncode, MlsDecode)] - struct TestEncryptable(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec); + struct TestInfo; - impl HpkeEncryptable for TestEncryptable { + impl HpkeInfo for TestInfo { const ENCRYPT_LABEL: &'static str = "EncryptWithLabel"; - - fn from_bytes(bytes: Vec) -> Result { - Ok(Self(bytes)) - } - - #[cfg_attr(coverage_nightly, coverage(off))] - fn get_bytes(&self) -> Result, MlsError> { - Ok(self.0.clone()) - } } impl HpkeInteropTestCase { @@ -151,12 +96,14 @@ pub(crate) mod test_utils { ciphertext: self.ciphertext.clone(), }; - let computed_plaintext = - TestEncryptable::decrypt(cs, &secret, &public, &self.context, &ciphertext) - .await - .unwrap(); + let info = TestInfo::hpke_info(&self.context).unwrap(); + + let computed_plaintext = cs + .hpke_open(&ciphertext, &secret, &public, &info, None) + .await + .unwrap(); - assert_eq!(&computed_plaintext.0, &self.plaintext) + assert_eq!(&computed_plaintext, &self.plaintext) } } } diff --git a/mls-rs/src/tree_kem/kem.rs b/mls-rs/src/tree_kem/kem.rs index 3a42205c..1333be1a 100644 --- a/mls-rs/src/tree_kem/kem.rs +++ b/mls-rs/src/tree_kem/kem.rs @@ -6,15 +6,13 @@ use crate::client::MlsError; use crate::crypto::{CipherSuiteProvider, SignatureSecretKey}; use crate::group::GroupContext; use crate::identity::SigningIdentity; -use crate::iter::wrap_iter; -use crate::tree_kem::math as tree_math; +use crate::tree_kem::{hpke_encryption::HpkeInfo, math as tree_math}; + use alloc::vec; use alloc::vec::Vec; use itertools::Itertools; use mls_rs_codec::MlsEncode; - -#[cfg(all(not(mls_build_async), feature = "rayon"))] -use {crate::iter::ParallelIteratorExt, rayon::prelude::*}; +use mls_rs_core::{crypto::HpkePublicKey, error::IntoAnyError}; #[cfg(mls_build_async)] use futures::{StreamExt, TryStreamExt}; @@ -22,7 +20,6 @@ use futures::{StreamExt, TryStreamExt}; #[cfg(feature = "std")] use std::collections::HashSet; -use super::hpke_encryption::HpkeEncryptable; use super::leaf_node::ConfigProperties; use super::node::NodeTypeResolver; use super::{ @@ -170,7 +167,6 @@ impl<'a> TreeKem<'a> { }) } - #[cfg(any(mls_build_async, not(feature = "rayon")))] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn encrypt_path_secrets( &self, @@ -187,58 +183,68 @@ impl<'a> TreeKem<'a> { #[cfg(not(feature = "std"))] let excluding = excluding.collect::>(); - let mut node_updates = Vec::new(); - - for ((_, copath_index), path_secret) in path.into_iter().zip(path_secrets.iter()) { - if let Some(path_secret) = path_secret { - node_updates.push( - self.encrypt_copath_node_resolution( - cipher_suite, - path_secret, - copath_index, - context_bytes, - &excluding, - ) - .await?, - ); - } + let filtered_path = path + .iter() + .copied() + .zip(path_secrets.iter()) + .filter_map(|((dp, cp), s)| s.as_ref().map(|s| (dp, cp, s))) + .collect_vec(); + + let mut pt = Vec::new(); + let mut recipients = Vec::new(); + + for (_, cp, s) in &filtered_path { + pt.push(s.as_ref()); + recipients.push(self.filtered_recipients(*cp, &excluding)?); } - Ok(node_updates) - } + let info = PathSecret::hpke_info(context_bytes)?; - #[cfg(all(not(mls_build_async), feature = "rayon"))] - fn encrypt_path_secrets( - &self, - path: Vec<(u32, u32)>, - path_secrets: &[Option], - context_bytes: &[u8], - cipher_suite: &P, - excluding: &[LeafIndex], - ) -> Result, MlsError> { - let excluding = excluding.iter().copied().map(NodeIndex::from); + let ct = cipher_suite + .mm_hpke_seal(&info, None, &pt, &recipients) + .await + .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; - #[cfg(feature = "std")] - let excluding = excluding.collect::>(); - #[cfg(not(feature = "std"))] - let excluding = excluding.collect::>(); + (ct.len() == filtered_path.len()) + .then_some(()) + .ok_or(MlsError::WrongPathLen)?; - path.into_par_iter() - .zip(path_secrets.par_iter()) - .filter_map(|((_, copath_index), path_secret)| { - path_secret.as_ref().map(|path_secret| { - self.encrypt_copath_node_resolution( - cipher_suite, - path_secret, - copath_index, - context_bytes, - &excluding, - ) + filtered_path + .iter() + .zip(ct) + .map(|((dp, _, _), ct)| { + Ok(UpdatePathNode { + public_key: self.node_public_key(*dp)?.clone(), + encrypted_path_secret: ct, }) }) .collect() } + fn filtered_recipients( + &self, + node_index: NodeIndex, + #[cfg(feature = "std")] excluding: &HashSet, + #[cfg(not(feature = "std"))] excluding: &[NodeIndex], + ) -> Result, MlsError> { + self.tree_kem_public + .nodes + .get_resolution_index(node_index)? + .into_iter() + .filter(|idx| !excluding.contains(idx)) + .map(|idx| self.node_public_key(idx)) + .collect() + } + + fn node_public_key(&self, node_index: NodeIndex) -> Result<&HpkePublicKey, MlsError> { + Ok(self + .tree_kem_public + .nodes + .borrow_node(node_index)? + .as_non_empty()? + .public_key()) + } + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn decap( self, @@ -246,7 +252,7 @@ impl<'a> TreeKem<'a> { update_path: &ValidatedUpdatePath, added_leaves: &[LeafIndex], context_bytes: &[u8], - cipher_suite_provider: &CP, + cipher_suite: &CP, ) -> Result where CP: CipherSuiteProvider, @@ -261,33 +267,41 @@ impl<'a> TreeKem<'a> { let resolved_pos = self.find_resolved_pos(&path, lca_index)?; let ct_pos = self.find_ciphertext_pos(path[lca_index], path[resolved_pos], added_leaves)?; - let lca_node = update_path.nodes[lca_index] - .as_ref() - .ok_or(MlsError::LcaNotFoundInDirectPath)?; - - let ct = lca_node - .encrypted_path_secret - .get(ct_pos) - .ok_or(MlsError::LcaNotFoundInDirectPath)?; + let info = PathSecret::hpke_info(context_bytes)?; let secret = self.private_key.secret_keys[resolved_pos] .as_ref() .ok_or(MlsError::UpdateErrorNoSecretKey)?; - let public = self - .tree_kem_public + let public = self.node_public_key(path[resolved_pos])?; + + let ct = update_path .nodes - .borrow_node(path[resolved_pos])? - .as_ref() - .ok_or(MlsError::UpdateErrorNoSecretKey)? - .public_key(); + .iter() + .filter_map(|node| { + node.as_ref() + .map(|node| node.encrypted_path_secret.as_slice()) + }) + .collect_vec(); + + let count_blanks = update_path + .nodes + .iter() + .take(lca_index) + .filter(|n| n.is_none()) + .count(); - let lca_path_secret = - PathSecret::decrypt(cipher_suite_provider, secret, public, context_bytes, ct).await?; + let self_ct_index = (lca_index - count_blanks, ct_pos); + + let lca_path_secret = cipher_suite + .mm_hpke_open(&ct, self_ct_index, secret, public, &info, None) + .await + .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))? + .ok_or(MlsError::WrongPathLen)? + .into(); // Derive the rest of the secrets for the tree and assign to the proper nodes - let mut node_secret_gen = - PathSecretGenerator::starting_with(cipher_suite_provider, lca_path_secret); + let mut node_secret_gen = PathSecretGenerator::starting_with(cipher_suite, lca_path_secret); // Update secrets based on the decrypted path secret in the update self.private_key.secret_keys.resize(path.len() + 1, None); @@ -298,8 +312,7 @@ impl<'a> TreeKem<'a> { // Verify the private key we calculated properly matches the public key we inserted into the tree. This guarantees // that we will be able to decrypt later. - let (hpke_private, hpke_public) = - secret.to_hpke_key_pair(cipher_suite_provider).await?; + let (hpke_private, hpke_public) = secret.to_hpke_key_pair(cipher_suite).await?; if hpke_public != update.public_key { return Err(MlsError::PubKeyMismatch); @@ -314,56 +327,6 @@ impl<'a> TreeKem<'a> { node_secret_gen.next_secret().await } - #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] - async fn encrypt_copath_node_resolution( - &self, - cipher_suite_provider: &P, - path_secret: &PathSecret, - copath_index: NodeIndex, - context: &[u8], - #[cfg(feature = "std")] excluding: &HashSet, - #[cfg(not(feature = "std"))] excluding: &[NodeIndex], - ) -> Result { - let reso = self - .tree_kem_public - .nodes - .get_resolution_index(copath_index)?; - - let make_ctxt = |idx| async move { - let node = self - .tree_kem_public - .nodes - .borrow_node(idx)? - .as_non_empty()?; - - path_secret - .encrypt(cipher_suite_provider, node.public_key(), context) - .await - }; - - let ctxts = wrap_iter(reso).filter(|&idx| async move { !excluding.contains(&idx) }); - - #[cfg(not(mls_build_async))] - let ctxts = ctxts.map(make_ctxt); - - #[cfg(mls_build_async)] - let ctxts = ctxts.then(make_ctxt); - - let ctxts = ctxts.try_collect().await?; - - let path_index = tree_math::parent(copath_index); - - Ok(UpdatePathNode { - public_key: self - .tree_kem_public - .nodes - .borrow_as_parent(path_index)? - .public_key - .clone(), - encrypted_path_secret: ctxts, - }) - } - #[inline] fn find_resolved_pos( &self, diff --git a/mls-rs/src/tree_kem/path_secret.rs b/mls-rs/src/tree_kem/path_secret.rs index 220689f3..30192c9f 100644 --- a/mls-rs/src/tree_kem/path_secret.rs +++ b/mls-rs/src/tree_kem/path_secret.rs @@ -12,7 +12,7 @@ use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::error::IntoAnyError; use zeroize::Zeroizing; -use super::hpke_encryption::HpkeEncryptable; +use super::hpke_encryption::HpkeInfo; #[derive(Debug, Clone, Eq, PartialEq, MlsSize, MlsEncode, MlsDecode)] pub struct PathSecret(#[mls_codec(with = "mls_rs_codec::byte_vec")] Zeroizing>); @@ -53,16 +53,8 @@ impl PathSecret { } } -impl HpkeEncryptable for PathSecret { +impl HpkeInfo for PathSecret { const ENCRYPT_LABEL: &'static str = "UpdatePathNode"; - - fn from_bytes(bytes: Vec) -> Result { - Ok(Self(Zeroizing::new(bytes))) - } - - fn get_bytes(&self) -> Result, MlsError> { - Ok(self.to_vec()) - } } impl PathSecret { diff --git a/mls-rs/test_data/group_state.mls b/mls-rs/test_data/group_state.mls index 2ed6d9e3..a8704825 100644 Binary files a/mls-rs/test_data/group_state.mls and b/mls-rs/test_data/group_state.mls differ