Skip to content

Commit

Permalink
removed From implementations resulting in silent field element conver…
Browse files Browse the repository at this point in the history
…sions
  • Loading branch information
irakliyk committed Jan 28, 2024
1 parent a450b81 commit 3bde922
Show file tree
Hide file tree
Showing 28 changed files with 233 additions and 285 deletions.
2 changes: 1 addition & 1 deletion air/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ mod tests {
]);
let expected = vec![
BaseElement::from(ext_fri),
BaseElement::from(grinding_factor as u32),
BaseElement::from(grinding_factor),
BaseElement::from(blowup_factor as u32),
BaseElement::from(num_queries as u32),
];
Expand Down
18 changes: 15 additions & 3 deletions air/src/proof/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,22 @@ impl Context {
// --------------------------------------------------------------------------------------------
/// Creates a new context for a computation described by the specified field, trace info, and
/// proof options.
///
/// # Panics
/// Panics if either trace length or the LDE domain size implied by the trace length and the
/// blowup factor is greater then [u32::MAX].
pub fn new<B: StarkField>(trace_info: &TraceInfo, options: ProofOptions) -> Self {
// TODO: return errors instead of panicking?

let trace_length = trace_info.length();
assert!(trace_length <= u32::MAX as usize, "trace length too big");

let lde_domain_size = trace_length * options.blowup_factor();
assert!(lde_domain_size <= u32::MAX as usize, "LDE domain size too big");

Context {
trace_layout: trace_info.layout().clone(),
trace_length: trace_info.length(),
trace_length,
trace_meta: trace_info.meta().to_vec(),
field_modulus_bytes: B::get_modulus_le_bytes(),
options,
Expand Down Expand Up @@ -117,7 +129,7 @@ impl<E: StarkField> ToElements<E> for Context {

// convert proof options and trace length to elements
result.append(&mut self.options.to_elements());
result.push(E::from(self.trace_length as u64));
result.push(E::from(self.trace_length as u32));

// convert trace metadata to elements; this is done by breaking trace metadata into chunks
// of bytes which are slightly smaller than the number of bytes needed to encode a field
Expand Down Expand Up @@ -257,7 +269,7 @@ mod tests {
BaseElement::from(1_u32), // lower bits of field modulus
BaseElement::from(u32::MAX), // upper bits of field modulus
BaseElement::from(ext_fri),
BaseElement::from(grinding_factor as u32),
BaseElement::from(grinding_factor),
BaseElement::from(blowup_factor as u32),
BaseElement::from(num_queries as u32),
BaseElement::from(trace_length as u32),
Expand Down
5 changes: 3 additions & 2 deletions crypto/benches/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use winter_crypto::{build_merkle_nodes, concurrent, hashers::Blake3_256, Hasher}
type Blake3 = Blake3_256<BaseElement>;
type Blake3Digest = <Blake3 as Hasher>::Digest;

#[allow(clippy::needless_range_loop)]
pub fn merkle_tree_construction(c: &mut Criterion) {
let mut merkle_group = c.benchmark_group("merkle tree construction");

Expand All @@ -26,10 +27,10 @@ pub fn merkle_tree_construction(c: &mut Criterion) {
res
};
merkle_group.bench_with_input(BenchmarkId::new("sequential", size), &data, |b, i| {
b.iter(|| build_merkle_nodes::<Blake3>(&i))
b.iter(|| build_merkle_nodes::<Blake3>(i))
});
merkle_group.bench_with_input(BenchmarkId::new("concurrent", size), &data, |b, i| {
b.iter(|| concurrent::build_merkle_nodes::<Blake3>(&i))
b.iter(|| concurrent::build_merkle_nodes::<Blake3>(i))
});
}
}
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/hash/griffin/griffin64_256_jive/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use super::{Digest, DIGEST_SIZE};
use core::slice;
use math::{fields::f64::BaseElement, StarkField};
use math::fields::f64::BaseElement;
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

// DIGEST TRAIT IMPLEMENTATIONS
Expand Down
7 changes: 4 additions & 3 deletions crypto/src/hash/griffin/griffin64_256_jive/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use proptest::prelude::*;

use rand_utils::{rand_array, rand_value};

#[allow(clippy::needless_range_loop)]
#[test]
fn mds_inv_test() {
let mut mul_result = [[BaseElement::new(0); STATE_WIDTH]; STATE_WIDTH];
Expand Down Expand Up @@ -196,15 +197,15 @@ fn apply_mds_naive(state: &mut [BaseElement; STATE_WIDTH]) {

proptest! {
#[test]
fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) {
fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) {

let mut v1 = [BaseElement::ZERO;STATE_WIDTH];
let mut v1 = [BaseElement::ZERO; STATE_WIDTH];
let mut v2;

for i in 0..STATE_WIDTH {
v1[i] = BaseElement::new(a[i]);
}
v2 = v1.clone();
v2 = v1;

apply_mds_naive(&mut v1);
GriffinJive64_256::apply_linear(&mut v2);
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/hash/rescue/rp64_256/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use super::{Digest, DIGEST_SIZE};
use core::slice;
use math::{fields::f64::BaseElement, StarkField};
use math::fields::f64::BaseElement;
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

// DIGEST TRAIT IMPLEMENTATIONS
Expand Down
6 changes: 3 additions & 3 deletions crypto/src/hash/rescue/rp64_256/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,15 @@ fn apply_mds_naive(state: &mut [BaseElement; STATE_WIDTH]) {

proptest! {
#[test]
fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) {
fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) {

let mut v1 = [BaseElement::ZERO;STATE_WIDTH];
let mut v1 = [BaseElement::ZERO; STATE_WIDTH];
let mut v2;

for i in 0..STATE_WIDTH {
v1[i] = BaseElement::new(a[i]);
}
v2 = v1.clone();
v2 = v1;

apply_mds_naive(&mut v1);
Rp64_256::apply_mds(&mut v2);
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/hash/rescue/rp64_256_jive/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use super::{Digest, DIGEST_SIZE};
use core::slice;
use math::{fields::f64::BaseElement, StarkField};
use math::fields::f64::BaseElement;
use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};

// DIGEST TRAIT IMPLEMENTATIONS
Expand Down
9 changes: 5 additions & 4 deletions crypto/src/hash/rescue/rp64_256_jive/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use proptest::prelude::*;

use rand_utils::{rand_array, rand_value};

#[allow(clippy::needless_range_loop)]
#[test]
fn mds_inv_test() {
let mut mul_result = [[BaseElement::new(0); STATE_WIDTH]; STATE_WIDTH];
Expand All @@ -36,7 +37,7 @@ fn mds_inv_test() {
#[test]
fn test_alphas() {
let e: BaseElement = rand_value();
let e_exp = e.exp(ALPHA.into());
let e_exp = e.exp(ALPHA);
assert_eq!(e, e_exp.exp(INV_ALPHA));
}

Expand Down Expand Up @@ -189,15 +190,15 @@ fn apply_mds_naive(state: &mut [BaseElement; STATE_WIDTH]) {

proptest! {
#[test]
fn mds_freq_proptest(a in any::<[u64;STATE_WIDTH]>()) {
fn mds_freq_proptest(a in any::<[u64; STATE_WIDTH]>()) {

let mut v1 = [BaseElement::ZERO;STATE_WIDTH];
let mut v1 = [BaseElement::ZERO; STATE_WIDTH];
let mut v2;

for i in 0..STATE_WIDTH {
v1[i] = BaseElement::new(a[i]);
}
v2 = v1.clone();
v2 = v1;

apply_mds_naive(&mut v1);
RpJive64_256::apply_mds(&mut v2);
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/merkle/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ mod tests {
proptest! {
#[test]
fn build_merkle_nodes_concurrent(ref data in vec(any::<[u8; 32]>(), 256..257).no_shrink()) {
let leaves = ByteDigest::bytes_as_digests(&data).to_vec();
let leaves = ByteDigest::bytes_as_digests(data).to_vec();
let sequential = super::super::build_merkle_nodes::<Sha3_256<BaseElement>>(&leaves);
let concurrent = super::build_merkle_nodes::<Sha3_256<BaseElement>>(&leaves);
assert_eq!(concurrent, sequential);
Expand Down
4 changes: 2 additions & 2 deletions examples/src/lamport/aggregate/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ fn apply_message_acc(
let m0_bit = state[0];
let m1_bit = state[1];

state[0] = BaseElement::from((m0 >> (cycle_num + 1)) & 1);
state[1] = BaseElement::from((m1 >> (cycle_num + 1)) & 1);
state[0] = BaseElement::new((m0 >> (cycle_num + 1)) & 1);
state[1] = BaseElement::new((m1 >> (cycle_num + 1)) & 1);
state[2] += power_of_two * m0_bit;
state[3] += power_of_two * m1_bit;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/src/lamport/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ pub fn message_to_elements(message: &[u8]) -> [BaseElement; 2] {
let checksum = m0.count_zeros() + m1.count_zeros();
let m1 = m1 | ((checksum as u128) << 119);

[BaseElement::from(m0), BaseElement::from(m1)]
[BaseElement::new(m0), BaseElement::new(m1)]
}

/// Reduces a list of public key elements to a single 32-byte value. The reduction is done
Expand Down
4 changes: 2 additions & 2 deletions examples/src/lamport/threshold/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ impl Air for LamportThresholdAir {
let mut m1_bits = Vec::with_capacity(SIG_CYCLE_LEN);
for i in 0..SIG_CYCLE_LEN {
let cycle_num = i / HASH_CYCLE_LEN;
m0_bits.push(BaseElement::from((m0 >> cycle_num) & 1));
m1_bits.push(BaseElement::from((m1 >> cycle_num) & 1));
m0_bits.push(BaseElement::new((m0 >> cycle_num) & 1));
m1_bits.push(BaseElement::new((m1 >> cycle_num) & 1));
}
result.push(m0_bits);
result.push(m1_bits);
Expand Down
6 changes: 3 additions & 3 deletions examples/src/lamport/threshold/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ fn update_sig_verification_state(
} else {
// for the 8th step of very cycle do the following:

let m0_bit = BaseElement::from((sig_info.m0 >> cycle_num) & 1);
let m1_bit = BaseElement::from((sig_info.m1 >> cycle_num) & 1);
let m0_bit = BaseElement::new((sig_info.m0 >> cycle_num) & 1);
let m1_bit = BaseElement::new((sig_info.m1 >> cycle_num) & 1);
let mp_bit = merkle_path_idx[0];

// copy next set of public keys into the registers computing hash of the public key
Expand Down Expand Up @@ -345,7 +345,7 @@ fn update_merkle_path_index(
let index_bit = state[0];
// the cycle is offset by +1 because the first node in the Merkle path is redundant and we
// get it by hashing the public key
state[0] = BaseElement::from((index >> (cycle_num + 1)) & 1);
state[0] = BaseElement::new((index >> (cycle_num + 1)) & 1);
state[1] += power_of_two * index_bit;
}

Expand Down
2 changes: 1 addition & 1 deletion fri/src/folding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ where
// build offset inverses and twiddles used during polynomial interpolation
let inv_offsets = get_inv_offsets(values.len(), domain_offset, N);
let inv_twiddles = get_inv_twiddles::<B>(N);
let len_offset = E::inv((N as u64).into());
let len_offset = E::inv((N as u32).into());

let mut result = unsafe { uninit_vector(values.len()) };
iter_mut!(result)
Expand Down
35 changes: 23 additions & 12 deletions math/src/fft/concurrent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,47 @@ pub fn evaluate_poly_with_offset<B: StarkField, E: FieldElement<BaseField = B>>(

/// Uses FFT algorithm to interpolate a polynomial from provided `values`; the interpolation
/// is done in-place, meaning `values` are updated with polynomial coefficients.
pub fn interpolate_poly<B, E>(v: &mut [E], inv_twiddles: &[B])
///
/// # Panics
/// Panics if the length of `values` is greater than [u32::MAX].
pub fn interpolate_poly<B, E>(values: &mut [E], inv_twiddles: &[B])
where
B: StarkField,
E: FieldElement<BaseField = B>,
{
split_radix_fft(v, inv_twiddles);
let inv_length = E::inv((v.len() as u64).into());
v.par_iter_mut().for_each(|e| *e *= inv_length);
permute(v);
assert!(values.len() <= u32::MAX as usize, "too many values");

split_radix_fft(values, inv_twiddles);
let inv_length = E::inv((values.len() as u32).into());
values.par_iter_mut().for_each(|e| *e *= inv_length);
permute(values);
}

/// Uses FFT algorithm to interpolate a polynomial from provided `values` over the domain defined
/// by `inv_twiddles` and offset by `domain_offset` factor.
///
///
/// # Panics
/// Panics if the length of `values` is greater than [u32::MAX].
pub fn interpolate_poly_with_offset<B, E>(values: &mut [E], inv_twiddles: &[B], domain_offset: B)
where
B: StarkField,
E: FieldElement<BaseField = B>,
{
assert!(values.len() <= u32::MAX as usize, "too many values");

split_radix_fft(values, inv_twiddles);
permute(values);

let domain_offset = E::inv(domain_offset.into());
let inv_len = E::inv((values.len() as u64).into());
let inv_len = E::inv((values.len() as u32).into());
let batch_size = values.len() / rayon::current_num_threads().next_power_of_two();

values.par_chunks_mut(batch_size).enumerate().for_each(|(i, batch)| {
let mut offset = domain_offset.exp(((i * batch_size) as u64).into()) * inv_len;
for coeff in batch.iter_mut() {
*coeff = *coeff * offset;
offset = offset * domain_offset;
*coeff *= offset;
offset *= domain_offset;
}
});
}
Expand Down Expand Up @@ -136,7 +147,7 @@ pub(super) fn split_radix_fft<B: StarkField, E: FieldElement<BaseField = B>>(
// apply inner FFTs
values
.par_chunks_mut(outer_len)
.for_each(|row| row.fft_in_place_raw(&twiddles, stretch, stretch, 0));
.for_each(|row| row.fft_in_place_raw(twiddles, stretch, stretch, 0));

// transpose inner x inner x stretch square matrix
transpose_square_stretch(values, inner_len, stretch);
Expand All @@ -149,10 +160,10 @@ pub(super) fn split_radix_fft<B: StarkField, E: FieldElement<BaseField = B>>(
let mut outer_twiddle = inner_twiddle;
for element in row.iter_mut().skip(1) {
*element = (*element).mul_base(outer_twiddle);
outer_twiddle = outer_twiddle * inner_twiddle;
outer_twiddle *= inner_twiddle;
}
}
row.fft_in_place(&twiddles);
row.fft_in_place(twiddles);
});
}

Expand Down Expand Up @@ -216,7 +227,7 @@ fn clone_and_shift<E: FieldElement>(source: &[E], destination: &mut [E], offset:
let mut factor = offset.exp(((i * batch_size) as u64).into());
for (s, d) in source.iter().zip(destination.iter_mut()) {
*d = (*s).mul_base(factor);
factor = factor * offset;
factor *= offset;
}
});
}
13 changes: 11 additions & 2 deletions math/src/fft/serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,16 @@ where

/// Interpolates `evaluations` over a domain of length `evaluations.len()` in the field specified
/// `B` into a polynomial in coefficient form using the FFT algorithm.
///
/// # Panics
/// Panics if the length of `evaluations` is greater than [u32::MAX].
pub fn interpolate_poly<B, E>(evaluations: &mut [E], inv_twiddles: &[B])
where
B: StarkField,
E: FieldElement<BaseField = B>,
{
let inv_length = B::inv((evaluations.len() as u64).into());
assert!(evaluations.len() <= u32::MAX as usize, "too many evaluations");
let inv_length = B::inv((evaluations.len() as u32).into());
evaluations.fft_in_place(inv_twiddles);
evaluations.shift_by(inv_length);
evaluations.permute();
Expand All @@ -71,6 +75,9 @@ where
/// Interpolates `evaluations` over a domain of length `evaluations.len()` and shifted by
/// `domain_offset` in the field specified by `B` into a polynomial in coefficient form using
/// the FFT algorithm.
///
/// # Panics
/// Panics if the length of `evaluations` is greater than [u32::MAX].
pub fn interpolate_poly_with_offset<B, E>(
evaluations: &mut [E],
inv_twiddles: &[B],
Expand All @@ -79,11 +86,13 @@ pub fn interpolate_poly_with_offset<B, E>(
B: StarkField,
E: FieldElement<BaseField = B>,
{
assert!(evaluations.len() <= u32::MAX as usize, "too many evaluations");

evaluations.fft_in_place(inv_twiddles);
evaluations.permute();

let domain_offset = B::inv(domain_offset);
let offset = B::inv((evaluations.len() as u64).into());
let offset = B::inv((evaluations.len() as u32).into());

evaluations.shift_by_series(offset, domain_offset);
}
12 changes: 0 additions & 12 deletions math/src/field/extensions/cubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,18 +301,6 @@ impl<B: ExtensibleField<3>> From<B> for CubeExtension<B> {
}
}

impl<B: ExtensibleField<3>> From<u128> for CubeExtension<B> {
fn from(value: u128) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}

impl<B: ExtensibleField<3>> From<u64> for CubeExtension<B> {
fn from(value: u64) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}

impl<B: ExtensibleField<3>> From<u32> for CubeExtension<B> {
fn from(value: u32) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
Expand Down
Loading

0 comments on commit 3bde922

Please sign in to comment.