Skip to content

Commit

Permalink
Align with core/std on overflowing_sh* (#430)
Browse files Browse the repository at this point in the history
Changes all shift functions which return an overflow flag (as a `Choice`
or `ConstChoice`) to use the `overflowing_sh*` name prefix, which aligns
with similar APIs in `core`/`std`.

In their place, adds new `Uint::{shl, shr}` functions which provide the
trait-like behavior (i.e. panic on overflow) but work in `const fn`
contexts (and can now panic at compile time on overflow).
  • Loading branch information
tarcieri authored Dec 15, 2023
1 parent cc3f984 commit d5e00ba
Show file tree
Hide file tree
Showing 14 changed files with 133 additions and 93 deletions.
4 changes: 2 additions & 2 deletions benches/boxed_uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn bench_shifts(c: &mut Criterion) {
group.bench_function("shl", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| x.shl(UINT_BITS / 2 + 10),
|x| x.overflowing_shl(UINT_BITS / 2 + 10),
BatchSize::SmallInput,
)
});
Expand All @@ -35,7 +35,7 @@ fn bench_shifts(c: &mut Criterion) {
group.bench_function("shr", |b| {
b.iter_batched(
|| BoxedUint::random(&mut OsRng, UINT_BITS),
|x| x.shr(UINT_BITS / 2 + 10),
|x| x.overflowing_shr(UINT_BITS / 2 + 10),
BatchSize::SmallInput,
)
});
Expand Down
32 changes: 24 additions & 8 deletions benches/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,35 @@ fn bench_shl(c: &mut Criterion) {
let mut group = c.benchmark_group("left shift");

group.bench_function("shl_vartime, small, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shl_vartime(10), BatchSize::SmallInput)
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shl_vartime(10),
BatchSize::SmallInput,
)
});

group.bench_function("shl_vartime, large, U2048", |b| {
b.iter_batched(
|| U2048::ONE,
|x| black_box(x.shl_vartime(1024 + 10)),
|x| black_box(x.overflowing_shl_vartime(1024 + 10)),
BatchSize::SmallInput,
)
});

group.bench_function("shl_vartime_wide, large, U2048", |b| {
b.iter_batched(
|| (U2048::ONE, U2048::ONE),
|x| Uint::shl_vartime_wide(x, 1024 + 10),
|x| Uint::overflowing_shl_vartime_wide(x, 1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shl, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shl(1024 + 10), BatchSize::SmallInput)
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shl(1024 + 10),
BatchSize::SmallInput,
)
});

group.finish();
Expand All @@ -109,27 +117,35 @@ fn bench_shr(c: &mut Criterion) {
let mut group = c.benchmark_group("right shift");

group.bench_function("shr_vartime, small, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shr_vartime(10), BatchSize::SmallInput)
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shr_vartime(10),
BatchSize::SmallInput,
)
});

group.bench_function("shr_vartime, large, U2048", |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.shr_vartime(1024 + 10),
|x| x.overflowing_shr_vartime(1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr_vartime_wide, large, U2048", |b| {
b.iter_batched(
|| (U2048::ONE, U2048::ONE),
|x| Uint::shr_vartime_wide(x, 1024 + 10),
|x| Uint::overflowing_shr_vartime_wide(x, 1024 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr, U2048", |b| {
b.iter_batched(|| U2048::ONE, |x| x.shr(1024 + 10), BatchSize::SmallInput)
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shr(1024 + 10),
BatchSize::SmallInput,
)
});

group.finish();
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
//! U256::from_be_hex("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551");
//!
//! // Compute `MODULUS` shifted right by 1 at compile time
//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1).0;
//! pub const MODULUS_SHR1: U256 = MODULUS.shr(1);
//! ```
//!
//! Note that large constant computations may accidentally trigger a the `const_eval_limit` of the compiler.
Expand Down
2 changes: 1 addition & 1 deletion src/uint/boxed/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl BoxedUint {
let bits_precision = self.bits_precision();
let mut rem = self.clone();
let mut quo = Self::zero_with_precision(bits_precision);
let (mut c, _overflow) = rhs.shl(bits_precision - mb);
let (mut c, _overflow) = rhs.overflowing_shl(bits_precision - mb);
let mut i = bits_precision;
let mut done = Choice::from(0u8);

Expand Down
8 changes: 4 additions & 4 deletions src/uint/boxed/shl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ impl BoxedUint {
///
/// Returns a zero and a truthy `Choice` if `shift >= self.bits_precision()`,
/// or the result and a falsy `Choice` otherwise.
pub fn shl(&self, shift: u32) -> (Self, Choice) {
pub fn overflowing_shl(&self, shift: u32) -> (Self, Choice) {
let mut result = self.clone();
let overflow = result.overflowing_shl_assign(shift);
(result, overflow)
Expand Down Expand Up @@ -125,7 +125,7 @@ impl Shl<u32> for &BoxedUint {
type Output = BoxedUint;

fn shl(self, shift: u32) -> BoxedUint {
let (result, overflow) = self.shl(shift);
let (result, overflow) = self.overflowing_shl(shift);
assert!(!bool::from(overflow), "attempt to shift left with overflow");
result
}
Expand Down Expand Up @@ -154,8 +154,8 @@ mod tests {
fn shl() {
let one = BoxedUint::one_with_precision(128);

assert_eq!(BoxedUint::from(2u8), one.shl(1).0);
assert_eq!(BoxedUint::from(4u8), one.shl(2).0);
assert_eq!(BoxedUint::from(2u8), &one << 1);
assert_eq!(BoxedUint::from(4u8), &one << 2);
assert_eq!(
BoxedUint::from(0x80000000000000000u128),
one.shl_vartime(67).unwrap()
Expand Down
12 changes: 6 additions & 6 deletions src/uint/boxed/shr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ impl BoxedUint {
///
/// Returns a zero and a truthy `Choice` if `shift >= self.bits_precision()`,
/// or the result and a falsy `Choice` otherwise.
pub fn shr(&self, shift: u32) -> (Self, Choice) {
pub fn overflowing_shr(&self, shift: u32) -> (Self, Choice) {
let mut result = self.clone();
let overflow = result.overflowing_shr_assign(shift);
(result, overflow)
Expand Down Expand Up @@ -129,7 +129,7 @@ impl Shr<u32> for &BoxedUint {
type Output = BoxedUint;

fn shr(self, shift: u32) -> BoxedUint {
let (result, overflow) = self.shr(shift);
let (result, overflow) = self.overflowing_shr(shift);
assert!(
!bool::from(overflow),
"attempt to shift right with overflow"
Expand Down Expand Up @@ -163,10 +163,10 @@ mod tests {
#[test]
fn shr() {
let n = BoxedUint::from(0x80000000000000000u128);
assert_eq!(BoxedUint::zero(), n.shr(68).0);
assert_eq!(BoxedUint::one(), n.shr(67).0);
assert_eq!(BoxedUint::from(2u8), n.shr(66).0);
assert_eq!(BoxedUint::from(4u8), n.shr(65).0);
assert_eq!(BoxedUint::zero(), &n >> 68);
assert_eq!(BoxedUint::one(), &n >> 67);
assert_eq!(BoxedUint::from(2u8), &n >> 66);
assert_eq!(BoxedUint::from(4u8), &n >> 65);
}

#[test]
Expand Down
16 changes: 8 additions & 8 deletions src/uint/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let mut rem = *self;
let mut quo = Self::ZERO;
// If there is overflow, it means `mb == 0`, so `rhs == 0`.
let (mut c, _overflow) = rhs.0.shl(Self::BITS - mb);
let (mut c, _overflow) = rhs.0.overflowing_shl(Self::BITS - mb);

let mut i = Self::BITS;
let mut done = ConstChoice::FALSE;
Expand Down Expand Up @@ -64,7 +64,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let mut rem = *self;
let mut quo = Self::ZERO;
// If there is overflow, it means `mb == 0`, so `rhs == 0`.
let (mut c, _overflow) = rhs.0.shl_vartime(bd);
let (mut c, _overflow) = rhs.0.overflowing_shl_vartime(bd);

loop {
let (mut r, borrow) = rem.sbb(&c, Limb::ZERO);
Expand Down Expand Up @@ -92,7 +92,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let mb = rhs.0.bits_vartime();
let mut bd = Self::BITS - mb;
let mut rem = *self;
let (mut c, _overflow) = rhs.0.shl_vartime(bd);
let (mut c, _overflow) = rhs.0.overflowing_shl_vartime(bd);

loop {
let (r, borrow) = rem.sbb(&c, Limb::ZERO);
Expand Down Expand Up @@ -123,7 +123,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
let (mut lower, mut upper) = lower_upper;

// Factor of the modulus, split into two halves
let (mut c, _overflow) = Self::shl_vartime_wide((rhs.0, Uint::ZERO), bd);
let (mut c, _overflow) = Self::overflowing_shl_vartime_wide((rhs.0, Uint::ZERO), bd);

loop {
let (lower_sub, borrow) = lower.sbb(&c.0, Limb::ZERO);
Expand All @@ -135,7 +135,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
break;
}
bd -= 1;
let (new_c, _overflow) = Self::shr_vartime_wide(c, 1);
let (new_c, _overflow) = Self::overflowing_shr_vartime_wide(c, 1);
c = new_c;
}

Expand Down Expand Up @@ -634,8 +634,8 @@ mod tests {
fn div() {
let mut rng = ChaChaRng::from_seed([7u8; 32]);
for _ in 0..25 {
let (num, _) = U256::random(&mut rng).shr_vartime(128);
let den = NonZero::new(U256::random(&mut rng).shr_vartime(128).0).unwrap();
let (num, _) = U256::random(&mut rng).overflowing_shr_vartime(128);
let den = NonZero::new(U256::random(&mut rng).overflowing_shr_vartime(128).0).unwrap();
let n = num.checked_mul(den.as_ref());
if n.is_some().into() {
let (q, _) = n.unwrap().div_rem(&den);
Expand Down Expand Up @@ -724,7 +724,7 @@ mod tests {
for _ in 0..25 {
let num = U256::random(&mut rng);
let k = rng.next_u32() % 256;
let (den, _) = U256::ONE.shl_vartime(k);
let (den, _) = U256::ONE.overflowing_shl_vartime(k);

let a = num.rem2k(k);
let e = num.wrapping_rem(&den);
Expand Down
6 changes: 3 additions & 3 deletions src/uint/inv_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
// b_{i+1} = (b_i - a * X_i) / 2
b = Self::select(&b, &b.wrapping_sub(self), x_i_choice).shr1();
// Store the X_i bit in the result (x = x | (1 << X_i))
let (shifted, _overflow) = Uint::from_word(x_i).shl_vartime(i);
let (shifted, _overflow) = Uint::from_word(x_i).overflowing_shl_vartime(i);
x = x.bitor(&shifted);

i += 1;
Expand Down Expand Up @@ -162,7 +162,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
pub const fn inv_mod(&self, modulus: &Self) -> (Self, ConstChoice) {
// Decompose `modulus = s * 2^k` where `s` is odd
let k = modulus.trailing_zeros();
let (s, _overflow) = modulus.shr(k);
let (s, _overflow) = modulus.overflowing_shr(k);

// Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
// Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
Expand All @@ -178,7 +178,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {

// This part is mod 2^k
// Will not overflow since `modulus` is nonzero, and therefore `k < BITS`.
let (shifted, _overflow) = Uint::ONE.shl(k);
let (shifted, _overflow) = Uint::ONE.overflowing_shl(k);
let mask = shifted.wrapping_sub(&Uint::ONE);
let t = (b.wrapping_sub(&a).wrapping_mul(&m_odd_inv)).bitand(&mask);

Expand Down
2 changes: 1 addition & 1 deletion src/uint/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {

// Double the current result, this accounts for the other half of the multiplication grid.
// TODO: The top word is empty so we can also use a special purpose shl.
(lo, hi) = Self::shl_vartime_wide((lo, hi), 1).0;
(lo, hi) = Self::overflowing_shl_vartime_wide((lo, hi), 1).0;

// Handle the diagonal of the multiplication grid, which finishes the multiplication grid.
let mut carry = Limb::ZERO;
Expand Down
Loading

0 comments on commit d5e00ba

Please sign in to comment.