From 6df07fe4c10ab33453f22f1e1a7890cfeb607ed9 Mon Sep 17 00:00:00 2001 From: Shahar Samocha Date: Mon, 25 Mar 2024 17:42:53 +0200 Subject: [PATCH] Constraints eval for wide fib --- .../wide_fibonacci/constraint_eval.rs | 406 +++++++++++++++++- src/examples/wide_fibonacci/mod.rs | 50 ++- 2 files changed, 451 insertions(+), 5 deletions(-) diff --git a/src/examples/wide_fibonacci/constraint_eval.rs b/src/examples/wide_fibonacci/constraint_eval.rs index 42e0275c9..81eda35be 100644 --- a/src/examples/wide_fibonacci/constraint_eval.rs +++ b/src/examples/wide_fibonacci/constraint_eval.rs @@ -1,9 +1,16 @@ +use num_traits::Zero; + use super::structs::WideFibComponent; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentTrace, Mask}; use crate::core::backend::CPUBackend; use crate::core::circle::CirclePoint; +use crate::core::constraints::coset_vanishing; +use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; +use crate::core::poly::circle::CanonicCoset; +use crate::core::utils::bit_reverse_index; use crate::core::ColumnVec; impl Component for WideFibComponent { @@ -19,12 +26,405 @@ impl Component for WideFibComponent { Mask(vec![vec![0_usize]; 256]) } + // TODO(ShaharS), precompute random coeff powers. + // TODO(ShaharS), use intermidiate value to save the computation of the squares. fn evaluate_constraint_quotients_on_domain( &self, - _trace: &ComponentTrace<'_, CPUBackend>, - _evaluation_accumulator: &mut DomainEvaluationAccumulator, + trace: &ComponentTrace<'_, CPUBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, ) { - unimplemented!("not implemented") + let mut trace_evals = vec![]; + // TODO(ShaharS), Share this LDE with the commitment LDE. + for poly_index in 0..64 { + let poly = &trace.columns[poly_index]; + let trace_eval_domain = + CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); + trace_evals.push(poly.evaluate(trace_eval_domain).bit_reverse()); + } + let zero_domain = CanonicCoset::new(self.log_size).coset; + let eval_domain = CanonicCoset::new(self.log_size + 1).circle_domain(); + let mut denoms = vec![]; + for point in eval_domain.iter() { + denoms.push(coset_vanishing(zero_domain, point)); + } + let mut denom_inverses = + vec![BaseField::zero(); 1 << (self.max_constraint_log_degree_bound())]; + BaseField::batch_inverse(&denoms, &mut denom_inverses); + let mut numerators = + vec![SecureField::zero(); 1 << (self.max_constraint_log_degree_bound())]; + let random_coeff = evaluation_accumulator.random_coeff; + let [mut accum] = + evaluation_accumulator.columns([(self.max_constraint_log_degree_bound(), 64)]); + for (i, point_index) in eval_domain.iter_indices().enumerate() { + numerators[i] = numerators[i] * random_coeff + + (trace_evals[2].get_at(point_index) + - ((trace_evals[0].get_at(point_index) * trace_evals[0].get_at(point_index)) + + (trace_evals[1].get_at(point_index) + * trace_evals[1].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[3].get_at(point_index) + - ((trace_evals[1].get_at(point_index) * trace_evals[1].get_at(point_index)) + + (trace_evals[2].get_at(point_index) + * trace_evals[2].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[4].get_at(point_index) + - ((trace_evals[2].get_at(point_index) * trace_evals[2].get_at(point_index)) + + (trace_evals[3].get_at(point_index) + * trace_evals[3].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[5].get_at(point_index) + - ((trace_evals[3].get_at(point_index) * trace_evals[3].get_at(point_index)) + + (trace_evals[4].get_at(point_index) + * trace_evals[4].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[6].get_at(point_index) + - ((trace_evals[4].get_at(point_index) * trace_evals[4].get_at(point_index)) + + (trace_evals[5].get_at(point_index) + * trace_evals[5].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[7].get_at(point_index) + - ((trace_evals[5].get_at(point_index) * trace_evals[5].get_at(point_index)) + + (trace_evals[6].get_at(point_index) + * trace_evals[6].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[8].get_at(point_index) + - ((trace_evals[6].get_at(point_index) * trace_evals[6].get_at(point_index)) + + (trace_evals[7].get_at(point_index) + * trace_evals[7].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[9].get_at(point_index) + - ((trace_evals[7].get_at(point_index) * trace_evals[7].get_at(point_index)) + + (trace_evals[8].get_at(point_index) + * trace_evals[8].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[10].get_at(point_index) + - ((trace_evals[8].get_at(point_index) * trace_evals[8].get_at(point_index)) + + (trace_evals[9].get_at(point_index) + * trace_evals[9].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[11].get_at(point_index) + - ((trace_evals[9].get_at(point_index) * trace_evals[9].get_at(point_index)) + + (trace_evals[10].get_at(point_index) + * trace_evals[10].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[12].get_at(point_index) + - ((trace_evals[10].get_at(point_index) + * trace_evals[10].get_at(point_index)) + + (trace_evals[11].get_at(point_index) + * trace_evals[11].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[13].get_at(point_index) + - ((trace_evals[11].get_at(point_index) + * trace_evals[11].get_at(point_index)) + + (trace_evals[12].get_at(point_index) + * trace_evals[12].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[14].get_at(point_index) + - ((trace_evals[12].get_at(point_index) + * trace_evals[12].get_at(point_index)) + + (trace_evals[13].get_at(point_index) + * trace_evals[13].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[15].get_at(point_index) + - ((trace_evals[13].get_at(point_index) + * trace_evals[13].get_at(point_index)) + + (trace_evals[14].get_at(point_index) + * trace_evals[14].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[16].get_at(point_index) + - ((trace_evals[14].get_at(point_index) + * trace_evals[14].get_at(point_index)) + + (trace_evals[15].get_at(point_index) + * trace_evals[15].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[17].get_at(point_index) + - ((trace_evals[15].get_at(point_index) + * trace_evals[15].get_at(point_index)) + + (trace_evals[16].get_at(point_index) + * trace_evals[16].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[18].get_at(point_index) + - ((trace_evals[16].get_at(point_index) + * trace_evals[16].get_at(point_index)) + + (trace_evals[17].get_at(point_index) + * trace_evals[17].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[19].get_at(point_index) + - ((trace_evals[17].get_at(point_index) + * trace_evals[17].get_at(point_index)) + + (trace_evals[18].get_at(point_index) + * trace_evals[18].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[20].get_at(point_index) + - ((trace_evals[18].get_at(point_index) + * trace_evals[18].get_at(point_index)) + + (trace_evals[19].get_at(point_index) + * trace_evals[19].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[21].get_at(point_index) + - ((trace_evals[19].get_at(point_index) + * trace_evals[19].get_at(point_index)) + + (trace_evals[20].get_at(point_index) + * trace_evals[20].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[22].get_at(point_index) + - ((trace_evals[20].get_at(point_index) + * trace_evals[20].get_at(point_index)) + + (trace_evals[21].get_at(point_index) + * trace_evals[21].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[23].get_at(point_index) + - ((trace_evals[21].get_at(point_index) + * trace_evals[21].get_at(point_index)) + + (trace_evals[22].get_at(point_index) + * trace_evals[22].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[24].get_at(point_index) + - ((trace_evals[22].get_at(point_index) + * trace_evals[22].get_at(point_index)) + + (trace_evals[23].get_at(point_index) + * trace_evals[23].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[25].get_at(point_index) + - ((trace_evals[23].get_at(point_index) + * trace_evals[23].get_at(point_index)) + + (trace_evals[24].get_at(point_index) + * trace_evals[24].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[26].get_at(point_index) + - ((trace_evals[24].get_at(point_index) + * trace_evals[24].get_at(point_index)) + + (trace_evals[25].get_at(point_index) + * trace_evals[25].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[27].get_at(point_index) + - ((trace_evals[25].get_at(point_index) + * trace_evals[25].get_at(point_index)) + + (trace_evals[26].get_at(point_index) + * trace_evals[26].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[28].get_at(point_index) + - ((trace_evals[26].get_at(point_index) + * trace_evals[26].get_at(point_index)) + + (trace_evals[27].get_at(point_index) + * trace_evals[27].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[29].get_at(point_index) + - ((trace_evals[27].get_at(point_index) + * trace_evals[27].get_at(point_index)) + + (trace_evals[28].get_at(point_index) + * trace_evals[28].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[30].get_at(point_index) + - ((trace_evals[28].get_at(point_index) + * trace_evals[28].get_at(point_index)) + + (trace_evals[29].get_at(point_index) + * trace_evals[29].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[31].get_at(point_index) + - ((trace_evals[29].get_at(point_index) + * trace_evals[29].get_at(point_index)) + + (trace_evals[30].get_at(point_index) + * trace_evals[30].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[32].get_at(point_index) + - ((trace_evals[30].get_at(point_index) + * trace_evals[30].get_at(point_index)) + + (trace_evals[31].get_at(point_index) + * trace_evals[31].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[33].get_at(point_index) + - ((trace_evals[31].get_at(point_index) + * trace_evals[31].get_at(point_index)) + + (trace_evals[32].get_at(point_index) + * trace_evals[32].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[34].get_at(point_index) + - ((trace_evals[32].get_at(point_index) + * trace_evals[32].get_at(point_index)) + + (trace_evals[33].get_at(point_index) + * trace_evals[33].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[35].get_at(point_index) + - ((trace_evals[33].get_at(point_index) + * trace_evals[33].get_at(point_index)) + + (trace_evals[34].get_at(point_index) + * trace_evals[34].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[36].get_at(point_index) + - ((trace_evals[34].get_at(point_index) + * trace_evals[34].get_at(point_index)) + + (trace_evals[35].get_at(point_index) + * trace_evals[35].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[37].get_at(point_index) + - ((trace_evals[35].get_at(point_index) + * trace_evals[35].get_at(point_index)) + + (trace_evals[36].get_at(point_index) + * trace_evals[36].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[38].get_at(point_index) + - ((trace_evals[36].get_at(point_index) + * trace_evals[36].get_at(point_index)) + + (trace_evals[37].get_at(point_index) + * trace_evals[37].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[39].get_at(point_index) + - ((trace_evals[37].get_at(point_index) + * trace_evals[37].get_at(point_index)) + + (trace_evals[38].get_at(point_index) + * trace_evals[38].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[40].get_at(point_index) + - ((trace_evals[38].get_at(point_index) + * trace_evals[38].get_at(point_index)) + + (trace_evals[39].get_at(point_index) + * trace_evals[39].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[41].get_at(point_index) + - ((trace_evals[39].get_at(point_index) + * trace_evals[39].get_at(point_index)) + + (trace_evals[40].get_at(point_index) + * trace_evals[40].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[42].get_at(point_index) + - ((trace_evals[40].get_at(point_index) + * trace_evals[40].get_at(point_index)) + + (trace_evals[41].get_at(point_index) + * trace_evals[41].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[43].get_at(point_index) + - ((trace_evals[41].get_at(point_index) + * trace_evals[41].get_at(point_index)) + + (trace_evals[42].get_at(point_index) + * trace_evals[42].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[44].get_at(point_index) + - ((trace_evals[42].get_at(point_index) + * trace_evals[42].get_at(point_index)) + + (trace_evals[43].get_at(point_index) + * trace_evals[43].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[45].get_at(point_index) + - ((trace_evals[43].get_at(point_index) + * trace_evals[43].get_at(point_index)) + + (trace_evals[44].get_at(point_index) + * trace_evals[44].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[46].get_at(point_index) + - ((trace_evals[44].get_at(point_index) + * trace_evals[44].get_at(point_index)) + + (trace_evals[45].get_at(point_index) + * trace_evals[45].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[47].get_at(point_index) + - ((trace_evals[45].get_at(point_index) + * trace_evals[45].get_at(point_index)) + + (trace_evals[46].get_at(point_index) + * trace_evals[46].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[48].get_at(point_index) + - ((trace_evals[46].get_at(point_index) + * trace_evals[46].get_at(point_index)) + + (trace_evals[47].get_at(point_index) + * trace_evals[47].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[49].get_at(point_index) + - ((trace_evals[47].get_at(point_index) + * trace_evals[47].get_at(point_index)) + + (trace_evals[48].get_at(point_index) + * trace_evals[48].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[50].get_at(point_index) + - ((trace_evals[48].get_at(point_index) + * trace_evals[48].get_at(point_index)) + + (trace_evals[49].get_at(point_index) + * trace_evals[49].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[51].get_at(point_index) + - ((trace_evals[49].get_at(point_index) + * trace_evals[49].get_at(point_index)) + + (trace_evals[50].get_at(point_index) + * trace_evals[50].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[52].get_at(point_index) + - ((trace_evals[50].get_at(point_index) + * trace_evals[50].get_at(point_index)) + + (trace_evals[51].get_at(point_index) + * trace_evals[51].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[53].get_at(point_index) + - ((trace_evals[51].get_at(point_index) + * trace_evals[51].get_at(point_index)) + + (trace_evals[52].get_at(point_index) + * trace_evals[52].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[54].get_at(point_index) + - ((trace_evals[52].get_at(point_index) + * trace_evals[52].get_at(point_index)) + + (trace_evals[53].get_at(point_index) + * trace_evals[53].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[55].get_at(point_index) + - ((trace_evals[53].get_at(point_index) + * trace_evals[53].get_at(point_index)) + + (trace_evals[54].get_at(point_index) + * trace_evals[54].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[56].get_at(point_index) + - ((trace_evals[54].get_at(point_index) + * trace_evals[54].get_at(point_index)) + + (trace_evals[55].get_at(point_index) + * trace_evals[55].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[57].get_at(point_index) + - ((trace_evals[55].get_at(point_index) + * trace_evals[55].get_at(point_index)) + + (trace_evals[56].get_at(point_index) + * trace_evals[56].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[58].get_at(point_index) + - ((trace_evals[56].get_at(point_index) + * trace_evals[56].get_at(point_index)) + + (trace_evals[57].get_at(point_index) + * trace_evals[57].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[59].get_at(point_index) + - ((trace_evals[57].get_at(point_index) + * trace_evals[57].get_at(point_index)) + + (trace_evals[58].get_at(point_index) + * trace_evals[58].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[60].get_at(point_index) + - ((trace_evals[58].get_at(point_index) + * trace_evals[58].get_at(point_index)) + + (trace_evals[59].get_at(point_index) + * trace_evals[59].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[61].get_at(point_index) + - ((trace_evals[59].get_at(point_index) + * trace_evals[59].get_at(point_index)) + + (trace_evals[60].get_at(point_index) + * trace_evals[60].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[62].get_at(point_index) + - ((trace_evals[60].get_at(point_index) + * trace_evals[60].get_at(point_index)) + + (trace_evals[61].get_at(point_index) + * trace_evals[61].get_at(point_index)))); + numerators[i] = numerators[i] * random_coeff + + (trace_evals[63].get_at(point_index) + - ((trace_evals[61].get_at(point_index) + * trace_evals[61].get_at(point_index)) + + (trace_evals[62].get_at(point_index) + * trace_evals[62].get_at(point_index)))); + } + for (i, (num, denom)) in numerators.iter().zip(denom_inverses.iter()).enumerate() { + accum.accumulate( + bit_reverse_index(i, self.max_constraint_log_degree_bound()), + *num * *denom, + ); + } } fn evaluate_constraint_quotients_at_point( diff --git a/src/examples/wide_fibonacci/mod.rs b/src/examples/wide_fibonacci/mod.rs index 397c35a34..3680946f2 100644 --- a/src/examples/wide_fibonacci/mod.rs +++ b/src/examples/wide_fibonacci/mod.rs @@ -6,12 +6,17 @@ pub mod trace_gen; #[cfg(test)] mod tests { use itertools::Itertools; - use num_traits::Zero; + use num_traits::{One, Zero}; - use super::structs::Input; + use super::structs::{Input, WideFibComponent}; use super::trace_asserts::assert_constraints_on_row; use super::trace_gen::write_trace_row; + use crate::core::air::accumulation::DomainEvaluationAccumulator; + use crate::core::air::{Component, ComponentTrace}; + use crate::core::backend::cpu::CPUCircleEvaluation; use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::QM31; + use crate::core::poly::circle::CanonicCoset; fn fill_trace(private_input: &[Input]) -> Vec> { let zero_vec = vec![BaseField::zero(); private_input.len()]; @@ -33,4 +38,45 @@ mod tests { let flat_trace = trace.into_iter().flatten().collect_vec(); assert_constraints_on_row(&flat_trace); } + + #[test] + fn test_wide_fib_constraints() { + let wide_fib = WideFibComponent { log_size: 7 }; + let mut acc = DomainEvaluationAccumulator::new( + QM31::from_u32_unchecked(1, 2, 3, 4), + wide_fib.log_size + 1, + ); + let inputs = (0..1 << wide_fib.log_size) + .map(|i| Input { + a: BaseField::one(), + b: BaseField::from_u32_unchecked(i as u32), + }) + .collect_vec(); + + let trace = fill_trace(&inputs); + + let trace_domain = CanonicCoset::new(wide_fib.log_size); + let trace = trace + .into_iter() + .map(|col| CPUCircleEvaluation::new_canonical_ordered(trace_domain, col)) + .collect_vec(); + let trace_polys = trace + .into_iter() + .map(|eval| eval.interpolate()) + .collect_vec(); + + let trace = ComponentTrace { + columns: trace_polys.iter().collect(), + }; + + wide_fib.evaluate_constraint_quotients_on_domain(&trace, &mut acc); + + let res = acc.finalize(); + let poly = res.0[0].clone(); + for coeff in + poly.coeffs[(1 << (wide_fib.max_constraint_log_degree_bound() - 1)) + 1..].iter() + { + assert_eq!(*coeff, BaseField::zero()); + } + } }