Skip to content

Commit

Permalink
Remove PreProcessedColumn Enum
Browse files Browse the repository at this point in the history
  • Loading branch information
Gali-StarkWare committed Jan 13, 2025
1 parent eacecc0 commit 9b3cf14
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 187 deletions.
63 changes: 24 additions & 39 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::iter::zip;
use std::ops::Deref;
use std::rc::Rc;

use itertools::Itertools;
#[cfg(feature = "parallel")]
Expand All @@ -12,7 +10,6 @@ use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::logup::LogupSums;
use super::preprocessed_columns::PreprocessedColumn;
use super::{
EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX,
};
Expand Down Expand Up @@ -50,7 +47,7 @@ pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
/// Mapping of preprocessed columns to their index.
preprocessed_columns: HashMap<Rc<dyn PreprocessedColumn>, usize>,
preprocessed_columns: Vec<String>,
/// Controls whether the preprocessed columns are dynamic or static (default=Dynamic).
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode,
}
Expand Down Expand Up @@ -82,38 +79,23 @@ impl TraceLocationAllocator {
}

/// Create a new `TraceLocationAllocator` with fixed preprocessed columns setup.
pub fn new_with_preproccessed_columns(
preprocessed_columns: &[Rc<dyn PreprocessedColumn>],
) -> Self {
pub fn new_with_preproccessed_columns(preprocessed_columns: &[String]) -> Self {
Self {
next_tree_offsets: Default::default(),
preprocessed_columns: preprocessed_columns
.iter()
.enumerate()
.map(|(i, col)| (col.clone(), i))
.collect(),
preprocessed_columns: preprocessed_columns.to_vec(),
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static,
}
}

pub const fn preprocessed_columns(&self) -> &HashMap<Rc<dyn PreprocessedColumn>, usize> {
pub const fn preprocessed_columns(&self) -> &Vec<String> {
&self.preprocessed_columns
}

// validates that `self.preprocessed_columns` is consistent with
// `preprocessed_columns`.
// I.e. preprocessed_columns[i] == self.preprocessed_columns[i].
pub fn validate_preprocessed_columns(
&self,
preprocessed_columns: &[Rc<dyn PreprocessedColumn>],
) {
assert_eq!(preprocessed_columns.len(), self.preprocessed_columns.len());
for (column, idx) in self.preprocessed_columns.iter() {
let preprocessed_column = preprocessed_columns
.get(*idx)
.expect("Preprocessed column is missing from preprocessed_columns");
assert_eq!(column.id(), preprocessed_column.id());
}
pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[String]) {
assert_eq!(self.preprocessed_columns, preprocessed_columns);
}
}

Expand Down Expand Up @@ -152,22 +134,25 @@ impl<E: FrameworkEval> FrameworkComponent<E> {
.iter()
.map(|col| {
let next_column = location_allocator.preprocessed_columns.len();
*location_allocator
if let Some(pos) = location_allocator
.preprocessed_columns
.entry(col.clone())
.or_insert_with(|| {
if matches!(
location_allocator.preprocessed_columns_allocation_mode,
PreprocessedColumnsAllocationMode::Static
) {
panic!(
"Preprocessed column {:?} is missing from static alloction",
col
);
}

next_column
})
.iter()
.position(|x| x == col)
{
pos
} else {
if matches!(
location_allocator.preprocessed_columns_allocation_mode,
PreprocessedColumnsAllocationMode::Static
) {
panic!(
"Preprocessed column {:?} is missing from static allocation",
col
);
}
location_allocator.preprocessed_columns.push(col.clone());
next_column
}
})
.collect();
Self {
Expand Down
9 changes: 3 additions & 6 deletions crates/prover/src/constraint_framework/expr/evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::rc::Rc;

use num_traits::Zero;

use super::{BaseExpr, ExtExpr};
use crate::constraint_framework::expr::ColumnExpr;
use crate::constraint_framework::preprocessed_columns::PreprocessedColumn;
use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31;
use crate::core::lookups::utils::Fraction;
Expand Down Expand Up @@ -176,8 +173,8 @@ impl EvalAtRow for ExprEvaluator {
intermediate
}

fn get_preprocessed_column(&mut self, column: Rc<dyn PreprocessedColumn>) -> Self::F {
BaseExpr::Param(column.name().to_string())
fn get_preprocessed_column(&mut self, column: String) -> Self::F {
BaseExpr::Param(column)
}

crate::constraint_framework::logup_proxy!();
Expand Down Expand Up @@ -210,7 +207,7 @@ mod tests {
\
let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0, trace_2_column_6_offset_0]) \
- (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1]) \
- ((total_sum) * (preprocessed_is_first)))) \
- ((total_sum) * (preprocessed_is_first_16)))) \
* (intermediate1) \
- (qm31(1, 0, 0, 0));"
.to_string();
Expand Down
11 changes: 3 additions & 8 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::rc::Rc;
use num_traits::{One, Zero};

use super::logup::{LogupAtRow, LogupSums};
use super::preprocessed_columns::PreprocessedColumn;
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
Expand All @@ -22,16 +21,12 @@ use crate::core::pcs::TreeVec;
pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
pub preprocessed_columns: Vec<Rc<dyn PreprocessedColumn>>,
pub preprocessed_columns: Vec<String>,
pub logup: LogupAtRow<Self>,
pub arithmetic_counts: ArithmeticCounts,
}
impl InfoEvaluator {
pub fn new(
log_size: u32,
preprocessed_columns: Vec<Rc<dyn PreprocessedColumn>>,
logup_sums: LogupSums,
) -> Self {
pub fn new(log_size: u32, preprocessed_columns: Vec<String>, logup_sums: LogupSums) -> Self {
Self {
mask_offsets: Default::default(),
n_constraints: Default::default(),
Expand Down Expand Up @@ -70,7 +65,7 @@ impl EvalAtRow for InfoEvaluator {
array::from_fn(|_| FieldCounter::one())
}

fn get_preprocessed_column(&mut self, column: Rc<dyn PreprocessedColumn>) -> Self::F {
fn get_preprocessed_column(&mut self, column: String) -> Self::F {
self.preprocessed_columns.push(column);
FieldCounter::one()
}
Expand Down
11 changes: 5 additions & 6 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ mod simd_domain;
use std::array;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Neg, Sub};
use std::rc::Rc;

pub use assert::{assert_constraints, AssertEvaluator};
pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
pub use info::InfoEvaluator;
use num_traits::{One, Zero};
pub use point::PointEvaluator;
use preprocessed_columns::PreprocessedColumn;
pub use simd_domain::SimdDomainEvaluator;

use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -88,7 +86,7 @@ pub trait EvalAtRow {
mask_item
}

fn get_preprocessed_column(&mut self, _column: Rc<dyn PreprocessedColumn>) -> Self::F {
fn get_preprocessed_column(&mut self, _column: String) -> Self::F {
let [mask_item] = self.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
mask_item
}
Expand Down Expand Up @@ -173,11 +171,12 @@ macro_rules! logup_proxy {
() => {
fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
if self.logup.fracs.is_empty() {
self.logup.is_first = self.get_preprocessed_column(std::rc::Rc::new(
self.logup.is_first = self.get_preprocessed_column(
crate::constraint_framework::preprocessed_columns::IsFirst::new(
self.logup.log_size,
),
));
)
.id(),
);
self.logup.is_finalized = false;
}
self.logup.fracs.push(fraction.clone());
Expand Down
34 changes: 3 additions & 31 deletions crates/prover/src/constraint_framework/preprocessed_columns.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::fmt::Debug;
use std::hash::Hash;
// use std::hash::Hash;
use std::simd::Simd;

use num_traits::{One, Zero};
Expand All @@ -11,25 +11,6 @@ use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;

pub trait PreprocessedColumn: Debug {
fn name(&self) -> &'static str;
/// Used for comparing preprocessed columns.
/// Column IDs must be unique in a given context.
fn id(&self) -> String;
fn log_size(&self) -> u32;
}
impl PartialEq for dyn PreprocessedColumn {
fn eq(&self, other: &Self) -> bool {
self.id() == other.id()
}
}
impl Eq for dyn PreprocessedColumn {}
impl Hash for dyn PreprocessedColumn {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.id().hash(state);
}
}

/// A column with `1` at the first position, and `0` elsewhere.
#[derive(Debug, Clone)]
pub struct IsFirst {
Expand Down Expand Up @@ -62,18 +43,9 @@ impl IsFirst {
col.set(0, BaseField::one());
CircleEvaluation::new(CanonicCoset::new(self.log_size).circle_domain(), col)
}
}
impl PreprocessedColumn for IsFirst {
fn name(&self) -> &'static str {
"preprocessed_is_first"
}

fn id(&self) -> String {
format!("IsFirst(log_size: {})", self.log_size)
}

fn log_size(&self) -> u32 {
self.log_size
pub fn id(&self) -> String {
format!("preprocessed_is_first_{}", self.log_size).to_string()
}
}

Expand Down
82 changes: 53 additions & 29 deletions crates/prover/src/examples/blake/air.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::rc::Rc;
use std::simd::u32x16;

use itertools::{chain, multiunzip, Itertools};
Expand All @@ -10,7 +9,7 @@ use super::preprocessed_columns::XorTable;
use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval};
use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval};
use super::xor_table::{xor12, xor4, xor7, xor8, xor9};
use crate::constraint_framework::preprocessed_columns::{IsFirst, PreprocessedColumn};
use crate::constraint_framework::preprocessed_columns::IsFirst;
use crate::constraint_framework::{TraceLocationAllocator, PREPROCESSED_TRACE_IDX};
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::m31::LOG_N_LANES;
Expand All @@ -28,28 +27,53 @@ use crate::examples::blake::{
round, xor_table, BlakeXorElements, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT,
};

fn preprocessed_xor_columns() -> [Rc<dyn PreprocessedColumn>; 20] {
fn preprocessed_xor_columns() -> [String; 20] {
[
Rc::new(XorTable::new(12, 4, 0)),
Rc::new(XorTable::new(12, 4, 1)),
Rc::new(XorTable::new(12, 4, 2)),
Rc::new(IsFirst::new(xor12::column_bits::<12, 4>())),
Rc::new(XorTable::new(9, 2, 0)),
Rc::new(XorTable::new(9, 2, 1)),
Rc::new(XorTable::new(9, 2, 2)),
Rc::new(IsFirst::new(xor9::column_bits::<9, 2>())),
Rc::new(XorTable::new(8, 2, 0)),
Rc::new(XorTable::new(8, 2, 1)),
Rc::new(XorTable::new(8, 2, 2)),
Rc::new(IsFirst::new(xor8::column_bits::<8, 2>())),
Rc::new(XorTable::new(7, 2, 0)),
Rc::new(XorTable::new(7, 2, 1)),
Rc::new(XorTable::new(7, 2, 2)),
Rc::new(IsFirst::new(xor7::column_bits::<7, 2>())),
Rc::new(XorTable::new(4, 0, 0)),
Rc::new(XorTable::new(4, 0, 1)),
Rc::new(XorTable::new(4, 0, 2)),
Rc::new(IsFirst::new(xor4::column_bits::<4, 0>())),
XorTable::new(12, 4, 0).id(),
XorTable::new(12, 4, 1).id(),
XorTable::new(12, 4, 2).id(),
IsFirst::new(xor12::column_bits::<12, 4>()).id(),
XorTable::new(9, 2, 0).id(),
XorTable::new(9, 2, 1).id(),
XorTable::new(9, 2, 2).id(),
IsFirst::new(xor9::column_bits::<9, 2>()).id(),
XorTable::new(8, 2, 0).id(),
XorTable::new(8, 2, 1).id(),
XorTable::new(8, 2, 2).id(),
IsFirst::new(xor8::column_bits::<8, 2>()).id(),
XorTable::new(7, 2, 0).id(),
XorTable::new(7, 2, 1).id(),
XorTable::new(7, 2, 2).id(),
IsFirst::new(xor7::column_bits::<7, 2>()).id(),
XorTable::new(4, 0, 0).id(),
XorTable::new(4, 0, 1).id(),
XorTable::new(4, 0, 2).id(),
IsFirst::new(xor4::column_bits::<4, 0>()).id(),
]
}

const fn preprocessed_xor_columns_log_sizes() -> [u32; 20] {
[
2 * (12 - 4),
2 * (12 - 4),
2 * (12 - 4),
xor12::column_bits::<12, 4>(),
2 * (9 - 2),
2 * (9 - 2),
2 * (9 - 2),
xor9::column_bits::<9, 2>(),
2 * (8 - 2),
2 * (8 - 2),
2 * (8 - 2),
xor8::column_bits::<8, 2>(),
2 * (7 - 2),
2 * (7 - 2),
2 * (7 - 2),
xor7::column_bits::<7, 2>(),
2 * 4,
2 * 4,
2 * 4,
xor4::column_bits::<4, 0>(),
]
}

Expand Down Expand Up @@ -90,7 +114,7 @@ impl BlakeStatement0 {
log_sizes[PREPROCESSED_TRACE_IDX] = chain!(
[scheduler_is_first_column_log_size],
blake_round_is_first_column_log_sizes,
preprocessed_xor_columns().map(|column| column.log_size()),
preprocessed_xor_columns_log_sizes(),
)
.collect_vec();

Expand Down Expand Up @@ -163,11 +187,11 @@ impl BlakeComponents {
fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self {
let log_size = stmt0.log_size;

let scheduler_is_first_column: Rc<dyn PreprocessedColumn> = Rc::new(IsFirst::new(log_size));
let blake_round_is_first_columns_iter = ROUND_LOG_SPLIT.iter().map(|l| {
let column: Rc<dyn PreprocessedColumn> = Rc::new(IsFirst::new(log_size + l));
column
});
let scheduler_is_first_column = IsFirst::new(log_size).id();
let blake_round_is_first_columns_iter: Vec<String> = ROUND_LOG_SPLIT
.iter()
.map(|l| IsFirst::new(log_size + l).id())
.collect_vec();

let tree_span_provider = &mut TraceLocationAllocator::new_with_preproccessed_columns(
&chain!(
Expand Down
Loading

0 comments on commit 9b3cf14

Please sign in to comment.