Skip to content

Commit

Permalink
Merge pull request #94 from zkcrypto/trait-constants
Browse files Browse the repository at this point in the history
Add associated constants of type `Self` to `Field` and `PrimeField`
  • Loading branch information
str4d authored Nov 2, 2022
2 parents 9a844a7 + 6bf93ee commit 1eddf54
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 66 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ and this library adheres to Rust's notion of

## [Unreleased]
### Added
- `ff::Field::{ZERO, ONE}`
- `ff::Field::pow`
- `ff::Field::{sqrt_ratio, sqrt_alt}`
- `ff::PrimeField::{MULTIPLICATIVE_GENERATOR, ROOT_OF_UNITY}`
- `ff::helpers`:
- `sqrt_tonelli_shanks`
- `sqrt_ratio_generic`
Expand All @@ -21,6 +23,11 @@ and this library adheres to Rust's notion of
of `Field::sqrt` and implement `Field::sqrt_ratio` in terms of that
implementation using the `ff::helpers::sqrt_ratio_generic` helper function.

### Removed
- `ff::Field::{zero, one}` (use `ff::Field::{ZERO, ONE}` instead).
- `ff::PrimeField::{multiplicative_generator, root_of_unity}` (use
`ff::PrimeField::{MULTIPLICATIVE_GENERATOR, ROOT_OF_UNITY}` instead).

## [0.12.1] - 2022-10-28
### Fixed
- `ff_derive` previously generated a `Field::random` implementation that would
Expand Down
40 changes: 11 additions & 29 deletions ff_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ fn prime_field_constants_and_sqrt(
} else if (modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() {
// Addition chain for (t - 1) // 2
let t_minus_1_over_2 = if t == BigUint::one() {
quote!( #name::one() )
quote!( #name::ONE )
} else {
pow_fixed::generate(&quote! {self}, (&t - BigUint::one()) >> 1)
};
Expand Down Expand Up @@ -547,7 +547,7 @@ fn prime_field_constants_and_sqrt(
let mut j_less_than_v: ::ff::derive::subtle::Choice = 1.into();

for j in 2..max_v {
let tmp_is_one = tmp.ct_eq(&#name::one());
let tmp_is_one = tmp.ct_eq(&#name::ONE);
let squared = #name::conditional_select(&tmp, &z, tmp_is_one).square();
tmp = #name::conditional_select(&squared, &tmp, tmp_is_one);
let new_z = #name::conditional_select(&z, &squared, tmp_is_one);
Expand All @@ -557,7 +557,7 @@ fn prime_field_constants_and_sqrt(
}

let result = x * &z;
x = #name::conditional_select(&result, &x, b.ct_eq(&#name::one()));
x = #name::conditional_select(&result, &x, b.ct_eq(&#name::ONE));
z = z.square();
b *= &z;
v = k;
Expand Down Expand Up @@ -841,7 +841,6 @@ fn prime_field_impl(
/// field.
fn inv_impl(
a: proc_macro2::TokenStream,
name: &syn::Ident,
modulus: &BigUint,
) -> proc_macro2::TokenStream {
// Addition chain for p - 2
Expand All @@ -860,13 +859,13 @@ fn prime_field_impl(
#mod_minus_2
};

::ff::derive::subtle::CtOption::new(inv, !#a.ct_eq(&#name::zero()))
::ff::derive::subtle::CtOption::new(inv, !#a.is_zero())
}
}

let squaring_impl = sqr_impl(quote! {self}, limbs);
let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs);
let invert_impl = inv_impl(quote! {self}, name, modulus);
let invert_impl = inv_impl(quote! {self}, modulus);
let montgomery_impl = mont_impl(limbs);

// self.0[0].ct_eq(&other.0[0]) & self.0[1].ct_eq(&other.0[1]) & ...
Expand Down Expand Up @@ -934,7 +933,7 @@ fn prime_field_impl(
impl ::core::default::Default for #name {
fn default() -> #name {
use ::ff::Field;
#name::zero()
#name::ZERO
}
}

Expand Down Expand Up @@ -1207,20 +1206,19 @@ fn prime_field_impl(

const CAPACITY: u32 = Self::NUM_BITS - 1;

fn multiplicative_generator() -> Self {
GENERATOR
}
const MULTIPLICATIVE_GENERATOR: Self = GENERATOR;

const S: u32 = S;

fn root_of_unity() -> Self {
ROOT_OF_UNITY
}
const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
}

#prime_field_bits_impl

impl ::ff::Field for #name {
const ZERO: Self = #name([0; #limbs]);
const ONE: Self = R;

/// Computes a uniformly random element using rejection sampling.
fn random(mut rng: impl ::ff::derive::rand_core::RngCore) -> Self {
loop {
Expand All @@ -1245,22 +1243,6 @@ fn prime_field_impl(
}
}

#[inline]
fn zero() -> Self {
#name([0; #limbs])
}

#[inline]
fn one() -> Self {
R
}

#[inline]
fn is_zero(&self) -> ::ff::derive::subtle::Choice {
use ::ff::derive::subtle::ConstantTimeEq;
self.ct_eq(&Self::zero())
}

#[inline]
fn is_zero_vartime(&self) -> bool {
self.0.iter().all(|&e| e == 0)
Expand Down
18 changes: 9 additions & 9 deletions src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ where
I: IntoIterator<Item = &'a mut F>,
{
fn batch_invert(self) -> F {
let mut acc = F::one();
let mut acc = F::ONE;
let iter = self.into_iter();
let mut tmp = alloc::vec::Vec::with_capacity(iter.size_hint().0);
for p in iter {
let q = *p;
tmp.push((acc, p));
acc = F::conditional_select(&(acc * q), &acc, q.ct_eq(&F::zero()));
acc = F::conditional_select(&(acc * q), &acc, q.is_zero());
}
acc = acc.invert().unwrap();
let allinv = acc;

for (tmp, p) in tmp.into_iter().rev() {
let skip = p.ct_eq(&F::zero());
let skip = p.is_zero();

let tmp = tmp * acc;
acc = F::conditional_select(&(acc * *p), &acc, skip);
Expand Down Expand Up @@ -74,17 +74,17 @@ impl BatchInverter {
{
assert_eq!(elements.len(), scratch_space.len());

let mut acc = F::one();
let mut acc = F::ONE;
for (p, scratch) in elements.iter().zip(scratch_space.iter_mut()) {
*scratch = acc;
acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
acc = F::conditional_select(&(acc * *p), &acc, p.is_zero());
}
acc = acc.invert().unwrap();
let allinv = acc;

for (p, scratch) in elements.iter_mut().zip(scratch_space.iter()).rev() {
let tmp = *scratch * acc;
let skip = p.ct_eq(&F::zero());
let skip = p.is_zero();
acc = F::conditional_select(&(acc * *p), &acc, skip);
*p = F::conditional_select(&tmp, &p, skip);
}
Expand All @@ -109,19 +109,19 @@ impl BatchInverter {
TE: Fn(&mut T) -> &mut F,
TS: Fn(&mut T) -> &mut F,
{
let mut acc = F::one();
let mut acc = F::ONE;
for item in items.iter_mut() {
*(scratch_space)(item) = acc;
let p = (element)(item);
acc = F::conditional_select(&(acc * *p), &acc, p.ct_eq(&F::zero()));
acc = F::conditional_select(&(acc * *p), &acc, p.is_zero());
}
acc = acc.invert().unwrap();
let allinv = acc;

for item in items.iter_mut().rev() {
let tmp = *(scratch_space)(item) * acc;
let p = (element)(item);
let skip = p.ct_eq(&F::zero());
let skip = p.is_zero();
acc = F::conditional_select(&(acc * *p), &acc, skip);
*p = F::conditional_select(&tmp, &p, skip);
}
Expand Down
12 changes: 6 additions & 6 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub fn sqrt_tonelli_shanks<F: PrimeField, S: AsRef<[u64]>>(f: &F, tm1d2: S) -> C
let mut b = x * w;

// Initialize z as the 2^S root of unity.
let mut z = F::root_of_unity();
let mut z = F::ROOT_OF_UNITY;

for max_v in (1..=F::S).rev() {
let mut k = 1;
Expand All @@ -41,7 +41,7 @@ pub fn sqrt_tonelli_shanks<F: PrimeField, S: AsRef<[u64]>>(f: &F, tm1d2: S) -> C
// - for k < j <= v, we square z in order to calculate ω.
// - for j > v, we do nothing.
for j in 2..max_v {
let b2k_is_one = b2k.ct_eq(&F::one());
let b2k_is_one = b2k.ct_eq(&F::ONE);
let squared = F::conditional_select(&b2k, &z, b2k_is_one).square();
b2k = F::conditional_select(&squared, &b2k, b2k_is_one);
let new_z = F::conditional_select(&z, &squared, b2k_is_one);
Expand All @@ -51,7 +51,7 @@ pub fn sqrt_tonelli_shanks<F: PrimeField, S: AsRef<[u64]>>(f: &F, tm1d2: S) -> C
}

let result = x * z;
x = F::conditional_select(&result, &x, b.ct_eq(&F::one()));
x = F::conditional_select(&result, &x, b.ct_eq(&F::ONE));
z = z.square();
b *= z;
v = k;
Expand All @@ -76,7 +76,7 @@ pub fn sqrt_tonelli_shanks<F: PrimeField, S: AsRef<[u64]>>(f: &F, tm1d2: S) -> C
///
/// where $G_S$ is a non-square.
///
/// For this method, $G_S$ is currently [`PrimeField::root_of_unity`], a generator of the
/// For this method, $G_S$ is currently [`PrimeField::ROOT_OF_UNITY`], a generator of the
/// order $2^S$ subgroup. Users of this crate should not rely on this generator being
/// fixed; it may be changed in future crate versions to simplify the implementation of
/// the SSWU hash-to-curve algorithm.
Expand Down Expand Up @@ -108,8 +108,8 @@ pub fn sqrt_ratio_generic<F: PrimeField>(num: &F, div: &F) -> (Choice, F) {
// based on whether a is square, but for the boolean output we need to handle the
// num != 0 && div == 0 case specifically.

let a = div.invert().unwrap_or_else(F::zero) * num;
let b = a * F::root_of_unity();
let a = div.invert().unwrap_or(F::ZERO) * num;
let b = a * F::ROOT_OF_UNITY;
let sqrt_a = a.sqrt();
let sqrt_b = b.sqrt();

Expand Down
40 changes: 20 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ pub trait Field:
+ for<'a> AddAssign<&'a Self>
+ for<'a> SubAssign<&'a Self>
{
/// Returns an element chosen uniformly at random using a user-provided RNG.
fn random(rng: impl RngCore) -> Self;
/// The zero element of the field, the additive identity.
const ZERO: Self;

/// Returns the zero element of the field, the additive identity.
fn zero() -> Self;
/// The one element of the field, the multiplicative identity.
const ONE: Self;

/// Returns the one element of the field, the multiplicative identity.
fn one() -> Self;
/// Returns an element chosen uniformly at random using a user-provided RNG.
fn random(rng: impl RngCore) -> Self;

/// Returns true iff this element is zero.
fn is_zero(&self) -> Choice {
self.ct_eq(&Self::zero())
self.ct_eq(&Self::ZERO)
}

/// Returns true iff this element is zero.
Expand Down Expand Up @@ -127,15 +127,15 @@ pub trait Field:
///
/// The provided method is implemented in terms of [`Self::sqrt_ratio`].
fn sqrt_alt(&self) -> (Choice, Self) {
Self::sqrt_ratio(self, &Self::one())
Self::sqrt_ratio(self, &Self::ONE)
}

/// Returns the square root of the field element, if it is
/// quadratic residue.
///
/// The provided method is implemented in terms of [`Self::sqrt_ratio`].
fn sqrt(&self) -> CtOption<Self> {
let (is_square, res) = Self::sqrt_ratio(self, &Self::one());
let (is_square, res) = Self::sqrt_ratio(self, &Self::ONE);
CtOption::new(res, is_square)
}

Expand All @@ -148,7 +148,7 @@ pub trait Field:
/// same number of digits (`exp.as_ref().len()`). It is variable time with respect to
/// the number of digits in the exponent.
fn pow<S: AsRef<[u64]>>(&self, exp: S) -> Self {
let mut res = Self::one();
let mut res = Self::ONE;
for e in exp.as_ref().iter().rev() {
for i in (0..64).rev() {
res = res.square();
Expand All @@ -169,7 +169,7 @@ pub trait Field:
/// the exponent is fixed, this operation is effectively constant time. However, for
/// stronger constant-time guarantees, [`Field::pow`] should be used.
fn pow_vartime<S: AsRef<[u64]>>(&self, exp: S) -> Self {
let mut res = Self::one();
let mut res = Self::ONE;
for e in exp.as_ref().iter().rev() {
for i in (0..64).rev() {
res = res.square();
Expand Down Expand Up @@ -202,10 +202,10 @@ pub trait PrimeField: Field + From<u64> {
}

if s == "0" {
return Some(Self::zero());
return Some(Self::ZERO);
}

let mut res = Self::zero();
let mut res = Self::ZERO;

let ten = Self::from(10);

Expand Down Expand Up @@ -281,28 +281,28 @@ pub trait PrimeField: Field + From<u64> {
/// This is usually `Self::NUM_BITS - 1`.
const CAPACITY: u32;

/// Returns a fixed multiplicative generator of `modulus - 1` order. This element must
/// also be a quadratic nonresidue.
/// A fixed multiplicative generator of `modulus - 1` order. This element must also be
/// a quadratic nonresidue.
///
/// It can be calculated using [SageMath] as `GF(modulus).primitive_element()`.
///
/// Implementations of this method MUST ensure that this is the generator used to
/// Implementations of this trait MUST ensure that this is the generator used to
/// derive `Self::root_of_unity`.
///
/// [SageMath]: https://www.sagemath.org/
fn multiplicative_generator() -> Self;
const MULTIPLICATIVE_GENERATOR: Self;

/// An integer `s` satisfying the equation `2^s * t = modulus - 1` with `t` odd.
///
/// This is the number of leading zero bits in the little-endian bit representation of
/// `modulus - 1`.
const S: u32;

/// Returns the `2^s` root of unity.
/// The `2^s` root of unity.
///
/// It can be calculated by exponentiating `Self::multiplicative_generator` by `t`,
/// It can be calculated by exponentiating `Self::MULTIPLICATIVE_GENERATOR` by `t`,
/// where `t = (modulus - 1) >> Self::S`.
fn root_of_unity() -> Self;
const ROOT_OF_UNITY: Self;
}

/// This represents the bits of an element of a prime field.
Expand Down
4 changes: 2 additions & 2 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ mod full_limbs {
fn batch_inversion() {
use ff::{BatchInverter, Field};

let one = Bls381K12Scalar::one();
let one = Bls381K12Scalar::ONE;

// [1, 2, 3, 4]
let values: Vec<_> = (0..4)
Expand All @@ -55,7 +55,7 @@ fn batch_inversion() {
// Test BatchInverter::invert_with_external_scratch
{
let mut elements = values.clone();
let mut scratch_space = vec![Bls381K12Scalar::zero(); elements.len()];
let mut scratch_space = vec![Bls381K12Scalar::ZERO; elements.len()];
BatchInverter::invert_with_external_scratch(&mut elements, &mut scratch_space);
for (a, a_inv) in values.iter().zip(elements.into_iter()) {
assert_eq!(*a * a_inv, one);
Expand Down

0 comments on commit 1eddf54

Please sign in to comment.