Skip to content

Commit

Permalink
kbkdf: provide a param builder (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
baloo authored Jan 31, 2025
1 parent 0715721 commit b9e3645
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 52 deletions.
4 changes: 2 additions & 2 deletions kbkdf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ The most common way to use KBKDF is as follows: you generate a shared secret wit
```rust
use hex_literal::hex;
use hmac::Hmac;
use kbkdf::{Counter, Kbkdf};
use kbkdf::{Counter, Kbkdf, Params};
use sha2::Sha256;

type HmacSha256 = Hmac<Sha256>;
let counter = Counter::<HmacSha256, HmacSha256>::default();
let key = counter
.derive(b"secret", true, true, true, b"label", b"")
.derive(Params::builder(b"secret").with_label(b"label").build())
.unwrap();
assert_eq!(
key,
Expand Down
97 changes: 77 additions & 20 deletions kbkdf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,71 @@ impl fmt::Display for Error {

impl core::error::Error for Error {}

/// Parameters used for KBKDF
pub struct Params<'k, 'l, 'c> {
pub kin: &'k [u8],
pub label: &'l [u8],
pub context: &'c [u8],
pub use_l: bool,
pub use_separator: bool,
pub use_counter: bool,
}

impl<'k, 'l, 'c> Params<'k, 'l, 'c> {
/// Create a new builder for [`Params`]
pub fn builder(kin: &'k [u8]) -> ParamsBuilder<'k, 'l, 'c> {
let params = Params {
kin,
label: &[],
context: &[],
use_l: true,
use_separator: true,
use_counter: true,
};
ParamsBuilder(params)
}
}

/// Parameters builders for [`Params`]
pub struct ParamsBuilder<'k, 'l, 'c>(Params<'k, 'l, 'c>);

impl<'k, 'l, 'c> ParamsBuilder<'k, 'l, 'c> {
/// Return the built [`Params`]
pub fn build(self) -> Params<'k, 'l, 'c> {
self.0
}

/// Set the label for the parameters
pub fn with_label(mut self, label: &'l [u8]) -> Self {
self.0.label = label;
self
}

/// Set the context for the parameters
pub fn with_context(mut self, context: &'c [u8]) -> Self {
self.0.context = context;
self
}

/// During the iterations, append the length of the Prf
pub fn use_l(mut self, use_l: bool) -> Self {
self.0.use_l = use_l;
self
}

/// During the iterations, separate the label from the context with a NULL byte
pub fn use_separator(mut self, use_separator: bool) -> Self {
self.0.use_separator = use_separator;
self
}

/// During the iterations, update the Prf with the iteration counter
pub fn use_counter(mut self, use_counter: bool) -> Self {
self.0.use_counter = use_counter;
self
}
}

// Helper structure along with [`KbkdfUser`] to compute values of L and H.
struct KbkdfCore<OutputLen, PrfOutputLen> {
_marker: PhantomData<(OutputLen, PrfOutputLen)>,
Expand Down Expand Up @@ -73,15 +138,7 @@ where
<Prf::OutputSize as Mul<U8>>::Output: Unsigned,
{
/// Derives `key` from `kin` and other parameters.
fn derive(
&self,
kin: &[u8],
use_l: bool,
use_separator: bool,
use_counter: bool,
label: &[u8],
context: &[u8],
) -> Result<Array<u8, K::KeySize>, Error> {
fn derive(&self, params: Params) -> Result<Array<u8, K::KeySize>, Error> {
// n - An integer whose value is the number of iterations of the PRF needed to generate L
// bits of keying material
let n: u32 = <KbkdfCore<K::KeySize, Prf::OutputSize> as KbkdfUser>::L::U32
Expand All @@ -97,25 +154,25 @@ where
let mut ki = None;
self.input_iv(&mut ki);
let mut a = {
let mut h = Prf::new_from_slice(kin).unwrap();
h.update(label);
if use_separator {
let mut h = Prf::new_from_slice(params.kin).unwrap();
h.update(params.label);
if params.use_separator {
h.update(&[0]);
}
h.update(context);
h.update(params.context);
h.finalize().into_bytes()
};

for counter in 1..=n {
if counter > 1 {
a = {
let mut h = Prf::new_from_slice(kin).unwrap();
let mut h = Prf::new_from_slice(params.kin).unwrap();
h.update(a.as_slice());
h.finalize().into_bytes()
};
}

let mut h = Prf::new_from_slice(kin).unwrap();
let mut h = Prf::new_from_slice(params.kin).unwrap();

if Self::FEEDBACK_KI {
if let Some(ki) = ki {
Expand All @@ -126,7 +183,7 @@ where
if Self::DOUBLE_PIPELINE {
h.update(a.as_slice());
}
if use_counter {
if params.use_counter {
// counter encoded as big endian u32
// Type parameter R encodes how large the value is to be (either U8, U16, U24, or U32)
//
Expand All @@ -137,12 +194,12 @@ where
}

// Fixed input data
h.update(label);
if use_separator {
h.update(params.label);
if params.use_separator {
h.update(&[0]);
}
h.update(context);
if use_l {
h.update(params.context);
if params.use_l {
h.update(
&(<KbkdfCore<K::KeySize, Prf::OutputSize> as KbkdfUser>::L::U32).to_be_bytes()
[..],
Expand Down
41 changes: 30 additions & 11 deletions kbkdf/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Array, Counter, DoublePipeline, Feedback, Kbkdf};
use super::{Array, Counter, DoublePipeline, Feedback, Kbkdf, Params};
use core::convert::TryFrom;
use digest::{consts::*, crypto_common::KeySizeUser};
use hex_literal::hex;
Expand Down Expand Up @@ -71,7 +71,14 @@ fn test_static_values_counter() {
let counter = Counter::<HmacSha256, HmacSha512>::default();
for (v, i) in KNOWN_VALUES_COUNTER_HMAC_SHA256.iter().zip(0..) {
assert_eq!(
counter.derive(v.key, v.use_l, v.use_separator, true, v.label, v.context,),
counter.derive(
Params::builder(v.key)
.use_l(v.use_l)
.use_separator(v.use_separator)
.with_label(v.label)
.with_context(v.context)
.build()
),
Ok(Array::<_, _>::try_from(v.expected).unwrap().clone()),
"key derivation failed for (index: {i}):\n{v:x?}"
);
Expand All @@ -90,13 +97,10 @@ fn test_counter_kbkdfvs() {
let counter = Counter::<HmacSha256, MockOutput>::default();
// KDFCTR_gen.txt count 15
assert_eq!(
counter.derive(
&hex!("43eef6d824fd820405626ab9b6d79f1fd04e126ab8e17729e3afc7cb5af794f8"),
false,
false,
true,
&hex!("5e269b5a7bdedcc3e875e2725693a257fc60011af7dcd68a3358507fe29b0659ca66951daa05a15032033650bc58a27840f8fbe9f4088b9030738f68"),
&[],
counter.derive(Params::builder(
&hex!("43eef6d824fd820405626ab9b6d79f1fd04e126ab8e17729e3afc7cb5af794f8")).use_l(false).use_separator(
false).with_label(
&hex!("5e269b5a7bdedcc3e875e2725693a257fc60011af7dcd68a3358507fe29b0659ca66951daa05a15032033650bc58a27840f8fbe9f4088b9030738f68")).build()
),
Ok(Array::<u8, U32>::from(hex!("f0a339ecbcae6add1afb27da3ba40a1320c6427a58afb9dc366b219b7eb29ecf")).clone()),
);
Expand Down Expand Up @@ -160,7 +164,14 @@ fn test_static_values_feedback() {
let iv = v.iv.map(|iv| Array::try_from(iv).unwrap());
let feedback = Feedback::<HmacSha256, HmacSha512>::new(iv.as_ref());
assert_eq!(
feedback.derive(v.key, v.use_l, v.use_separator, true, v.label, v.context,),
feedback.derive(
Params::builder(v.key)
.use_l(v.use_l)
.use_separator(v.use_separator)
.with_label(v.label)
.with_context(v.context)
.build()
),
Ok(Array::<_, _>::try_from(v.expected).unwrap().clone()),
"key derivation failed for (index: {i}):\n{v:x?}"
);
Expand Down Expand Up @@ -195,7 +206,15 @@ fn test_static_values_double_pipeline() {
for (v, i) in KNOWN_VALUES_DOUBLE_PIPELINE_HMAC_SHA256.iter().zip(0..) {
let dbl_pipeline = DoublePipeline::<HmacSha256, MockOutput>::default();
assert_eq!(
dbl_pipeline.derive(v.key, v.use_l, v.use_separator, false, v.label, v.context,),
dbl_pipeline.derive(
Params::builder(v.key)
.use_l(v.use_l)
.use_separator(v.use_separator)
.use_counter(false)
.with_label(v.label)
.with_context(v.context)
.build(),
),
Ok(Array::<_, _>::try_from(v.expected).unwrap().clone()),
"key derivation failed for (index: {i}):\n{v:x?}"
);
Expand Down
39 changes: 20 additions & 19 deletions kbkdf/tests/kbkdf/parser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use digest::consts::*;
use hex;
use kbkdf::Kbkdf;
use kbkdf::{Kbkdf, Params};

use core::{convert::TryInto, ops::Mul};
use digest::{
Expand Down Expand Up @@ -183,12 +183,13 @@ impl TestData for CounterTestData {

let key = counter
.derive(
self.ki.as_slice(),
false,
false,
use_counter,
label.as_slice(),
context.as_slice(),
Params::builder(self.ki.as_slice())
.use_l(false)
.use_separator(false)
.use_counter(use_counter)
.with_label(label.as_slice())
.with_context(context.as_slice())
.build(),
)
.unwrap();

Expand Down Expand Up @@ -238,12 +239,12 @@ impl TestData for DoublePipelineTestData {

let key = double_pipeline
.derive(
self.ki.as_slice(),
false,
false,
use_counter,
self.fixed_data.as_slice(),
&[],
Params::builder(self.ki.as_slice())
.use_l(false)
.use_separator(false)
.use_counter(use_counter)
.with_label(self.fixed_data.as_slice())
.build(),
)
.unwrap();

Expand Down Expand Up @@ -309,12 +310,12 @@ impl TestData for FeedbackTestData {

let key = feedback
.derive(
self.ki.as_slice(),
false,
false,
use_counter,
self.fixed_data.as_slice(),
&[],
Params::builder(self.ki.as_slice())
.use_l(false)
.use_separator(false)
.use_counter(use_counter)
.with_label(self.fixed_data.as_slice())
.build(),
)
.unwrap();

Expand Down

0 comments on commit b9e3645

Please sign in to comment.