diff --git a/benches/uint.rs b/benches/uint.rs index 3bb2a961..7d98b470 100644 --- a/benches/uint.rs +++ b/benches/uint.rs @@ -1,7 +1,10 @@ -use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use criterion::measurement::WallTime; +use criterion::{ + black_box, criterion_group, criterion_main, BatchSize, BenchmarkGroup, BenchmarkId, Criterion, +}; use crypto_bigint::{ - Limb, NonZero, Odd, Random, RandomBits, RandomMod, Reciprocal, Uint, U1024, U128, U2048, U256, - U4096, U512, + Limb, NonZero, Odd, Random, RandomBits, RandomMod, Reciprocal, Uint, U1024, U128, U16384, + U2048, U256, U4096, U512, U8192, }; use rand_chacha::ChaCha8Rng; use rand_core::{OsRng, RngCore, SeedableRng}; @@ -332,78 +335,106 @@ fn bench_gcd(c: &mut Criterion) { group.finish(); } -fn bench_shl(c: &mut Criterion) { - let mut group = c.benchmark_group("left shift"); - - group.bench_function("shl_vartime, small, U2048", |b| { +fn shl_benchmarks(group: &mut BenchmarkGroup) { + group.bench_function(BenchmarkId::new("shl_vartime, small", LIMBS), |b| { b.iter_batched( - || U2048::ONE, + || Uint::::ONE, |x| x.overflowing_shl_vartime(10), BatchSize::SmallInput, ) }); - - group.bench_function("shl_vartime, large, U2048", |b| { + group.bench_function(BenchmarkId::new("shl_vartime, large", LIMBS), |b| { b.iter_batched( - || U2048::ONE, - |x| black_box(x.overflowing_shl_vartime(1024 + 10)), + || Uint::::ONE, + |x| black_box(x.overflowing_shl_vartime(Uint::::BITS / 2 + 10)), BatchSize::SmallInput, ) }); - - group.bench_function("shl_vartime_wide, large, U2048", |b| { + group.bench_function(BenchmarkId::new("shl_vartime_wide, large", LIMBS), |b| { b.iter_batched( - || (U2048::ONE, U2048::ONE), - |x| Uint::overflowing_shl_vartime_wide(x, 1024 + 10), + || (Uint::::ONE, Uint::::ONE), + |x| Uint::overflowing_shl_vartime_wide(x, Uint::::BITS / 2 + 10), BatchSize::SmallInput, ) }); - - group.bench_function("shl, U2048", |b| { + group.bench_function(BenchmarkId::new("shl, small", LIMBS), |b| { + b.iter_batched( + || Uint::::ONE, + |x| x.overflowing_shl(10), + BatchSize::SmallInput, + ) + }); + group.bench_function(BenchmarkId::new("shl, large", LIMBS), |b| { b.iter_batched( - || U2048::ONE, - |x| x.overflowing_shl(1024 + 10), + || Uint::::ONE, + |x| x.overflowing_shl(Uint::::BITS / 2 + 10), BatchSize::SmallInput, ) }); +} + +fn bench_shl(c: &mut Criterion) { + let mut group = c.benchmark_group("left shift"); + + shl_benchmarks::<{ U256::LIMBS }>(&mut group); + shl_benchmarks::<{ U512::LIMBS }>(&mut group); + shl_benchmarks::<{ U1024::LIMBS }>(&mut group); + shl_benchmarks::<{ U2048::LIMBS }>(&mut group); + shl_benchmarks::<{ U4096::LIMBS }>(&mut group); + shl_benchmarks::<{ U8192::LIMBS }>(&mut group); + shl_benchmarks::<{ U16384::LIMBS }>(&mut group); group.finish(); } -fn bench_shr(c: &mut Criterion) { - let mut group = c.benchmark_group("right shift"); - - group.bench_function("shr_vartime, small, U2048", |b| { +fn shr_benchmarks(group: &mut BenchmarkGroup) { + group.bench_function(BenchmarkId::new("shr_vartime, small", LIMBS), |b| { b.iter_batched( - || U2048::ONE, + || Uint::::ONE, |x| x.overflowing_shr_vartime(10), BatchSize::SmallInput, ) }); - - group.bench_function("shr_vartime, large, U2048", |b| { + group.bench_function(BenchmarkId::new("shr_vartime, large", LIMBS), |b| { b.iter_batched( - || U2048::ONE, - |x| x.overflowing_shr_vartime(1024 + 10), + || Uint::::ONE, + |x| x.overflowing_shr_vartime(Uint::::BITS / 2 + 10), BatchSize::SmallInput, ) }); - - group.bench_function("shr_vartime_wide, large, U2048", |b| { + group.bench_function(BenchmarkId::new("shr_vartime_wide, large", LIMBS), |b| { b.iter_batched( - || (U2048::ONE, U2048::ONE), - |x| Uint::overflowing_shr_vartime_wide(x, 1024 + 10), + || (Uint::::ONE, Uint::::ONE), + |x| Uint::overflowing_shr_vartime_wide(x, Uint::::BITS / 2 + 10), BatchSize::SmallInput, ) }); - - group.bench_function("shr, U2048", |b| { + group.bench_function(BenchmarkId::new("shr, small", LIMBS), |b| { b.iter_batched( - || U2048::ONE, - |x| x.overflowing_shr(1024 + 10), + || Uint::::ONE, + |x| x.overflowing_shr(10), BatchSize::SmallInput, ) }); + group.bench_function(BenchmarkId::new("shr, large", LIMBS), |b| { + b.iter_batched( + || Uint::::ONE, + |x| x.overflowing_shr(Uint::::BITS / 2 + 10), + BatchSize::SmallInput, + ) + }); +} + +fn bench_shr(c: &mut Criterion) { + let mut group = c.benchmark_group("right shift"); + + shr_benchmarks::<{ U256::LIMBS }>(&mut group); + shr_benchmarks::<{ U512::LIMBS }>(&mut group); + shr_benchmarks::<{ U1024::LIMBS }>(&mut group); + shr_benchmarks::<{ U2048::LIMBS }>(&mut group); + shr_benchmarks::<{ U4096::LIMBS }>(&mut group); + shr_benchmarks::<{ U8192::LIMBS }>(&mut group); + shr_benchmarks::<{ U16384::LIMBS }>(&mut group); group.finish(); } diff --git a/src/limb/shl.rs b/src/limb/shl.rs index ebcc6dcc..850bf64d 100644 --- a/src/limb/shl.rs +++ b/src/limb/shl.rs @@ -1,6 +1,6 @@ //! Limb left bitshift -use crate::Limb; +use crate::{ConstChoice, Limb}; use core::ops::{Shl, ShlAssign}; use num_traits::WrappingShl; @@ -12,10 +12,21 @@ impl Limb { Limb(self.0 << shift) } - /// Computes `self << 1` and return the result and the carry (0 or 1). + /// Computes `self << shift` and returns the result as well as the carry: the `shift` _least_ + /// significant bits of the `carry` are equal to the `shift` _most_ significant bits of `self`. + /// + /// Panics if `shift` overflows `Limb::BITS`. #[inline(always)] - pub(crate) const fn shl1(self) -> (Self, Self) { - (Self(self.0 << 1), Self(self.0 >> Self::HI_BIT)) + pub const fn carrying_shl(self, shift: u32) -> (Self, Self) { + // Note that we can compute carry = self >> (Self::BITS - shift) whenever shift > 0. + // However, we need to account for the case that shift = 0: + // - the carry should be 0, and + // - the value by which carry is left shifted should be made to be < Self::BITS. + let shift_is_zero = ConstChoice::from_u32_eq(shift, 0); + let carry = Self::select(self, Self::ZERO, shift_is_zero); + let left_shift = shift_is_zero.select_u32(Self::BITS - shift, 0); + + (self.shl(shift), carry.shr(left_shift)) } } diff --git a/src/limb/shr.rs b/src/limb/shr.rs index 08549732..6d3d0370 100644 --- a/src/limb/shr.rs +++ b/src/limb/shr.rs @@ -1,6 +1,6 @@ //! Limb right bitshift -use crate::{Limb, WrappingShr}; +use crate::{ConstChoice, Limb, WrappingShr}; use core::ops::{Shr, ShrAssign}; impl Limb { @@ -11,10 +11,21 @@ impl Limb { Limb(self.0 >> shift) } - /// Computes `self >> 1` and return the result and the carry (0 or `1 << HI_BIT`). + /// Computes `self >> shift` and returns the result as well as the carry: the `shift` _most_ + /// significant bits of the `carry` are equal to the `shift` _least_ significant bits of `self`. + /// + /// Panics if `shift` overflows `Limb::BITS`. #[inline(always)] - pub(crate) const fn shr1(self) -> (Self, Self) { - (Self(self.0 >> 1), Self(self.0 << Self::HI_BIT)) + pub const fn carrying_shr(self, shift: u32) -> (Self, Self) { + // Note that we can compute carry = self << (Self::BITS - shift) whenever shift > 0. + // However, we need to account for the case that shift = 0: + // - the carry should be 0, and + // - the value by which carry is left shifted should be made to be < Self::BITS. + let shift_is_zero = ConstChoice::from_u32_eq(shift, 0); + let carry = Self::select(self, Self::ZERO, shift_is_zero); + let left_shift = shift_is_zero.select_u32(Self::BITS - shift, 0); + + (self.shr(shift), carry.shl(left_shift)) } } diff --git a/src/uint.rs b/src/uint.rs index fe2208d2..d13e1b58 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -165,7 +165,7 @@ impl Uint { } /// Borrow the limbs of this [`Uint`] mutably. - pub fn as_limbs_mut(&mut self) -> &mut [Limb; LIMBS] { + pub const fn as_limbs_mut(&mut self) -> &mut [Limb; LIMBS] { &mut self.limbs } diff --git a/src/uint/add_mod.rs b/src/uint/add_mod.rs index 506f57c2..a3af579f 100644 --- a/src/uint/add_mod.rs +++ b/src/uint/add_mod.rs @@ -38,7 +38,7 @@ impl Uint { /// /// Assumes `self` as unbounded integer is `< p`. pub const fn double_mod(&self, p: &Self) -> Self { - let (w, carry) = self.overflowing_shl1(); + let (w, carry) = self.carrying_shl1(); // Attempt to subtract the modulus, to ensure the result is in the field. let (w, borrow) = w.sbb(p, Limb::ZERO); diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 5ad17a8e..5206c294 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -25,22 +25,62 @@ impl Uint { /// /// Returns `None` if `shift >= Self::BITS`. pub const fn overflowing_shl(&self, shift: u32) -> ConstCtOption { - // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift` - // (which lies in range `0 <= shift < BITS`). - let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros(); - let overflow = ConstChoice::from_u32_lt(shift, Self::BITS).not(); - let shift = shift % Self::BITS; + let (intra_limb_shift, limb_shift) = Self::decompose_shift(shift); + self.intra_limb_carrying_shl(intra_limb_shift) + .0 + .full_limb_overflowing_shl(limb_shift) + } + + /// Computes `self << shift`, for `shift < Limb::BITS`. Also returns a [Limb] containing the + /// `carry`. + /// + /// Panics if `shift >= Limb::BITS`. + #[inline(always)] + const fn intra_limb_carrying_shl(&self, shift: u32) -> (Self, Limb) { + debug_assert!(shift < Limb::BITS); + + let (mut result, mut carry) = (*self, Limb::ZERO); + + let limbs = result.as_limbs_mut(); + let mut i = 0; + while i < limbs.len() { + let (shifted, new_carry) = limbs[i].carrying_shl(shift); + limbs[i] = shifted.bitxor(carry); + carry = new_carry; + + i += 1; + } + + (result, carry) + } + + /// Compute `self << (Limb::BITS * limb_shift)`, for `limb_shift < Self::LIMBS`. + /// In other words, shift `self` left by `limb_shift` full limbs. + /// + /// Returns `None` if `limb_shift >= Self::LIMBS`. + #[inline] + pub const fn full_limb_overflowing_shl(&self, limb_shift: u32) -> ConstCtOption { + let shift_bits = u32::BITS - (LIMBS as u32 - 1).leading_zeros(); + let overflow = ConstChoice::from_u32_lt(limb_shift, LIMBS as u32).not(); + let limb_shift = limb_shift % LIMBS as u32; + let mut result = *self; let mut i = 0; while i < shift_bits { - let bit = ConstChoice::from_u32_lsb((shift >> i) & 1); - result = Uint::select( - &result, - &result - .overflowing_shl_vartime(1 << i) - .expect("shift within range"), - bit, - ); + let bit = ConstChoice::from_u32_lsb((limb_shift >> i) & 1); + + let mut j = Self::LIMBS; + let limbs = result.as_limbs_mut(); + let offset = 1 << i; + while j > offset { + j -= 1; + limbs[j] = Limb::select(limbs[j], limbs[j - offset], bit); + } + while j > 0 { + j -= 1; + limbs[j] = Limb::select(limbs[j], Limb::ZERO, bit); + } + i += 1; } @@ -63,31 +103,36 @@ impl Uint { return ConstCtOption::none(Self::ZERO); } - let shift_num = (shift / Limb::BITS) as usize; - let rem = shift % Limb::BITS; - + let (rem, shift_num) = Self::decompose_shift(shift); + let shift_num = shift_num as usize; let mut i = shift_num; while i < LIMBS { limbs[i] = self.limbs[i - shift_num]; i += 1; } - if rem == 0 { - return ConstCtOption::some(Self { limbs }); - } - - let mut carry = Limb::ZERO; - - let mut i = shift_num; - while i < LIMBS { - let shifted = limbs[i].shl(rem); - let new_carry = limbs[i].shr(Limb::BITS - rem); - limbs[i] = shifted.bitor(carry); - carry = new_carry; - i += 1; + let mut shifted = Self { limbs }; + if rem != 0 { + shifted = shifted.intra_limb_carrying_shl(rem).0; } + ConstCtOption::some(shifted) + } - ConstCtOption::some(Self { limbs }) + /// Split `shift` into `shift % Limb::BITS` (its intra-limb-shift component), and + /// `shift / Limb::BITS` (its limb-shift component). + /// + /// This function achieves this without using a division/remainder operation. + pub(crate) const fn decompose_shift(shift: u32) -> (u32, u32) { + // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift` + // (which lies in range `0 <= shift < BITS`). + // + // Split shift into (shift % Limb::BITS, shift / Limb::BITS) + // Since Limb::BITS is known to be a power of two, this can also be computed as follows: + let limb_bits_bits = u32::BITS - (Limb::BITS - 1).leading_zeros(); + let intra_limb_shift = shift & (Limb::BITS - 1); + let limb_shift = shift >> limb_bits_bits; + + (intra_limb_shift, limb_shift) } /// Computes a left shift on a wide input as `(lo, hi)`. @@ -161,21 +206,11 @@ impl Uint { (Uint::::new(limbs), Limb(carry)) } - /// Computes `self << 1` in constant-time, returning [`ConstChoice::TRUE`] - /// if the most significant bit was set, and [`ConstChoice::FALSE`] otherwise. + /// Computes `self << 1` in constant-time, furthermore returning a [Limb] containing the + /// `carry`. #[inline(always)] - pub(crate) const fn overflowing_shl1(&self) -> (Self, Limb) { - let mut ret = Self::ZERO; - let mut i = 0; - let mut carry = Limb::ZERO; - while i < LIMBS { - let (shifted, new_carry) = self.limbs[i].shl1(); - ret.limbs[i] = shifted.bitor(carry); - carry = new_carry; - i += 1; - } - - (ret, carry) + pub(crate) const fn carrying_shl1(&self) -> (Self, Limb) { + self.intra_limb_carrying_shl(1) } } @@ -259,7 +294,7 @@ mod tests { #[test] fn shl1() { assert_eq!(N << 1, TWO_N); - assert_eq!(N.overflowing_shl1(), (TWO_N, Limb::ONE)); + assert_eq!(N.carrying_shl1(), (TWO_N, Limb::ONE)); } #[test] diff --git a/src/uint/shr.rs b/src/uint/shr.rs index 0212b570..1cfd1a5e 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -25,22 +25,61 @@ impl Uint { /// /// Returns `None` if `shift >= Self::BITS`. pub const fn overflowing_shr(&self, shift: u32) -> ConstCtOption { - // `floor(log2(BITS - 1))` is the number of bits in the representation of `shift` - // (which lies in range `0 <= shift < BITS`). - let shift_bits = u32::BITS - (Self::BITS - 1).leading_zeros(); - let overflow = ConstChoice::from_u32_lt(shift, Self::BITS).not(); - let shift = shift % Self::BITS; + let (intra_limb_shift, limb_shift) = Self::decompose_shift(shift); + self.intra_limb_carrying_shr(intra_limb_shift) + .0 + .full_limb_overflowing_shr(limb_shift) + } + + /// Computes `self >> shift` for `shift < Limb::BITS`. Also returns a [Limb] containing the + /// `carry`. + /// + /// Panics if `shift >= Limb::BITS`. + #[inline(always)] + const fn intra_limb_carrying_shr(&self, shift: u32) -> (Self, Limb) { + debug_assert!(shift < Limb::BITS); + + let (mut result, mut carry) = (*self, Limb::ZERO); + + let limbs = result.as_limbs_mut(); + let mut i = limbs.len(); + while i > 0 { + i -= 1; + let (shifted, new_carry) = limbs[i].carrying_shr(shift); + limbs[i] = shifted.bitxor(carry); + carry = new_carry; + } + + (result, carry) + } + + /// Compute `self >> (Limb::BITS * limb_shift)`, for `limb_shift < Self::LIMBS`. + /// In other words, shift `self` right by `limb_shift` full limbs. + /// + /// Returns `None` if `limb_shift >= Self::LIMBS`. + #[inline] + pub const fn full_limb_overflowing_shr(&self, limb_shift: u32) -> ConstCtOption { + let shift_bits = u32::BITS - (LIMBS as u32 - 1).leading_zeros(); + let overflow = ConstChoice::from_u32_lt(limb_shift, LIMBS as u32).not(); + let limb_shift = limb_shift % LIMBS as u32; + let mut result = *self; let mut i = 0; while i < shift_bits { - let bit = ConstChoice::from_u32_lsb((shift >> i) & 1); - result = Uint::select( - &result, - &result - .overflowing_shr_vartime(1 << i) - .expect("shift within range"), - bit, - ); + let bit = ConstChoice::from_u32_lsb((limb_shift >> i) & 1); + + let mut j = 0; + let limbs = result.as_limbs_mut(); + let offset = 1 << i; + while j < Self::LIMBS.saturating_sub(offset) { + limbs[j] = Limb::select(limbs[j], limbs[j + offset], bit); + j += 1; + } + while j < Self::LIMBS { + limbs[j] = Limb::select(limbs[j], Limb::ZERO, bit); + j += 1; + } + i += 1; } @@ -63,30 +102,19 @@ impl Uint { return ConstCtOption::none(Self::ZERO); } - let shift_num = (shift / Limb::BITS) as usize; - let rem = shift % Limb::BITS; - + let (rem, shift_num) = Self::decompose_shift(shift); + let shift_num = shift_num as usize; let mut i = 0; while i < LIMBS - shift_num { limbs[i] = self.limbs[i + shift_num]; i += 1; } - if rem == 0 { - return ConstCtOption::some(Self { limbs }); + let mut shifted = Self { limbs }; + if rem != 0 { + shifted = shifted.intra_limb_carrying_shr(rem).0; } - - let mut carry = Limb::ZERO; - - while i > 0 { - i -= 1; - let shifted = limbs[i].shr(rem); - let new_carry = limbs[i].shl(Limb::BITS - rem); - limbs[i] = shifted.bitor(carry); - carry = new_carry; - } - - ConstCtOption::some(Self { limbs }) + ConstCtOption::some(shifted) } /// Computes a right shift on a wide input as `(lo, hi)`. @@ -145,16 +173,7 @@ impl Uint { /// if the least significant bit was set, and [`ConstChoice::FALSE`] otherwise. #[inline(always)] pub(crate) const fn shr1_with_carry(&self) -> (Self, ConstChoice) { - let mut ret = Self::ZERO; - let mut i = LIMBS; - let mut carry = Limb::ZERO; - while i > 0 { - i -= 1; - let (shifted, new_carry) = self.limbs[i].shr1(); - ret.limbs[i] = shifted.bitor(carry); - carry = new_carry; - } - + let (ret, carry) = self.intra_limb_carrying_shr(1); (ret, ConstChoice::from_word_lsb(carry.0 >> Limb::HI_BIT)) } } diff --git a/tests/limb.rs b/tests/limb.rs new file mode 100644 index 00000000..66e37028 --- /dev/null +++ b/tests/limb.rs @@ -0,0 +1,37 @@ +use crypto_bigint::{Limb, Word}; +use proptest::prelude::*; + +prop_compose! { + fn limb()(x in any::()) -> Limb { + Limb::from(x) + } +} +proptest! { + #[test] + fn carrying_shr_doesnt_panic(limb in limb(), shift in 0..32u32) { + limb.carrying_shr(shift); + } + + #[test] + fn carrying_shr(limb in limb(), shift in 0..32u32) { + if shift == 0 { + assert_eq!(limb.carrying_shr(shift), (limb, Limb::ZERO)); + } else { + assert_eq!(limb.carrying_shr(shift), (limb.shr(shift), limb.shl(Limb::BITS - shift))); + } + } + + #[test] + fn carrying_shl_doesnt_panic(limb in limb(), shift in 0..32u32) { + limb.carrying_shl(shift); + } + + #[test] + fn carrying_shl(limb in limb(), shift in 0..32u32) { + if shift == 0 { + assert_eq!(limb.carrying_shl(shift), (limb, Limb::ZERO)); + } else { + assert_eq!(limb.carrying_shl(shift), (limb.shl(shift), limb.shr(Limb::BITS - shift))); + } + } +}