Skip to content

Commit

Permalink
Use associated constants of type Self in Field and PrimeField
Browse files Browse the repository at this point in the history
We now require that the type implementing `Field`, and its particular
values for these constants, can be constructed in a const context. Once
upon a time this might have been onerous, but it should now be a
reasonable requirement given our MSRV of 1.56.0.

Closes #87.
  • Loading branch information
str4d committed Nov 2, 2022
1 parent 9a844a7 commit 6bf93ee
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 66 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ and this library adheres to Rust's notion of

## [Unreleased]
### Added
- `ff::Field::{ZERO, ONE}`
- `ff::Field::pow`
- `ff::Field::{sqrt_ratio, sqrt_alt}`
- `ff::PrimeField::{MULTIPLICATIVE_GENERATOR, ROOT_OF_UNITY}`
- `ff::helpers`:
- `sqrt_tonelli_shanks`
- `sqrt_ratio_generic`
Expand All @@ -21,6 +23,11 @@ and this library adheres to Rust's notion of
of `Field::sqrt` and implement `Field::sqrt_ratio` in terms of that
implementation using the `ff::helpers::sqrt_ratio_generic` helper function.

### Removed
- `ff::Field::{zero, one}` (use `ff::Field::{ZERO, ONE}` instead).
- `ff::PrimeField::{multiplicative_generator, root_of_unity}` (use
`ff::PrimeField::{MULTIPLICATIVE_GENERATOR, ROOT_OF_UNITY}` instead).

## [0.12.1] - 2022-10-28
### Fixed
- `ff_derive` previously generated a `Field::random` implementation that would
Expand Down
40 changes: 11 additions & 29 deletions ff_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ fn prime_field_constants_and_sqrt(
} else if (modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() {
// Addition chain for (t - 1) // 2
let t_minus_1_over_2 = if t == BigUint::one() {
quote!( #name::one() )
quote!( #name::ONE )
} else {
pow_fixed::generate(&quote! {self}, (&t - BigUint::one()) >> 1)
};
Expand Down Expand Up @@ -547,7 +547,7 @@ fn prime_field_constants_and_sqrt(
let mut j_less_than_v: ::ff::derive::subtle::Choice = 1.into();

for j in 2..max_v {
let tmp_is_one = tmp.ct_eq(&#name::one());
let tmp_is_one = tmp.ct_eq(&#name::ONE);
let squared = #name::conditional_select(&tmp, &z, tmp_is_one).square();
tmp = #name::conditional_select(&squared, &tmp, tmp_is_one);
let new_z = #name::conditional_select(&z, &squared, tmp_is_one);
Expand All @@ -557,7 +557,7 @@ fn prime_field_constants_and_sqrt(
}

let result = x * &z;
x = #name::conditional_select(&result, &x, b.ct_eq(&#name::one()));
x = #name::conditional_select(&result, &x, b.ct_eq(&#name::ONE));
z = z.square();
b *= &z;
v = k;
Expand Down Expand Up @@ -841,7 +841,6 @@ fn prime_field_impl(
/// field.
fn inv_impl(
a: proc_macro2::TokenStream,
name: &syn::Ident,
modulus: &BigUint,
) -> proc_macro2::TokenStream {
// Addition chain for p - 2
Expand All @@ -860,13 +859,13 @@ fn prime_field_impl(
#mod_minus_2
};

::ff::derive::subtle::CtOption::new(inv, !#a.ct_eq(&#name::zero()))
::ff::derive::subtle::CtOption::new(inv, !#a.is_zero())
}
}

let squaring_impl = sqr_impl(quote! {self}, limbs);
let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs);
let invert_impl = inv_impl(quote! {self}, name, modulus);
let invert_impl = inv_impl(quote! {self}, modulus);
let montgomery_impl = mont_impl(limbs);

// self.0[0].ct_eq(&other.0[0]) & self.0[1].ct_eq(&other.0[1]) & ...
Expand Down Expand Up @@ -934,7 +933,7 @@ fn prime_field_impl(
impl ::core::default::Default for #name {
fn default() -> #name {
use ::ff::Field;
#name::zero()
#name::ZERO
}
}

Expand Down Expand Up @@ -1207,20 +1206,19 @@ fn prime_field_impl(

const CAPACITY: u32 = Self::NUM_BITS - 1;

fn multiplicative_generator() -> Self {
GENERATOR
}
const MULTIPLICATIVE_GENERATOR: Self = GENERATOR;

const S: u32 = S;

fn root_of_unity() -> Self {
ROOT_OF_UNITY
}
const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
}

#prime_field_bits_impl

impl ::ff::Field for #name {
const ZERO: Self = #name([0; #limbs]);
const ONE: Self = R;

/// Computes a uniformly random element using rejection sampling.
fn random(mut rng: impl ::ff::derive::rand_core::RngCore) -> Self {
loop {
Expand All @@ -1245,22 +1243,6 @@ fn prime_field_impl(
}
}

#[inline]
fn zero() -> Self {
#name([0; #limbs])
}

#[inline]
fn one() -> Self {
R
}

#[inline]
fn is_zero(&self) -> ::ff::derive::subtle::Choice {
use ::ff::derive::subtle::ConstantTimeEq;
self.ct_eq(&Self::zero())
}

#[inline]
fn is_zero_vartime(&self) -> bool {
self.0.iter().all(|&e| e == 0)
Expand Down
18 changes: 9 additions & 9 deletions src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ where
I: IntoIterator<Item = &'a mut F>,
{
fn batch_invert(self) -> F {
let mut acc = F::one();
let mut acc = F::ONE;
let iter = self.into_iter();
let mut tmp = alloc::vec::Vec::with_capacity(iter.size_hint().0);
for p in iter {
let q = *p;
tmp.push((acc, p));
acc = F::conditional_select(&(acc * q), &acc, q.ct_eq(&F::zero()));
acc = F::conditional_select(&(acc * q), &acc, q.is_zero());
}
acc = acc.invert().unwrap();
let allinv = acc;

for (tmp, p) in tmp.into_iter().rev() {
let skip = p.ct_eq(&F::zero());
let skip = p.is_zero();

let tmp = tmp * acc;
acc = F::conditional_select(&(acc * *p), &acc, skip);
Expand Down Expand Up @@ -74,17 +74,17 @@ impl BatchInverter {
{
assert_eq!(elements.len(), scratch_space.len());

let mut acc = F::one();
let mut acc = F::ONE;
for (p, scratch) in elements.iter().zip(scratch_space.iter_mut()) {
*scratch = acc;
acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
acc = F::conditional_select(&(acc * *p), &acc, p.is_zero());
}
acc = acc.invert().unwrap();
let allinv = acc;

for (p, scratch) in elements.iter_mut().zip(scratch_space.iter()).rev() {
let tmp = *scratch * acc;
let skip = p.ct_eq(&F::zero());
let skip = p.is_zero();
acc = F::conditional_select(&(acc * *p), &acc, skip);
*p = F::conditional_select(&tmp, &p, skip);
}
Expand All @@ -109,19 +109,19 @@ impl BatchInverter {
TE: Fn(&mut T) -> &mut F,
TS: Fn(&mut T) -> &mut F,
{
let mut acc = F::one();
let mut acc = F::ONE;
for item in items.iter_mut() {
*(scratch_space)(item) = acc;
let p = (element)(item);
acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
acc = F::conditional_select(&(acc * *p), &acc, p.is_zero());
}
acc = acc.invert().unwrap();
let allinv = acc;

for item in items.iter_mut().rev() {
let tmp = *(scratch_space)(item) * acc;
let p = (element)(item);
let skip = p.ct_eq(&F::zero());
let skip = p.is_zero();
acc = F::conditional_select(&(acc * *p), &acc, skip);
*p = F::conditional_select(&tmp, &p, skip);
}
Expand Down
12 changes: 6 additions & 6 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub fn sqrt_tonelli_shanks<F: PrimeField, S: AsRef<[u64]>>(f: &F, tm1d2: S) -> C
let mut b = x * w;

// Initialize z as the 2^S root of unity.
let mut z = F::root_of_unity();
let mut z = F::ROOT_OF_UNITY;

for max_v in (1..=F::S).rev() {
let mut k = 1;
Expand All @@ -41,7 +41,7 @@ pub fn sqrt_tonelli_shanks<F: PrimeField, S: AsRef<[u64]>>(f: &F, tm1d2: S) -> C
// - for k < j <= v, we square z in order to calculate ω.
// - for j > v, we do nothing.
for j in 2..max_v {
let b2k_is_one = b2k.ct_eq(&F::one());
let b2k_is_one = b2k.ct_eq(&F::ONE);
let squared = F::conditional_select(&b2k, &z, b2k_is_one).square();
b2k = F::conditional_select(&squared, &b2k, b2k_is_one);
let new_z = F::conditional_select(&z, &squared, b2k_is_one);
Expand All @@ -51,7 +51,7 @@ pub fn sqrt_tonelli_shanks<F: PrimeField, S: AsRef<[u64]>>(f: &F, tm1d2: S) -> C
}

let result = x * z;
x = F::conditional_select(&result, &x, b.ct_eq(&F::one()));
x = F::conditional_select(&result, &x, b.ct_eq(&F::ONE));
z = z.square();
b *= z;
v = k;
Expand All @@ -76,7 +76,7 @@ pub fn sqrt_tonelli_shanks<F: PrimeField, S: AsRef<[u64]>>(f: &F, tm1d2: S) -> C
///
/// where $G_S$ is a non-square.
///
/// For this method, $G_S$ is currently [`PrimeField::root_of_unity`], a generator of the
/// For this method, $G_S$ is currently [`PrimeField::ROOT_OF_UNITY`], a generator of the
/// order $2^S$ subgroup. Users of this crate should not rely on this generator being
/// fixed; it may be changed in future crate versions to simplify the implementation of
/// the SSWU hash-to-curve algorithm.
Expand Down Expand Up @@ -108,8 +108,8 @@ pub fn sqrt_ratio_generic<F: PrimeField>(num: &F, div: &F) -> (Choice, F) {
// based on whether a is square, but for the boolean output we need to handle the
// num != 0 && div == 0 case specifically.

let a = div.invert().unwrap_or_else(F::zero) * num;
let b = a * F::root_of_unity();
let a = div.invert().unwrap_or(F::ZERO) * num;
let b = a * F::ROOT_OF_UNITY;
let sqrt_a = a.sqrt();
let sqrt_b = b.sqrt();

Expand Down
40 changes: 20 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ pub trait Field:
+ for<'a> AddAssign<&'a Self>
+ for<'a> SubAssign<&'a Self>
{
/// Returns an element chosen uniformly at random using a user-provided RNG.
fn random(rng: impl RngCore) -> Self;
/// The zero element of the field, the additive identity.
const ZERO: Self;

/// Returns the zero element of the field, the additive identity.
fn zero() -> Self;
/// The one element of the field, the multiplicative identity.
const ONE: Self;

/// Returns the one element of the field, the multiplicative identity.
fn one() -> Self;
/// Returns an element chosen uniformly at random using a user-provided RNG.
fn random(rng: impl RngCore) -> Self;

/// Returns true iff this element is zero.
fn is_zero(&self) -> Choice {
self.ct_eq(&Self::zero())
self.ct_eq(&Self::ZERO)
}

/// Returns true iff this element is zero.
Expand Down Expand Up @@ -127,15 +127,15 @@ pub trait Field:
///
/// The provided method is implemented in terms of [`Self::sqrt_ratio`].
fn sqrt_alt(&self) -> (Choice, Self) {
Self::sqrt_ratio(self, &Self::one())
Self::sqrt_ratio(self, &Self::ONE)
}

/// Returns the square root of the field element, if it is
/// quadratic residue.
///
/// The provided method is implemented in terms of [`Self::sqrt_ratio`].
fn sqrt(&self) -> CtOption<Self> {
let (is_square, res) = Self::sqrt_ratio(self, &Self::one());
let (is_square, res) = Self::sqrt_ratio(self, &Self::ONE);
CtOption::new(res, is_square)
}

Expand All @@ -148,7 +148,7 @@ pub trait Field:
/// same number of digits (`exp.as_ref().len()`). It is variable time with respect to
/// the number of digits in the exponent.
fn pow<S: AsRef<[u64]>>(&self, exp: S) -> Self {
let mut res = Self::one();
let mut res = Self::ONE;
for e in exp.as_ref().iter().rev() {
for i in (0..64).rev() {
res = res.square();
Expand All @@ -169,7 +169,7 @@ pub trait Field:
/// the exponent is fixed, this operation is effectively constant time. However, for
/// stronger constant-time guarantees, [`Field::pow`] should be used.
fn pow_vartime<S: AsRef<[u64]>>(&self, exp: S) -> Self {
let mut res = Self::one();
let mut res = Self::ONE;
for e in exp.as_ref().iter().rev() {
for i in (0..64).rev() {
res = res.square();
Expand Down Expand Up @@ -202,10 +202,10 @@ pub trait PrimeField: Field + From<u64> {
}

if s == "0" {
return Some(Self::zero());
return Some(Self::ZERO);
}

let mut res = Self::zero();
let mut res = Self::ZERO;

let ten = Self::from(10);

Expand Down Expand Up @@ -281,28 +281,28 @@ pub trait PrimeField: Field + From<u64> {
/// This is usually `Self::NUM_BITS - 1`.
const CAPACITY: u32;

/// Returns a fixed multiplicative generator of `modulus - 1` order. This element must
/// also be a quadratic nonresidue.
/// A fixed multiplicative generator of `modulus - 1` order. This element must also be
/// a quadratic nonresidue.
///
/// It can be calculated using [SageMath] as `GF(modulus).primitive_element()`.
///
/// Implementations of this method MUST ensure that this is the generator used to
/// Implementations of this trait MUST ensure that this is the generator used to
/// derive `Self::root_of_unity`.
///
/// [SageMath]: https://www.sagemath.org/
fn multiplicative_generator() -> Self;
const MULTIPLICATIVE_GENERATOR: Self;

/// An integer `s` satisfying the equation `2^s * t = modulus - 1` with `t` odd.
///
/// This is the number of leading zero bits in the little-endian bit representation of
/// `modulus - 1`.
const S: u32;

/// Returns the `2^s` root of unity.
/// The `2^s` root of unity.
///
/// It can be calculated by exponentiating `Self::multiplicative_generator` by `t`,
/// It can be calculated by exponentiating `Self::MULTIPLICATIVE_GENERATOR` by `t`,
/// where `t = (modulus - 1) >> Self::S`.
fn root_of_unity() -> Self;
const ROOT_OF_UNITY: Self;
}

/// This represents the bits of an element of a prime field.
Expand Down
4 changes: 2 additions & 2 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ mod full_limbs {
fn batch_inversion() {
use ff::{BatchInverter, Field};

let one = Bls381K12Scalar::one();
let one = Bls381K12Scalar::ONE;

// [1, 2, 3, 4]
let values: Vec<_> = (0..4)
Expand All @@ -55,7 +55,7 @@ fn batch_inversion() {
// Test BatchInverter::invert_with_external_scratch
{
let mut elements = values.clone();
let mut scratch_space = vec![Bls381K12Scalar::zero(); elements.len()];
let mut scratch_space = vec![Bls381K12Scalar::ZERO; elements.len()];
BatchInverter::invert_with_external_scratch(&mut elements, &mut scratch_space);
for (a, a_inv) in values.iter().zip(elements.into_iter()) {
assert_eq!(*a * a_inv, one);
Expand Down

0 comments on commit 6bf93ee

Please sign in to comment.