From bb9fb741789d177a5023aa3f65a832da94cf1de2 Mon Sep 17 00:00:00 2001 From: Kara Date: Mon, 13 Jan 2025 11:54:03 +0800 Subject: [PATCH] opt: convert trace data to table (#215) * chore: refactor into_tables() * chore: remove test * chore: refactor join() * chore: remove useless comments * chore: modify according to the comments * chore: modify according to the comments * chore: modify according to the comments --- .gitignore | 1 + prover/src/util.rs | 37 ++++++++++++++++++++++ prover/src/witness/traces.rs | 59 ++++++++++++------------------------ 3 files changed, 57 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index a8710ed8..0b3e4c06 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /output /examples/target **/.DS_Store +mips-zkm-zkvm-elf diff --git a/prover/src/util.rs b/prover/src/util.rs index 97b629c1..1cd7c5dd 100644 --- a/prover/src/util.rs +++ b/prover/src/util.rs @@ -8,6 +8,8 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::util::transpose; +#[allow(unused_imports)] +use plonky2_maybe_rayon::rayon; /// Construct an integer from its constituent bits (in little-endian order) pub fn limb_from_bits_le(iter: impl IntoIterator) -> P { @@ -68,3 +70,38 @@ pub fn u32_array_to_u8_vec(u32_array: &[u32; 8]) -> Vec { } u8_vec } + +macro_rules! join { + ($($($a:expr),+$(,)?)?) => { + crate::util::__join!{0;;$($($a,)+)?} + }; +} + +macro_rules! __join { + ($len:expr; $($f:ident $r:ident $a:expr),*; $b:expr, $($c:expr,)*) => { + crate::util::__join!{$len + 1; $($f $r $a,)* f r $b; $($c,)* } + }; + ($len:expr; $($f:ident $r:ident $a:expr),* ;) => { + match ($(Some(crate::util::__sendable_closure($a)),)*) { + ($(mut $f,)*) => { + $(let mut $r = None;)* + let array: [&mut (dyn FnMut() + Send); $len] = [ + $(&mut || $r = Some((&mut $f).take().unwrap()())),* + ]; + rayon::iter::ParallelIterator::for_each( + rayon::iter::IntoParallelIterator::into_par_iter(array), + |f| f(), + ); + ($($r.unwrap(),)*) + } + } + }; +} + +#[doc(hidden)] +pub(crate) fn __sendable_closure R + Send>(x: F) -> F { + x +} + +pub(crate) use __join; +pub(crate) use join; diff --git a/prover/src/witness/traces.rs b/prover/src/witness/traces.rs index 4939566d..5cf0c8c1 100644 --- a/prover/src/witness/traces.rs +++ b/prover/src/witness/traces.rs @@ -19,6 +19,7 @@ use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; use crate::poseidon::constants::SPONGE_WIDTH; use crate::poseidon_sponge::columns::POSEIDON_RATE_BYTES; use crate::poseidon_sponge::poseidon_sponge_stark::PoseidonSpongeOp; +use crate::util::join; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryOp; use crate::{arithmetic, logic}; @@ -206,46 +207,24 @@ impl Traces { timed!( timing, "convert trace to table parallelly", - rayon::join( - || rayon::join( - || memory_trace = all_stark.memory_stark.generate_trace(&mut memory_ops,), - || arithmetic_trace = - all_stark.arithmetic_stark.generate_trace(&arithmetic_ops), - ), - || { - rayon::join( - || { - cpu_trace = trace_rows_to_poly_values( - cpu.into_iter().map(|x| x.into()).collect(), - ) - }, - || { - poseidon_trace = all_stark - .poseidon_stark - .generate_trace(&poseidon_inputs, min_rows) - }, - ); - rayon::join( - || { - poseidon_sponge_trace = all_stark - .poseidon_sponge_stark - .generate_trace(&poseidon_sponge_ops, min_rows) - }, - || { - keccak_trace = all_stark - .keccak_stark - .generate_trace(keccak_inputs, min_rows) - }, - ); - rayon::join( - || { - keccak_sponge_trace = all_stark - .keccak_sponge_stark - .generate_trace(keccak_sponge_ops, min_rows) - }, - || logic_trace = all_stark.logic_stark.generate_trace(logic_ops, min_rows), - ); - }, + join!( + || memory_trace = all_stark.memory_stark.generate_trace(&mut memory_ops), + || arithmetic_trace = all_stark.arithmetic_stark.generate_trace(&arithmetic_ops), + || cpu_trace = + trace_rows_to_poly_values(cpu.into_iter().map(|x| x.into()).collect()), + || poseidon_trace = all_stark + .poseidon_stark + .generate_trace(&poseidon_inputs, min_rows), + || poseidon_sponge_trace = all_stark + .poseidon_sponge_stark + .generate_trace(&poseidon_sponge_ops, min_rows), + || keccak_trace = all_stark + .keccak_stark + .generate_trace(keccak_inputs, min_rows), + || keccak_sponge_trace = all_stark + .keccak_sponge_stark + .generate_trace(keccak_sponge_ops, min_rows), + || logic_trace = all_stark.logic_stark.generate_trace(logic_ops, min_rows), ) );