Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TraceOodFrame #266

Merged
merged 14 commits into from
Mar 27, 2024
2 changes: 1 addition & 1 deletion air/src/proof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod queries;
pub use queries::Queries;

mod ood_frame;
pub use ood_frame::{OodFrame, OodFrameTraceStates, ParsedOodFrame};
pub use ood_frame::{OodFrame, TraceOodFrame};

mod table;
pub use table::Table;
Expand Down
202 changes: 134 additions & 68 deletions air/src/proof/ood_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,17 @@
// LICENSE file in the root directory of this source tree.

use alloc::vec::Vec;
use crypto::ElementHasher;
use math::FieldElement;
use utils::{
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader,
};

use crate::LagrangeKernelEvaluationFrame;
use crate::{EvaluationFrame, LagrangeKernelEvaluationFrame};

// OUT-OF-DOMAIN FRAME
// ================================================================================================

/// Represents an [`OodFrame`] where the trace and constraint evaluations have been parsed out.
pub struct ParsedOodFrame<E> {
pub trace_evaluations: Vec<E>,
pub lagrange_kernel_trace_evaluations: Option<Vec<E>>,
pub constraint_evaluations: Vec<E>,
}

/// Trace and constraint polynomial evaluations at an out-of-domain point.
///
/// This struct contains the following evaluations:
Expand All @@ -45,49 +39,53 @@ impl OodFrame {
// UPDATERS
// --------------------------------------------------------------------------------------------

/// Updates the trace state portion of this out-of-domain frame. This also returns a compacted
/// version of the out-of-domain frame (including the Lagrange kernel frame, if any) with the
/// rows interleaved. This is done so that reseeding of the random coin needs to be done only
/// once as opposed to once per each row.
/// Updates the trace state portion of this out-of-domain frame, and returns the hash of the
/// trace states.
///
/// The out-of-domain frame is stored as one vector of interleaved values, one from the current
/// row and the other from the next row. Given the input frame
///
/// +-------+-------+-------+-------+-------+-------+-------+-------+
/// | a1 | a2 | ... | an | c1 | c2 | ... | cm |
/// +-------+-------+-------+-------+-------+-------+-------+-------+
/// | b1 | b2 | ... | bn | d1 | d2 | ... | dm |
/// +-------+-------+-------+-------+-------+-------+-------+-------+
///
/// with n being the main trace width and m the auxiliary trace width, the values are stored as
///
/// [a1, b1, a2, b2, ..., an, bn, c1, d1, c2, d2, ..., cm, dm]
///
/// into `Self::trace_states` (as byte values).
///
/// # Panics
/// Panics if evaluation frame has already been set.
pub fn set_trace_states<E: FieldElement>(
&mut self,
trace_states: &OodFrameTraceStates<E>,
) -> Vec<E> {
pub fn set_trace_states<E, H>(&mut self, trace_ood_frame: &TraceOodFrame<E>) -> H::Digest
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
{
assert!(self.trace_states.is_empty(), "trace sates have already been set");

// save the evaluations with the current and next evaluations interleaved for each polynomial

let mut result = vec![];
for col in 0..trace_states.num_columns() {
result.push(trace_states.current_row[col]);
result.push(trace_states.next_row[col]);
}
let (main_and_aux_trace_states, lagrange_trace_states) = trace_ood_frame.to_trace_states();

// there are 2 frames: current and next
let frame_size: u8 = 2;

self.trace_states.write_u8(frame_size);
self.trace_states.write_many(&result);
self.trace_states.write_many(&main_and_aux_trace_states);

// save the Lagrange kernel evaluation frame (if any)
let lagrange_trace_states = {
let lagrange_trace_states = match trace_states.lagrange_kernel_frame {
Some(ref lagrange_trace_states) => lagrange_trace_states.inner().to_vec(),
None => Vec::new(),
};

{
// trace states length will be smaller than u8::MAX, since it is `== log2(trace_len) + 1`
debug_assert!(lagrange_trace_states.len() < u8::MAX.into());
self.lagrange_kernel_trace_states.write_u8(lagrange_trace_states.len() as u8);
self.lagrange_kernel_trace_states.write_many(&lagrange_trace_states);

lagrange_trace_states
};

result.into_iter().chain(lagrange_trace_states).collect()
let elements_to_hash: Vec<E> =
main_and_aux_trace_states.into_iter().chain(lagrange_trace_states).collect();

H::hash_elements(&elements_to_hash)
}

/// Updates constraint evaluation portion of this out-of-domain frame.
Expand All @@ -104,16 +102,16 @@ impl OodFrame {

// PARSER
// --------------------------------------------------------------------------------------------
/// Returns main and auxiliary (if any) trace evaluation frames and a vector of out-of-domain
/// constraint evaluations contained in `self`.
/// Returns an out-of-domain trace frame and a vector of out-of-domain constraint evaluations
/// contained in `self`.
///
/// # Panics
/// Panics if either `main_trace_width` or `num_evaluations` are equal to zero.
///
/// # Errors
/// Returns an error if:
/// * Valid [`crate::EvaluationFrame`]s for the specified `main_trace_width` and `aux_trace_width`
/// could not be parsed from the internal bytes.
/// * Valid [`crate::EvaluationFrame`]s for the specified `main_trace_width` and
/// `aux_trace_width` could not be parsed from the internal bytes.
/// * A vector of evaluations specified by `num_evaluations` could not be parsed from the
/// internal bytes.
/// * Any unconsumed bytes remained after the parsing was complete.
Expand All @@ -122,24 +120,39 @@ impl OodFrame {
main_trace_width: usize,
aux_trace_width: usize,
num_evaluations: usize,
) -> Result<ParsedOodFrame<E>, DeserializationError> {
) -> Result<(TraceOodFrame<E>, Vec<E>), DeserializationError> {
assert!(main_trace_width > 0, "trace width cannot be zero");
assert!(num_evaluations > 0, "number of evaluations cannot be zero");

// parse main and auxiliary trace evaluation frames
let mut reader = SliceReader::new(&self.trace_states);
let frame_size = reader.read_u8()? as usize;
let trace = reader.read_many((main_trace_width + aux_trace_width) * frame_size)?;
// Parse main and auxiliary trace evaluation frames. This does the reverse operation done in
// `set_trace_states()`.
let (current_row, next_row) = {
let mut reader = SliceReader::new(&self.trace_states);
let frame_size = reader.read_u8()? as usize;
let trace = reader.read_many((main_trace_width + aux_trace_width) * frame_size)?;

if reader.has_more_bytes() {
return Err(DeserializationError::UnconsumedBytes);
}
if reader.has_more_bytes() {
return Err(DeserializationError::UnconsumedBytes);
}

let mut current_row = Vec::with_capacity(main_trace_width);
let mut next_row = Vec::with_capacity(main_trace_width);

for col in trace.chunks_exact(2) {
current_row.push(col[0]);
next_row.push(col[1]);
}

(current_row, next_row)
};

// parse Lagrange kernel column trace
let mut reader = SliceReader::new(&self.lagrange_kernel_trace_states);
let lagrange_kernel_frame_size = reader.read_u8()? as usize;
let lagrange_kernel_trace = if lagrange_kernel_frame_size > 0 {
Some(reader.read_many(lagrange_kernel_frame_size)?)
let lagrange_kernel_frame = if lagrange_kernel_frame_size > 0 {
let lagrange_kernel_trace = reader.read_many(lagrange_kernel_frame_size)?;

Some(LagrangeKernelEvaluationFrame::new(lagrange_kernel_trace))
} else {
None
};
Expand All @@ -151,11 +164,10 @@ impl OodFrame {
return Err(DeserializationError::UnconsumedBytes);
}

Ok(ParsedOodFrame {
trace_evaluations: trace,
lagrange_kernel_trace_evaluations: lagrange_kernel_trace,
constraint_evaluations: evaluations,
})
Ok((
TraceOodFrame::new(current_row, next_row, main_trace_width, lagrange_kernel_frame),
evaluations,
))
}
}

Expand Down Expand Up @@ -213,28 +225,31 @@ impl Deserializable for OodFrame {
// OOD FRAME TRACE STATES
// ================================================================================================

/// Stores the trace evaluations at `z` and `gz`, where `z` is a random Field element. If
/// the Air contains a Lagrange kernel auxiliary column, then that column interpolated polynomial
/// will be evaluated at `z`, `gz`, `g^2 z`, ... `g^(2^(v-1)) z`, where `v == log(trace_len)`, and
/// stored in `lagrange_kernel_frame`.
pub struct OodFrameTraceStates<E: FieldElement> {
/// Stores the trace evaluations at `z` and `gz`, where `z` is a random Field element in
/// `current_row` and `next_row`, respectively. If the Air contains a Lagrange kernel auxiliary
/// column, then that column interpolated polynomial will be evaluated at `z`, `gz`, `g^2 z`, ...
/// `g^(2^(v-1)) z`, where `v == log(trace_len)`, and stored in `lagrange_kernel_frame`.
pub struct TraceOodFrame<E: FieldElement> {
current_row: Vec<E>,
next_row: Vec<E>,
main_trace_width: usize,
lagrange_kernel_frame: Option<LagrangeKernelEvaluationFrame<E>>,
}

impl<E: FieldElement> OodFrameTraceStates<E> {
/// Creates a new [`OodFrameTraceStates`] from current, next and optionally Lagrange kernel frames.
impl<E: FieldElement> TraceOodFrame<E> {
/// Creates a new [`TraceOodFrame`] from current, next and optionally Lagrange kernel frames.
pub fn new(
current_frame: Vec<E>,
next_frame: Vec<E>,
current_row: Vec<E>,
next_row: Vec<E>,
main_trace_width: usize,
lagrange_kernel_frame: Option<LagrangeKernelEvaluationFrame<E>>,
) -> Self {
assert_eq!(current_frame.len(), next_frame.len());
assert_eq!(current_row.len(), next_row.len());

Self {
current_row: current_frame,
next_row: next_frame,
current_row,
next_row,
main_trace_width,
lagrange_kernel_frame,
}
}
Expand All @@ -244,18 +259,69 @@ impl<E: FieldElement> OodFrameTraceStates<E> {
self.current_row.len()
}

/// Returns the current frame.
pub fn current_frame(&self) -> &[E] {
/// Returns the current row, consisting of both main and auxiliary columns.
pub fn current_row(&self) -> &[E] {
&self.current_row
}

/// Returns the next frame.
pub fn next_frame(&self) -> &[E] {
/// Returns the next frame, consisting of both main and auxiliary columns.
pub fn next_row(&self) -> &[E] {
&self.next_row
}

/// Returns the evaluation frame for the main trace
pub fn main_frame(&self) -> EvaluationFrame<E> {
let current = self.current_row[0..self.main_trace_width].to_vec();
let next = self.next_row[0..self.main_trace_width].to_vec();

EvaluationFrame::from_rows(current, next)
}

/// Returns the evaluation frame for the auxiliary trace
pub fn aux_frame(&self) -> Option<EvaluationFrame<E>> {
if self.has_aux_frame() {
let current = self.current_row[self.main_trace_width..].to_vec();
let next = self.next_row[self.main_trace_width..].to_vec();

Some(EvaluationFrame::from_rows(current, next))
} else {
None
}
}

/// Hashes the main, auxiliary and Lagrange kernel frame in a manner consistent with
/// [`OodFrame::set_trace_states`], with the purpose of reseeding the public coin.
pub fn hash<H: ElementHasher<BaseField = E::BaseField>>(&self) -> H::Digest {
let (mut trace_states, mut lagrange_trace_states) = self.to_trace_states();
trace_states.append(&mut lagrange_trace_states);

H::hash_elements(&trace_states)
}

/// Returns the Lagrange kernel frame, if any.
pub fn lagrange_kernel_frame(&self) -> Option<&LagrangeKernelEvaluationFrame<E>> {
self.lagrange_kernel_frame.as_ref()
}

/// Returns true if an auxiliary frame is present
fn has_aux_frame(&self) -> bool {
self.current_row.len() > self.main_trace_width
}

/// Returns the main/aux frame and Lagrange kernel frame as element vectors. Specifically, the
/// main and auxiliary frames are interleaved, as described in [`OodFrame::set_trace_states`].
fn to_trace_states(&self) -> (Vec<E>, Vec<E>) {
let mut main_and_aux_frame_states = Vec::new();
for col in 0..self.current_row.len() {
main_and_aux_frame_states.push(self.current_row[col]);
main_and_aux_frame_states.push(self.next_row[col]);
}

let lagrange_frame_states = match self.lagrange_kernel_frame {
Some(ref lagrange_kernel_frame) => lagrange_kernel_frame.inner().to_vec(),
None => Vec::new(),
};

(main_and_aux_frame_states, lagrange_frame_states)
}
}
8 changes: 4 additions & 4 deletions prover/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// LICENSE file in the root directory of this source tree.

use air::{
proof::{Commitments, Context, OodFrame, OodFrameTraceStates, Queries, StarkProof},
proof::{Commitments, Context, OodFrame, Queries, StarkProof, TraceOodFrame},
Air, ConstraintCompositionCoefficients, DeepCompositionCoefficients,
};
use alloc::vec::Vec;
Expand Down Expand Up @@ -85,9 +85,9 @@ where

/// Saves the evaluations of trace polynomials over the out-of-domain evaluation frame. This
/// also reseeds the public coin with the hashes of the evaluation frame states.
pub fn send_ood_trace_states(&mut self, trace_states: &OodFrameTraceStates<E>) {
let result = self.ood_frame.set_trace_states(trace_states);
self.public_coin.reseed(H::hash_elements(&result));
pub fn send_ood_trace_states(&mut self, trace_ood_frame: &TraceOodFrame<E>) {
let trace_states_hash = self.ood_frame.set_trace_states::<E, H>(trace_ood_frame);
self.public_coin.reseed(trace_states_hash);
}

/// Saves the evaluations of constraint composition polynomial columns at the out-of-domain
Expand Down
12 changes: 6 additions & 6 deletions prover/src/composer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// LICENSE file in the root directory of this source tree.

use super::{constraints::CompositionPoly, StarkDomain, TracePolyTable};
use air::{proof::OodFrameTraceStates, DeepCompositionCoefficients};
use air::{proof::TraceOodFrame, DeepCompositionCoefficients};
use alloc::vec::Vec;
use math::{add_in_place, fft, mul_acc, polynom, ExtensionOf, FieldElement, StarkField};
use utils::iter_mut;
Expand Down Expand Up @@ -64,7 +64,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
pub fn add_trace_polys(
&mut self,
trace_polys: TracePolyTable<E>,
ood_trace_states: OodFrameTraceStates<E>,
ood_trace_states: TraceOodFrame<E>,
) {
assert!(self.coefficients.is_empty());

Expand All @@ -89,7 +89,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
acc_trace_poly::<E::BaseField, E>(
&mut t1_composition,
poly,
ood_trace_states.current_frame()[i],
ood_trace_states.current_row()[i],
self.cc.trace[i],
);

Expand All @@ -98,7 +98,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
acc_trace_poly::<E::BaseField, E>(
&mut t2_composition,
poly,
ood_trace_states.next_frame()[i],
ood_trace_states.next_row()[i],
self.cc.trace[i],
);

Expand All @@ -112,7 +112,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
acc_trace_poly::<E, E>(
&mut t1_composition,
poly,
ood_trace_states.current_frame()[i],
ood_trace_states.current_row()[i],
self.cc.trace[i],
);

Expand All @@ -121,7 +121,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
acc_trace_poly::<E, E>(
&mut t2_composition,
poly,
ood_trace_states.next_frame()[i],
ood_trace_states.next_row()[i],
self.cc.trace[i],
);

Expand Down
Loading
Loading