diff --git a/emulator/Cargo.toml b/emulator/Cargo.toml index 89215a19..6a536976 100644 --- a/emulator/Cargo.toml +++ b/emulator/Cargo.toml @@ -14,6 +14,7 @@ lazy_static = "1.4.0" elf = { version = "0.7", default-features = false } log = { version = "0.4.14", default-features = false } itertools = "0.13.0" +keccak-hash = "0.11.0" [features] test = [] diff --git a/emulator/src/state.rs b/emulator/src/state.rs index 6ab3d7fe..b8f30de1 100644 --- a/emulator/src/state.rs +++ b/emulator/src/state.rs @@ -25,6 +25,7 @@ pub const PAGE_CYCLES: u64 = PAGE_LOAD_CYCLES + PAGE_HASH_CYCLES; pub const IMAGE_ID_CYCLES: u64 = 3; pub const MAX_INSTRUCTION_CYCLES: u64 = PAGE_CYCLES * 6; //TOFIX pub const RESERVE_CYCLES: u64 = IMAGE_ID_CYCLES + MAX_INSTRUCTION_CYCLES; +use keccak_hash::keccak; // image_id = keccak(page_hash_root || end_pc) #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Default)] @@ -529,6 +530,23 @@ impl InstrumentedState { log::debug!("syscall {} {} {} {}", syscall_num, a0, a1, a2); match syscall_num { + 0x010109 => { + assert!((a0 & 3) == 0); + assert!((a2 & 3) == 0); + let bytes = (0..a1) + .map(|i| self.state.memory.byte(a0 + i)) + .collect::>(); + log::debug!("keccak {:X?}", bytes); + let result = keccak(&bytes).0; + log::debug!("result {:X?}", result); + let result: [u32; 8] = core::array::from_fn(|i| { + u32::from_be_bytes(core::array::from_fn(|j| result[i * 4 + j])) + }); + assert!(result.len() == 8); + for (i, data) in result.iter().enumerate() { + self.state.memory.set_memory(a2 + ((i << 2) as u32), *data); + } + } 0xF0 => { if self.state.input_stream_ptr >= self.state.input_stream.len() { panic!("not enough vecs in hint input stream"); @@ -654,7 +672,9 @@ impl InstrumentedState { v0 = a2; } FD_PUBLIC_VALUES => { + log::debug!("commit {:X?}", slice); self.state.public_values_stream.extend_from_slice(slice); + v0 = a2; } FD_HINT => { diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 1477262c..049cf2b5 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -35,6 +35,7 @@ byteorder = "1.5.0" hex = "0.4" hashbrown = { version = "0.14.0", default-features = false, features = ["ahash", "serde"] } # NOTE: When upgrading, see `ahash` dependency. lazy_static = "1.4.0" +keccak-hash = "0.10.0" elf = { version = "0.7", default-features = false } sha2 = { version = "0.10.8", default-features = false } @@ -45,7 +46,6 @@ keccak-hash = "0.10.0" plonky2x = { git = "https://github.com/zkMIPS/succinctx.git", package = "plonky2x", branch = "zkm" } plonky2x-derive = { git = "https://github.com/zkMIPS/succinctx.git", package = "plonky2x-derive", branch = "zkm" } - [features] test = [] diff --git a/prover/examples/Cargo.toml b/prover/examples/Cargo.toml index a663b35b..f9c50110 100644 --- a/prover/examples/Cargo.toml +++ b/prover/examples/Cargo.toml @@ -4,6 +4,7 @@ members = [ "sha2-rust/host", "sha2-precompile/host", "sha2-go/host", + "keccak/host", "split-seg", "prove-seg" ] diff --git a/prover/examples/keccak/guest/Cargo.toml b/prover/examples/keccak/guest/Cargo.toml new file mode 100644 index 00000000..f8e66bae --- /dev/null +++ b/prover/examples/keccak/guest/Cargo.toml @@ -0,0 +1,9 @@ +[workspace] +[package] +version = "0.1.0" +name = "keccak-precompile" +edition = "2021" + +[dependencies] +#zkm-runtime = { git = "https://github.com/zkMIPS/zkm", package = "zkm-runtime" } +zkm-runtime = { path = "../../../../runtime/entrypoint" } diff --git a/prover/examples/keccak/guest/src/main.rs b/prover/examples/keccak/guest/src/main.rs new file mode 100644 index 00000000..90be5bd0 --- /dev/null +++ b/prover/examples/keccak/guest/src/main.rs @@ -0,0 +1,16 @@ +#![no_std] +#![no_main] + +extern crate alloc; +use alloc::vec::Vec; + +zkm_runtime::entrypoint!(main); + +pub fn main() { + let public_input: Vec = zkm_runtime::io::read(); + let input: Vec = zkm_runtime::io::read(); + + let output = zkm_runtime::io::keccak(&input.as_slice()); + assert_eq!(output.to_vec(), public_input); + zkm_runtime::io::commit::<[u8; 32]>(&output); +} diff --git a/prover/examples/keccak/host/Cargo.toml b/prover/examples/keccak/host/Cargo.toml new file mode 100644 index 00000000..f48d0085 --- /dev/null +++ b/prover/examples/keccak/host/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "keccak-host" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[dependencies] +zkm-prover = { workspace = true } +zkm-emulator = { workspace = true } +zkm-utils = { path = "../../utils" } + + +log = { version = "0.4.14", default-features = false } +serde = { version = "1.0.144", features = ["derive"] } +serde_json = "1.0" +byteorder = "1.5.0" +hex = "0.4" +env_logger = "0.11.5" +anyhow = "1.0.75" + +[build-dependencies] +zkm-build = { workspace = true } +[features] +test = ["zkm-prover/test"] diff --git a/prover/examples/keccak/host/build.rs b/prover/examples/keccak/host/build.rs new file mode 100644 index 00000000..4b9fa8ed --- /dev/null +++ b/prover/examples/keccak/host/build.rs @@ -0,0 +1,3 @@ +fn main() { + zkm_build::build_program(&format!("{}/../guest", env!("CARGO_MANIFEST_DIR"))); +} diff --git a/prover/examples/keccak/host/src/main.rs b/prover/examples/keccak/host/src/main.rs new file mode 100644 index 00000000..9b55241a --- /dev/null +++ b/prover/examples/keccak/host/src/main.rs @@ -0,0 +1,46 @@ +use std::env; + +use zkm_emulator::utils::{load_elf_with_patch, split_prog_into_segs}; +use zkm_utils::utils::prove_segments; + +const ELF_PATH: &str = "../guest/elf/mips-zkm-zkvm-elf"; + +fn prove_keccak_rust() { + // 1. split ELF into segs + let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); + let seg_size = env::var("SEG_SIZE").unwrap_or("65536".to_string()); + let seg_size = seg_size.parse::<_>().unwrap_or(0); + + let mut state = load_elf_with_patch(ELF_PATH, vec![]); + // load input + let args = env::var("ARGS").unwrap_or("data-to-hash".to_string()); + // assume the first arg is the hash output(which is a public input), and the second is the input. + let args: Vec<&str> = args.split_whitespace().collect(); + assert!(args.len() >= 1); + + let public_input: Vec = hex::decode(args[0]).unwrap(); + state.add_input_stream(&public_input); + log::info!("expected public value in hex: {:X?}", args[0]); + log::info!("expected public value: {:X?}", public_input); + + let private_input: Vec = if args.len() > 1 { + hex::decode(args[1]).unwrap() + } else { + vec![] + }; + log::info!("private input value: {:X?}", private_input); + state.add_input_stream(&private_input); + + let (_total_steps, seg_num, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size); + + let value = state.read_public_values::<[u8; 32]>(); + log::info!("public value: {:X?}", value); + log::info!("public value: {} in hex", hex::encode(value)); + + let _ = prove_segments(&seg_path, "", "", "", seg_num, 0, vec![]).unwrap(); +} + +fn main() { + env_logger::try_init().unwrap_or_default(); + prove_keccak_rust(); +} diff --git a/prover/examples/utils/src/utils.rs b/prover/examples/utils/src/utils.rs index 2d03db60..af5dc87f 100644 --- a/prover/examples/utils/src/utils.rs +++ b/prover/examples/utils/src/utils.rs @@ -16,7 +16,7 @@ use zkm_prover::cpu::kernel::assembler::segment_kernel; use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; use zkm_prover::generation::state::{AssumptionReceipts, Receipt}; -const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; +const DEGREE_BITS_RANGE: [Range; 8] = [10..21, 12..22, 11..21, 8..21, 6..21, 6..21, 6..21, 13..23]; const D: usize = 2; type C = PoseidonGoldilocksConfig; diff --git a/prover/src/all_stark.rs b/prover/src/all_stark.rs index 99b7d655..14191577 100644 --- a/prover/src/all_stark.rs +++ b/prover/src/all_stark.rs @@ -9,6 +9,11 @@ use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; +use crate::keccak::keccak_stark; +use crate::keccak::keccak_stark::KeccakStark; +use crate::keccak_sponge::columns::KECCAK_RATE_BYTES; +use crate::keccak_sponge::keccak_sponge_stark; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; use crate::logic; use crate::logic::LogicStark; use crate::memory::memory_stark; @@ -26,6 +31,8 @@ pub struct AllStark, const D: usize> { pub cpu_stark: CpuStark, pub poseidon_stark: PoseidonStark, pub poseidon_sponge_stark: PoseidonSpongeStark, + pub keccak_stark: KeccakStark, + pub keccak_sponge_stark: KeccakSpongeStark, pub logic_stark: LogicStark, pub memory_stark: MemoryStark, pub cross_table_lookups: Vec>, @@ -38,6 +45,8 @@ impl, const D: usize> Default for AllStark { cpu_stark: CpuStark::default(), poseidon_stark: PoseidonStark::default(), poseidon_sponge_stark: PoseidonSpongeStark::default(), + keccak_stark: KeccakStark::default(), + keccak_sponge_stark: KeccakSpongeStark::default(), logic_stark: LogicStark::default(), memory_stark: MemoryStark::default(), cross_table_lookups: all_cross_table_lookups(), @@ -52,6 +61,8 @@ impl, const D: usize> AllStark { self.cpu_stark.num_lookup_helper_columns(config), self.poseidon_stark.num_lookup_helper_columns(config), self.poseidon_sponge_stark.num_lookup_helper_columns(config), + self.keccak_stark.num_lookup_helper_columns(config), + self.keccak_sponge_stark.num_lookup_helper_columns(config), self.logic_stark.num_lookup_helper_columns(config), self.memory_stark.num_lookup_helper_columns(config), ] @@ -64,8 +75,10 @@ pub enum Table { Cpu = 1, Poseidon = 2, PoseidonSponge = 3, - Logic = 4, - Memory = 5, + Keccak = 4, + KeccakSponge = 5, + Logic = 6, + Memory = 7, } pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; @@ -80,6 +93,8 @@ impl Table { Self::Cpu, Self::Poseidon, Self::PoseidonSponge, + Self::Keccak, + Self::KeccakSponge, Self::Logic, Self::Memory, ] @@ -92,6 +107,9 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { ctl_poseidon_sponge(), ctl_poseidon_inputs(), ctl_poseidon_outputs(), + ctl_keccak_sponge(), + ctl_keccak_inputs(), + ctl_keccak_outputs(), ctl_logic(), ctl_memory(), ] @@ -152,16 +170,72 @@ fn ctl_poseidon_sponge() -> CrossTableLookup { CrossTableLookup::new(vec![cpu_looking], poseidon_sponge_looked) } +// We now need two different looked tables for `KeccakStark`: +// one for the inputs and one for the outputs. +// They are linked with the timestamp. +fn ctl_keccak_inputs() -> CrossTableLookup { + let keccak_sponge_looking = TableWithColumns::new( + Table::KeccakSponge, + keccak_sponge_stark::ctl_looking_keccak_inputs(), + Some(keccak_sponge_stark::ctl_looking_keccak_filter()), + ); + let keccak_looked = TableWithColumns::new( + Table::Keccak, + keccak_stark::ctl_data_inputs(), + Some(keccak_stark::ctl_filter_inputs()), + ); + CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked) +} + +fn ctl_keccak_outputs() -> CrossTableLookup { + let keccak_sponge_looking = TableWithColumns::new( + Table::KeccakSponge, + keccak_sponge_stark::ctl_looking_keccak_outputs(), + Some(keccak_sponge_stark::ctl_looking_keccak_filter()), + ); + let keccak_looked = TableWithColumns::new( + Table::Keccak, + keccak_stark::ctl_data_outputs(), + Some(keccak_stark::ctl_filter_outputs()), + ); + CrossTableLookup::new(vec![keccak_sponge_looking], keccak_looked) +} + +fn ctl_keccak_sponge() -> CrossTableLookup { + let cpu_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_keccak_sponge(), + Some(cpu_stark::ctl_filter_keccak_sponge()), + ); + let keccak_sponge_looked = TableWithColumns::new( + Table::KeccakSponge, + keccak_sponge_stark::ctl_looked_data(), + Some(keccak_sponge_stark::ctl_looked_filter()), + ); + CrossTableLookup::new(vec![cpu_looking], keccak_sponge_looked) +} + pub(crate) fn ctl_logic() -> CrossTableLookup { let cpu_looking = TableWithColumns::new( Table::Cpu, cpu_stark::ctl_data_logic(), Some(cpu_stark::ctl_filter_logic()), ); + + let mut all_lookers = vec![cpu_looking]; + for i in 0..keccak_sponge_stark::num_logic_ctls() { + let keccak_sponge_looking = TableWithColumns::new( + Table::KeccakSponge, + keccak_sponge_stark::ctl_looking_logic(i), + Some(keccak_sponge_stark::ctl_looking_logic_filter()), + ); + all_lookers.push(keccak_sponge_looking); + } + let logic_looked = TableWithColumns::new(Table::Logic, logic::ctl_data(), Some(logic::ctl_filter())); - CrossTableLookup::new(vec![cpu_looking], logic_looked) + CrossTableLookup::new(all_lookers, logic_looked) } fn ctl_memory() -> CrossTableLookup { @@ -179,10 +253,19 @@ fn ctl_memory() -> CrossTableLookup { Some(poseidon_sponge_stark::ctl_looking_memory_filter(i)), ) }); + + let keccak_sponge_reads = (0..KECCAK_RATE_BYTES).map(|i| { + TableWithColumns::new( + Table::KeccakSponge, + keccak_sponge_stark::ctl_looking_memory(i), + Some(keccak_sponge_stark::ctl_looking_memory_filter(i)), + ) + }); let all_lookers = [] .into_iter() .chain(cpu_memory_gp_ops) .chain(poseidon_sponge_reads) + .chain(keccak_sponge_reads) .collect(); let memory_looked = TableWithColumns::new( Table::Memory, diff --git a/prover/src/cpu/columns/general.rs b/prover/src/cpu/columns/general.rs index b61ddc76..90ffc2b5 100644 --- a/prover/src/cpu/columns/general.rs +++ b/prover/src/cpu/columns/general.rs @@ -11,6 +11,7 @@ pub(crate) union CpuGeneralColumnsView { shift: CpuShiftView, io: CpuIOAuxView, hash: CpuHashView, + khash: CpuKHashView, misc: CpuMiscView, } @@ -25,6 +26,16 @@ impl CpuGeneralColumnsView { unsafe { &mut self.hash } } + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn khash(&self) -> &CpuKHashView { + unsafe { &self.khash } + } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn khash_mut(&mut self) -> &mut CpuKHashView { + unsafe { &mut self.khash } + } + // SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn syscall(&self) -> &CpuSyscallView { unsafe { &self.syscall } @@ -152,5 +163,10 @@ pub(crate) struct CpuHashView { pub(crate) value: [T; 4], } +#[derive(Copy, Clone)] +pub(crate) struct CpuKHashView { + pub(crate) value: [T; 8], +} + // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_SHARED_COLUMNS: usize = size_of::>(); diff --git a/prover/src/cpu/columns/mod.rs b/prover/src/cpu/columns/mod.rs index 031efbff..f92fd311 100644 --- a/prover/src/cpu/columns/mod.rs +++ b/prover/src/cpu/columns/mod.rs @@ -102,6 +102,7 @@ pub struct CpuColumnsView { // inst_index: [rs_bits, rt_bits, rd_bits, shamt_bits, func_bits] /// Filter. 1 iff a Poseidon sponge lookup is performed on this row. pub is_poseidon_sponge: T, + pub is_keccak_sponge: T, pub(crate) general: CpuGeneralColumnsView, diff --git a/prover/src/cpu/cpu_stark.rs b/prover/src/cpu/cpu_stark.rs index d6c9585c..de688cb3 100644 --- a/prover/src/cpu/cpu_stark.rs +++ b/prover/src/cpu/cpu_stark.rs @@ -23,6 +23,30 @@ use crate::memory::segments::Segment; use crate::memory::{NUM_CHANNELS, VALUE_LIMBS}; use crate::stark::Stark; +pub fn ctl_data_keccak_sponge() -> Vec> { + // When executing KECCAK_GENERAL, the GP memory channels are used as follows: + // GP channel 0: stack[-1] = context + // GP channel 1: stack[-2] = segment + // GP channel 2: stack[-3] = virt + // GP channel 3: stack[-4] = len + // GP channel 4: pushed = outputs + let context = Column::single(COL_MAP.mem_channels[0].value); + let segment = Column::single(COL_MAP.mem_channels[1].value); + let virt = Column::single(COL_MAP.mem_channels[2].value); + let len = Column::single(COL_MAP.mem_channels[3].value); + + let num_channels = F::from_canonical_usize(NUM_CHANNELS); + let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); + + let mut cols = vec![context, segment, virt, len, timestamp]; + cols.extend(COL_MAP.general.khash().value.map(Column::single)); + cols +} + +pub fn ctl_filter_keccak_sponge() -> Filter { + Filter::new_simple(Column::single(COL_MAP.is_keccak_sponge)) +} + pub fn ctl_data_poseidon_sponge() -> Vec> { // When executing POSEIDON_GENERAL, the GP memory channels are used as follows: // GP channel 0: stack[-1] = context diff --git a/prover/src/fixed_recursive_verifier.rs b/prover/src/fixed_recursive_verifier.rs index 2269664d..0dd7aaf1 100644 --- a/prover/src/fixed_recursive_verifier.rs +++ b/prover/src/fixed_recursive_verifier.rs @@ -392,6 +392,20 @@ where &all_stark.cross_table_lookups, stark_config, ); + let keccak = RecursiveCircuitsForTable::new( + Table::Keccak, + &all_stark.keccak_stark, + degree_bits_ranges[Table::Keccak as usize].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let keccak_sponge = RecursiveCircuitsForTable::new( + Table::KeccakSponge, + &all_stark.keccak_sponge_stark, + degree_bits_ranges[Table::KeccakSponge as usize].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); let logic = RecursiveCircuitsForTable::new( Table::Logic, &all_stark.logic_stark, @@ -407,7 +421,16 @@ where stark_config, ); - let by_table = [arithmetic, cpu, poseidon, poseidon_sponge, logic, memory]; + let by_table = [ + arithmetic, + cpu, + poseidon, + poseidon_sponge, + keccak, + keccak_sponge, + logic, + memory, + ]; let root = Self::create_root_circuit(&by_table, stark_config); let aggregation = Self::create_aggregation_circuit(&root); let block = Self::create_block_circuit(&aggregation); diff --git a/prover/src/keccak/keccak_stark.rs b/prover/src/keccak/keccak_stark.rs index 8d7a26da..54b9d0e0 100644 --- a/prover/src/keccak/keccak_stark.rs +++ b/prover/src/keccak/keccak_stark.rs @@ -8,8 +8,6 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::plonk_common::reduce_with_powers_ext_circuit; -use plonky2::timed; -use plonky2::util::timing::TimingTree; use super::columns::reg_input_limb; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; @@ -231,20 +229,10 @@ impl, const D: usize> KeccakStark { &self, inputs: Vec<([u64; NUM_INPUTS], usize)>, min_rows: usize, - timing: &mut TimingTree, ) -> Vec> { // Generate the witness, except for permuted columns in the lookup argument. - let trace_rows = timed!( - timing, - "generate trace rows", - self.generate_trace_rows(inputs, min_rows) - ); - let trace_polys = timed!( - timing, - "convert to PolynomialValues", - trace_rows_to_poly_values(trace_rows) - ); - trace_polys + let trace_rows = self.generate_trace_rows(inputs, min_rows); + trace_rows_to_poly_values(trace_rows) } } @@ -713,11 +701,7 @@ mod tests { (0..NUM_PERMS).map(|_| (rand::random(), 0)).collect(); let mut timing = TimingTree::new("prove", log::Level::Debug); - let trace_poly_values = timed!( - timing, - "generate trace", - stark.generate_trace(input, 8, &mut timing) - ); + let trace_poly_values = stark.generate_trace(input, 8); // TODO: Cloning this isn't great; consider having `from_values` accept a reference, // or having `compute_permutation_z_polys` read trace values from the `PolynomialBatch`. diff --git a/prover/src/keccak_sponge/keccak_sponge_stark.rs b/prover/src/keccak_sponge/keccak_sponge_stark.rs index 07012e22..038e34af 100644 --- a/prover/src/keccak_sponge/keccak_sponge_stark.rs +++ b/prover/src/keccak_sponge/keccak_sponge_stark.rs @@ -11,8 +11,6 @@ use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::timed; -use plonky2::util::timing::TimingTree; use plonky2_util::ceil_div_usize; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; @@ -225,22 +223,11 @@ impl, const D: usize> KeccakSpongeStark { &self, operations: Vec, min_rows: usize, - timing: &mut TimingTree, ) -> Vec> { // Generate the witness row-wise. - let trace_rows = timed!( - timing, - "generate trace rows", - self.generate_trace_rows(operations, min_rows) - ); - - let trace_polys = timed!( - timing, - "convert to PolynomialValues", - trace_rows_to_poly_values(trace_rows) - ); + let trace_rows = self.generate_trace_rows(operations, min_rows); - trace_polys + trace_rows_to_poly_values(trace_rows) } fn generate_trace_rows( diff --git a/prover/src/prover.rs b/prover/src/prover.rs index 837a3f85..b7f1c705 100644 --- a/prover/src/prover.rs +++ b/prover/src/prover.rs @@ -307,6 +307,36 @@ where timing, )? ); + + let keccak_proof = timed!( + timing, + "prove Keccak STARK", + prove_single_table( + &all_stark.keccak_stark, + config, + &trace_poly_values[Table::Keccak as usize], + &trace_commitments[Table::Keccak as usize], + &ctl_data_per_table[Table::Keccak as usize], + ctl_challenges, + challenger, + timing, + )? + ); + let keccak_sponge_proof = timed!( + timing, + "prove Keccak sponge STARK", + prove_single_table( + &all_stark.keccak_sponge_stark, + config, + &trace_poly_values[Table::KeccakSponge as usize], + &trace_commitments[Table::KeccakSponge as usize], + &ctl_data_per_table[Table::KeccakSponge as usize], + ctl_challenges, + challenger, + timing, + )? + ); + let logic_proof = timed!( timing, "prove Logic STARK", @@ -341,6 +371,8 @@ where cpu_proof, poseidon_proof, poseidon_sponge_proof, + keccak_proof, + keccak_sponge_proof, logic_proof, memory_proof, ]) diff --git a/prover/src/verifier.rs b/prover/src/verifier.rs index 3d55cff9..f2114a17 100644 --- a/prover/src/verifier.rs +++ b/prover/src/verifier.rs @@ -48,6 +48,8 @@ where cpu_stark, poseidon_stark, poseidon_sponge_stark, + keccak_stark, + keccak_sponge_stark, logic_stark, memory_stark, cross_table_lookups, @@ -99,6 +101,24 @@ where &ctl_challenges, config, )?; + + verify_stark_proof_with_challenges( + keccak_stark, + &all_proof.stark_proofs[Table::Keccak as usize].proof, + &stark_challenges[Table::Keccak as usize], + &ctl_vars_per_table[Table::Keccak as usize], + &ctl_challenges, + config, + )?; + verify_stark_proof_with_challenges( + keccak_sponge_stark, + &all_proof.stark_proofs[Table::KeccakSponge as usize].proof, + &stark_challenges[Table::KeccakSponge as usize], + &ctl_vars_per_table[Table::KeccakSponge as usize], + &ctl_challenges, + config, + )?; + verify_stark_proof_with_challenges( logic_stark, &all_proof.stark_proofs[Table::Logic as usize].proof, diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index f2d43901..f7e3369a 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -11,8 +11,11 @@ use anyhow::{Context, Result}; use plonky2::field::types::Field; +use super::util::keccak_sponge_log; +use crate::keccak_sponge::columns::{KECCAK_RATE_BYTES, KECCAK_RATE_U32S}; use crate::poseidon_sponge::columns::POSEIDON_RATE_BYTES; use itertools::Itertools; +use keccak_hash::keccak; use plonky2::field::extension::Extendable; use plonky2::hash::hash_types::RichField; use plonky2::plonk::config::GenericConfig; @@ -67,6 +70,7 @@ pub fn generate_pinv_diff(val0: u32, val1: u32, lv: &mut CpuColumnsVie logic.diff_pinv = (val0_f - val1_f).try_inverse().unwrap_or(F::ZERO) * num_unequal_limbs_inv; } +pub(crate) const SYSKECCAK: usize = 0x010109; pub(crate) const SYSGETPID: usize = 4020; pub(crate) const SYSGETGID: usize = 4047; pub(crate) const SYSMMAP2: usize = 4210; @@ -1092,6 +1096,83 @@ pub(crate) fn commit, C: GenericConfig, c Ok(()) } +pub(crate) fn generate_keccak< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + state: &mut GenerationState, + addr: usize, + len: usize, + ptr: usize, +) -> Result<()> { + let mut map_addr = addr; + let mut cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + let mut j = 0; + let mut keccak_data_addr = Vec::new(); + let mut keccak_value_byte_be = vec![0u8; len]; + + for i in (0..len).step_by(WORD_SIZE) { + if j == 8 { + state.traces.push_cpu(cpu_row); + cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + j = 0; + } + let addr = MemoryAddress::new(0, Segment::Code, map_addr); + let (word, mem_op) = mem_read_gp_with_log_and_fill(j, addr, state, &mut cpu_row); + let bytes = word.to_be_bytes(); + let final_len = if i + 4 > len { len } else { 4 }; + keccak_value_byte_be[i..i + final_len].copy_from_slice(&bytes[0..final_len]); + keccak_data_addr.push(addr); + state.traces.push_memory(mem_op); + map_addr += 4; + j += 1; + } + + state.traces.push_cpu(cpu_row); + state.memory.apply_ops(&state.traces.memory_ops); + + let mut cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + cpu_row.is_keccak_sponge = F::ONE; + + // The Keccak sponge CTL uses memory value columns for its inputs and outputs. + cpu_row.mem_channels[0].value = F::ZERO; // context + cpu_row.mem_channels[1].value = F::from_canonical_usize(Segment::Code as usize); + let final_idx = len / KECCAK_RATE_BYTES * KECCAK_RATE_U32S; + cpu_row.mem_channels[2].value = F::from_canonical_usize(keccak_data_addr[final_idx].virt); + cpu_row.mem_channels[3].value = F::from_canonical_usize(len); + + let hash_data_bytes = keccak(&keccak_value_byte_be).0; + let hash_data_be = core::array::from_fn(|i| { + u32::from_le_bytes(core::array::from_fn(|j| hash_data_bytes[i * 4 + j])) + }); + + let hash_data = hash_data_be.map(u32::from_be); + + cpu_row.general.khash_mut().value = hash_data.map(F::from_canonical_u32); + cpu_row.general.khash_mut().value.reverse(); + + keccak_sponge_log(state, keccak_data_addr, keccak_value_byte_be); + state.traces.push_cpu(cpu_row); + + cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + map_addr = ptr; + assert!(hash_data_be.len() == 8); + for i in 0..hash_data_be.len() { + let addr = MemoryAddress::new(0, Segment::Code, map_addr); + let mem_op = + mem_write_gp_log_and_fill(i, addr, state, &mut cpu_row, hash_data_be[i].to_be()); + state.traces.push_memory(mem_op); + map_addr += 4; + } + state.traces.push_cpu(cpu_row); + Ok(()) +} + pub(crate) fn generate_syscall< F: RichField + Extendable, C: GenericConfig, @@ -1110,6 +1191,7 @@ pub(crate) fn generate_syscall< let mut is_load_preimage = false; let mut is_load_input = false; let mut is_verify = false; + let mut is_keccak = false; let mut is_commit = false; let result = match sys_num { SYSGETPID => { @@ -1263,6 +1345,10 @@ pub(crate) fn generate_syscall< is_verify = true; Ok(()) } + SYSKECCAK => { + is_keccak = true; + Ok(()) + } _ => { row.general.syscall_mut().sysnum[11] = F::ONE; Ok(()) @@ -1291,6 +1377,9 @@ pub(crate) fn generate_syscall< if is_commit { let _ = commit(state, a1, a2); } + if is_keccak { + let _ = generate_keccak(state, a0, a1, a2); + } result } diff --git a/prover/src/witness/traces.rs b/prover/src/witness/traces.rs index 4f130615..4939566d 100644 --- a/prover/src/witness/traces.rs +++ b/prover/src/witness/traces.rs @@ -1,3 +1,4 @@ +use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; use plonky2::hash::hash_types::RichField; @@ -11,6 +12,10 @@ use crate::arithmetic::{BinaryOperator, Operation}; use crate::config::StarkConfig; use crate::cpu::columns::CpuColumnsView; +use crate::keccak::keccak_stark; +use crate::keccak_sponge; +use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; use crate::poseidon::constants::SPONGE_WIDTH; use crate::poseidon_sponge::columns::POSEIDON_RATE_BYTES; use crate::poseidon_sponge::poseidon_sponge_stark::PoseidonSpongeOp; @@ -24,6 +29,8 @@ pub struct TraceCheckpoint { pub(self) cpu_len: usize, pub(self) poseidon_len: usize, pub(self) poseidon_sponge_len: usize, + pub(self) keccak_len: usize, + pub(self) keccak_sponge_len: usize, pub(self) logic_len: usize, pub(self) memory_len: usize, } @@ -36,6 +43,8 @@ pub(crate) struct Traces { pub(crate) memory_ops: Vec, pub(crate) poseidon_inputs: Vec<([T; SPONGE_WIDTH], usize)>, pub(crate) poseidon_sponge_ops: Vec, + pub(crate) keccak_inputs: Vec<([u64; keccak_stark::NUM_INPUTS], usize)>, + pub(crate) keccak_sponge_ops: Vec, } impl Traces { @@ -47,6 +56,8 @@ impl Traces { memory_ops: vec![], poseidon_inputs: vec![], poseidon_sponge_ops: vec![], + keccak_inputs: vec![], + keccak_sponge_ops: vec![], } } @@ -71,6 +82,12 @@ impl Traces { .iter() .map(|op| op.input.len() / POSEIDON_RATE_BYTES + 1) .sum(), + keccak_len: self.keccak_inputs.len() * keccak_stark::NUM_ROUNDS, + keccak_sponge_len: self + .keccak_sponge_ops + .iter() + .map(|op| op.input.len() / keccak_sponge::columns::KECCAK_RATE_BYTES + 1) + .sum(), logic_len: self.logic_ops.len(), // This is technically a lower-bound, as we may fill gaps, // but this gives a relatively good estimate. @@ -85,6 +102,8 @@ impl Traces { cpu_len: self.cpu.len(), poseidon_len: self.poseidon_inputs.len(), poseidon_sponge_len: self.poseidon_sponge_ops.len(), + keccak_len: self.keccak_inputs.len(), + keccak_sponge_len: self.keccak_sponge_ops.len(), logic_len: self.logic_ops.len(), memory_len: self.memory_ops.len(), } @@ -96,6 +115,9 @@ impl Traces { self.poseidon_inputs.truncate(checkpoint.poseidon_len); self.poseidon_sponge_ops .truncate(checkpoint.poseidon_sponge_len); + self.keccak_inputs.truncate(checkpoint.keccak_len); + self.keccak_sponge_ops + .truncate(checkpoint.keccak_sponge_len); self.logic_ops.truncate(checkpoint.logic_len); self.memory_ops.truncate(checkpoint.memory_len); } @@ -128,6 +150,24 @@ impl Traces { self.poseidon_sponge_ops.push(op); } + pub fn push_keccak(&mut self, input: [u64; keccak_stark::NUM_INPUTS], clock: usize) { + self.keccak_inputs.push((input, clock)); + } + + pub fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES], clock: usize) { + let chunks = input + .chunks(size_of::()) + .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap())) + .collect_vec() + .try_into() + .unwrap(); + self.push_keccak(chunks, clock); + } + + pub fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) { + self.keccak_sponge_ops.push(op); + } + pub fn clock(&self) -> usize { self.cpu.len() } @@ -150,6 +190,8 @@ impl Traces { mut memory_ops, poseidon_inputs, poseidon_sponge_ops, + keccak_inputs, + keccak_sponge_ops, } = self; let mut memory_trace = vec![]; @@ -157,6 +199,8 @@ impl Traces { let mut cpu_trace = vec![]; let mut poseidon_trace = vec![]; let mut poseidon_sponge_trace = vec![]; + let mut keccak_trace = vec![]; + let mut keccak_sponge_trace = vec![]; let mut logic_trace = vec![]; timed!( @@ -187,6 +231,18 @@ impl Traces { .poseidon_sponge_stark .generate_trace(&poseidon_sponge_ops, min_rows) }, + || { + keccak_trace = all_stark + .keccak_stark + .generate_trace(keccak_inputs, min_rows) + }, + ); + rayon::join( + || { + keccak_sponge_trace = all_stark + .keccak_sponge_stark + .generate_trace(keccak_sponge_ops, min_rows) + }, || logic_trace = all_stark.logic_stark.generate_trace(logic_ops, min_rows), ); }, @@ -198,6 +254,8 @@ impl Traces { cpu_trace, poseidon_trace, poseidon_sponge_trace, + keccak_trace, + keccak_sponge_trace, logic_trace, memory_trace, ] diff --git a/prover/src/witness/util.rs b/prover/src/witness/util.rs index c786624f..379f7d36 100644 --- a/prover/src/witness/util.rs +++ b/prover/src/witness/util.rs @@ -6,11 +6,13 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use crate::cpu::columns::CpuColumnsView; +use crate::cpu::kernel::keccak_util::keccakf_u8s; use crate::cpu::membus::NUM_CHANNELS; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::generation::state::GenerationState; use crate::keccak_sponge::columns::KECCAK_RATE_BYTES; use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES; +use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; use crate::logic; use crate::memory::segments::Segment; use crate::poseidon::constants::{SPONGE_RATE, SPONGE_WIDTH}; @@ -463,6 +465,94 @@ pub(crate) fn poseidon_sponge_log< }); } +pub(crate) fn keccak_sponge_log< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + state: &mut GenerationState, + base_address: Vec, + input: Vec, // BE +) { + let clock = state.traces.clock(); + + let mut absorbed_bytes = 0; + let mut input_blocks = input.chunks_exact(KECCAK_RATE_BYTES); + let mut sponge_state = [0u8; KECCAK_WIDTH_BYTES]; + // Since the keccak read byte by byte, and the memory unit is of 4-byte, we just need to read + // the same memory for 4 keccak-op + let mut n_gp = 0; + for block in input_blocks.by_ref() { + for i in 0..block.len() { + //for &byte in block { + let align = (i / 4) * 4; + let val = u32::from_le_bytes(block[align..(align + 4)].try_into().unwrap()); + + let addr_idx = absorbed_bytes / 4; + state.traces.push_memory(MemoryOp::new( + MemoryChannel::GeneralPurpose(n_gp), + clock, + base_address[addr_idx], + MemoryOpKind::Read, + val.to_be(), + )); + n_gp += 1; + n_gp %= NUM_GP_CHANNELS - 1; + absorbed_bytes += 1; + } + xor_into_sponge(state, &mut sponge_state, block.try_into().unwrap()); + state + .traces + .push_keccak_bytes(sponge_state, clock * NUM_CHANNELS); + keccakf_u8s(&mut sponge_state); + } + + let rem = input_blocks.remainder(); + + // patch data to match sponge logic + let mut rem_data = [0u8; KECCAK_RATE_BYTES]; + rem_data[0..rem.len()].copy_from_slice(&rem[0..rem.len()]); + rem_data[rem.len()] = 1; + rem_data[KECCAK_RATE_BYTES - 1] |= 0b10000000; + for i in 0..rem.len() { + let align = (i / 4) * 4; + let val = u32::from_le_bytes(rem_data[align..align + 4].try_into().unwrap()); + let addr_idx = absorbed_bytes / 4; + + state.traces.push_memory(MemoryOp::new( + MemoryChannel::GeneralPurpose(n_gp), + clock, + base_address[addr_idx], + MemoryOpKind::Read, + val.to_be(), + )); + n_gp += 1; + n_gp %= NUM_GP_CHANNELS - 1; + absorbed_bytes += 1; + } + let mut final_block = [0u8; KECCAK_RATE_BYTES]; + final_block[..input_blocks.remainder().len()].copy_from_slice(input_blocks.remainder()); + // pad10*1 rule + if input_blocks.remainder().len() == KECCAK_RATE_BYTES - 1 { + // Both 1s are placed in the same byte. + final_block[input_blocks.remainder().len()] = 0b10000001; + } else { + final_block[input_blocks.remainder().len()] = 1; + final_block[KECCAK_RATE_BYTES - 1] = 0b10000000; + } + xor_into_sponge(state, &mut sponge_state, &final_block); + state + .traces + .push_keccak_bytes(sponge_state, clock * NUM_CHANNELS); + + //FIXME: how to setup the base address + state.traces.push_keccak_sponge(KeccakSpongeOp { + base_address, + timestamp: clock * NUM_CHANNELS, + input, + }); +} + fn xor_into_sponge, C: GenericConfig, const D: usize>( state: &mut GenerationState, sponge_state: &mut [u8; KECCAK_WIDTH_BYTES], diff --git a/runtime/entrypoint/src/syscalls/keccak.rs b/runtime/entrypoint/src/syscalls/keccak.rs new file mode 100644 index 00000000..6bd4b614 --- /dev/null +++ b/runtime/entrypoint/src/syscalls/keccak.rs @@ -0,0 +1,26 @@ +#[cfg(target_os = "zkvm")] +use core::arch::asm; + +/// Executes the Keccak256 permutation on the given state. +/// +/// ### Safety +/// +/// The caller must ensure that `state` is valid pointer to data that is aligned along a four +/// byte boundary. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_keccak(state: *const u32, len: usize, result: *mut u8) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "syscall", + in("$2") crate::syscalls::KECCAK_PERMUTE, + in("$4") state, + in("$5") len, + in("$6") result, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} diff --git a/runtime/entrypoint/src/syscalls/mod.rs b/runtime/entrypoint/src/syscalls/mod.rs index f63b0b90..fd835bb6 100644 --- a/runtime/entrypoint/src/syscalls/mod.rs +++ b/runtime/entrypoint/src/syscalls/mod.rs @@ -2,11 +2,13 @@ mod halt; mod io; +mod keccak; mod memory; mod sys; pub use halt::*; pub use io::*; +pub use keccak::*; pub use memory::*; pub use sys::*; @@ -27,3 +29,6 @@ pub const HINT_READ: u32 = 0x00_00_00_F1; /// Executes `HINT_READ`. pub const VERIFY: u32 = 0x00_00_00_F2; + +/// Executes `KECCAK_PERMUTE`. +pub const KECCAK_PERMUTE: u32 = 0x00_01_01_09; diff --git a/runtime/precompiles/src/io.rs b/runtime/precompiles/src/io.rs index e213be53..ae8ec9f7 100644 --- a/runtime/precompiles/src/io.rs +++ b/runtime/precompiles/src/io.rs @@ -1,6 +1,7 @@ //! Ported from Precompiles for SP1 zkVM. #![allow(unused_unsafe)] +use crate::syscall_keccak; use crate::syscall_verify; use crate::syscall_write; use crate::{syscall_hint_len, syscall_hint_read}; @@ -110,3 +111,42 @@ pub fn print(buf: Vec) { .write_all(buf.as_slice()) .unwrap(); } + +pub fn keccak(data: &[u8]) -> [u8; 32] { + let len = data.len(); + let mut u32_array = Vec::new(); + + if len == 0 { + return [ + 0xC5, 0xD2, 0x46, 0x01, 0x86, 0xF7, 0x23, 0x3C, 0x92, 0x7E, 0x7D, 0xB2, 0xDC, 0xC7, + 0x03, 0xC0, 0xE5, 0, 0xB6, 0x53, 0xCA, 0x82, 0x27, 0x3B, 0x7B, 0xFA, 0xD8, 0x04, 0x5D, + 0x85, 0xA4, 0x70, + ]; + } + + // covert to u32 to align the memory + for i in (0..len).step_by(4) { + if i + 4 <= len { + let u32_value = u32::from_be_bytes([data[0], data[1], data[2], data[3]]); + u32_array.push(u32_value); + } else { + let mut padded_chunk = [0u8; 4]; + padded_chunk[..len - i].copy_from_slice(&data[i..]); + padded_chunk[len - i] = 1; + let end = len % 136; + if end + 4 > 136 { + padded_chunk[3] |= 0x80; + } + let u32_value = u32::from_be_bytes(padded_chunk); + u32_array.push(u32_value); + } + } + + let mut result = [0u8; 32]; + // Read the vec into uninitialized memory. The syscall assumes the memory is uninitialized, + // which should be true because the allocator does not dealloc, so a new alloc should be fresh. + unsafe { + syscall_keccak(u32_array.as_ptr(), len, result.as_mut_ptr()); + } + result +} diff --git a/runtime/precompiles/src/lib.rs b/runtime/precompiles/src/lib.rs index 17774d2d..f91466dc 100644 --- a/runtime/precompiles/src/lib.rs +++ b/runtime/precompiles/src/lib.rs @@ -17,4 +17,6 @@ extern "C" { pub fn syscall_hint_read(ptr: *mut u8, len: usize); pub fn sys_alloc_aligned(bytes: usize, align: usize) -> *mut u8; pub fn syscall_verify(claim_digest: &[u8; 32]); + /// Executes the Keccak-256 permutation on the given state. + pub fn syscall_keccak(state: *const u32, len: usize, result: *mut u8); }