Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup Uint::shr #753

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 68 additions & 37 deletions benches/uint.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion};
use criterion::measurement::WallTime;
use criterion::{
black_box, criterion_group, criterion_main, BatchSize, BenchmarkGroup, BenchmarkId, Criterion,
};
use crypto_bigint::{
Limb, NonZero, Odd, Random, RandomBits, RandomMod, Reciprocal, Uint, U1024, U128, U2048, U256,
U4096, U512,
Limb, NonZero, Odd, Random, RandomBits, RandomMod, Reciprocal, Uint, U1024, U128, U16384,
U2048, U256, U4096, U512, U8192,
};
use rand_chacha::ChaCha8Rng;
use rand_core::{OsRng, RngCore, SeedableRng};
Expand Down Expand Up @@ -332,78 +335,106 @@ fn bench_gcd(c: &mut Criterion) {
group.finish();
}

fn bench_shl(c: &mut Criterion) {
let mut group = c.benchmark_group("left shift");

group.bench_function("shl_vartime, small, U2048", |b| {
fn shl_benchmarks<const LIMBS: usize>(group: &mut BenchmarkGroup<WallTime>) {
group.bench_function(BenchmarkId::new("shl_vartime, small", LIMBS), |b| {
b.iter_batched(
|| U2048::ONE,
|| Uint::<LIMBS>::ONE,
|x| x.overflowing_shl_vartime(10),
BatchSize::SmallInput,
)
});

group.bench_function("shl_vartime, large, U2048", |b| {
group.bench_function(BenchmarkId::new("shl_vartime, large", LIMBS), |b| {
b.iter_batched(
|| U2048::ONE,
|x| black_box(x.overflowing_shl_vartime(1024 + 10)),
|| Uint::<LIMBS>::ONE,
|x| black_box(x.overflowing_shl_vartime(Uint::<LIMBS>::BITS/2 + 10)),
BatchSize::SmallInput,
)
});

group.bench_function("shl_vartime_wide, large, U2048", |b| {
group.bench_function(BenchmarkId::new("shl_vartime_wide, large", LIMBS), |b| {
b.iter_batched(
|| (U2048::ONE, U2048::ONE),
|x| Uint::overflowing_shl_vartime_wide(x, 1024 + 10),
|| (Uint::<LIMBS>::ONE, Uint::<LIMBS>::ONE),
|x| Uint::overflowing_shl_vartime_wide(x, Uint::<LIMBS>::BITS/2 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shl, U2048", |b| {
group.bench_function(BenchmarkId::new("shl, small", LIMBS), |b| {
b.iter_batched(
|| Uint::<LIMBS>::ONE,
|x| x.overflowing_shl( 10),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("shl, large", LIMBS), |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shl(1024 + 10),
|| Uint::<LIMBS>::ONE,
|x| x.overflowing_shl( Uint::<LIMBS>::BITS/2 + 10),
BatchSize::SmallInput,
)
});
}

fn bench_shl(c: &mut Criterion) {
let mut group = c.benchmark_group("left shift");

shl_benchmarks::<{ U256::LIMBS }>(&mut group);
shl_benchmarks::<{ U512::LIMBS }>(&mut group);
shl_benchmarks::<{ U1024::LIMBS }>(&mut group);
shl_benchmarks::<{ U2048::LIMBS }>(&mut group);
shl_benchmarks::<{ U4096::LIMBS }>(&mut group);
shl_benchmarks::<{ U8192::LIMBS }>(&mut group);
shl_benchmarks::<{ U16384::LIMBS }>(&mut group);

group.finish();
}

fn bench_shr(c: &mut Criterion) {
let mut group = c.benchmark_group("right shift");

group.bench_function("shr_vartime, small, U2048", |b| {
fn shr_benchmarks<const LIMBS: usize>(group: &mut BenchmarkGroup<WallTime>) {
group.bench_function(BenchmarkId::new("shr_vartime, small", LIMBS), |b| {
b.iter_batched(
|| U2048::ONE,
|| Uint::<LIMBS>::ONE,
|x| x.overflowing_shr_vartime(10),
BatchSize::SmallInput,
)
});

group.bench_function("shr_vartime, large, U2048", |b| {
group.bench_function(BenchmarkId::new("shr_vartime, large", LIMBS), |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shr_vartime(1024 + 10),
|| Uint::<LIMBS>::ONE,
|x| x.overflowing_shr_vartime(Uint::<LIMBS>::BITS / 2 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr_vartime_wide, large, U2048", |b| {
group.bench_function(BenchmarkId::new("shr_vartime_wide, large", LIMBS), |b| {
b.iter_batched(
|| (U2048::ONE, U2048::ONE),
|x| Uint::overflowing_shr_vartime_wide(x, 1024 + 10),
|| (Uint::<LIMBS>::ONE, Uint::<LIMBS>::ONE),
|x| Uint::overflowing_shr_vartime_wide(x, Uint::<LIMBS>::BITS / 2 + 10),
BatchSize::SmallInput,
)
});

group.bench_function("shr, U2048", |b| {
group.bench_function(BenchmarkId::new("shr, small", LIMBS), |b| {
b.iter_batched(
|| U2048::ONE,
|x| x.overflowing_shr(1024 + 10),
|| Uint::<LIMBS>::ONE,
|x| x.overflowing_shr(10),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("shr, large", LIMBS), |b| {
b.iter_batched(
|| Uint::<LIMBS>::ONE,
|x| x.overflowing_shr(Uint::<LIMBS>::BITS / 2 + 10),
BatchSize::SmallInput,
)
});
}

fn bench_shr(c: &mut Criterion) {
let mut group = c.benchmark_group("right shift");

shr_benchmarks::<{ U256::LIMBS }>(&mut group);
shr_benchmarks::<{ U512::LIMBS }>(&mut group);
shr_benchmarks::<{ U1024::LIMBS }>(&mut group);
shr_benchmarks::<{ U2048::LIMBS }>(&mut group);
shr_benchmarks::<{ U4096::LIMBS }>(&mut group);
shr_benchmarks::<{ U8192::LIMBS }>(&mut group);
shr_benchmarks::<{ U16384::LIMBS }>(&mut group);

group.finish();
}
Expand Down
19 changes: 15 additions & 4 deletions src/limb/shl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Limb left bitshift

use crate::Limb;
use crate::{ConstChoice, Limb};
use core::ops::{Shl, ShlAssign};
use num_traits::WrappingShl;

Expand All @@ -12,10 +12,21 @@ impl Limb {
Limb(self.0 << shift)
}

/// Computes `self << 1` and return the result and the carry (0 or 1).
/// Computes `self << shift` and returns the result as well as the carry: the `shift` _least_
/// significant bits of the `carry` are equal to the `shift` _most_ significant bits of `self`.
///
/// Panics if `shift` overflows `Limb::BITS`.
#[inline(always)]
pub(crate) const fn shl1(self) -> (Self, Self) {
(Self(self.0 << 1), Self(self.0 >> Self::HI_BIT))
pub const fn carrying_shl(self, shift: u32) -> (Self, Self) {
// Note that we can compute carry = self >> (Self::BITS - shift) whenever shift > 0.
// However, we need to account for the case that shift = 0:
// - the carry should be 0, and
// - the value by which carry is left shifted should be made to be < Self::BITS.
let shift_is_zero = ConstChoice::from_u32_eq(shift, 0);
let carry = Self::select(self, Self::ZERO, shift_is_zero);
let left_shift = shift_is_zero.select_u32(Self::BITS - shift, 0);

(self.shl(shift), carry.shr(left_shift))
}
}

Expand Down
19 changes: 15 additions & 4 deletions src/limb/shr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Limb right bitshift

use crate::{Limb, WrappingShr};
use crate::{ConstChoice, Limb, WrappingShr};
use core::ops::{Shr, ShrAssign};

impl Limb {
Expand All @@ -11,10 +11,21 @@ impl Limb {
Limb(self.0 >> shift)
}

/// Computes `self >> 1` and return the result and the carry (0 or `1 << HI_BIT`).
/// Computes `self >> shift` and returns the result as well as the carry: the `shift` _most_
/// significant bits of the `carry` are equal to the `shift` _least_ significant bits of `self`.
///
/// Panics if `shift` overflows `Limb::BITS`.
#[inline(always)]
pub(crate) const fn shr1(self) -> (Self, Self) {
(Self(self.0 >> 1), Self(self.0 << Self::HI_BIT))
pub const fn carrying_shr(self, shift: u32) -> (Self, Self) {
// Note that we can compute carry = self << (Self::BITS - shift) whenever shift > 0.
// However, we need to account for the case that shift = 0:
// - the carry should be 0, and
// - the value by which carry is left shifted should be made to be < Self::BITS.
let shift_is_zero = ConstChoice::from_u32_eq(shift, 0);
let carry = Self::select(self, Self::ZERO, shift_is_zero);
let left_shift = shift_is_zero.select_u32(Self::BITS - shift, 0);

(self.shr(shift), carry.shl(left_shift))
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/uint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
}

/// Borrow the limbs of this [`Uint`] mutably.
pub fn as_limbs_mut(&mut self) -> &mut [Limb; LIMBS] {
pub const fn as_limbs_mut(&mut self) -> &mut [Limb; LIMBS] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

&mut self.limbs
}

Expand Down
2 changes: 1 addition & 1 deletion src/uint/add_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl<const LIMBS: usize> Uint<LIMBS> {
///
/// Assumes `self` as unbounded integer is `< p`.
pub const fn double_mod(&self, p: &Self) -> Self {
let (w, carry) = self.overflowing_shl1();
let (w, carry) = self.carrying_shl1();

// Attempt to subtract the modulus, to ensure the result is in the field.
let (w, borrow) = w.sbb(p, Limb::ZERO);
Expand Down
Loading
Loading