Skip to content

Commit

Permalink
Add inv_mod() that supports any moduli
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Aug 28, 2023
1 parent cf4eb1d commit 1a13201
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 6 deletions.
54 changes: 54 additions & 0 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,59 @@ fn bench_shifts<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
});
}

fn bench_inv_mod<M: Measurement>(group: &mut BenchmarkGroup<'_, M>) {
group.bench_function("inv_odd_mod, U256", |b| {
b.iter_batched(
|| {
let m = U256::random(&mut OsRng) | U256::ONE;
loop {
let x = U256::random(&mut OsRng);
let (_, is_some) = x.inv_odd_mod(&m);
if is_some.into() {
break (x, m);
}
}
},
|(x, m)| x.inv_odd_mod(&m),
BatchSize::SmallInput,
)
});

group.bench_function("inv_mod, U256, odd modulus", |b| {
b.iter_batched(
|| {
let m = U256::random(&mut OsRng) | U256::ONE;
loop {
let x = U256::random(&mut OsRng);
let (_, is_some) = x.inv_odd_mod(&m);
if is_some.into() {
break (x, m);
}
}
},
|(x, m)| x.inv_mod(&m),
BatchSize::SmallInput,
)
});

group.bench_function("inv_mod, U256", |b| {
b.iter_batched(
|| {
let m = U256::random(&mut OsRng);
loop {
let x = U256::random(&mut OsRng);
let (_, is_some) = x.inv_mod(&m);
if is_some.into() {
break (x, m);
}
}
},
|(x, m)| x.inv_mod(&m),
BatchSize::SmallInput,
)
});
}

fn bench_wrapping_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("wrapping ops");
bench_division(&mut group);
Expand All @@ -169,6 +222,7 @@ fn bench_montgomery(c: &mut Criterion) {
fn bench_modular_ops(c: &mut Criterion) {
let mut group = c.benchmark_group("modular ops");
bench_shifts(&mut group);
bench_inv_mod(&mut group);
group.finish();
}

Expand Down
4 changes: 4 additions & 0 deletions src/ct_choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ impl CtChoice {
Self(!self.0)
}

pub(crate) const fn or(&self, other: Self) -> Self {
Self(self.0 | other.0)
}

pub(crate) const fn and(&self, other: Self) -> Self {
Self(self.0 & other.0)
}
Expand Down
73 changes: 68 additions & 5 deletions src/uint/inv_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,43 @@ impl<const LIMBS: usize> Uint<LIMBS> {
}

/// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
/// Returns `(inverse, Word::MAX)` if an inverse exists, otherwise `(undefined, Word::ZERO)`.
/// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
/// otherwise `(undefined, CtChoice::FALSE)`.
pub const fn inv_odd_mod(&self, modulus: &Self) -> (Self, CtChoice) {
self.inv_odd_mod_bounded(modulus, Uint::<LIMBS>::BITS, Uint::<LIMBS>::BITS)
}

/// Computes the multiplicative inverse of `self` mod `modulus`.
/// Returns `(inverse, CtChoice::TRUE)` if an inverse exists,
/// otherwise `(undefined, CtChoice::FALSE)`.
pub fn inv_mod(&self, modulus: &Self) -> (Self, CtChoice) {
// Decompose `modulus = s * 2^k` where `s` is odd
let k = modulus.trailing_zeros();
let s = modulus.shr(k);

// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`.
let (a, a_is_some) = self.inv_odd_mod(&s);
let b = self.inv_mod2k(k);
// inverse modulo 2^k exists either if `k` is 0 or if `self` is odd.
let b_is_some = CtChoice::from_usize_being_nonzero(k)
.not()
.or(self.ct_is_odd());

// Restore from RNS:
// self^{-1} = a mod s = b mod 2^k
// => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
let m_odd_inv = s.inv_mod2k(k); // `s` is odd, so this always exists

// This part is mod 2^k
let mask = (Uint::ONE << k).wrapping_sub(&Uint::ONE);
let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)) & mask;

// Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
// so `a + s * t <= s * 2^k - 1 == modulus - 1`.
let result = a.wrapping_add(&s.wrapping_mul(&t));
(result, a_is_some.and(b_is_some))
}
}

#[cfg(test)]
Expand All @@ -165,7 +198,7 @@ mod tests {
}

#[test]
fn test_invert() {
fn test_invert_odd() {
let a = U1024::from_be_hex(concat![
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
Expand All @@ -178,15 +211,45 @@ mod tests {
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
]);

let (res, is_some) = a.inv_odd_mod(&m);

let expected = U1024::from_be_hex(concat![
"B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
"D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
"88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
"3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
]);

let (res, is_some) = a.inv_odd_mod(&m);
assert!(is_some.is_true_vartime());
assert_eq!(res, expected);

// Even though it is less efficient, it still works
let (res, is_some) = a.inv_mod(&m);
assert!(is_some.is_true_vartime());
assert_eq!(res, expected);
}

#[test]
fn test_invert_even() {
let a = U1024::from_be_hex(concat![
"000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
"347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
"BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
"382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
]);
let m = U1024::from_be_hex(concat![
"D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
"37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
"D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
"558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
]);
let expected = U1024::from_be_hex(concat![
"1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
"DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
"FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
"5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
]);

let (res, is_some) = a.inv_mod(&m);
assert!(is_some.is_true_vartime());
assert_eq!(res, expected);
}
Expand Down
19 changes: 18 additions & 1 deletion tests/proptests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crypto_bigint::{
modular::runtime_mod::{DynResidue, DynResidueParams},
Encoding, Limb, NonZero, Word, U256,
CtChoice, Encoding, Limb, NonZero, Word, U256,
};
use num_bigint::BigUint;
use num_integer::Integer;
Expand Down Expand Up @@ -233,6 +233,23 @@ proptest! {
}
}

#[test]
fn inv_mod(a in uint(), b in uint()) {
let a_bi = to_biguint(&a);
let b_bi = to_biguint(&b);

let expected_is_some = if a_bi.gcd(&b_bi) == BigUint::one() { CtChoice::TRUE } else { CtChoice::FALSE };
let (actual, actual_is_some) = a.inv_mod(&b);

assert_eq!(bool::from(expected_is_some), bool::from(actual_is_some));

if actual_is_some.into() {
let inv_bi = to_biguint(&actual);
let res = (inv_bi * a_bi) % b_bi;
assert_eq!(res, BigUint::one());
}
}

#[test]
fn wrapping_sqrt(a in uint()) {
let a_bi = to_biguint(&a);
Expand Down

0 comments on commit 1a13201

Please sign in to comment.