Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify map conditional compilation #158

Merged
merged 4 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mls-rs-crypto-awslc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ license = "Apache-2.0 OR MIT"

[dependencies]
aws-lc-rs = "1.7.0"
aws-lc-sys = { version = "0.16.0" }
mls-rs-core = { path = "../mls-rs-core", version = "0.18.0" }
aws-lc-sys = { version = "0.17.0" }
mls-rs-core = { path = "../mls-rs-core", version = "=0.18.0" }
mls-rs-crypto-hpke = { path = "../mls-rs-crypto-hpke", version = "0.9.0" }
mls-rs-crypto-traits = { path = "../mls-rs-crypto-traits", version = "0.10.0" }
mls-rs-identity-x509 = { path = "../mls-rs-identity-x509", version = "0.11.0" }
Expand Down
10 changes: 3 additions & 7 deletions mls-rs/src/group/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ use crate::psk::{
ResumptionPSKUsage, ResumptionPsk,
};

#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
use std::collections::HashMap;

#[cfg(feature = "private_message")]
use ciphertext_processor::*;

Expand Down Expand Up @@ -265,10 +262,9 @@ where
epoch_secrets: EpochSecrets,
private_tree: TreeKemPrivate,
key_schedule: KeySchedule,
#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, // Hash of leaf node hpke public key to secret key
#[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
#[cfg(feature = "by_ref_proposal")]
pending_updates:
crate::map::SmallMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, // Hash of leaf node hpke public key to secret key
pending_commit: Option<CommitGeneration>,
#[cfg(feature = "psk")]
previous_psk: Option<PskSecretInput>,
Expand Down
11 changes: 2 additions & 9 deletions mls-rs/src/group/proposal_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ use crate::group::{proposal_filter::FilterStrategy, ProposalRef, ProtocolVersion

use crate::tree_kem::leaf_node::LeafNode;

#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
use std::collections::HashMap;

#[cfg(feature = "by_ref_proposal")]
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};

Expand All @@ -50,10 +47,7 @@ pub struct CachedProposal {
pub(crate) struct ProposalCache {
protocol_version: ProtocolVersion,
group_id: Vec<u8>,
#[cfg(feature = "std")]
pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
#[cfg(not(feature = "std"))]
pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
pub(crate) proposals: crate::map::SmallMap<ProposalRef, CachedProposal>,
}

#[cfg(feature = "by_ref_proposal")]
Expand Down Expand Up @@ -83,8 +77,7 @@ impl ProposalCache {
pub fn import(
protocol_version: ProtocolVersion,
group_id: Vec<u8>,
#[cfg(feature = "std")] proposals: HashMap<ProposalRef, CachedProposal>,
#[cfg(not(feature = "std"))] proposals: Vec<(ProposalRef, CachedProposal)>,
proposals: crate::map::SmallMap<ProposalRef, CachedProposal>,
) -> Self {
Self {
protocol_version,
Expand Down
59 changes: 6 additions & 53 deletions mls-rs/src/group/secret_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,11 @@ use core::{

use zeroize::Zeroizing;

use crate::{client::MlsError, tree_kem::math::TreeIndex, CipherSuiteProvider};
use crate::{client::MlsError, map::LargeMap, tree_kem::math::TreeIndex, CipherSuiteProvider};

use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::error::IntoAnyError;

#[cfg(feature = "std")]
use std::collections::HashMap;

#[cfg(not(feature = "std"))]
use alloc::collections::BTreeMap;

use super::key_schedule::kdf_expand_with_label;

pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;
Expand Down Expand Up @@ -94,13 +88,9 @@ impl From<Zeroizing<Vec<u8>>> for TreeSecret {
#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct TreeSecretsVec<T: TreeIndex> {
#[cfg(feature = "std")]
inner: HashMap<T, SecretTreeNode>,
#[cfg(not(feature = "std"))]
inner: Vec<(T, SecretTreeNode)>,
inner: LargeMap<T, SecretTreeNode>,
}

#[cfg(feature = "std")]
impl<T: TreeIndex> TreeSecretsVec<T> {
fn set_node(&mut self, index: T, value: SecretTreeNode) {
self.inner.insert(index, value);
Expand All @@ -111,30 +101,6 @@ impl<T: TreeIndex> TreeSecretsVec<T> {
}
}

#[cfg(not(feature = "std"))]
impl<T: TreeIndex> TreeSecretsVec<T> {
fn set_node(&mut self, index: T, value: SecretTreeNode) {
if let Some(i) = self.find_node(&index) {
self.inner[i] = (index, value)
} else {
self.inner.push((index, value))
}
}

fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
self.find_node(index).map(|i| self.inner.remove(i).1)
}

fn find_node(&self, index: &T) -> Option<usize> {
use itertools::Itertools;

self.inner
.iter()
.find_position(|(i, _)| i == index)
.map(|(i, _)| i)
}
}

#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SecretTree<T: TreeIndex> {
Expand Down Expand Up @@ -364,10 +330,8 @@ impl MessageKeyData {
pub struct SecretKeyRatchet {
secret: TreeSecret,
generation: u32,
#[cfg(all(feature = "out_of_order", feature = "std"))]
history: HashMap<u32, MessageKeyData>,
#[cfg(all(feature = "out_of_order", not(feature = "std")))]
history: BTreeMap<u32, MessageKeyData>,
#[cfg(feature = "out_of_order")]
history: LargeMap<u32, MessageKeyData>,
}

impl MlsSize for SecretKeyRatchet {
Expand Down Expand Up @@ -404,20 +368,9 @@ impl MlsDecode for SecretKeyRatchet {
Ok(Self {
secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
generation: u32::mls_decode(reader)?,
#[cfg(all(feature = "std", feature = "out_of_order"))]
history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
let mut items = HashMap::default();

while !data.is_empty() {
let item = MessageKeyData::mls_decode(data)?;
items.insert(item.generation, item);
}

Ok(items)
})?,
#[cfg(all(not(feature = "std"), feature = "out_of_order"))]
#[cfg(feature = "out_of_order")]
history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
let mut items = alloc::collections::BTreeMap::default();
let mut items = LargeMap::default();

while !data.is_empty() {
let item = MessageKeyData::mls_decode(data)?;
Expand Down
19 changes: 5 additions & 14 deletions mls-rs/src/group/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::{
use crate::{
crypto::{HpkePublicKey, HpkeSecretKey},
group::ProposalRef,
map::SmallMap,
};

#[cfg(feature = "by_ref_proposal")]
Expand All @@ -27,12 +28,6 @@ use mls_rs_core::crypto::SignatureSecretKey;
#[cfg(feature = "tree_index")]
use mls_rs_core::identity::IdentityProvider;

#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
use std::collections::HashMap;

#[cfg(all(feature = "by_ref_proposal", not(feature = "std")))]
use alloc::vec::Vec;

use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository};

#[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
Expand All @@ -43,10 +38,8 @@ pub(crate) struct Snapshot {
private_tree: TreeKemPrivate,
epoch_secrets: EpochSecrets,
key_schedule: KeySchedule,
#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
#[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
#[cfg(feature = "by_ref_proposal")]
pending_updates: SmallMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
pending_commit: Option<CommitGeneration>,
signer: SignatureSecretKey,
}
Expand All @@ -55,10 +48,8 @@ pub(crate) struct Snapshot {
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct RawGroupState {
pub(crate) context: GroupContext,
#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
#[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
#[cfg(feature = "by_ref_proposal")]
pub(crate) proposals: SmallMap<ProposalRef, CachedProposal>,
pub(crate) public_tree: TreeKemPublic,
pub(crate) interim_transcript_hash: InterimTranscriptHash,
pub(crate) pending_reinit: Option<ReInitProposal>,
Expand Down
1 change: 1 addition & 0 deletions mls-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ mod hash_reference;
pub mod identity;
mod iter;
mod key_package;
pub(crate) mod map;
/// Pre-shared key support.
pub mod psk;
mod signer;
Expand Down
120 changes: 120 additions & 0 deletions mls-rs/src/map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use alloc::vec::Vec;
use core::{
hash::Hash,
ops::{Deref, DerefMut},
};

use map_impl::SmallMapInner;
pub use map_impl::{LargeMap, LargeMapEntry, SmallMap};
use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};

#[cfg(feature = "std")]
mod map_impl {
use core::hash::Hash;
use std::collections::{hash_map::Entry, HashMap};

#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SmallMap<K: Hash + Eq, V>(pub(super) HashMap<K, V>);

pub type LargeMap<K, V> = SmallMap<K, V>;
pub(super) type SmallMapInner<K, V> = HashMap<K, V>;
pub type LargeMapEntry<'a, K, V> = Entry<'a, K, V>;
}

#[cfg(not(feature = "std"))]
mod map_impl {
use core::hash::Hash;

use alloc::{
collections::{btree_map::Entry, BTreeMap},
vec::Vec,
};
#[cfg(feature = "by_ref_proposal")]
use itertools::Itertools;

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmallMap<K: Hash + Eq, V>(pub(super) Vec<(K, V)>);

pub type LargeMap<K, V> = BTreeMap<K, V>;
pub(super) type SmallMapInner<K, V> = Vec<(K, V)>;
pub type LargeMapEntry<'a, K, V> = Entry<'a, K, V>;

#[cfg(feature = "by_ref_proposal")]
impl<K: Hash + Eq, V> SmallMap<K, V> {
pub fn get(&self, key: &K) -> Option<&V> {
self.find(key).map(|i| &self.0[i].1)
}

pub fn insert(&mut self, key: K, value: V) {
match self.0.iter_mut().find(|(k, _)| (k == &key)) {
Some((_, v)) => *v = value,
None => self.0.push((key, value)),
}
}

pub fn remove(&mut self, key: &K) -> Option<V> {
self.find(key).map(|i| self.0.remove(i).1)
}

fn find(&self, key: &K) -> Option<usize> {
self.0
.iter()
.position(|(k, _)| k == key)
}
}
}

impl<K: Hash + Eq, V> Default for SmallMap<K, V> {
fn default() -> Self {
Self(SmallMapInner::new())
}
}

impl<K: Hash + Eq, V> Deref for SmallMap<K, V> {
type Target = SmallMapInner<K, V>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<K: Hash + Eq, V> DerefMut for SmallMap<K, V> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl<K, V> MlsDecode for SmallMap<K, V>
where
K: Hash + Eq + MlsEncode + MlsDecode + MlsSize,
V: MlsEncode + MlsDecode + MlsSize,
{
fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
SmallMapInner::mls_decode(reader).map(Self)
}
}

impl<K, V> MlsSize for SmallMap<K, V>
where
K: Hash + Eq + MlsEncode + MlsDecode + MlsSize,
V: MlsEncode + MlsDecode + MlsSize,
{
fn mls_encoded_len(&self) -> usize {
self.0.mls_encoded_len()
}
}

impl<K, V> MlsEncode for SmallMap<K, V>
where
K: Hash + Eq + MlsEncode + MlsDecode + MlsSize,
V: MlsEncode + MlsDecode + MlsSize,
{
fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
self.0.mls_encode(writer)
}
}
Loading
Loading