diff --git a/kbkdf/README.md b/kbkdf/README.md index dfad7c4..9110804 100644 --- a/kbkdf/README.md +++ b/kbkdf/README.md @@ -19,13 +19,13 @@ The most common way to use KBKDF is as follows: you generate a shared secret wit ```rust use hex_literal::hex; use hmac::Hmac; -use kbkdf::{Counter, Kbkdf}; +use kbkdf::{Counter, Kbkdf, Params}; use sha2::Sha256; type HmacSha256 = Hmac; let counter = Counter::::default(); let key = counter - .derive(b"secret", true, true, true, b"label", b"") + .derive(Params::builder(b"secret").with_label(b"label").build()) .unwrap(); assert_eq!( key, diff --git a/kbkdf/src/lib.rs b/kbkdf/src/lib.rs index 3fe002e..68fc9d4 100644 --- a/kbkdf/src/lib.rs +++ b/kbkdf/src/lib.rs @@ -30,6 +30,71 @@ impl fmt::Display for Error { impl core::error::Error for Error {} +/// Parameters used for KBKDF +pub struct Params<'k, 'l, 'c> { + pub kin: &'k [u8], + pub label: &'l [u8], + pub context: &'c [u8], + pub use_l: bool, + pub use_separator: bool, + pub use_counter: bool, +} + +impl<'k, 'l, 'c> Params<'k, 'l, 'c> { + /// Create a new builder for [`Params`] + pub fn builder(kin: &'k [u8]) -> ParamsBuilder<'k, 'l, 'c> { + let params = Params { + kin, + label: &[], + context: &[], + use_l: true, + use_separator: true, + use_counter: true, + }; + ParamsBuilder(params) + } +} + +/// Parameters builders for [`Params`] +pub struct ParamsBuilder<'k, 'l, 'c>(Params<'k, 'l, 'c>); + +impl<'k, 'l, 'c> ParamsBuilder<'k, 'l, 'c> { + /// Return the built [`Params`] + pub fn build(self) -> Params<'k, 'l, 'c> { + self.0 + } + + /// Set the label for the parameters + pub fn with_label(mut self, label: &'l [u8]) -> Self { + self.0.label = label; + self + } + + /// Set the context for the parameters + pub fn with_context(mut self, context: &'c [u8]) -> Self { + self.0.context = context; + self + } + + /// During the iterations, append the length of the Prf + pub fn use_l(mut self, use_l: bool) -> Self { + self.0.use_l = use_l; + self + } + + /// During the iterations, separate the label from the context with a NULL byte + pub fn use_separator(mut self, use_separator: bool) -> Self { + self.0.use_separator = use_separator; + self + } + + /// During the iterations, update the Prf with the iteration counter + pub fn use_counter(mut self, use_counter: bool) -> Self { + self.0.use_counter = use_counter; + self + } +} + // Helper structure along with [`KbkdfUser`] to compute values of L and H. struct KbkdfCore { _marker: PhantomData<(OutputLen, PrfOutputLen)>, @@ -73,15 +138,7 @@ where >::Output: Unsigned, { /// Derives `key` from `kin` and other parameters. - fn derive( - &self, - kin: &[u8], - use_l: bool, - use_separator: bool, - use_counter: bool, - label: &[u8], - context: &[u8], - ) -> Result, Error> { + fn derive(&self, params: Params) -> Result, Error> { // n - An integer whose value is the number of iterations of the PRF needed to generate L // bits of keying material let n: u32 = as KbkdfUser>::L::U32 @@ -97,25 +154,25 @@ where let mut ki = None; self.input_iv(&mut ki); let mut a = { - let mut h = Prf::new_from_slice(kin).unwrap(); - h.update(label); - if use_separator { + let mut h = Prf::new_from_slice(params.kin).unwrap(); + h.update(params.label); + if params.use_separator { h.update(&[0]); } - h.update(context); + h.update(params.context); h.finalize().into_bytes() }; for counter in 1..=n { if counter > 1 { a = { - let mut h = Prf::new_from_slice(kin).unwrap(); + let mut h = Prf::new_from_slice(params.kin).unwrap(); h.update(a.as_slice()); h.finalize().into_bytes() }; } - let mut h = Prf::new_from_slice(kin).unwrap(); + let mut h = Prf::new_from_slice(params.kin).unwrap(); if Self::FEEDBACK_KI { if let Some(ki) = ki { @@ -126,7 +183,7 @@ where if Self::DOUBLE_PIPELINE { h.update(a.as_slice()); } - if use_counter { + if params.use_counter { // counter encoded as big endian u32 // Type parameter R encodes how large the value is to be (either U8, U16, U24, or U32) // @@ -137,12 +194,12 @@ where } // Fixed input data - h.update(label); - if use_separator { + h.update(params.label); + if params.use_separator { h.update(&[0]); } - h.update(context); - if use_l { + h.update(params.context); + if params.use_l { h.update( &( as KbkdfUser>::L::U32).to_be_bytes() [..], diff --git a/kbkdf/src/tests.rs b/kbkdf/src/tests.rs index 625c54c..df89e2b 100644 --- a/kbkdf/src/tests.rs +++ b/kbkdf/src/tests.rs @@ -1,4 +1,4 @@ -use super::{Array, Counter, DoublePipeline, Feedback, Kbkdf}; +use super::{Array, Counter, DoublePipeline, Feedback, Kbkdf, Params}; use core::convert::TryFrom; use digest::{consts::*, crypto_common::KeySizeUser}; use hex_literal::hex; @@ -71,7 +71,14 @@ fn test_static_values_counter() { let counter = Counter::::default(); for (v, i) in KNOWN_VALUES_COUNTER_HMAC_SHA256.iter().zip(0..) { assert_eq!( - counter.derive(v.key, v.use_l, v.use_separator, true, v.label, v.context,), + counter.derive( + Params::builder(v.key) + .use_l(v.use_l) + .use_separator(v.use_separator) + .with_label(v.label) + .with_context(v.context) + .build() + ), Ok(Array::<_, _>::try_from(v.expected).unwrap().clone()), "key derivation failed for (index: {i}):\n{v:x?}" ); @@ -90,13 +97,10 @@ fn test_counter_kbkdfvs() { let counter = Counter::::default(); // KDFCTR_gen.txt count 15 assert_eq!( - counter.derive( - &hex!("43eef6d824fd820405626ab9b6d79f1fd04e126ab8e17729e3afc7cb5af794f8"), - false, - false, - true, - &hex!("5e269b5a7bdedcc3e875e2725693a257fc60011af7dcd68a3358507fe29b0659ca66951daa05a15032033650bc58a27840f8fbe9f4088b9030738f68"), - &[], + counter.derive(Params::builder( + &hex!("43eef6d824fd820405626ab9b6d79f1fd04e126ab8e17729e3afc7cb5af794f8")).use_l(false).use_separator( + false).with_label( + &hex!("5e269b5a7bdedcc3e875e2725693a257fc60011af7dcd68a3358507fe29b0659ca66951daa05a15032033650bc58a27840f8fbe9f4088b9030738f68")).build() ), Ok(Array::::from(hex!("f0a339ecbcae6add1afb27da3ba40a1320c6427a58afb9dc366b219b7eb29ecf")).clone()), ); @@ -160,7 +164,14 @@ fn test_static_values_feedback() { let iv = v.iv.map(|iv| Array::try_from(iv).unwrap()); let feedback = Feedback::::new(iv.as_ref()); assert_eq!( - feedback.derive(v.key, v.use_l, v.use_separator, true, v.label, v.context,), + feedback.derive( + Params::builder(v.key) + .use_l(v.use_l) + .use_separator(v.use_separator) + .with_label(v.label) + .with_context(v.context) + .build() + ), Ok(Array::<_, _>::try_from(v.expected).unwrap().clone()), "key derivation failed for (index: {i}):\n{v:x?}" ); @@ -195,7 +206,15 @@ fn test_static_values_double_pipeline() { for (v, i) in KNOWN_VALUES_DOUBLE_PIPELINE_HMAC_SHA256.iter().zip(0..) { let dbl_pipeline = DoublePipeline::::default(); assert_eq!( - dbl_pipeline.derive(v.key, v.use_l, v.use_separator, false, v.label, v.context,), + dbl_pipeline.derive( + Params::builder(v.key) + .use_l(v.use_l) + .use_separator(v.use_separator) + .use_counter(false) + .with_label(v.label) + .with_context(v.context) + .build(), + ), Ok(Array::<_, _>::try_from(v.expected).unwrap().clone()), "key derivation failed for (index: {i}):\n{v:x?}" ); diff --git a/kbkdf/tests/kbkdf/parser.rs b/kbkdf/tests/kbkdf/parser.rs index 81f0799..ae730e0 100644 --- a/kbkdf/tests/kbkdf/parser.rs +++ b/kbkdf/tests/kbkdf/parser.rs @@ -1,6 +1,6 @@ use digest::consts::*; use hex; -use kbkdf::Kbkdf; +use kbkdf::{Kbkdf, Params}; use core::{convert::TryInto, ops::Mul}; use digest::{ @@ -183,12 +183,13 @@ impl TestData for CounterTestData { let key = counter .derive( - self.ki.as_slice(), - false, - false, - use_counter, - label.as_slice(), - context.as_slice(), + Params::builder(self.ki.as_slice()) + .use_l(false) + .use_separator(false) + .use_counter(use_counter) + .with_label(label.as_slice()) + .with_context(context.as_slice()) + .build(), ) .unwrap(); @@ -238,12 +239,12 @@ impl TestData for DoublePipelineTestData { let key = double_pipeline .derive( - self.ki.as_slice(), - false, - false, - use_counter, - self.fixed_data.as_slice(), - &[], + Params::builder(self.ki.as_slice()) + .use_l(false) + .use_separator(false) + .use_counter(use_counter) + .with_label(self.fixed_data.as_slice()) + .build(), ) .unwrap(); @@ -309,12 +310,12 @@ impl TestData for FeedbackTestData { let key = feedback .derive( - self.ki.as_slice(), - false, - false, - use_counter, - self.fixed_data.as_slice(), - &[], + Params::builder(self.ki.as_slice()) + .use_l(false) + .use_separator(false) + .use_counter(use_counter) + .with_label(self.fixed_data.as_slice()) + .build(), ) .unwrap();