Skip to content

Commit

Permalink
benchmark air and prover
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer committed Mar 14, 2024
1 parent 589fc8a commit 3c9c5ea
Showing 1 changed file with 185 additions and 46 deletions.
231 changes: 185 additions & 46 deletions prover/benches/lagrange_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

use air::TraceInfo;
use air::{
Air, AirContext, Assertion, AuxTraceRandElements, ConstraintCompositionCoefficients,
EvaluationFrame, FieldExtension, ProofOptions, TraceInfo, TransitionConstraintDegree,
};
use criterion::{criterion_group, criterion_main, Criterion};
use math::{fields::f64::BaseElement, FieldElement};
use winter_prover::{matrix::ColMatrix, Trace};
use crypto::{hashers::Blake3_256, DefaultRandomCoin};
use math::{fields::f64::BaseElement, ExtensionOf, FieldElement};
use winter_prover::{
matrix::ColMatrix, DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, Trace,
TracePolyTable,
};

fn prove_with_lagrange_kernel(c: &mut Criterion) {}

Expand All @@ -20,25 +27,16 @@ struct LagrangeTrace {
// dummy main trace
main_trace: ColMatrix<BaseElement>,
info: TraceInfo,
lagrange_kernel_col_idx: Option<usize>,
}

impl LagrangeTrace {
fn new(
trace_len: usize,
aux_segment_width: usize,
lagrange_kernel_col_idx: Option<usize>,
) -> Self {
fn new(trace_len: usize, aux_segment_width: usize) -> Self {
assert!(trace_len < u32::MAX.try_into().unwrap());

let main_trace_col: Vec<BaseElement> =
(0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect();

let num_aux_segment_rands = if lagrange_kernel_col_idx.is_some() {
trace_len.ilog2() as usize
} else {
1
};
let num_aux_segment_rands = trace_len.ilog2() as usize;

Self {
main_trace: ColMatrix::new(vec![main_trace_col]),
Expand All @@ -49,7 +47,6 @@ impl LagrangeTrace {
trace_len,
vec![],
),
lagrange_kernel_col_idx,
}
}

Expand All @@ -69,7 +66,8 @@ impl Trace for LagrangeTrace {
&self.main_trace
}

/// Each non-Lagrange kernel segment will simply take the sum the random elements + its index
/// Each non-Lagrange kernel segment will simply take the sum the random elements, and multiply
/// by the main column
fn build_aux_segment<E: FieldElement<BaseField = Self::BaseField>>(
&mut self,
aux_segments: &[ColMatrix<E>],
Expand All @@ -80,39 +78,40 @@ impl Trace for LagrangeTrace {

let mut columns = Vec::new();

for col_idx in 0..self.aux_trace_width() {
let column = if self
.lagrange_kernel_col_idx
.map(|lagrange_col_idx| lagrange_col_idx == col_idx)
.unwrap_or_default()
{
// building the Lagrange kernel column
let r = lagrange_kernel_rand_elements.unwrap();

let mut column = Vec::with_capacity(self.len());

for row_idx in 0..self.len() {
let mut row_value = E::ZERO;
for (bit_idx, &r_i) in r.iter().enumerate() {
if row_idx & (1 << bit_idx) == 0 {
row_value *= E::ONE - r_i;
} else {
row_value *= r_i;
}
// first build the Lagrange kernel column
{
let r = lagrange_kernel_rand_elements.unwrap();

let mut lagrange_col = Vec::with_capacity(self.len());

for row_idx in 0..self.len() {
let mut row_value = E::ZERO;
for (bit_idx, &r_i) in r.iter().enumerate() {
if row_idx & (1 << bit_idx) == 0 {
row_value *= E::ONE - r_i;
} else {
row_value *= r_i;
}
column.push(row_value);
}
lagrange_col.push(row_value);
}

columns.push(lagrange_col);
}

column
} else {
// building a dummy auxiliary column
(0..self.len())
.map(|row_idx| {
rand_elements.iter().fold(E::ZERO, |acc, &r| acc + r)
+ E::from(row_idx as u32)
})
.collect()
};
// Then all other auxiliary columns
for _ in 1..self.aux_trace_width() {
// building a dummy auxiliary column
let column = self
.main_segment()
.get_column(0)
.iter()
.map(|row_val| {
let rand_summed = rand_elements.iter().fold(E::ZERO, |acc, &r| acc + r);

rand_summed.mul_base(*row_val)
})
.collect();

columns.push(column);
}
Expand All @@ -128,3 +127,143 @@ impl Trace for LagrangeTrace {
self.main_trace.read_row_into(next_row_idx, frame.next_mut());
}
}

// AIR
// =================================================================================================

struct LagrangeKernelAir {
context: AirContext<BaseElement>,
}

impl Air for LagrangeKernelAir {
type BaseField = BaseElement;

type PublicInputs = ();

fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self {
Self {
context: AirContext::new_multi_segment(
trace_info,
vec![TransitionConstraintDegree::new(1)],
vec![TransitionConstraintDegree::new(2)],
1,
1,
Some(0),
options,
),
}
}

fn context(&self) -> &AirContext<Self::BaseField> {
&self.context
}

fn evaluate_transition<E: math::FieldElement<BaseField = Self::BaseField>>(
&self,
frame: &EvaluationFrame<E>,
_periodic_values: &[E],
result: &mut [E],
) {
let current = frame.current()[0];
let next = frame.next()[0];

// increments by 1
result[0] = next - current - E::ONE;
}

fn get_assertions(&self) -> Vec<Assertion<Self::BaseField>> {
vec![Assertion::single(0, 0, BaseElement::ZERO)]
}

fn evaluate_aux_transition<F, E>(
&self,
main_frame: &EvaluationFrame<F>,
aux_frame: &EvaluationFrame<E>,
_periodic_values: &[F],
aux_rand_elements: &AuxTraceRandElements<E>,
result: &mut [E],
) where
F: FieldElement<BaseField = Self::BaseField>,
E: FieldElement<BaseField = Self::BaseField> + ExtensionOf<F>,
{
let main_frame_current = main_frame.current()[0];
let aux_next = aux_frame.next()[0];

let rand_summed: E = aux_rand_elements
.get_segment_elements(0)
.iter()
.fold(E::ZERO, |acc, x| acc + *x);

result[0] = aux_next - rand_summed.mul_base(main_frame_current);
}

fn get_aux_assertions<E: FieldElement<BaseField = Self::BaseField>>(
&self,
aux_rand_elements: &AuxTraceRandElements<E>,
) -> Vec<Assertion<E>> {
let rand_summed: E = aux_rand_elements
.get_segment_elements(0)
.iter()
.fold(E::ZERO, |acc, x| acc + *x);

vec![Assertion::single(1, 0, rand_summed)]
}
}

// LagrangeProver
// ================================================================================================

struct LagrangeProver {
options: ProofOptions,
}

impl LagrangeProver {
fn new() -> Self {
Self {
options: ProofOptions::new(1, 2, 0, FieldExtension::None, 2, 1),
}
}
}

impl Prover for LagrangeProver {
type BaseField = BaseElement;
type Air = LagrangeKernelAir;
type Trace = LagrangeTrace;
type HashFn = Blake3_256<BaseElement>;
type RandomCoin = DefaultRandomCoin<Self::HashFn>;
type TraceLde<E: FieldElement<BaseField = BaseElement>> = DefaultTraceLde<E, Self::HashFn>;
type ConstraintEvaluator<'a, E: FieldElement<BaseField = BaseElement>> =
DefaultConstraintEvaluator<'a, LagrangeKernelAir, E>;

fn get_pub_inputs(&self, _trace: &Self::Trace) -> <<Self as Prover>::Air as Air>::PublicInputs {
()
}

fn options(&self) -> &ProofOptions {
&self.options
}

fn new_trace_lde<E>(
&self,
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
) -> (Self::TraceLde<E>, TracePolyTable<E>)
where
E: math::FieldElement<BaseField = Self::BaseField>,
{
DefaultTraceLde::new(trace_info, main_trace, domain)
}

fn new_evaluator<'a, E>(
&self,
air: &'a Self::Air,
aux_rand_elements: AuxTraceRandElements<E>,
composition_coefficients: ConstraintCompositionCoefficients<E>,
) -> Self::ConstraintEvaluator<'a, E>
where
E: math::FieldElement<BaseField = Self::BaseField>,
{
DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients)
}
}

0 comments on commit 3c9c5ea

Please sign in to comment.