Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular inversion improvements #263

Merged
merged 3 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,59 @@ fn bench_shifts<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
});
}

fn bench_inv_mod<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
group.bench_function("inv_odd_mod, U256", |b| {
b.iter_batched(
|| {
let m = U256::random(&mut OsRng) | U256::ONE;
loop {
let x = U256::random(&mut OsRng);
let (_, is_some) = x.inv_odd_mod(&m);
if is_some.into() {
break (x, m);
}
}
},
|(x, m)| x.inv_odd_mod(&m),
BatchSize::SmallInput,
)
});

group.bench_function("inv_mod, U256, odd modulus", |b| {
b.iter_batched(
|| {
let m = U256::random(&mut OsRng) | U256::ONE;
loop {
let x = U256::random(&mut OsRng);
let (_, is_some) = x.inv_odd_mod(&m);
if is_some.into() {
break (x, m);
}
}
},
|(x, m)| x.inv_mod(&m),
BatchSize::SmallInput,
)
});

group.bench_function("inv_mod, U256", |b| {
b.iter_batched(
|| {
let m = U256::random(&mut OsRng);
loop {
let x = U256::random(&mut OsRng);
let (_, is_some) = x.inv_mod(&m);
if is_some.into() {
break (x, m);
}
}
},
|(x, m)| x.inv_mod(&m),
BatchSize::SmallInput,
)
});
}

fn bench_wrapping_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");
bench_division(&mut group);
Expand All @@ -169,6 +222,7 @@ fn bench_montgomery(c: &mut Criterion) {
fn bench_modular_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("modular ops");
bench_shifts(&mut group);
bench_inv_mod(&mut group);
group.finish();
}

Expand Down
15 changes: 15 additions & 0 deletions src/ct_choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ impl CtChoice {
Self(value.wrapping_neg())
}

/// Returns the truthy value if `value != 0`, and the falsy value otherwise.
pub(crate) const fn from_usize_being_nonzero(value: usize) -> Self {
const HI_BIT: u32 = usize::BITS - 1;
Self::from_lsb(((value | value.wrapping_neg()) >> HI_BIT) as Word)
}

/// Returns the truthy value if `x == y`, and the falsy value otherwise.
pub(crate) const fn from_usize_equality(x: usize, y: usize) -> Self {
Self::from_usize_being_nonzero(x.wrapping_sub(y)).not()
}

/// Returns the truthy value if `x < y`, and the falsy value otherwise.
pub(crate) const fn from_usize_lt(x: usize, y: usize) -> Self {
let bit = (((!x) & y) | (((!x) | y) & (x.wrapping_sub(y)))) >> (usize::BITS - 1);
Expand All @@ -39,6 +50,10 @@ impl CtChoice {
Self(!self.0)
}

pub(crate) const fn or(&self, other: Self) -> Self {
Self(self.0 | other.0)
}

pub(crate) const fn and(&self, other: Self) -> Self {
Self(self.0 & other.0)
}
Expand Down
51 changes: 48 additions & 3 deletions src/uint/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
/// Get the value of the bit at position `index`, as a truthy or falsy `CtChoice`.
/// Returns the falsy value for indices out of range.
pub const fn bit(&self, index: usize) -> CtChoice {
let limb_num = Limb((index / Limb::BITS) as Word);
let limb_num = index / Limb::BITS;
let index_in_limb = index % Limb::BITS;
let index_mask = 1 << index_in_limb;

Expand All @@ -79,18 +79,36 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let mut i = 0;
while i < LIMBS {
let bit = limbs[i] & index_mask;
let is_right_limb = Limb::ct_eq(limb_num, Limb(i as Word));
let is_right_limb = CtChoice::from_usize_equality(i, limb_num);
result |= is_right_limb.if_true(bit);
i += 1;
}

CtChoice::from_lsb(result >> index_in_limb)
}

/// Sets the bit at `index` to 0 or 1 depending on the value of `bit_value`.
pub(crate) const fn set_bit(self, index: usize, bit_value: CtChoice) -> Self {
let mut result = self;
let limb_num = index / Limb::BITS;
let index_in_limb = index % Limb::BITS;
let index_mask = 1 << index_in_limb;

let mut i = 0;
while i < LIMBS {
let is_right_limb = CtChoice::from_usize_equality(i, limb_num);
let old_limb = result.limbs[i].0;
let new_limb = bit_value.select(old_limb & !index_mask, old_limb | index_mask);
result.limbs[i] = Limb(is_right_limb.select(old_limb, new_limb));
i += 1;
}
result
}
}

#[cfg(test)]
mod tests {
use crate::U256;
use crate::{CtChoice, U256};

fn uint_with_bits_at(positions: &[usize]) -> U256 {
let mut result = U256::ZERO;
Expand Down Expand Up @@ -159,4 +177,31 @@ mod tests {
let u = U256::ZERO;
assert_eq!(u.trailing_zeros() as u32, 256);
}

#[test]
fn set_bit() {
let u = uint_with_bits_at(&[16, 79, 150]);
assert_eq!(
u.set_bit(127, CtChoice::TRUE),
uint_with_bits_at(&[16, 79, 127, 150])
);

let u = uint_with_bits_at(&[16, 79, 150]);
assert_eq!(
u.set_bit(150, CtChoice::TRUE),
uint_with_bits_at(&[16, 79, 150])
);

let u = uint_with_bits_at(&[16, 79, 150]);
assert_eq!(
u.set_bit(127, CtChoice::FALSE),
uint_with_bits_at(&[16, 79, 150])
);

let u = uint_with_bits_at(&[16, 79, 150]);
assert_eq!(
u.set_bit(150, CtChoice::FALSE),
uint_with_bits_at(&[16, 79])
);
}
}
143 changes: 124 additions & 19 deletions src/uint/inv_mod.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,68 @@
use super::Uint;
use crate::{CtChoice, Limb};
use crate::CtChoice;

impl<const LIMBS: usize> Uint<LIMBS> {
/// Computes 1/`self` mod 2^k as specified in Algorithm 4 from
/// A Secure Algorithm for Inversion Modulo 2k by
/// Sadiel de la Fe and Carles Ferrer. See
/// <https://www.mdpi.com/2410-387X/2/3/23>.
/// Computes 1/`self` mod `2^k`.
/// This method is constant-time w.r.t. `self` but not `k`.
///
/// Conditions: `self` < 2^k and `self` must be odd
pub const fn inv_mod2k(&self, k: usize) -> Self {
let mut x = Self::ZERO;
let mut b = Self::ONE;
pub const fn inv_mod2k_vartime(&self, k: usize) -> Self {
// Using the Algorithm 3 from "A Secure Algorithm for Inversion Modulo 2k"
// by Sadiel de la Fe and Carles Ferrer.
// See <https://www.mdpi.com/2410-387X/2/3/23>.

// Note that we are not using Alrgorithm 4, since we have a different approach
// of enforcing constant-timeness w.r.t. `self`.

let mut x = Self::ZERO; // keeps `x` during iterations
let mut b = Self::ONE; // keeps `b_i` during iterations
let mut i = 0;

while i < k {
let mut x_i = Self::ZERO;
let j = b.limbs[0].0 & 1;
x_i.limbs[0] = Limb(j);
x = x.bitor(&x_i.shl_vartime(i));
// X_i = b_i mod 2
let x_i = b.limbs[0].0 & 1;
let x_i_choice = CtChoice::from_lsb(x_i);
// b_{i+1} = (b_i - a * X_i) / 2
b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1);
// Store the X_i bit in the result (x = x | (1 << X_i))
x = x.bitor(&Uint::from_word(x_i).shl_vartime(i));

i += 1;
}

x
}

/// Computes 1/`self` mod `2^k`.
///
/// Conditions: `self` < 2^k and `self` must be odd
pub const fn inv_mod2k(&self, k: usize) -> Self {
// This is the same algorithm as in `inv_mod2k_vartime()`,
// but made constant-time w.r.t `k` as well.

let mut x = Self::ZERO; // keeps `x` during iterations
let mut b = Self::ONE; // keeps `b_i` during iterations
let mut i = 0;

while i < Self::BITS {
// Only iterations for i = 0..k need to change `x`,
// the rest are dummy ones performed for the sake of constant-timeness.
let within_range = CtChoice::from_usize_lt(i, k);

// X_i = b_i mod 2
let x_i = b.limbs[0].0 & 1;
let x_i_choice = CtChoice::from_lsb(x_i);
// b_{i+1} = (b_i - a * X_i) / 2
b = Self::ct_select(&b, &b.wrapping_sub(self), x_i_choice).shr_vartime(1);

// Store the X_i bit in the result (x = x | (1 << X_i))
// Don't change the result in dummy iterations.
let x_i_choice = x_i_choice.and(within_range);
x = x.set_bit(i, x_i_choice);

let t = b.wrapping_sub(self);
b = Self::ct_select(&b, &t, CtChoice::from_lsb(j)).shr_vartime(1);
i += 1;
}

x
}

Expand Down Expand Up @@ -97,10 +137,45 @@ impl<const LIMBS: usize> Uint<LIMBS> {
}

/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
/// Returns `(inverse, Word::MAX)` if an inverse exists, otherwise `(undefined, Word::ZERO)`.
/// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
/// otherwise `(undefined, CtChoice::FALSE)`.
pub const fn inv_odd_mod(&self, modulus: &Self) -> (Self, CtChoice) {
self.inv_odd_mod_bounded(modulus, Uint::<LIMBS>::BITS, Uint::<LIMBS>::BITS)
}

/// Computes the multiplicative inverse of `self` mod `modulus`.
/// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
/// otherwise `(undefined, CtChoice::FALSE)`.
pub fn inv_mod(&self, modulus: &Self) -> (Self, CtChoice) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized, it could have been made const.

// Decompose `modulus = s * 2^k` where `s` is odd
let k = modulus.trailing_zeros();
let s = modulus.shr(k);

// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
let (a, a_is_some) = self.inv_odd_mod(&s);
let b = self.inv_mod2k(k);
// inverse modulo 2^k exists either if `k` is 0 or if `self` is odd.
let b_is_some = CtChoice::from_usize_being_nonzero(k)
.not()
.or(self.ct_is_odd());

// Restore from RNS:
// self^{-1} = a mod s = b mod 2^k
// => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
// (essentially one step of the Garner's algorithm for recovery from RNS).

let m_odd_inv = s.inv_mod2k(k); // `s` is odd, so this always exists

// This part is mod 2^k
let mask = (Uint::ONE << k).wrapping_sub(&Uint::ONE);
let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)) & mask;

// Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
// so `a + s * t <= s * 2^k - 1 == modulus - 1`.
let result = a.wrapping_add(&s.wrapping_mul(&t));
(result, a_is_some.and(b_is_some))
}
}

#[cfg(test)]
Expand All @@ -125,7 +200,7 @@ mod tests {
}

#[test]
fn test_invert() {
fn test_invert_odd() {
let a = U1024::from_be_hex(concat![
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
Expand All @@ -138,15 +213,45 @@ mod tests {
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
]);

let (res, is_some) = a.inv_odd_mod(&m);

let expected = U1024::from_be_hex(concat![
"B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
"D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
"88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
"3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
]);

let (res, is_some) = a.inv_odd_mod(&m);
assert!(is_some.is_true_vartime());
assert_eq!(res, expected);

// Even though it is less efficient, it still works
let (res, is_some) = a.inv_mod(&m);
assert!(is_some.is_true_vartime());
assert_eq!(res, expected);
}

#[test]
fn test_invert_even() {
let a = U1024::from_be_hex(concat![
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
"BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
"382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
]);
let m = U1024::from_be_hex(concat![
"D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
"37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
]);
let expected = U1024::from_be_hex(concat![
"1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
"DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
"FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
"5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
]);

let (res, is_some) = a.inv_mod(&m);
assert!(is_some.is_true_vartime());
assert_eq!(res, expected);
}
Expand Down
2 changes: 1 addition & 1 deletion src/uint/modular/constant_mod/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ macro_rules! impl_modulus {
const MOD_NEG_INV: $crate::Limb = $crate::Limb(
$crate::Word::MIN.wrapping_sub(
Self::MODULUS
.inv_mod2k($crate::Word::BITS as usize)
.inv_mod2k_vartime($crate::Word::BITS as usize)
.as_limbs()[0]
.0,
),
Expand Down
5 changes: 3 additions & 2 deletions src/uint/modular/runtime_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ impl<const LIMBS: usize> DynResidueParams<LIMBS> {
// Since we are calculating the inverse modulo (Word::MAX+1),
// we can take the modulo right away and calculate the inverse of the first limb only.
let modulus_lo = Uint::<1>::from_words([modulus.limbs[0].0]);
let mod_neg_inv =
Limb(Word::MIN.wrapping_sub(modulus_lo.inv_mod2k(Word::BITS as usize).limbs[0].0));
let mod_neg_inv = Limb(
Word::MIN.wrapping_sub(modulus_lo.inv_mod2k_vartime(Word::BITS as usize).limbs[0].0),
);

let r3 = montgomery_reduction(&r2.square_wide(), modulus, mod_neg_inv);

Expand Down
Loading