From f8bf439cacddca57f0bcf8ece4ad621d4386f579 Mon Sep 17 00:00:00 2001 From: Angell Li Date: Mon, 13 Jan 2025 22:14:02 +0800 Subject: [PATCH 01/25] add temply test for sha256 syscall --- prover/examples/Cargo.toml | 1 + prover/examples/sha2-syscall/guest/Cargo.toml | 17 +++ .../examples/sha2-syscall/guest/src/consts.rs | 108 ++++++++++++++++++ .../sha2-syscall/guest/src/core_api.rs | 92 +++++++++++++++ prover/examples/sha2-syscall/guest/src/lib.rs | 46 ++++++++ .../examples/sha2-syscall/guest/src/main.rs | 44 +++++++ prover/examples/sha2-syscall/host/Cargo.toml | 22 ++++ prover/examples/sha2-syscall/host/build.rs | 3 + prover/examples/sha2-syscall/host/src/main.rs | 42 +++++++ 9 files changed, 375 insertions(+) create mode 100644 prover/examples/sha2-syscall/guest/Cargo.toml create mode 100644 prover/examples/sha2-syscall/guest/src/consts.rs create mode 100644 prover/examples/sha2-syscall/guest/src/core_api.rs create mode 100644 prover/examples/sha2-syscall/guest/src/lib.rs create mode 100644 prover/examples/sha2-syscall/guest/src/main.rs create mode 100644 prover/examples/sha2-syscall/host/Cargo.toml create mode 100644 prover/examples/sha2-syscall/host/build.rs create mode 100644 prover/examples/sha2-syscall/host/src/main.rs diff --git a/prover/examples/Cargo.toml b/prover/examples/Cargo.toml index f9c50110..aac035b6 100644 --- a/prover/examples/Cargo.toml +++ b/prover/examples/Cargo.toml @@ -4,6 +4,7 @@ members = [ "sha2-rust/host", "sha2-precompile/host", "sha2-go/host", + "sha2-syscall/host", "keccak/host", "split-seg", "prove-seg" diff --git a/prover/examples/sha2-syscall/guest/Cargo.toml b/prover/examples/sha2-syscall/guest/Cargo.toml new file mode 100644 index 00000000..77537980 --- /dev/null +++ b/prover/examples/sha2-syscall/guest/Cargo.toml @@ -0,0 +1,17 @@ +[workspace] +[package] +version = "0.1.0" +name = "sha2-syscall" +edition = "2021" + +[dependencies] +#zkm-runtime = { git = "https://github.com/zkMIPS/zkm", package = "zkm-runtime" } +zkm-runtime = { path = "../../../../runtime/entrypoint" } +digest = "0.10.4" +cfg-if = "1.0" +hex-literal = "0.2.2" + +[features] +default = ["std"] +std = ["digest/std"] +oid = ["digest/oid"] # Enable OID support. WARNING: Bumps MSRV to 1.57 diff --git a/prover/examples/sha2-syscall/guest/src/consts.rs b/prover/examples/sha2-syscall/guest/src/consts.rs new file mode 100644 index 00000000..6913769d --- /dev/null +++ b/prover/examples/sha2-syscall/guest/src/consts.rs @@ -0,0 +1,108 @@ + +#![allow(dead_code, clippy::unreadable_literal)] + +pub const STATE_LEN: usize = 8; +pub const BLOCK_LEN: usize = 16; + +pub type State256 = [u32; STATE_LEN]; +pub type State512 = [u64; STATE_LEN]; + +/// Constants necessary for SHA-256 family of digests. +pub const K32: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +/// Constants necessary for SHA-256 family of digests. +pub const K32X4: [[u32; 4]; 16] = [ + [K32[3], K32[2], K32[1], K32[0]], + [K32[7], K32[6], K32[5], K32[4]], + [K32[11], K32[10], K32[9], K32[8]], + [K32[15], K32[14], K32[13], K32[12]], + [K32[19], K32[18], K32[17], K32[16]], + [K32[23], K32[22], K32[21], K32[20]], + [K32[27], K32[26], K32[25], K32[24]], + [K32[31], K32[30], K32[29], K32[28]], + [K32[35], K32[34], K32[33], K32[32]], + [K32[39], K32[38], K32[37], K32[36]], + [K32[43], K32[42], K32[41], K32[40]], + [K32[47], K32[46], K32[45], K32[44]], + [K32[51], K32[50], K32[49], K32[48]], + [K32[55], K32[54], K32[53], K32[52]], + [K32[59], K32[58], K32[57], K32[56]], + [K32[63], K32[62], K32[61], K32[60]], +]; + +/// Constants necessary for SHA-512 family of digests. +pub const K64: [u64; 80] = [ + 0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f, 0xe9b5dba58189dbbc, + 0x3956c25bf348b538, 0x59f111f1b605d019, 0x923f82a4af194f9b, 0xab1c5ed5da6d8118, + 0xd807aa98a3030242, 0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235, 0xc19bf174cf692694, + 0xe49b69c19ef14ad2, 0xefbe4786384f25e3, 0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, + 0x2de92c6f592b0275, 0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5, + 0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f, 0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2, 0xd5a79147930aa725, 0x06ca6351e003826f, 0x142929670a0e6e70, + 0x27b70a8546d22ffc, 0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df, + 0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6, 0x92722c851482353b, + 0xa2bfe8a14cf10364, 0xa81a664bbc423001, 0xc24b8b70d0f89791, 0xc76c51a30654be30, + 0xd192e819d6ef5218, 0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99, 0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb, 0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc, 0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec, + 0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915, 0xc67178f2e372532b, + 0xca273eceea26619c, 0xd186b8c721c0c207, 0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, + 0x06f067aa72176fba, 0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b, + 0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc, 0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a, 0x5fcb6fab3ad6faec, 0x6c44198c4a475817, +]; + +/// Constants necessary for SHA-512 family of digests. +pub const K64X2: [[u64; 2]; 40] = [ + [K64[1], K64[0]], [K64[3], K64[2]], [K64[5], K64[4]], [K64[7], K64[6]], + [K64[9], K64[8]], [K64[11], K64[10]], [K64[13], K64[12]], [K64[15], K64[14]], + [K64[17], K64[16]], [K64[19], K64[18]], [K64[21], K64[20]], [K64[23], K64[22]], + [K64[25], K64[24]], [K64[27], K64[26]], [K64[29], K64[28]], [K64[31], K64[30]], + [K64[33], K64[32]], [K64[35], K64[34]], [K64[37], K64[36]], [K64[39], K64[38]], + [K64[41], K64[40]], [K64[43], K64[42]], [K64[45], K64[44]], [K64[47], K64[46]], + [K64[49], K64[48]], [K64[51], K64[50]], [K64[53], K64[52]], [K64[55], K64[54]], + [K64[57], K64[56]], [K64[59], K64[58]], [K64[61], K64[60]], [K64[63], K64[62]], + [K64[65], K64[64]], [K64[67], K64[66]], [K64[69], K64[68]], [K64[71], K64[70]], + [K64[73], K64[72]], [K64[75], K64[74]], [K64[77], K64[76]], [K64[79], K64[78]], +]; + +pub const H256_224: State256 = [ + 0xc1059ed8, 0x367cd507, 0x3070dd17, 0xf70e5939, + 0xffc00b31, 0x68581511, 0x64f98fa7, 0xbefa4fa4, +]; + +pub const H256_256: State256 = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, + 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub const H512_224: State512 = [ + 0x8c3d37c819544da2, 0x73e1996689dcd4d6, 0x1dfab7ae32ff9c82, 0x679dd514582f9fcf, + 0x0f6d2b697bd44da8, 0x77e36f7304c48942, 0x3f9d85a86a1d36c8, 0x1112e6ad91d692a1, +]; + +pub const H512_256: State512 = [ + 0x22312194fc2bf72c, 0x9f555fa3c84c64c2, 0x2393b86b6f53b151, 0x963877195940eabd, + 0x96283ee2a88effe3, 0xbe5e1e2553863992, 0x2b0199fc2c85b8aa, 0x0eb72ddc81c52ca2, +]; + +pub const H512_384: State512 = [ + 0xcbbb9d5dc1059ed8, 0x629a292a367cd507, 0x9159015a3070dd17, 0x152fecd8f70e5939, + 0x67332667ffc00b31, 0x8eb44a8768581511, 0xdb0c2e0d64f98fa7, 0x47b5481dbefa4fa4, +]; + +pub const H512_512: State512 = [ + 0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b, 0xa54ff53a5f1d36f1, + 0x510e527fade682d1, 0x9b05688c2b3e6c1f, 0x1f83d9abfb41bd6b, 0x5be0cd19137e2179, +]; diff --git a/prover/examples/sha2-syscall/guest/src/core_api.rs b/prover/examples/sha2-syscall/guest/src/core_api.rs new file mode 100644 index 00000000..ca1faaf2 --- /dev/null +++ b/prover/examples/sha2-syscall/guest/src/core_api.rs @@ -0,0 +1,92 @@ +use crate::consts; +use core::{fmt, slice::from_ref}; +use digest::{ + block_buffer::Eager, + core_api::{ + AlgorithmName, Block, BlockSizeUser, Buffer, BufferKindUser, OutputSizeUser, TruncSide, + UpdateCore, VariableOutputCore, + }, + typenum::{Unsigned, U32, U64}, + HashMarker, InvalidOutputSize, Output, + generic_array::GenericArray, +}; + +/// Core block-level SHA-256 hasher with variable output size. +/// +/// Supports initialization only for 28 and 32 byte output sizes, +/// i.e. 224 and 256 bits respectively. +#[derive(Clone)] +pub struct Sha256VarCore { + state: consts::State256, + block_len: u64, +} + +impl HashMarker for Sha256VarCore {} + +impl BlockSizeUser for Sha256VarCore { + type BlockSize = U64; +} + +impl BufferKindUser for Sha256VarCore { + type BufferKind = Eager; +} + +impl UpdateCore for Sha256VarCore { + #[inline] + fn update_blocks(&mut self, blocks: &[Block]) { + self.block_len += blocks.len() as u64; + compress256(&mut self.state, blocks); + } +} + +impl OutputSizeUser for Sha256VarCore { + type OutputSize = U32; +} + +impl VariableOutputCore for Sha256VarCore { + const TRUNC_SIDE: TruncSide = TruncSide::Left; + + #[inline] + fn new(output_size: usize) -> Result { + let state = match output_size { + 28 => consts::H256_224, + 32 => consts::H256_256, + _ => return Err(InvalidOutputSize), + }; + let block_len = 0; + Ok(Self { state, block_len }) + } + + #[inline] + fn finalize_variable_core(&mut self, buffer: &mut Buffer, out: &mut Output) { + let bs = Self::BlockSize::U64; + let bit_len = 8 * (buffer.get_pos() as u64 + bs * self.block_len); + buffer.len64_padding_be(bit_len, |b| compress256(&mut self.state, from_ref(b))); + + for (chunk, v) in out.chunks_exact_mut(4).zip(self.state.iter()) { + chunk.copy_from_slice(&v.to_be_bytes()); + } + } +} + +impl AlgorithmName for Sha256VarCore { + #[inline] + fn write_alg_name(f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Sha256") + } +} + +impl fmt::Debug for Sha256VarCore { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("Sha256VarCore { ... }") + } +} + +pub fn compress256(state: &mut [u32; 8], blocks: &[GenericArray]) { + // SAFETY: GenericArray and [u8; 64] have + // exactly the same memory layout + let p = blocks.as_ptr() as *const [u8; 64]; + let blocks = unsafe { core::slice::from_raw_parts(p, blocks.len()) }; + zkm_runtime::io::compress(state, blocks) +} diff --git a/prover/examples/sha2-syscall/guest/src/lib.rs b/prover/examples/sha2-syscall/guest/src/lib.rs new file mode 100644 index 00000000..f449f04f --- /dev/null +++ b/prover/examples/sha2-syscall/guest/src/lib.rs @@ -0,0 +1,46 @@ +//! An implementation of the [SHA-2][1] cryptographic hash algorithms. +//! +//! There are 6 standard algorithms specified in the SHA-2 standard: [`Sha224`], +//! [`Sha256`], [`Sha512_224`], [`Sha512_256`], [`Sha384`], and [`Sha512`]. +//! +//! Algorithmically, there are only 2 core algorithms: SHA-256 and SHA-512. +//! All other algorithms are just applications of these with different initial +//! hash values, and truncated to different digest bit lengths. The first two +//! algorithms in the list are based on SHA-256, while the last three on SHA-512. +//! +//! # Usage +//! +//! ```rust +//! use hex_literal::hex; +//! use sha2::{Sha256, Sha512, Digest}; +//! +//! // create a Sha256 object +//! let mut hasher = Sha256::new(); +//! +//! // write input message +//! hasher.update(b"hello world"); +//! +//! // read hash digest and consume hasher +//! let result = hasher.finalize(); +//! +//! assert_eq!(result[..], hex!(" +//! b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9 +//! ")[..]); +//! +//! // same for Sha512 +//! let mut hasher = Sha512::new(); +//! hasher.update(b"hello world"); +//! let result = hasher.finalize(); +//! +//! assert_eq!(result[..], hex!(" +//! 309ecc489c12d6eb4cc40f50c902f2b4d0ed77ee511a7c7a9bcd3ca86d4cd86f +//! 989dd35bc5ff499670da34255b45b0cfd830e81f605dcf7dc5542e93ae9cd76f +//! ")[..]); +//! ``` +//! +//! Also see [RustCrypto/hashes][2] readme. +//! +//! [1]: https://en.wikipedia.org/wiki/SHA-2 +//! [2]: https://github.com/RustCrypto/hashes + +#![no_std] diff --git a/prover/examples/sha2-syscall/guest/src/main.rs b/prover/examples/sha2-syscall/guest/src/main.rs new file mode 100644 index 00000000..28ffce0d --- /dev/null +++ b/prover/examples/sha2-syscall/guest/src/main.rs @@ -0,0 +1,44 @@ +#![no_std] +#![no_main] + +extern crate alloc; +use alloc::vec::Vec; + +pub use digest::{self, Digest}; + +#[cfg(feature = "oid")] +use digest::const_oid::{AssociatedOid, ObjectIdentifier}; +use digest::{ + consts::{U28, U32}, + core_api::{CoreWrapper, CtVariableCoreWrapper}, + impl_oid_carrier, +}; + +#[rustfmt::skip] +mod consts; +mod core_api; + +pub use core_api::{compress256, Sha256VarCore}; + +impl_oid_carrier!(OidSha256, "2.16.840.1.101.3.4.2.1"); +impl_oid_carrier!(OidSha224, "2.16.840.1.101.3.4.2.4"); + +/// SHA-224 hasher. +pub type Sha224 = CoreWrapper>; +/// SHA-256 hasher. +pub type Sha256 = CoreWrapper>; + + +zkm_runtime::entrypoint!(main); + +pub fn main() { + let public_input: Vec = zkm_runtime::io::read(); + let input: Vec = zkm_runtime::io::read(); + + let result = Sha256::digest(input); + + let output: [u8; 32] = result.into(); + assert_eq!(output.to_vec(), public_input); + + zkm_runtime::io::commit::<[u8; 32]>(&output); +} diff --git a/prover/examples/sha2-syscall/host/Cargo.toml b/prover/examples/sha2-syscall/host/Cargo.toml new file mode 100644 index 00000000..76bf7490 --- /dev/null +++ b/prover/examples/sha2-syscall/host/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "sha2-syscall-host" +version = { workspace = true } +edition = { workspace = true } +publish = false + +[dependencies] +zkm-prover = { workspace = true } +zkm-emulator = { workspace = true } +zkm-utils = { path = "../../utils" } + + +log = { version = "0.4.14", default-features = false } +serde = { version = "1.0.144", features = ["derive"] } +serde_json = "1.0" +byteorder = "1.5.0" +hex = "0.4" +env_logger = "0.11.5" +anyhow = "1.0.75" + +[build-dependencies] +zkm-build = { workspace = true } diff --git a/prover/examples/sha2-syscall/host/build.rs b/prover/examples/sha2-syscall/host/build.rs new file mode 100644 index 00000000..4b9fa8ed --- /dev/null +++ b/prover/examples/sha2-syscall/host/build.rs @@ -0,0 +1,3 @@ +fn main() { + zkm_build::build_program(&format!("{}/../guest", env!("CARGO_MANIFEST_DIR"))); +} diff --git a/prover/examples/sha2-syscall/host/src/main.rs b/prover/examples/sha2-syscall/host/src/main.rs new file mode 100644 index 00000000..be7041ed --- /dev/null +++ b/prover/examples/sha2-syscall/host/src/main.rs @@ -0,0 +1,42 @@ +use std::env; + +use zkm_emulator::utils::{load_elf_with_patch, split_prog_into_segs}; +use zkm_utils::utils::prove_segments; + +const ELF_PATH: &str = "../guest/elf/mips-zkm-zkvm-elf"; + +fn prove_sha2_rust() { + // 1. split ELF into segs + let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); + let seg_size = env::var("SEG_SIZE").unwrap_or("65536".to_string()); + let seg_size = seg_size.parse::<_>().unwrap_or(0); + + let mut state = load_elf_with_patch(ELF_PATH, vec![]); + // load input + let args = env::var("ARGS").unwrap_or("data-to-hash".to_string()); + // assume the first arg is the hash output(which is a public input), and the second is the input. + let args: Vec<&str> = args.split_whitespace().collect(); + assert_eq!(args.len(), 2); + + let public_input: Vec = hex::decode(args[0]).unwrap(); + state.add_input_stream(&public_input); + log::info!("expected public value in hex: {:X?}", args[0]); + log::info!("expected public value: {:X?}", public_input); + + let private_input = args[1].as_bytes().to_vec(); + log::info!("private input value: {:X?}", private_input); + state.add_input_stream(&private_input); + + let (_total_steps, seg_num, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size); + + let value = state.read_public_values::<[u8; 32]>(); + log::info!("public value: {:X?}", value); + log::info!("public value: {} in hex", hex::encode(value)); + + let _ = prove_segments(&seg_path, "", "", "", seg_num, 0, vec![]).unwrap(); +} + +fn main() { + env_logger::try_init().unwrap_or_default(); + prove_sha2_rust(); +} From 689d582cab63922a4a57f42bb27fe84bd3711bb5 Mon Sep 17 00:00:00 2001 From: Angell Li Date: Mon, 13 Jan 2025 22:14:42 +0800 Subject: [PATCH 02/25] add runtime/emualtor/witness support for sha256 syscall --- emulator/src/state.rs | 90 +++++++++++- prover/src/witness/operation.rs | 159 ++++++++++++++++++++++ runtime/entrypoint/src/syscalls/mod.rs | 8 ++ runtime/entrypoint/src/syscalls/sha256.rs | 39 ++++++ runtime/precompiles/src/io.rs | 19 +++ runtime/precompiles/src/lib.rs | 4 + 6 files changed, 318 insertions(+), 1 deletion(-) create mode 100644 runtime/entrypoint/src/syscalls/sha256.rs diff --git a/emulator/src/state.rs b/emulator/src/state.rs index b8f30de1..ea5b72e8 100644 --- a/emulator/src/state.rs +++ b/emulator/src/state.rs @@ -498,6 +498,17 @@ impl Display for InstrumentedState { } } +pub const SHA_COMPRESS_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + impl InstrumentedState { pub fn new(state: Box, block_path: String) -> Box { Box::new(Self { @@ -530,7 +541,84 @@ impl InstrumentedState { log::debug!("syscall {} {} {} {}", syscall_num, a0, a1, a2); match syscall_num { - 0x010109 => { + 0x300105 => { // SHA_EXTEND + let w_ptr = a0; + assert!(a1 == 0, "arg2 must be 0"); + + for i in 16..64 { + // Read w[i-15]. + let w_i_minus_15 = self.state.memory.get_memory(w_ptr + (i - 15) * 4); + // Compute `s0`. + let s0 = + w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); + + // Read w[i-2]. + let w_i_minus_2 = self.state.memory.get_memory(w_ptr + (i - 2) * 4); + // Compute `s1`. + let s1 = + w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); + + // Read w[i-16]. + let w_i_minus_16 = self.state.memory.get_memory(w_ptr + (i - 16) * 4); + + // Read w[i-7]. + let w_i_minus_7 = self.state.memory.get_memory(w_ptr + (i - 7) * 4); + + // Compute `w_i`. + let w_i = s1.wrapping_add(w_i_minus_16).wrapping_add(s0).wrapping_add(w_i_minus_7); + + // Write w[i]. + self.state.memory.set_memory(w_ptr + i * 4, w_i); + } + }, + 0x010106 => { // SHA_COMPRESS + let w_ptr = a0; + let h_ptr = a1; + let mut hx = [0u32; 8]; + for i in 0..8 { + hx[i] = self.state.memory.get_memory(h_ptr + i as u32 * 4); + } + + let mut original_w = Vec::new(); + // Execute the "compress" phase. + let mut a = hx[0]; + let mut b = hx[1]; + let mut c = hx[2]; + let mut d = hx[3]; + let mut e = hx[4]; + let mut f = hx[5]; + let mut g = hx[6]; + let mut h = hx[7]; + for i in 0..64 { + let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); + let ch = (e & f) ^ (!e & g); + let w_i = self.state.memory.get_memory(w_ptr + i * 4); + original_w.push(w_i); + let temp1 = h + .wrapping_add(s1) + .wrapping_add(ch) + .wrapping_add(SHA_COMPRESS_K[i as usize]) + .wrapping_add(w_i); + let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22); + let maj = (a & b) ^ (a & c) ^ (b & c); + let temp2 = s0.wrapping_add(maj); + + h = g; + g = f; + f = e; + e = d.wrapping_add(temp1); + d = c; + c = b; + b = a; + a = temp1.wrapping_add(temp2); + } + // Execute the "finalize" phase. + let v = [a, b, c, d, e, f, g, h]; + for i in 0..8 { + self.state.memory.set_memory(h_ptr + i as u32 * 4, hx[i].wrapping_add(v[i])); + } + }, + 0x010109 => { //keccak assert!((a0 & 3) == 0); assert!((a2 & 3) == 0); let bytes = (0..a1) diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index 6b7c251e..1b438252 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -70,6 +70,8 @@ pub fn generate_pinv_diff(val0: u32, val1: u32, lv: &mut CpuColumnsVie logic.diff_pinv = (val0_f - val1_f).try_inverse().unwrap_or(F::ZERO) * num_unequal_limbs_inv; } +pub(crate) const SYSSHAEXTEND: usize = 0x00300105; +pub(crate) const SYSSHACOMPRESS: usize = 0x00010106; pub(crate) const SYSKECCAK: usize = 0x010109; pub(crate) const SYSGETPID: usize = 4020; pub(crate) const SYSGETGID: usize = 4047; @@ -1178,6 +1180,147 @@ pub(crate) fn generate_keccak< Ok(()) } +pub(crate) fn generate_sha_extend< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + state: &mut GenerationState, + w_ptr: usize, + a1: usize, +) -> Result<()> { + assert!(a1 == 0, "arg2 must be 0"); + + for i in 16..64 { + let mut cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + + let addr = MemoryAddress::new(0, Segment::Code, w_ptr + (i - 15) * 4); + let (w_i_minus_15, mem_op) = mem_read_gp_with_log_and_fill(0, addr, state, &mut cpu_row); + state.traces.push_memory(mem_op); + let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); + + // Read w[i-2]. + let addr = MemoryAddress::new(0, Segment::Code, w_ptr + (i - 2) * 4); + let (w_i_minus_2, mem_op) = mem_read_gp_with_log_and_fill(1, addr, state, &mut cpu_row); + state.traces.push_memory(mem_op); + // Compute `s1`. + let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); + + // Read w[i-16]. + let addr = MemoryAddress::new(0, Segment::Code, w_ptr + (i - 16) * 4); + let (w_i_minus_16, mem_op) = mem_read_gp_with_log_and_fill(2, addr, state, &mut cpu_row); + state.traces.push_memory(mem_op); + + // Read w[i-7]. + let addr = MemoryAddress::new(0, Segment::Code, w_ptr + (i - 7) * 4); + let (w_i_minus_7, mem_op) = mem_read_gp_with_log_and_fill(3, addr, state, &mut cpu_row); + state.traces.push_memory(mem_op); + + // Compute `w_i`. + let w_i = s1 + .wrapping_add(w_i_minus_16) + .wrapping_add(s0) + .wrapping_add(w_i_minus_7); + + // Write w[i]. + let addr = MemoryAddress::new(0, Segment::Code, w_ptr + i * 4); + let mem_op = mem_write_gp_log_and_fill(i, addr, state, &mut cpu_row, w_i); + state.traces.push_memory(mem_op); + state.traces.push_cpu(cpu_row); + } + + Ok(()) +} + +pub const SHA_COMPRESS_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +pub(crate) fn generate_sha_compress< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + state: &mut GenerationState, + w_ptr: usize, + h_ptr: usize, +) -> Result<()> { + let mut hx = [0u32; 8]; + let mut cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + for i in 0..8 { + let addr = MemoryAddress::new(0, Segment::Code, h_ptr + i * 4); + let (value, mem_op) = mem_read_gp_with_log_and_fill(i, addr, state, &mut cpu_row); + state.traces.push_memory(mem_op); + hx[i] = value + } + state.traces.push_cpu(cpu_row); + let mut original_w = Vec::new(); + // Execute the "compress" phase. + let mut a = hx[0]; + let mut b = hx[1]; + let mut c = hx[2]; + let mut d = hx[3]; + let mut e = hx[4]; + let mut f = hx[5]; + let mut g = hx[6]; + let mut h = hx[7]; + let mut j = 0; + for i in 0..64 { + let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); + let ch = (e & f) ^ (!e & g); + if j == 8 { + state.traces.push_cpu(cpu_row); + cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + j = 0; + } + let addr = MemoryAddress::new(0, Segment::Code, w_ptr + i * 4); + let (w_i, mem_op) = mem_read_gp_with_log_and_fill(j, addr, state, &mut cpu_row); + state.traces.push_memory(mem_op); + j += 1; + original_w.push(w_i); + let temp1 = h + .wrapping_add(s1) + .wrapping_add(ch) + .wrapping_add(SHA_COMPRESS_K[i as usize]) + .wrapping_add(w_i); + let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22); + let maj = (a & b) ^ (a & c) ^ (b & c); + let temp2 = s0.wrapping_add(maj); + + h = g; + g = f; + f = e; + e = d.wrapping_add(temp1); + d = c; + c = b; + b = a; + a = temp1.wrapping_add(temp2); + } + state.traces.push_cpu(cpu_row); + // Execute the "finalize" phase. + let v = [a, b, c, d, e, f, g, h]; + let mut cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + for i in 0..8 { + let addr = MemoryAddress::new(0, Segment::Code, h_ptr + i * 4); + let mem_op = + mem_write_gp_log_and_fill(i, addr, state, &mut cpu_row, hx[i].wrapping_add(v[i])); + state.traces.push_memory(mem_op); + } + state.traces.push_cpu(cpu_row); + Ok(()) +} + pub(crate) fn generate_syscall< F: RichField + Extendable, C: GenericConfig, @@ -1198,6 +1341,8 @@ pub(crate) fn generate_syscall< let mut is_verify = false; let mut is_keccak = false; let mut is_commit = false; + let mut is_sha_extend = false; + let mut is_sha_compress = false; let result = match sys_num { SYSGETPID => { row.general.syscall_mut().sysnum[0] = F::ONE; @@ -1354,6 +1499,14 @@ pub(crate) fn generate_syscall< is_keccak = true; Ok(()) } + SYSSHACOMPRESS => { + is_sha_compress = true; + Ok(()) + } + SYSSHAEXTEND => { + is_sha_extend = true; + Ok(()) + } _ => { row.general.syscall_mut().sysnum[11] = F::ONE; Ok(()) @@ -1385,6 +1538,12 @@ pub(crate) fn generate_syscall< if is_keccak { let _ = generate_keccak(state, a0, a1, a2); } + if is_sha_compress { + let _ = generate_sha_compress(state, a0, a1); + } + if is_sha_extend { + let _ = generate_sha_compress(state, a0, a1); + } result } diff --git a/runtime/entrypoint/src/syscalls/mod.rs b/runtime/entrypoint/src/syscalls/mod.rs index fd835bb6..ca3c6dbf 100644 --- a/runtime/entrypoint/src/syscalls/mod.rs +++ b/runtime/entrypoint/src/syscalls/mod.rs @@ -3,12 +3,14 @@ mod halt; mod io; mod keccak; +mod sha256; mod memory; mod sys; pub use halt::*; pub use io::*; pub use keccak::*; +pub use sha256::*; pub use memory::*; pub use sys::*; @@ -32,3 +34,9 @@ pub const VERIFY: u32 = 0x00_00_00_F2; /// Executes `KECCAK_PERMUTE`. pub const KECCAK_PERMUTE: u32 = 0x00_01_01_09; + +/// Executes `SHA_EXTEND`. +pub const SHA_EXTEND: u32 = 0x00_30_01_05; + +/// Executes `SHA_COMPRESS`. +pub const SHA_COMPRESS: u32 = 0x00_01_01_06; diff --git a/runtime/entrypoint/src/syscalls/sha256.rs b/runtime/entrypoint/src/syscalls/sha256.rs new file mode 100644 index 00000000..d0d7c0ab --- /dev/null +++ b/runtime/entrypoint/src/syscalls/sha256.rs @@ -0,0 +1,39 @@ +#[cfg(target_os = "zkvm")] +use core::arch::asm; + +/// Executes the Keccak256 permutation on the given state. +/// +/// ### Safety +/// +/// The caller must ensure that `state` is valid pointer to data that is aligned along a four +/// byte boundary. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_sha256_compress(w: *mut u32, state: *mut u32) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "syscall", + in("$2") crate::syscalls::SHA_COMPRESS, + in("$4") w, + in("$5") state, + ); + } +} + +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_sha256_extend(w: *mut u32) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "syscall", + in("$2") crate::syscalls::SHA_EXTEND, + in("$4") w, + in("$5") 0 + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} diff --git a/runtime/precompiles/src/io.rs b/runtime/precompiles/src/io.rs index 335b7d4f..346786b0 100644 --- a/runtime/precompiles/src/io.rs +++ b/runtime/precompiles/src/io.rs @@ -5,6 +5,7 @@ use crate::syscall_keccak; use crate::syscall_verify; use crate::syscall_write; use crate::{syscall_hint_len, syscall_hint_read}; +use crate::{syscall_sha256_extend, syscall_sha256_compress}; use serde::de::DeserializeOwned; use serde::Serialize; use sha2::{Digest, Sha256}; @@ -150,3 +151,21 @@ pub fn keccak(data: &[u8]) -> [u8; 32] { } result } + +pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + unsafe { + for i in 0..blocks.len() { + let mut w = [0u32; 64]; + for j in 0..16 { + w[j] = u32::from_be_bytes([ + blocks[i][j * 4], + blocks[i][j * 4 + 1], + blocks[i][j * 4 + 2], + blocks[i][j * 4 + 3], + ]); + } + syscall_sha256_extend(w.as_mut_ptr()); + syscall_sha256_compress(w.as_mut_ptr(), state.as_mut_ptr()); + } + } +} diff --git a/runtime/precompiles/src/lib.rs b/runtime/precompiles/src/lib.rs index f91466dc..eae9b91d 100644 --- a/runtime/precompiles/src/lib.rs +++ b/runtime/precompiles/src/lib.rs @@ -19,4 +19,8 @@ extern "C" { pub fn syscall_verify(claim_digest: &[u8; 32]); /// Executes the Keccak-256 permutation on the given state. pub fn syscall_keccak(state: *const u32, len: usize, result: *mut u8); + pub fn syscall_sha256(state: *const u32, len: usize, result: *mut u8); + pub fn syscall_sha256_compress(w: *mut u32, state: *mut u32); + pub fn syscall_sha256_extend(w: *mut u32); + } From ff70039661e409b9d70b935f99f6ae4a5b4141bb Mon Sep 17 00:00:00 2001 From: Angell Li Date: Mon, 13 Jan 2025 22:24:33 +0800 Subject: [PATCH 03/25] fix memory channel --- prover/src/witness/operation.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index 1b438252..3a5b9791 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -1225,7 +1225,7 @@ pub(crate) fn generate_sha_extend< // Write w[i]. let addr = MemoryAddress::new(0, Segment::Code, w_ptr + i * 4); - let mem_op = mem_write_gp_log_and_fill(i, addr, state, &mut cpu_row, w_i); + let mem_op = mem_write_gp_log_and_fill(4, addr, state, &mut cpu_row, w_i); state.traces.push_memory(mem_op); state.traces.push_cpu(cpu_row); } From 83281c284efd69107f39853e6875accd76741b91 Mon Sep 17 00:00:00 2001 From: Angell Li Date: Mon, 13 Jan 2025 22:29:20 +0800 Subject: [PATCH 04/25] add missed cpu row --- prover/src/witness/operation.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index 3a5b9791..1138ea1d 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -1274,6 +1274,8 @@ pub(crate) fn generate_sha_compress< let mut g = hx[6]; let mut h = hx[7]; let mut j = 0; + cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); for i in 0..64 { let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); let ch = (e & f) ^ (!e & g); From f03c06980d9addd7b66fa1def535561665bdd886 Mon Sep 17 00:00:00 2001 From: Angell Li Date: Mon, 13 Jan 2025 22:57:45 +0800 Subject: [PATCH 05/25] fix typo. and add debug log --- emulator/src/state.rs | 2 ++ prover/src/witness/operation.rs | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/emulator/src/state.rs b/emulator/src/state.rs index ea5b72e8..0ada2c0e 100644 --- a/emulator/src/state.rs +++ b/emulator/src/state.rs @@ -569,6 +569,7 @@ impl InstrumentedState { // Write w[i]. self.state.memory.set_memory(w_ptr + i * 4, w_i); + log::info!("write {:X} {:X}", w_ptr + i * 4, w_i); } }, 0x010106 => { // SHA_COMPRESS @@ -616,6 +617,7 @@ impl InstrumentedState { let v = [a, b, c, d, e, f, g, h]; for i in 0..8 { self.state.memory.set_memory(h_ptr + i as u32 * 4, hx[i].wrapping_add(v[i])); + log::info!("write {:X} {:X}", h_ptr + i as u32 * 4, hx[i].wrapping_add(v[i])); } }, 0x010109 => { //keccak diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index 1138ea1d..f10f2697 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -1225,6 +1225,7 @@ pub(crate) fn generate_sha_extend< // Write w[i]. let addr = MemoryAddress::new(0, Segment::Code, w_ptr + i * 4); + log::info!("write {:X} {:X}", w_ptr + i * 4, w_i); let mem_op = mem_write_gp_log_and_fill(4, addr, state, &mut cpu_row, w_i); state.traces.push_memory(mem_op); state.traces.push_cpu(cpu_row); @@ -1260,7 +1261,7 @@ pub(crate) fn generate_sha_compress< let addr = MemoryAddress::new(0, Segment::Code, h_ptr + i * 4); let (value, mem_op) = mem_read_gp_with_log_and_fill(i, addr, state, &mut cpu_row); state.traces.push_memory(mem_op); - hx[i] = value + hx[i] = value; } state.traces.push_cpu(cpu_row); let mut original_w = Vec::new(); @@ -1318,6 +1319,7 @@ pub(crate) fn generate_sha_compress< let mem_op = mem_write_gp_log_and_fill(i, addr, state, &mut cpu_row, hx[i].wrapping_add(v[i])); state.traces.push_memory(mem_op); + log::info!("write {:X} {:X}", h_ptr + i * 4, hx[i].wrapping_add(v[i])); } state.traces.push_cpu(cpu_row); Ok(()) @@ -1544,7 +1546,7 @@ pub(crate) fn generate_syscall< let _ = generate_sha_compress(state, a0, a1); } if is_sha_extend { - let _ = generate_sha_compress(state, a0, a1); + let _ = generate_sha_extend(state, a0, a1); } result } From 24a884d8bcb97d014d07f1f6ca9f0dd4f0822539 Mon Sep 17 00:00:00 2001 From: Angell Li Date: Tue, 14 Jan 2025 20:42:34 +0800 Subject: [PATCH 06/25] fix: update memory every loop --- emulator/src/state.rs | 5 +++-- prover/src/witness/operation.rs | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/emulator/src/state.rs b/emulator/src/state.rs index 0ada2c0e..a5b5fbea 100644 --- a/emulator/src/state.rs +++ b/emulator/src/state.rs @@ -568,8 +568,9 @@ impl InstrumentedState { let w_i = s1.wrapping_add(w_i_minus_16).wrapping_add(s0).wrapping_add(w_i_minus_7); // Write w[i]. + log::debug!("{:X}, {:X}, {:X} {:X} {:X} {:X}", s1, s0, w_i_minus_16, w_i_minus_7, w_i_minus_15, w_i_minus_2); self.state.memory.set_memory(w_ptr + i * 4, w_i); - log::info!("write {:X} {:X}", w_ptr + i * 4, w_i); + log::debug!("extend write {:X} {:X}", w_ptr + i * 4, w_i); } }, 0x010106 => { // SHA_COMPRESS @@ -617,7 +618,7 @@ impl InstrumentedState { let v = [a, b, c, d, e, f, g, h]; for i in 0..8 { self.state.memory.set_memory(h_ptr + i as u32 * 4, hx[i].wrapping_add(v[i])); - log::info!("write {:X} {:X}", h_ptr + i as u32 * 4, hx[i].wrapping_add(v[i])); + log::debug!("write {:X} {:X}", h_ptr + i as u32 * 4, hx[i].wrapping_add(v[i])); } }, 0x010109 => { //keccak diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index f10f2697..be448823 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -1224,11 +1224,13 @@ pub(crate) fn generate_sha_extend< .wrapping_add(w_i_minus_7); // Write w[i]. + log::debug!("{:X}, {:X}, {:X} {:X} {:X} {:X}", s1, s0, w_i_minus_16, w_i_minus_7, w_i_minus_15, w_i_minus_2); let addr = MemoryAddress::new(0, Segment::Code, w_ptr + i * 4); - log::info!("write {:X} {:X}", w_ptr + i * 4, w_i); + log::debug!("extend write {:X} {:X}", w_ptr + i * 4, w_i); let mem_op = mem_write_gp_log_and_fill(4, addr, state, &mut cpu_row, w_i); state.traces.push_memory(mem_op); state.traces.push_cpu(cpu_row); + state.memory.apply_ops(&state.traces.memory_ops); } Ok(()) @@ -1319,7 +1321,7 @@ pub(crate) fn generate_sha_compress< let mem_op = mem_write_gp_log_and_fill(i, addr, state, &mut cpu_row, hx[i].wrapping_add(v[i])); state.traces.push_memory(mem_op); - log::info!("write {:X} {:X}", h_ptr + i * 4, hx[i].wrapping_add(v[i])); + log::debug!("write {:X} {:X}", h_ptr + i * 4, hx[i].wrapping_add(v[i])); } state.traces.push_cpu(cpu_row); Ok(()) From 4461f90cc3560c2707027d2abc33269f57782be2 Mon Sep 17 00:00:00 2001 From: vanhger Date: Mon, 20 Jan 2025 12:18:28 +0700 Subject: [PATCH 07/25] feat: add columns for SHA circuit --- prover/src/lib.rs | 4 + prover/src/sha_compress/columns.rs | 119 ++++++++++++++++++++++ prover/src/sha_compress/mod.rs | 1 + prover/src/sha_compress_sponge/columns.rs | 32 ++++++ prover/src/sha_compress_sponge/mod.rs | 1 + prover/src/sha_extend/columns.rs | 86 ++++++++++++++++ prover/src/sha_extend/constants.rs | 2 + prover/src/sha_extend/mod.rs | 2 + prover/src/sha_extend_sponge/columns.rs | 87 ++++++++++++++++ prover/src/sha_extend_sponge/mod.rs | 1 + 10 files changed, 335 insertions(+) create mode 100644 prover/src/sha_compress/columns.rs create mode 100644 prover/src/sha_compress/mod.rs create mode 100644 prover/src/sha_compress_sponge/columns.rs create mode 100644 prover/src/sha_compress_sponge/mod.rs create mode 100644 prover/src/sha_extend/columns.rs create mode 100644 prover/src/sha_extend/constants.rs create mode 100644 prover/src/sha_extend/mod.rs create mode 100644 prover/src/sha_extend_sponge/columns.rs create mode 100644 prover/src/sha_extend_sponge/mod.rs diff --git a/prover/src/lib.rs b/prover/src/lib.rs index ae003319..9eea1ca6 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -34,3 +34,7 @@ pub mod util; pub mod vanishing_poly; pub mod verifier; pub mod witness; +pub mod sha_extend; +pub mod sha_extend_sponge; +pub mod sha_compress; +pub mod sha_compress_sponge; diff --git a/prover/src/sha_compress/columns.rs b/prover/src/sha_compress/columns.rs new file mode 100644 index 00000000..c0258fd4 --- /dev/null +++ b/prover/src/sha_compress/columns.rs @@ -0,0 +1,119 @@ +use std::borrow::{Borrow, BorrowMut}; +use std::intrinsics::transmute; +use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; + +pub(crate) struct ShaCompressColumnsView { + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, + + /// Round number + pub i: T, + + /// 8 temp buffer values as input + pub a: T, + pub b: T, + pub c: T, + pub d: T, + pub e: T, + pub f: T, + pub g: T, + pub h: T, + + /// w[i] + pub w: [T; 64], + + /// Selector + pub round_i_filter: [T; 64], + + /// Intermediate values + pub k_i: T, + pub w_i: T, + pub e_rr_6: T, + pub e_rr_11: T, + pub e_rr_25: T, + pub s_1_inter: T, + pub s_1: T, + pub e_and_f: T, + pub e_not: T, + pub e_not_and_g: T, + pub ch: T, + pub temp1: T, + pub a_rr_2: T, + pub a_rr_13: T, + pub a_rr_22: T, + pub s_0_inter: T, + pub s_0: T, + pub a_and_b: T, + pub a_and_c: T, + pub b_and_c: T, + pub maj_inter: T, + pub maj: T, + pub temp2: T, + + /// Out + pub new_a: T, + pub new_b: T, + pub new_c: T, + pub new_d: T, + pub new_e: T, + pub new_f: T, + pub new_g: T, + pub new_h: T, + + /// 1 if this is the final round of the compress phase, 0 otherwise + pub is_final: T, + +} + +pub const NUM_SHA_COMPRESS_COLUMNS: usize = size_of::>(); + +impl From<[T; NUM_SHA_COMPRESS_COLUMNS]> for ShaCompressColumnsView { + fn from(value: [T; NUM_SHA_COMPRESS_COLUMNS]) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl From> for [T; NUM_SHA_COMPRESS_COLUMNS] { + fn from(value: ShaCompressColumnsView) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl Borrow> for [T; NUM_SHA_COMPRESS_COLUMNS] { + fn borrow(&self) -> &ShaCompressColumnsView { + unsafe { transmute(self) } + } +} + +impl BorrowMut> for [T; NUM_SHA_COMPRESS_COLUMNS] { + fn borrow_mut(&mut self) -> &mut ShaCompressColumnsView { + unsafe { transmute(self) } + } +} + +impl Borrow<[T; NUM_SHA_COMPRESS_COLUMNS]> for ShaCompressColumnsView { + fn borrow(&self) -> &[T; NUM_SHA_COMPRESS_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl BorrowMut<[T; NUM_SHA_COMPRESS_COLUMNS]> for ShaCompressColumnsView { + fn borrow_mut(&mut self) -> &mut [T; NUM_SHA_COMPRESS_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl Default for ShaCompressColumnsView { + fn default() -> Self { + [T::default(); NUM_SHA_COMPRESS_COLUMNS].into() + } +} + +const fn make_col_map() -> ShaCompressColumnsView { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_SHA_COMPRESS_COLUMNS], ShaCompressColumnsView>(indices_arr) + } +} + +pub(crate) const SHA_COMPRESS_COL_MAP: ShaCompressColumnsView = make_col_map(); diff --git a/prover/src/sha_compress/mod.rs b/prover/src/sha_compress/mod.rs new file mode 100644 index 00000000..eff2b9b8 --- /dev/null +++ b/prover/src/sha_compress/mod.rs @@ -0,0 +1 @@ +mod columns; \ No newline at end of file diff --git a/prover/src/sha_compress_sponge/columns.rs b/prover/src/sha_compress_sponge/columns.rs new file mode 100644 index 00000000..dd4e95f7 --- /dev/null +++ b/prover/src/sha_compress_sponge/columns.rs @@ -0,0 +1,32 @@ +pub(crate) struct ShaCompressSpongeColumnsView { + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, + + /// hx_i + pub hx: [T;8], + + /// w[i] + pub w: [T; 64], + + /// a,b...,h values after compressed + pub new_a: T, + pub new_b: T, + pub new_c: T, + pub new_d: T, + pub new_e: T, + pub new_f: T, + pub new_g: T, + pub new_h: T, + + /// output + pub final_hx: [T;8], + + /// The base address at which we will read the input block. + pub context: T, + pub segment: T, + /// Hx addresses + pub hx_virt: [T; 8], + + /// W_i addresses + pub w_virt: [T;64], +} \ No newline at end of file diff --git a/prover/src/sha_compress_sponge/mod.rs b/prover/src/sha_compress_sponge/mod.rs new file mode 100644 index 00000000..eff2b9b8 --- /dev/null +++ b/prover/src/sha_compress_sponge/mod.rs @@ -0,0 +1 @@ +mod columns; \ No newline at end of file diff --git a/prover/src/sha_extend/columns.rs b/prover/src/sha_extend/columns.rs new file mode 100644 index 00000000..a5327b25 --- /dev/null +++ b/prover/src/sha_extend/columns.rs @@ -0,0 +1,86 @@ +use std::borrow::{Borrow, BorrowMut}; +use std::intrinsics::transmute; +use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; + +pub(crate) struct ShaExtendColumnsView { + + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, + + /// round + pub i: T, + + /// Input + pub w_i_minus_15: T, + pub w_i_minus_2: T, + pub w_i_minus_16: T, + pub w_i_minus_7: T, + + /// Intermediate values + pub w_i_minus_15_rr_7: T, + pub w_i_minus_15_rr_18: T, + pub w_i_minus_15_rs_3: T, + pub s_0_inter: T, + pub s_0: T, + pub w_i_minus_2_rr_17: T, + pub w_i_minus_2_rr_19: T, + pub w_i_minus_2_rs_10: T, + pub s_i_inter: T, + pub s_1: T, + + /// Output + pub w_i: T, +} + +pub const NUM_SHA_EXTEND_COLUMNS: usize = size_of::>(); + +impl From<[T; NUM_SHA_EXTEND_COLUMNS]> for ShaExtendColumnsView { + fn from(value: [T; NUM_SHA_EXTEND_COLUMNS]) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl From> for [T; NUM_SHA_EXTEND_COLUMNS] { + fn from(value: ShaExtendColumnsView) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl Borrow> for [T; NUM_SHA_EXTEND_COLUMNS] { + fn borrow(&self) -> &ShaExtendColumnsView { + unsafe { transmute(self) } + } +} + +impl BorrowMut> for [T; NUM_SHA_EXTEND_COLUMNS] { + fn borrow_mut(&mut self) -> &mut ShaExtendColumnsView { + unsafe { transmute(self) } + } +} + +impl Borrow<[T; NUM_SHA_EXTEND_COLUMNS]> for ShaExtendColumnsView { + fn borrow(&self) -> &[T; NUM_SHA_EXTEND_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl BorrowMut<[T; NUM_SHA_EXTEND_COLUMNS]> for ShaExtendColumnsView { + fn borrow_mut(&mut self) -> &mut [T; NUM_SHA_EXTEND_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl Default for ShaExtendColumnsView { + fn default() -> Self { + [T::default(); NUM_SHA_EXTEND_COLUMNS].into() + } +} + +const fn make_col_map() -> ShaExtendColumnsView { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_SHA_EXTEND_COLUMNS], ShaExtendColumnsView>(indices_arr) + } +} + +pub(crate) const SHA_EXTEND_COL_MAP: ShaExtendColumnsView = make_col_map(); diff --git a/prover/src/sha_extend/constants.rs b/prover/src/sha_extend/constants.rs new file mode 100644 index 00000000..b2849ca6 --- /dev/null +++ b/prover/src/sha_extend/constants.rs @@ -0,0 +1,2 @@ +const NUM_ROUND_CONSTANTS: usize = 48; + diff --git a/prover/src/sha_extend/mod.rs b/prover/src/sha_extend/mod.rs new file mode 100644 index 00000000..20aa4912 --- /dev/null +++ b/prover/src/sha_extend/mod.rs @@ -0,0 +1,2 @@ +mod constants; +mod columns; \ No newline at end of file diff --git a/prover/src/sha_extend_sponge/columns.rs b/prover/src/sha_extend_sponge/columns.rs new file mode 100644 index 00000000..d4c1e088 --- /dev/null +++ b/prover/src/sha_extend_sponge/columns.rs @@ -0,0 +1,87 @@ +use std::borrow::{Borrow, BorrowMut}; +use std::intrinsics::transmute; +use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; + +pub(crate) const NUM_EXTEND_INPUT: usize = 4; +pub(crate) struct ShaExtendSpongeColumnsView { + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, + + /// Round number + pub i: T, + + /// Input + pub w_i_minus_15: T, + pub w_i_minus_2: T, + pub w_i_minus_16: T, + pub w_i_minus_7: T, + + /// Output + pub w_i: T, + + /// The base address at which we will read the input block. + pub context: T, + pub segment: T, + /// Input address + pub input_virt: [T; NUM_EXTEND_INPUT], + + /// Output address + pub output_virt: T, + + /// 1 if this is the final round of the extending phase, 0 otherwise + pub is_final: T, +} + +pub const NUM_SHA_EXTEND_SPONGE_COLUMNS: usize = size_of::>(); + +impl From<[T; NUM_SHA_EXTEND_SPONGE_COLUMNS]> for ShaExtendSpongeColumnsView { + fn from(value: [T; NUM_SHA_EXTEND_SPONGE_COLUMNS]) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl From> for [T; NUM_SHA_EXTEND_SPONGE_COLUMNS] { + fn from(value: ShaExtendSpongeColumnsView) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl Borrow> for [T; NUM_SHA_EXTEND_SPONGE_COLUMNS] { + fn borrow(&self) -> &ShaExtendSpongeColumnsView { + unsafe { transmute(self) } + } +} + +impl BorrowMut> for [T; NUM_SHA_EXTEND_SPONGE_COLUMNS] { + fn borrow_mut(&mut self) -> &mut ShaExtendSpongeColumnsView { + unsafe { transmute(self) } + } +} + +impl Borrow<[T; NUM_SHA_EXTEND_SPONGE_COLUMNS]> for ShaExtendSpongeColumnsView { + fn borrow(&self) -> &[T; NUM_SHA_EXTEND_SPONGE_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl BorrowMut<[T; NUM_SHA_EXTEND_SPONGE_COLUMNS]> for ShaExtendSpongeColumnsView { + fn borrow_mut(&mut self) -> &mut [T; NUM_SHA_EXTEND_SPONGE_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl Default for ShaExtendSpongeColumnsView { + fn default() -> Self { + [T::default(); NUM_SHA_EXTEND_SPONGE_COLUMNS].into() + } +} + +const fn make_col_map() -> ShaExtendSpongeColumnsView { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_SHA_EXTEND_SPONGE_COLUMNS], ShaExtendSpongeColumnsView>(indices_arr) + } +} + +pub(crate) const SHA_EXTEND_SPONGE_COL_MAP: ShaExtendSpongeColumnsView = make_col_map(); + diff --git a/prover/src/sha_extend_sponge/mod.rs b/prover/src/sha_extend_sponge/mod.rs new file mode 100644 index 00000000..eff2b9b8 --- /dev/null +++ b/prover/src/sha_extend_sponge/mod.rs @@ -0,0 +1 @@ +mod columns; \ No newline at end of file From fdd684dfa3e5319201e8908b7521fdf69193c259 Mon Sep 17 00:00:00 2001 From: vanhger Date: Tue, 21 Jan 2025 14:45:30 +0700 Subject: [PATCH 08/25] feat: add trace generation for SHA precompile circuits --- prover/src/sha_compress/columns.rs | 4 +- prover/src/sha_compress/constants.rs | 10 + prover/src/sha_compress/mod.rs | 4 +- prover/src/sha_compress/sha_compress_stark.rs | 272 ++++++++++++++++++ prover/src/sha_compress_sponge/columns.rs | 60 +++- prover/src/sha_compress_sponge/mod.rs | 3 +- .../sha_compress_sponge_stark.rs | 250 ++++++++++++++++ prover/src/sha_extend/columns.rs | 5 +- prover/src/sha_extend/constants.rs | 2 - prover/src/sha_extend/mod.rs | 4 +- prover/src/sha_extend/sha_extend_stark.rs | 172 +++++++++++ prover/src/sha_extend_sponge/mod.rs | 3 +- .../sha_extend_sponge_stark.rs | 193 +++++++++++++ 13 files changed, 968 insertions(+), 14 deletions(-) create mode 100644 prover/src/sha_compress/constants.rs create mode 100644 prover/src/sha_compress/sha_compress_stark.rs create mode 100644 prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs delete mode 100644 prover/src/sha_extend/constants.rs create mode 100644 prover/src/sha_extend/sha_extend_stark.rs create mode 100644 prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs diff --git a/prover/src/sha_compress/columns.rs b/prover/src/sha_compress/columns.rs index c0258fd4..87da0a42 100644 --- a/prover/src/sha_compress/columns.rs +++ b/prover/src/sha_compress/columns.rs @@ -1,7 +1,7 @@ use std::borrow::{Borrow, BorrowMut}; use std::intrinsics::transmute; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; - +#[derive(Clone)] pub(crate) struct ShaCompressColumnsView { /// The timestamp at which inputs should be read from memory. pub timestamp: T, @@ -65,7 +65,7 @@ pub(crate) struct ShaCompressColumnsView { } -pub const NUM_SHA_COMPRESS_COLUMNS: usize = size_of::>(); +pub const NUM_SHA_COMPRESS_COLUMNS: usize = size_of::>(); // 170 impl From<[T; NUM_SHA_COMPRESS_COLUMNS]> for ShaCompressColumnsView { fn from(value: [T; NUM_SHA_COMPRESS_COLUMNS]) -> Self { diff --git a/prover/src/sha_compress/constants.rs b/prover/src/sha_compress/constants.rs new file mode 100644 index 00000000..b386f60d --- /dev/null +++ b/prover/src/sha_compress/constants.rs @@ -0,0 +1,10 @@ +pub const SHA_COMPRESS_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; \ No newline at end of file diff --git a/prover/src/sha_compress/mod.rs b/prover/src/sha_compress/mod.rs index eff2b9b8..910b48c5 100644 --- a/prover/src/sha_compress/mod.rs +++ b/prover/src/sha_compress/mod.rs @@ -1 +1,3 @@ -mod columns; \ No newline at end of file +mod columns; +mod sha_compress_stark; +mod constants; \ No newline at end of file diff --git a/prover/src/sha_compress/sha_compress_stark.rs b/prover/src/sha_compress/sha_compress_stark.rs new file mode 100644 index 00000000..c2d928ae --- /dev/null +++ b/prover/src/sha_compress/sha_compress_stark.rs @@ -0,0 +1,272 @@ +use std::marker::PhantomData; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::evaluation_frame::StarkFrame; +use crate::sha_compress::columns::{ShaCompressColumnsView, NUM_SHA_COMPRESS_COLUMNS}; +use crate::sha_compress::constants::SHA_COMPRESS_K; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; + +pub const NUM_ROUND_CONSTANTS: usize = 64; + +pub const NUM_INPUTS: usize = 72; // 8 + 64 + +#[derive(Copy, Clone, Default)] +pub struct ShaCompressStark { + pub(crate) f: PhantomData, +} + +impl, const D: usize> ShaCompressStark { + pub(crate) fn generate_trace( + &self, + inputs: Vec<([u32; NUM_INPUTS], usize)>, + min_rows: usize, + ) -> Vec> { + // Generate the witness row-wise + let trace_rows = self.generate_trace_rows(inputs, min_rows); + trace_rows_to_poly_values(trace_rows) + } + + fn generate_trace_rows( + &self, + inputs_and_timestamps: Vec<([u32; NUM_INPUTS], usize)>, + min_rows: usize, + ) -> Vec<[F; NUM_SHA_COMPRESS_COLUMNS]> { + let num_rows = (inputs_and_timestamps.len() * NUM_ROUND_CONSTANTS) + .max(min_rows) + .next_power_of_two(); + + let mut rows = Vec::with_capacity(num_rows); + for input_and_timestamp in inputs_and_timestamps.iter() { + let rows_for_compress = self.generate_trace_rows_for_compress(*input_and_timestamp); + rows.extend(rows_for_compress); + } + + while rows.len() < num_rows { + rows.push([F::ZERO; NUM_SHA_COMPRESS_COLUMNS]); + } + rows + } + + fn generate_trace_rows_for_compress( + &self, + input_and_timestamp: ([u32; NUM_INPUTS], usize), + ) -> Vec<[F; NUM_SHA_COMPRESS_COLUMNS]> { + + let mut rows = vec![ShaCompressColumnsView::default(); NUM_ROUND_CONSTANTS]; + + let timestamp = input_and_timestamp.1; + let inputs = input_and_timestamp.0; + + // set the first row + + + for round in 0..NUM_ROUND_CONSTANTS { + rows[round].timestamp = F::from_canonical_usize(timestamp); + rows[round].i = F::from_canonical_usize(round); + rows[round].is_final = F::ZERO; + if round == NUM_ROUND_CONSTANTS - 1 { + rows[round].is_final = F::ONE; + } + } + + // Populate the round input for the first round. + [rows[0].a, rows[0].b, rows[0].c, rows[0].d, + rows[0].e, rows[0].f, rows[0].g, rows[0].h] = inputs[0..8] + .iter() + .map(|&x| F::from_canonical_u32(x)) + .collect::>() + .try_into() + .unwrap(); + + + rows[0].w = inputs[8..inputs.len()].iter() + .map(|&x| F::from_canonical_u32(x)) + .collect::>() + .try_into() + .unwrap(); + + self.generate_trace_row_for_round(&mut rows[0], 0); + for round in 1..NUM_ROUND_CONSTANTS{ + self.copy_output_to_input(&mut rows, round); + self.generate_trace_row_for_round(&mut rows[round], round); + } + + + rows.into_iter().map(|row| row.into()).collect::>() + } + + fn generate_trace_row_for_round(&self, row: &mut ShaCompressColumnsView, round: usize) { + row.round_i_filter = [F::ZERO; NUM_ROUND_CONSTANTS]; + row.round_i_filter[round] = F::ONE; + + row.k_i = F::from_canonical_u32(SHA_COMPRESS_K[round]); + row.w_i = row.w[round]; + + let e = row.e.to_canonical_u64() as u32; + let g = row.g.to_canonical_u64() as u32; + let e_rr_6 = e.rotate_right(6); + let e_rr_11 = e.rotate_right(11); + let s_1_inter = e_rr_6 ^ e_rr_11; + let e_rr_25 = e.rotate_right(25); + let s_1 = s_1_inter ^ e_rr_25; + + [row.e_rr_6, row.e_rr_11, row.e_rr_25, row.s_1_inter, row.s_1] + = [e_rr_6, e_rr_11, e_rr_25, s_1_inter, s_1].map(F::from_canonical_u32); + + let e_and_f = e & (row.f.to_canonical_u64() as u32); + let e_not = !e; + let e_not_and_g = e_not & g; + let ch = e_and_f ^ e_not_and_g; + let temp1 = (row.h.to_canonical_u64() as u32).wrapping_add(s_1) + .wrapping_add(ch) + .wrapping_add(row.k_i.to_canonical_u64() as u32) + .wrapping_add(row.w_i.to_canonical_u64() as u32); + + [row.e_and_f, row.e_not, row.e_not_and_g, row.ch, row.temp1] + = [e_and_f, e_not, e_not_and_g, ch, temp1].map(F::from_canonical_u32); + + let a = row.a.to_canonical_u64() as u32; + let a_rr_2 = a.rotate_right(2); + let a_rr_13 = a.rotate_right(13); + let a_rr_22 = a.rotate_right(22); + let s_0_inter = a_rr_2 ^ a_rr_13; + let s_0 = s_0_inter ^ a_rr_22; + + [row.a_rr_22, row.a_rr_13, row.a_rr_2, row.s_0_inter, row.s_0] + = [a_rr_22, a_rr_13, a_rr_2, s_0_inter, s_0].map(F::from_canonical_u32); + + let a_and_b = a & (row.b.to_canonical_u64() as u32); + let a_and_c = a & (row.c.to_canonical_u64() as u32); + let b_and_c = (row.b.to_canonical_u64() as u32) & (row.c.to_canonical_u64() as u32); + let maj_inter = a_and_b ^ a_and_c; + let maj = maj_inter ^ b_and_c; + let temp2 = s_0.wrapping_add(maj); + + let new_e = (row.d.to_canonical_u64() as u32).wrapping_add(temp1); + let new_a = temp1.wrapping_add(temp2); + [row.a_and_b, row.a_and_c, row.b_and_c, row.maj_inter, row.maj, row.temp2] + = [a_and_b, a_and_c, b_and_c, maj_inter, maj, temp2].map(F::from_canonical_u32); + + row.new_h = row.g; + row.new_g = row.f; + row.new_f = row.e; + row.new_e = F::from_canonical_u32(new_e); + row.new_d = row.c; + row.new_c = row.b; + row.new_b = row.a; + row.new_a = F::from_canonical_u32(new_a); + } + + fn copy_output_to_input(&self, rows: &mut Vec>, round: usize) { + rows[round].a = rows[round-1].new_a; + rows[round].b = rows[round-1].new_b; + rows[round].c = rows[round-1].new_c; + rows[round].d = rows[round-1].new_d; + rows[round].e = rows[round-1].new_e; + rows[round].f = rows[round-1].new_f; + rows[round].g = rows[round-1].new_g; + rows[round].h = rows[round-1].new_h; + rows[round].w = rows[round-1].w; + } +} + +impl, const D: usize> Stark for ShaCompressStark { + type EvaluationFrame + = StarkFrame + where + FE: FieldExtension, + P: PackedField; + type EvaluationFrameTarget = StarkFrame, NUM_SHA_COMPRESS_COLUMNS>; + + fn eval_packed_generic( + &self, + vars: &Self::EvaluationFrame, + yield_constr: &mut ConstraintConsumer

+ ) where + FE: FieldExtension, + P: PackedField + { + todo!() + } + + fn eval_ext_circuit( + &self, + builder: &mut CircuitBuilder, + vars: &Self::EvaluationFrameTarget, + yield_constr: &mut RecursiveConstraintConsumer + ) { + todo!() + } + + fn constraint_degree(&self) -> usize { + todo!() + } +} + +#[cfg(test)] +mod test { + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::Field; + use crate::sha_compress::columns::ShaCompressColumnsView; + use crate::sha_compress::sha_compress_stark::ShaCompressStark; + + const W: [u32; 64] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, + 67559435, 1711661200, 3020350282, 1447362251, 3118632270, 4004188394, 690615167, + 6070360, 1105370215, 2385558114, 2348232513, 507799627, 2098764358, 5845374, 823657968, + 2969863067, 3903496557, 4274682881, 2059629362, 1849247231, 2656047431, 835162919, + 2096647516, 2259195856, 1779072524, 3152121987, 4210324067, 1557957044, 376930560, + 982142628, 3926566666, 4164334963, 789545383, 1028256580, 2867933222, 3843938318, 1135234440, + 390334875, 2025924737, 3318322046, 3436065867, 652746999, 4261492214, 2543173532, 3334668051, + 3166416553, 634956631]; + + pub const H256_256: [u32;8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, + 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, + ]; + + #[test] + fn test_generation() -> Result<(), String>{ + + const D: usize = 2; + type F = GoldilocksField; + type S = ShaCompressStark; + + let w = W; + let h = H256_256; + + let mut input = vec![]; + input.extend(h);input.extend(w); + + + let stark = S::default(); + let rows = stark.generate_trace_rows_for_compress((input.try_into().unwrap(), 0)); + + assert_eq!(rows.len(), 64); + + + // check first row + let first_row: ShaCompressColumnsView = rows[0].into(); + assert_eq!(first_row.a, F::from_canonical_u32(0x6a09e667)); + assert_eq!(first_row.new_a, F::from_canonical_u32(4228417613)); + + // output + let last_row: ShaCompressColumnsView = rows[63].into(); + assert_eq!(last_row.is_final, F::ONE); + + assert_eq!(last_row.new_a, F::from_canonical_u32(1813631354)); + assert_eq!(last_row.new_b, F::from_canonical_u32(3315363907)); + assert_eq!(last_row.new_c, F::from_canonical_u32(209435322)); + assert_eq!(last_row.new_d, F::from_canonical_u32(267716009)); + assert_eq!(last_row.new_e, F::from_canonical_u32(646830348)); + assert_eq!(last_row.new_f, F::from_canonical_u32(362222596)); + assert_eq!(last_row.new_g, F::from_canonical_u32(3323089566)); + assert_eq!(last_row.new_h, F::from_canonical_u32(1912443780)); + Ok(()) + } +} \ No newline at end of file diff --git a/prover/src/sha_compress_sponge/columns.rs b/prover/src/sha_compress_sponge/columns.rs index dd4e95f7..64f922dc 100644 --- a/prover/src/sha_compress_sponge/columns.rs +++ b/prover/src/sha_compress_sponge/columns.rs @@ -1,3 +1,7 @@ +use std::borrow::{Borrow, BorrowMut}; +use std::intrinsics::transmute; +use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; + pub(crate) struct ShaCompressSpongeColumnsView { /// The timestamp at which inputs should be read from memory. pub timestamp: T, @@ -29,4 +33,58 @@ pub(crate) struct ShaCompressSpongeColumnsView { /// W_i addresses pub w_virt: [T;64], -} \ No newline at end of file +} + + +pub const NUM_SHA_COMPRESS_SPONGE_COLUMNS: usize = size_of::>(); + +impl From<[T; NUM_SHA_COMPRESS_SPONGE_COLUMNS]> for ShaCompressSpongeColumnsView { + fn from(value: [T; NUM_SHA_COMPRESS_SPONGE_COLUMNS]) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl From> for [T; NUM_SHA_COMPRESS_SPONGE_COLUMNS] { + fn from(value: ShaCompressSpongeColumnsView) -> Self { + unsafe { transmute_no_compile_time_size_checks(value) } + } +} + +impl Borrow> for [T; NUM_SHA_COMPRESS_SPONGE_COLUMNS] { + fn borrow(&self) -> &ShaCompressSpongeColumnsView { + unsafe { transmute(self) } + } +} + +impl BorrowMut> for [T; NUM_SHA_COMPRESS_SPONGE_COLUMNS] { + fn borrow_mut(&mut self) -> &mut ShaCompressSpongeColumnsView { + unsafe { transmute(self) } + } +} + +impl Borrow<[T; NUM_SHA_COMPRESS_SPONGE_COLUMNS]> for ShaCompressSpongeColumnsView { + fn borrow(&self) -> &[T; NUM_SHA_COMPRESS_SPONGE_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl BorrowMut<[T; NUM_SHA_COMPRESS_SPONGE_COLUMNS]> for ShaCompressSpongeColumnsView { + fn borrow_mut(&mut self) -> &mut [T; NUM_SHA_COMPRESS_SPONGE_COLUMNS] { + unsafe { transmute(self) } + } +} + +impl Default for ShaCompressSpongeColumnsView { + fn default() -> Self { + [T::default(); NUM_SHA_COMPRESS_SPONGE_COLUMNS].into() + } +} + +const fn make_col_map() -> ShaCompressSpongeColumnsView { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_SHA_COMPRESS_SPONGE_COLUMNS], ShaCompressSpongeColumnsView>(indices_arr) + } +} + +pub(crate) const SHA_COMPRESS_SPONGE_COL_MAP: ShaCompressSpongeColumnsView = make_col_map(); diff --git a/prover/src/sha_compress_sponge/mod.rs b/prover/src/sha_compress_sponge/mod.rs index eff2b9b8..9f1ac542 100644 --- a/prover/src/sha_compress_sponge/mod.rs +++ b/prover/src/sha_compress_sponge/mod.rs @@ -1 +1,2 @@ -mod columns; \ No newline at end of file +mod columns; +mod sha_compress_sponge_stark; \ No newline at end of file diff --git a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs new file mode 100644 index 00000000..cadbd1c1 --- /dev/null +++ b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs @@ -0,0 +1,250 @@ +use std::marker::PhantomData; +use itertools::Itertools; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::evaluation_frame::StarkFrame; +use crate::memory::segments::Segment; +use crate::sha_compress_sponge::columns::{ShaCompressSpongeColumnsView, NUM_SHA_COMPRESS_SPONGE_COLUMNS}; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; +use crate::witness::memory::MemoryAddress; +use crate::witness::operation::SHA_COMPRESS_K; + +#[derive(Clone, Debug)] +pub(crate) struct ShaCompressSpongeOp { + /// The base address at which inputs are read. + pub(crate) base_address: Vec, + + /// The timestamp at which inputs are read. + pub(crate) timestamp: usize, + + /// The input that was read. + pub(crate) input: Vec, +} + +#[derive(Copy, Clone, Default)] +pub struct ShaCompressSpongeStark { + f: PhantomData, +} + +impl, const D: usize> ShaCompressSpongeStark { + pub(crate) fn generate_trace( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec> { + // Generate the witness row-wise. + let trace_rows = self.generate_trace_rows(operations, min_rows); + + trace_rows_to_poly_values(trace_rows) + } + + fn generate_trace_rows( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec<[F; NUM_SHA_COMPRESS_SPONGE_COLUMNS]> { + let base_len = operations.len(); + let mut rows = Vec::with_capacity(base_len.max(min_rows).next_power_of_two()); + for op in operations { + rows.push(self.generate_rows_for_op(op).into()); + } + + let padded_rows = rows.len().max(min_rows).next_power_of_two(); + for _ in rows.len()..padded_rows { + rows.push(ShaCompressSpongeColumnsView::default().into()); + } + + rows + } + + fn generate_rows_for_op( + &self, + op: ShaCompressSpongeOp, + ) -> ShaCompressSpongeColumnsView { + let mut row = ShaCompressSpongeColumnsView::default(); + + row.timestamp = F::from_canonical_usize(op.timestamp); + + let new_buffer = self.compress(&op.input); + + row.hx = op.input[0..8] + .iter() + .map(|&x| F::from_canonical_u32(x)) + .collect::>() + .try_into() + .unwrap(); + + row.w = op.input[8..op.input.len()] + .iter() + .map(|&x| F::from_canonical_u32(x)) + .collect::>() + .try_into() + .unwrap(); + + row.context = F::from_canonical_usize(op.base_address[0].context); + row.segment = F::from_canonical_usize(op.base_address[Segment::Code as usize].segment); + + [row.new_a, row.new_b, row.new_c, row.new_d, row.new_e, row.new_f, row.new_g, row.new_h] + = new_buffer.iter() + .map(|&x| F::from_canonical_u32(x)) + .collect::>() + .try_into() + .unwrap(); + + row.final_hx = new_buffer.iter().zip(row.hx.iter()) + .map(|(&x, &hx)| F::from_canonical_u32(x.wrapping_add(hx.to_canonical_u64() as u32))) + .collect::>() + .try_into() + .unwrap(); + + let hx_virt = (0..8) + .map(|i| op.base_address[i].virt) + .collect_vec(); + let hx_virt: [usize; 8] = hx_virt.try_into().unwrap(); + row.hx_virt = hx_virt.map(F::from_canonical_usize); + + let w_virt = (8..op.input.len()) + .map(|i| op.base_address[i].virt) + .collect_vec(); + let w_virt: [usize; 64] = w_virt.try_into().unwrap(); + row.w_virt = w_virt.map(F::from_canonical_usize); + + row + } + + fn compress(&self, input: &[u32]) -> [u32; 8] { + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h]: [u32; 8] = input[0..8].try_into().unwrap(); + let mut t1: u32; + let mut t2: u32; + + for i in 0..64 { + t1 = h.wrapping_add(e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25)) + .wrapping_add((e & f) ^ ((!e) & g)).wrapping_add(SHA_COMPRESS_K[i]).wrapping_add(input[8 + i]); + t2 = (a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22)) + .wrapping_add((a & b) ^ (a & c) ^ (b & c)); + h = g; + g = f; + f = e; + e = d.wrapping_add(t1); + d = c; + c = b; + b = a; + a = t1.wrapping_add(t2); + } + + [a, b, c, d, e, f, g, h] + } +} + +impl, const D: usize> Stark for ShaCompressSpongeStark { + type EvaluationFrame + = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_SHA_COMPRESS_SPONGE_COLUMNS>; + + fn eval_packed_generic( + &self, + vars: &Self::EvaluationFrame, + yield_constr: &mut ConstraintConsumer

+ ) where + FE: FieldExtension, + P: PackedField + { + todo!() + } + + fn eval_ext_circuit( + &self, + builder: &mut CircuitBuilder, + vars: &Self::EvaluationFrameTarget, + yield_constr: &mut RecursiveConstraintConsumer + ) { + todo!() + } + + fn constraint_degree(&self) -> usize { + todo!() + } +} + + +#[cfg(test)] +mod test { + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::Field; + use crate::sha_compress_sponge::sha_compress_sponge_stark::{ShaCompressSpongeOp, ShaCompressSpongeStark}; + use crate::witness::memory::MemoryAddress; + + + const W: [u32; 64] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, + 67559435, 1711661200, 3020350282, 1447362251, 3118632270, 4004188394, 690615167, + 6070360, 1105370215, 2385558114, 2348232513, 507799627, 2098764358, 5845374, 823657968, + 2969863067, 3903496557, 4274682881, 2059629362, 1849247231, 2656047431, 835162919, + 2096647516, 2259195856, 1779072524, 3152121987, 4210324067, 1557957044, 376930560, + 982142628, 3926566666, 4164334963, 789545383, 1028256580, 2867933222, 3843938318, 1135234440, + 390334875, 2025924737, 3318322046, 3436065867, 652746999, 4261492214, 2543173532, 3334668051, + 3166416553, 634956631]; + + pub const H256_256: [u32;8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, + 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, + ]; + #[test] + fn test_generation() -> Result<(), String> { + const D: usize = 2; + type F = GoldilocksField; + + type S = ShaCompressSpongeStark; + + let stark = S::default(); + let hx_addresses: Vec = (0..32).step_by(4).map(|i| { + MemoryAddress { + context: 0, + segment: 0, + virt: i, + } + }).collect(); + + let w_addresses: Vec = (32..288).step_by(4).map(|i| { + MemoryAddress { + context: 0, + segment: 0, + virt: i, + } + }).collect(); + + let op = ShaCompressSpongeOp { + base_address: hx_addresses.iter().chain(w_addresses.iter()).cloned().collect(), + timestamp: 0, + input: H256_256.iter().chain(W.iter()).cloned().collect(), + }; + let row = stark.generate_rows_for_op(op); + + assert_eq!(row.new_a, F::from_canonical_u32(1813631354)); + assert_eq!(row.new_b, F::from_canonical_u32(3315363907)); + assert_eq!(row.new_c, F::from_canonical_u32(209435322)); + assert_eq!(row.new_d, F::from_canonical_u32(267716009)); + assert_eq!(row.new_e, F::from_canonical_u32(646830348)); + assert_eq!(row.new_f, F::from_canonical_u32(362222596)); + assert_eq!(row.new_g, F::from_canonical_u32(3323089566)); + assert_eq!(row.new_h, F::from_canonical_u32(1912443780)); + + let expected_values: [F; 8] = [3592665057_u32, 2164530888, 1223339564, 3041196771, 2006723467, + 2963045520, 3851824201, 3453903005].into_iter().map(F::from_canonical_u32) + .collect::>().try_into().unwrap(); + + + assert_eq!(row.final_hx, expected_values); + Ok(()) + + } +} \ No newline at end of file diff --git a/prover/src/sha_extend/columns.rs b/prover/src/sha_extend/columns.rs index a5327b25..623457e3 100644 --- a/prover/src/sha_extend/columns.rs +++ b/prover/src/sha_extend/columns.rs @@ -7,9 +7,6 @@ pub(crate) struct ShaExtendColumnsView { /// The timestamp at which inputs should be read from memory. pub timestamp: T, - /// round - pub i: T, - /// Input pub w_i_minus_15: T, pub w_i_minus_2: T, @@ -25,7 +22,7 @@ pub(crate) struct ShaExtendColumnsView { pub w_i_minus_2_rr_17: T, pub w_i_minus_2_rr_19: T, pub w_i_minus_2_rs_10: T, - pub s_i_inter: T, + pub s_1_inter: T, pub s_1: T, /// Output diff --git a/prover/src/sha_extend/constants.rs b/prover/src/sha_extend/constants.rs deleted file mode 100644 index b2849ca6..00000000 --- a/prover/src/sha_extend/constants.rs +++ /dev/null @@ -1,2 +0,0 @@ -const NUM_ROUND_CONSTANTS: usize = 48; - diff --git a/prover/src/sha_extend/mod.rs b/prover/src/sha_extend/mod.rs index 20aa4912..7cafd8a3 100644 --- a/prover/src/sha_extend/mod.rs +++ b/prover/src/sha_extend/mod.rs @@ -1,2 +1,2 @@ -mod constants; -mod columns; \ No newline at end of file +pub mod columns; +pub mod sha_extend_stark; \ No newline at end of file diff --git a/prover/src/sha_extend/sha_extend_stark.rs b/prover/src/sha_extend/sha_extend_stark.rs new file mode 100644 index 00000000..0dbc4e22 --- /dev/null +++ b/prover/src/sha_extend/sha_extend_stark.rs @@ -0,0 +1,172 @@ +use std::marker::PhantomData; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::evaluation_frame::StarkFrame; +use crate::sha_extend::columns::{ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS}; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; + +const NUM_ROUND_CONSTANTS: usize = 48; +const NUM_INPUTS: usize = 4; // w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 + +#[derive(Copy, Clone, Default)] +pub struct ShaExtendStark { + pub(crate) f: PhantomData, +} + + +impl, const D: usize> ShaExtendStark { + pub(crate) fn generate_trace( + &self, + inputs: Vec<([u32; NUM_INPUTS], usize)>, + min_rows: usize, + ) -> Vec> { + // Generate the witness row-wise + let trace_rows = self.generate_trace_rows(inputs, min_rows); + trace_rows_to_poly_values(trace_rows) + } + + fn generate_trace_rows( + &self, + inputs_and_timestamps: Vec<([u32; NUM_INPUTS], usize)>, + min_rows: usize, + ) -> Vec<[F; NUM_SHA_EXTEND_COLUMNS]> { + let num_rows = inputs_and_timestamps.len() + .max(min_rows).next_power_of_two(); + + let mut rows = Vec::with_capacity(num_rows); + for input_and_timestamp in inputs_and_timestamps.iter() { + let rows_for_extend = self.generate_trace_rows_for_extend(*input_and_timestamp); + rows.push(rows_for_extend.into()); + } + + // padding + while rows.len() < num_rows { + rows.push([F::ZERO; NUM_SHA_EXTEND_COLUMNS]); + } + + rows + } + + fn generate_trace_rows_for_extend( + &self, + input_and_timestamp: ([u32; NUM_INPUTS], usize), + ) -> ShaExtendColumnsView { + let mut row = ShaExtendColumnsView::default(); + + row.timestamp = F::from_canonical_usize(input_and_timestamp.1); + [row.w_i_minus_15, row.w_i_minus_2, row.w_i_minus_16, row.w_i_minus_7] + = input_and_timestamp.0.map(F::from_canonical_u32); + + self.generate_trace_row_for_round(&mut row); + row + } + + fn generate_trace_row_for_round(&self, row: &mut ShaExtendColumnsView) { + let w_i_minus_15_u32 = row.w_i_minus_15.to_canonical_u64() as u32; + row.w_i_minus_15_rr_7 = F::from_canonical_u32(w_i_minus_15_u32.rotate_right(7)); + row.w_i_minus_15_rr_18 = F::from_canonical_u32(w_i_minus_15_u32.rotate_right(18)); + row.w_i_minus_15_rs_3 = F::from_canonical_u32(w_i_minus_15_u32 >> 3); + + // (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) + + row.s_0_inter = F::from_canonical_u32(w_i_minus_15_u32.rotate_right(7) ^ w_i_minus_15_u32.rotate_right(18)); + // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) + row.s_0 = F::from_canonical_u32((row.s_0_inter.to_canonical_u64() as u32) ^ (w_i_minus_15_u32 >> 3)); + + let w_i_minus_2_u32 = row.w_i_minus_2.to_canonical_u64() as u32; + row.w_i_minus_2_rr_17 = F::from_canonical_u32(w_i_minus_2_u32.rotate_right(17)); + row.w_i_minus_2_rr_19 = F::from_canonical_u32(w_i_minus_2_u32.rotate_right(19)); + row.w_i_minus_2_rs_10 = F::from_canonical_u32(w_i_minus_2_u32 >> 10); + + // (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) + row.s_1_inter = F::from_canonical_u32(w_i_minus_2_u32.rotate_right(17) ^ w_i_minus_2_u32.rotate_right(19)); + // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) + row.s_1 = F::from_canonical_u32((row.s_1_inter.to_canonical_u64() as u32) ^ (w_i_minus_2_u32 >> 10)); + + // w_i = w[i-16] + s0 + w[i-7] + s1. + row.w_i = F::from_canonical_u32((row.w_i_minus_16.to_canonical_u64() as u32) + .wrapping_add(row.s_0.to_canonical_u64() as u32) + .wrapping_add(row.w_i_minus_7.to_canonical_u64() as u32) + .wrapping_add(row.s_1.to_canonical_u64() as u32)); + } +} + + +impl, const D: usize> Stark for ShaExtendStark { + type EvaluationFrame + = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_SHA_EXTEND_COLUMNS>; + + fn eval_packed_generic( + &self, + vars: &Self::EvaluationFrame, + yield_constr: &mut ConstraintConsumer

+ ) where + FE: FieldExtension, + P: PackedField + { + todo!() + } + + fn eval_ext_circuit( + &self, + builder: &mut CircuitBuilder, + vars: &Self::EvaluationFrameTarget, + yield_constr: &mut RecursiveConstraintConsumer) { + todo!() + } + + fn constraint_degree(&self) -> usize { + todo!() + } +} + +#[cfg(test)] +mod test { + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::{Field}; + use crate::sha_extend::sha_extend_stark::ShaExtendStark; + + #[test] + fn test_generation() -> Result<(), String> { + const D: usize = 2; + type F = GoldilocksField; + + type S = ShaExtendStark; + + let input = ([1, 2, 3, 4 as u32], 0); + + let stark = S::default(); + let row = stark.generate_trace_rows_for_extend(input); + + + // extend phase + let w_i_minus_15 = input.0[0]; + let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); + + let w_i_minus_2 = input.0[1]; + // Compute `s1`. + let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); + let w_i_minus_16 = input.0[2]; + let w_i_minus_7 = input.0[3]; + // Compute `w_i`. + let w_i = s1 + .wrapping_add(w_i_minus_16) + .wrapping_add(s0) + .wrapping_add(w_i_minus_7); + // println!("w_i: {}", w_i); + assert_eq!(row.w_i, F::from_canonical_u32(w_i)); + + Ok(()) + } +} \ No newline at end of file diff --git a/prover/src/sha_extend_sponge/mod.rs b/prover/src/sha_extend_sponge/mod.rs index eff2b9b8..96a740d5 100644 --- a/prover/src/sha_extend_sponge/mod.rs +++ b/prover/src/sha_extend_sponge/mod.rs @@ -1 +1,2 @@ -mod columns; \ No newline at end of file +pub mod columns; +pub mod sha_extend_sponge_stark; \ No newline at end of file diff --git a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs new file mode 100644 index 00000000..2efb2809 --- /dev/null +++ b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs @@ -0,0 +1,193 @@ +use std::marker::PhantomData; +use itertools::Itertools; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::evaluation_frame::StarkFrame; +use crate::memory::segments::Segment; +use crate::sha_extend_sponge::columns::{ShaExtendSpongeColumnsView, NUM_SHA_EXTEND_SPONGE_COLUMNS}; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; +use crate::witness::memory::MemoryAddress; + +pub(crate) struct ShaExtendSpongeOp { + /// The base address at which inputs are read + pub(crate) base_address: Vec, + + /// The timestamp at which inputs are read and output are written (same for both). + pub(crate) timestamp: usize, + + /// The input that was read. + /// Values: w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 + pub(crate) input: Vec, + + /// The index of round + pub(crate) i: usize, + + /// The base address at which the output is written. + pub(crate) output_address: MemoryAddress, +} + +#[derive(Copy, Clone, Default)] +pub struct ShaExtendSpongeStark { + f: PhantomData, +} + +impl, const D: usize> ShaExtendSpongeStark { + pub(crate) fn generate_trace( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec> { + // Generate the witness row-wise. + let trace_rows = self.generate_trace_rows(operations, min_rows); + + trace_rows_to_poly_values(trace_rows) + } + + fn generate_trace_rows( + &self, + operations: Vec, + min_rows: usize, + ) -> Vec<[F; NUM_SHA_EXTEND_SPONGE_COLUMNS]> { + let base_len = operations.len(); + let mut rows = Vec::with_capacity(base_len.max(min_rows).next_power_of_two()); + for op in operations { + rows.push(self.generate_rows_for_op(op).into()); + } + + let padded_rows = rows.len().max(min_rows).next_power_of_two(); + for _ in rows.len()..padded_rows { + rows.push(ShaExtendSpongeColumnsView::default().into()); + } + + rows + } + + fn generate_rows_for_op(&self, op: ShaExtendSpongeOp) -> ShaExtendSpongeColumnsView{ + let mut row = ShaExtendSpongeColumnsView::default(); + row.timestamp = F::from_canonical_usize(op.timestamp); + row.i = F::from_canonical_usize(op.i); + if op.i == 63 { + row.is_final = F::ONE; + } else { + row.is_final = F::ZERO; + } + + row.context = F::from_canonical_usize(op.base_address[0].context); + row.segment = F::from_canonical_usize(op.base_address[Segment::Code as usize].segment); + let mut virt = (0..op.input.len()) + .map(|i| op.base_address[i].virt) + .collect_vec(); + let virt: [usize; 4] = virt.try_into().unwrap(); + row.input_virt = virt.map(F::from_canonical_usize); + row.output_virt = F::from_canonical_usize(op.output_address.virt); + + row.w_i_minus_15 = F::from_canonical_u32(op.input[0]); + row.w_i_minus_2 = F::from_canonical_u32(op.input[1]); + row.w_i_minus_16 = F::from_canonical_u32(op.input[2]); + row.w_i_minus_7 = F::from_canonical_u32(op.input[3]); + + row.w_i = self.compute_w_i(&op.input.try_into().unwrap()); + row + } + + fn compute_w_i(&self, input: &[u32; 4]) -> F { + let s0 = input[0].rotate_right(7) ^ input[0].rotate_right(18) ^ (input[0] >> 3); + let s1 = input[1].rotate_right(17) ^ input[1].rotate_right(19) ^ (input[1] >> 10); + let w_i_u32 = s1 + .wrapping_add(input[2]) + .wrapping_add(s0) + .wrapping_add(input[3]); + F::from_canonical_u32(w_i_u32) + } +} + +impl, const D: usize> Stark for ShaExtendSpongeStark { + type EvaluationFrame + = StarkFrame + where + FE: FieldExtension, + P: PackedField; + + type EvaluationFrameTarget = StarkFrame, NUM_SHA_EXTEND_SPONGE_COLUMNS>; + + fn eval_packed_generic( + &self, + vars: &Self::EvaluationFrame, + yield_constr: &mut ConstraintConsumer

+ ) where + FE: FieldExtension, + P: PackedField + { + todo!() + } + + fn eval_ext_circuit( + &self, + builder: &mut CircuitBuilder, + vars: &Self::EvaluationFrameTarget, + yield_constr: &mut RecursiveConstraintConsumer + ) { + todo!() + } + + fn constraint_degree(&self) -> usize { + todo!() + } +} + +#[cfg(test)] +mod test { + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::Field; + use crate::memory::segments::Segment; + use crate::sha_extend_sponge::sha_extend_sponge_stark::{ShaExtendSpongeOp, ShaExtendSpongeStark}; + use crate::witness::memory::MemoryAddress; + + #[test] + fn test_generation() -> Result<(), String> { + const D: usize = 2; + type F = GoldilocksField; + + type S = ShaExtendSpongeStark; + + let op = ShaExtendSpongeOp { + base_address: vec![MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 4, + }, MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 56, + }, MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 0, + }, MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 36, + }], + timestamp: 0, + input: vec![1, 2, 3, 4], + i: 0, + output_address: MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 64, + }, + }; + + let stark = S::default(); + let row = stark.generate_rows_for_op(op); + assert_eq!(row.is_final, F::ZERO); + assert_eq!(row.w_i, F::from_canonical_u32(33652743)); + Ok(()) + } +} \ No newline at end of file From 906761fc9795102b2ba35e2050deb4b2d9b73939 Mon Sep 17 00:00:00 2001 From: vanhger Date: Wed, 22 Jan 2025 17:41:51 +0700 Subject: [PATCH 09/25] chore: change the columns in SHA extend phase --- prover/src/sha_extend/columns.rs | 46 ++++++---- prover/src/sha_extend/logic.rs | 68 ++++++++++++++ prover/src/sha_extend/mod.rs | 3 +- prover/src/sha_extend/sha_extend_stark.rs | 92 +++++++++++-------- prover/src/sha_extend_sponge/columns.rs | 30 +++--- .../sha_extend_sponge_stark.rs | 77 +++++++++++----- 6 files changed, 215 insertions(+), 101 deletions(-) create mode 100644 prover/src/sha_extend/logic.rs diff --git a/prover/src/sha_extend/columns.rs b/prover/src/sha_extend/columns.rs index 623457e3..3ff2718e 100644 --- a/prover/src/sha_extend/columns.rs +++ b/prover/src/sha_extend/columns.rs @@ -4,29 +4,30 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; pub(crate) struct ShaExtendColumnsView { - /// The timestamp at which inputs should be read from memory. - pub timestamp: T, - - /// Input - pub w_i_minus_15: T, - pub w_i_minus_2: T, - pub w_i_minus_16: T, - pub w_i_minus_7: T, + /// Input in big-endian order + pub w_i_minus_15: [T; 32], + pub w_i_minus_2: [T; 32], + pub w_i_minus_16: [T; 32], + pub w_i_minus_7: [T; 32], /// Intermediate values - pub w_i_minus_15_rr_7: T, - pub w_i_minus_15_rr_18: T, - pub w_i_minus_15_rs_3: T, - pub s_0_inter: T, - pub s_0: T, - pub w_i_minus_2_rr_17: T, - pub w_i_minus_2_rr_19: T, - pub w_i_minus_2_rs_10: T, - pub s_1_inter: T, - pub s_1: T, - + pub w_i_minus_15_rr_7: [T; 32], + pub w_i_minus_15_rr_18: [T; 32], + pub w_i_minus_15_rs_3: [T; 32], + pub s_0: [T; 32], + pub w_i_minus_2_rr_17: [T; 32], + pub w_i_minus_2_rr_19: [T; 32], + pub w_i_minus_2_rs_10: [T; 32], + pub s_1: [T; 32], + pub w_i_inter_0: [T; 32], // s_1 + w_i_minus_7] + pub carry_0: [T; 32], + pub w_i_inter_1: [T; 32], // w_i_inter_0 + s_0 + pub carry_1: [T; 32], + pub carry_2: [T; 32], /// Output - pub w_i: T, + pub w_i: [T; 32], // w_i_inter_1 + w_i_minus_16 + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, } pub const NUM_SHA_EXTEND_COLUMNS: usize = size_of::>(); @@ -81,3 +82,8 @@ const fn make_col_map() -> ShaExtendColumnsView { } pub(crate) const SHA_EXTEND_COL_MAP: ShaExtendColumnsView = make_col_map(); + +pub fn get_input_range(i: usize) -> std::ops::Range { + (0 + i * 32)..(32 + i * 32) +} + diff --git a/prover/src/sha_extend/logic.rs b/prover/src/sha_extend/logic.rs new file mode 100644 index 00000000..e21a274c --- /dev/null +++ b/prover/src/sha_extend/logic.rs @@ -0,0 +1,68 @@ +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +// these operators are applied in big-endian form + +pub fn rotate_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { + let mut result = [F::ZERO; 32]; + for i in 0..32 { + result[i] = value[(i + amount) % 32]; + } + result +} + +pub fn shift_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { + let mut result = [F::ZERO; 32]; + if amount < 32 { + for i in 0..32 - amount { + result[i] = value[i + amount]; + } + } + result +} + +pub fn xor3 , const D: usize>(a: [F; 32], b: [F; 32], c: [F; 32]) -> [F; 32] { + let mut result = [F::ZERO; 32]; + for i in 0..32 { + result[i] = crate::keccak::logic::xor([a[i], b[i], c[i]]); + } + result +} + +pub fn wrapping_add, const D: usize>(a: [F; 32], b: [F; 32]) -> ([F; 32], [F; 32]) { + let mut result = [F::ZERO; 32]; + let mut carries = [F::ZERO; 32]; + let mut sum = F::ZERO; + let mut carry = F::ZERO; + for i in 0..32 { + debug_assert!(a[i].is_zero() || a[i].is_one()); + debug_assert!(b[i].is_zero() || b[i].is_one()); + + let tmp = (a[i] + b[i] + carry).to_canonical_u64(); + sum = F::from_canonical_u64(tmp & 1); + carry = F::from_canonical_u64(tmp >> 1); + carries[i] = carry; + result[i] = sum; + } + (result, carries) +} + +pub fn from_be_bits_to_u32, const D: usize>(value: [F; 32]) -> u32 { + let mut result = 0; + for i in 0..32 { + debug_assert!(value[i].is_zero() || value[i].is_one()); + result |= (value[i].to_canonical_u64() as u32) << i; + } + result +} + +pub fn from_u32_to_be_bits(value: u32) -> [u32; 32] { + let mut result = [0; 32]; + for i in 0..32 { + result[i] = ((value >> i) & 1) as u32; + } + result +} \ No newline at end of file diff --git a/prover/src/sha_extend/mod.rs b/prover/src/sha_extend/mod.rs index 7cafd8a3..8fb26d6f 100644 --- a/prover/src/sha_extend/mod.rs +++ b/prover/src/sha_extend/mod.rs @@ -1,2 +1,3 @@ pub mod columns; -pub mod sha_extend_stark; \ No newline at end of file +pub mod sha_extend_stark; +pub mod logic; \ No newline at end of file diff --git a/prover/src/sha_extend/sha_extend_stark.rs b/prover/src/sha_extend/sha_extend_stark.rs index 0dbc4e22..ce38977a 100644 --- a/prover/src/sha_extend/sha_extend_stark.rs +++ b/prover/src/sha_extend/sha_extend_stark.rs @@ -1,3 +1,4 @@ +use std::borrow::Borrow; use std::marker::PhantomData; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; @@ -6,13 +7,13 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::evaluation_frame::StarkFrame; -use crate::sha_extend::columns::{ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS}; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; +use crate::sha_extend::columns::{get_input_range, ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS}; +use crate::sha_extend::logic::{rotate_right, shift_right, wrapping_add, xor3}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; -const NUM_ROUND_CONSTANTS: usize = 48; -const NUM_INPUTS: usize = 4; // w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 +const NUM_INPUTS: usize = 4 * 32; // w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 #[derive(Copy, Clone, Default)] pub struct ShaExtendStark { @@ -23,11 +24,11 @@ pub struct ShaExtendStark { impl, const D: usize> ShaExtendStark { pub(crate) fn generate_trace( &self, - inputs: Vec<([u32; NUM_INPUTS], usize)>, + inputs_and_timestamps: Vec<([u32; NUM_INPUTS], usize)>, min_rows: usize, ) -> Vec> { // Generate the witness row-wise - let trace_rows = self.generate_trace_rows(inputs, min_rows); + let trace_rows = self.generate_trace_rows(inputs_and_timestamps, min_rows); trace_rows_to_poly_values(trace_rows) } @@ -60,40 +61,39 @@ impl, const D: usize> ShaExtendStark { let mut row = ShaExtendColumnsView::default(); row.timestamp = F::from_canonical_usize(input_and_timestamp.1); - [row.w_i_minus_15, row.w_i_minus_2, row.w_i_minus_16, row.w_i_minus_7] - = input_and_timestamp.0.map(F::from_canonical_u32); + row.w_i_minus_15 = input_and_timestamp.0[get_input_range(0)] + .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + row.w_i_minus_2 = input_and_timestamp.0[get_input_range(1)] + .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + row.w_i_minus_16 = input_and_timestamp.0[get_input_range(2)] + .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + row.w_i_minus_7 = input_and_timestamp.0[get_input_range(3)] + .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); self.generate_trace_row_for_round(&mut row); row } fn generate_trace_row_for_round(&self, row: &mut ShaExtendColumnsView) { - let w_i_minus_15_u32 = row.w_i_minus_15.to_canonical_u64() as u32; - row.w_i_minus_15_rr_7 = F::from_canonical_u32(w_i_minus_15_u32.rotate_right(7)); - row.w_i_minus_15_rr_18 = F::from_canonical_u32(w_i_minus_15_u32.rotate_right(18)); - row.w_i_minus_15_rs_3 = F::from_canonical_u32(w_i_minus_15_u32 >> 3); - - // (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) + row.w_i_minus_15_rr_7 = rotate_right(row.w_i_minus_15, 7); + row.w_i_minus_15_rr_18 = rotate_right(row.w_i_minus_15, 18); + row.w_i_minus_15_rs_3 = shift_right(row.w_i_minus_15, 3); - row.s_0_inter = F::from_canonical_u32(w_i_minus_15_u32.rotate_right(7) ^ w_i_minus_15_u32.rotate_right(18)); // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) - row.s_0 = F::from_canonical_u32((row.s_0_inter.to_canonical_u64() as u32) ^ (w_i_minus_15_u32 >> 3)); + row.s_0 = xor3(row.w_i_minus_15_rr_7, row.w_i_minus_15_rr_18, row.w_i_minus_15_rs_3); - let w_i_minus_2_u32 = row.w_i_minus_2.to_canonical_u64() as u32; - row.w_i_minus_2_rr_17 = F::from_canonical_u32(w_i_minus_2_u32.rotate_right(17)); - row.w_i_minus_2_rr_19 = F::from_canonical_u32(w_i_minus_2_u32.rotate_right(19)); - row.w_i_minus_2_rs_10 = F::from_canonical_u32(w_i_minus_2_u32 >> 10); + row.w_i_minus_2_rr_17 = rotate_right(row.w_i_minus_2, 17); + row.w_i_minus_2_rr_19 = rotate_right(row.w_i_minus_2, 19); + row.w_i_minus_2_rs_10 = shift_right(row.w_i_minus_2, 10); - // (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) - row.s_1_inter = F::from_canonical_u32(w_i_minus_2_u32.rotate_right(17) ^ w_i_minus_2_u32.rotate_right(19)); // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) - row.s_1 = F::from_canonical_u32((row.s_1_inter.to_canonical_u64() as u32) ^ (w_i_minus_2_u32 >> 10)); + row.s_1 = xor3(row.w_i_minus_2_rr_17, row.w_i_minus_2_rr_19, row.w_i_minus_2_rs_10); + + // (w_i_inter_0, carry) = w[i-7] + s1. + (row.w_i_inter_0, row.carry_0) = wrapping_add(row.w_i_minus_7, row.s_1); + (row.w_i_inter_1, row.carry_1) = wrapping_add(row.w_i_inter_0, row.s_0); - // w_i = w[i-16] + s0 + w[i-7] + s1. - row.w_i = F::from_canonical_u32((row.w_i_minus_16.to_canonical_u64() as u32) - .wrapping_add(row.s_0.to_canonical_u64() as u32) - .wrapping_add(row.w_i_minus_7.to_canonical_u64() as u32) - .wrapping_add(row.s_1.to_canonical_u64() as u32)); + (row.w_i, row.carry_2) = wrapping_add(row.w_i_inter_1, row.w_i_minus_16); } } @@ -122,7 +122,9 @@ impl, const D: usize> Stark for ShaExtendStar &self, builder: &mut CircuitBuilder, vars: &Self::EvaluationFrameTarget, - yield_constr: &mut RecursiveConstraintConsumer) { + yield_constr: &mut RecursiveConstraintConsumer) + { + todo!() } @@ -137,35 +139,47 @@ mod test { use plonky2::field::types::{Field}; use crate::sha_extend::sha_extend_stark::ShaExtendStark; + fn to_be_bits(value: u32) -> [u32; 32] { + let mut result = [0; 32]; + for i in 0..32 { + result[i] = ((value >> i) & 1) as u32; + } + result + } + #[test] - fn test_generation() -> Result<(), String> { + fn test_correction() -> Result<(), String> { const D: usize = 2; type F = GoldilocksField; type S = ShaExtendStark; - - let input = ([1, 2, 3, 4 as u32], 0); + let mut input_values = vec![]; + input_values.extend((0..4).map(|i| to_be_bits(i as u32))); + let input_values = input_values.into_iter().flatten().collect::>(); + let input_values: [u32; 128] = input_values.try_into().unwrap(); + let input_and_timestamp = (input_values, 0); let stark = S::default(); - let row = stark.generate_trace_rows_for_extend(input); + let row = stark.generate_trace_rows_for_extend(input_and_timestamp.try_into().unwrap()); // extend phase - let w_i_minus_15 = input.0[0]; + let w_i_minus_15 = 0 as u32; let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); - let w_i_minus_2 = input.0[1]; + let w_i_minus_2 = 1 as u32; // Compute `s1`. let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); - let w_i_minus_16 = input.0[2]; - let w_i_minus_7 = input.0[3]; + let w_i_minus_16 = 2 as u32; + let w_i_minus_7 = 3 as u32; // Compute `w_i`. let w_i = s1 .wrapping_add(w_i_minus_16) .wrapping_add(s0) .wrapping_add(w_i_minus_7); - // println!("w_i: {}", w_i); - assert_eq!(row.w_i, F::from_canonical_u32(w_i)); + + let w_i_bin = to_be_bits(w_i); + assert_eq!(row.w_i, w_i_bin.map(F::from_canonical_u32)); Ok(()) } diff --git a/prover/src/sha_extend_sponge/columns.rs b/prover/src/sha_extend_sponge/columns.rs index d4c1e088..e39e0b4f 100644 --- a/prover/src/sha_extend_sponge/columns.rs +++ b/prover/src/sha_extend_sponge/columns.rs @@ -4,35 +4,33 @@ use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; pub(crate) const NUM_EXTEND_INPUT: usize = 4; pub(crate) struct ShaExtendSpongeColumnsView { - /// The timestamp at which inputs should be read from memory. - pub timestamp: T, - - /// Round number - pub i: T, /// Input - pub w_i_minus_15: T, - pub w_i_minus_2: T, - pub w_i_minus_16: T, - pub w_i_minus_7: T, + pub w_i_minus_15: [T; 32], + pub w_i_minus_2: [T; 32], + pub w_i_minus_16: [T; 32], + pub w_i_minus_7: [T; 32], /// Output - pub w_i: T, + pub w_i: [T; 32], + + /// round + pub round: [T; 48], - /// The base address at which we will read the input block. - pub context: T, - pub segment: T, /// Input address pub input_virt: [T; NUM_EXTEND_INPUT], /// Output address pub output_virt: T, - /// 1 if this is the final round of the extending phase, 0 otherwise - pub is_final: T, + pub context: T, + pub segment: T, + + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, } -pub const NUM_SHA_EXTEND_SPONGE_COLUMNS: usize = size_of::>(); +pub const NUM_SHA_EXTEND_SPONGE_COLUMNS: usize = size_of::>(); //170 impl From<[T; NUM_SHA_EXTEND_SPONGE_COLUMNS]> for ShaExtendSpongeColumnsView { fn from(value: [T; NUM_SHA_EXTEND_SPONGE_COLUMNS]) -> Self { diff --git a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs index 2efb2809..df2be7be 100644 --- a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs +++ b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs @@ -1,19 +1,25 @@ use std::marker::PhantomData; +use std::borrow::Borrow; use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::evaluation_frame::StarkFrame; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; +use crate::sha_extend::columns::get_input_range; +use crate::sha_extend::logic::{from_be_bits_to_u32, from_u32_to_be_bits}; use crate::sha_extend_sponge::columns::{ShaExtendSpongeColumnsView, NUM_SHA_EXTEND_SPONGE_COLUMNS}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryAddress; +pub const NUM_ROUNDS: usize = 48; + pub(crate) struct ShaExtendSpongeOp { /// The base address at which inputs are read pub(crate) base_address: Vec, @@ -22,11 +28,11 @@ pub(crate) struct ShaExtendSpongeOp { pub(crate) timestamp: usize, /// The input that was read. - /// Values: w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 + /// Values: w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 in big-endian order. pub(crate) input: Vec, /// The index of round - pub(crate) i: usize, + pub(crate) i: u32, /// The base address at which the output is written. pub(crate) output_address: MemoryAddress, @@ -71,39 +77,45 @@ impl, const D: usize> ShaExtendSpongeStark { fn generate_rows_for_op(&self, op: ShaExtendSpongeOp) -> ShaExtendSpongeColumnsView{ let mut row = ShaExtendSpongeColumnsView::default(); row.timestamp = F::from_canonical_usize(op.timestamp); - row.i = F::from_canonical_usize(op.i); - if op.i == 63 { - row.is_final = F::ONE; - } else { - row.is_final = F::ZERO; - } + row.round = [F::ZEROS; 48]; + row.round[op.i as usize] = F::ONE; row.context = F::from_canonical_usize(op.base_address[0].context); row.segment = F::from_canonical_usize(op.base_address[Segment::Code as usize].segment); - let mut virt = (0..op.input.len()) + let virt = (0..op.input.len() / 32) .map(|i| op.base_address[i].virt) .collect_vec(); let virt: [usize; 4] = virt.try_into().unwrap(); row.input_virt = virt.map(F::from_canonical_usize); row.output_virt = F::from_canonical_usize(op.output_address.virt); - row.w_i_minus_15 = F::from_canonical_u32(op.input[0]); - row.w_i_minus_2 = F::from_canonical_u32(op.input[1]); - row.w_i_minus_16 = F::from_canonical_u32(op.input[2]); - row.w_i_minus_7 = F::from_canonical_u32(op.input[3]); + row.w_i_minus_15 = op.input[get_input_range(0)] + .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + row.w_i_minus_2 = op.input[get_input_range(1)] + .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + row.w_i_minus_16 = op.input[get_input_range(2)] + .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + row.w_i_minus_7 = op.input[get_input_range(3)] + .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); - row.w_i = self.compute_w_i(&op.input.try_into().unwrap()); + row.w_i = self.compute_w_i(&mut row); row } - fn compute_w_i(&self, input: &[u32; 4]) -> F { - let s0 = input[0].rotate_right(7) ^ input[0].rotate_right(18) ^ (input[0] >> 3); - let s1 = input[1].rotate_right(17) ^ input[1].rotate_right(19) ^ (input[1] >> 10); + fn compute_w_i(&self, row: &mut ShaExtendSpongeColumnsView) -> [F; 32] { + let w_i_minus_15 = from_be_bits_to_u32(row.w_i_minus_15); + let w_i_minus_2 = from_be_bits_to_u32(row.w_i_minus_2); + let w_i_minus_16 = from_be_bits_to_u32(row.w_i_minus_16); + let w_i_minus_7 = from_be_bits_to_u32(row.w_i_minus_7); + let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); + let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); let w_i_u32 = s1 - .wrapping_add(input[2]) + .wrapping_add(w_i_minus_16) .wrapping_add(s0) - .wrapping_add(input[3]); - F::from_canonical_u32(w_i_u32) + .wrapping_add(w_i_minus_7); + + let w_i_bin = from_u32_to_be_bits(w_i_u32); + w_i_bin.iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap() } } @@ -149,13 +161,26 @@ mod test { use crate::sha_extend_sponge::sha_extend_sponge_stark::{ShaExtendSpongeOp, ShaExtendSpongeStark}; use crate::witness::memory::MemoryAddress; + + fn to_be_bits(value: u32) -> [u32; 32] { + let mut result = [0; 32]; + for i in 0..32 { + result[i] = ((value >> i) & 1) as u32; + } + result + } + #[test] - fn test_generation() -> Result<(), String> { + fn test_correction() -> Result<(), String> { const D: usize = 2; type F = GoldilocksField; type S = ShaExtendSpongeStark; + let mut input_values = vec![]; + input_values.extend((0..4).map(|i| to_be_bits(i as u32))); + let input_values = input_values.into_iter().flatten().collect::>(); + let op = ShaExtendSpongeOp { base_address: vec![MemoryAddress { context: 0, @@ -175,7 +200,7 @@ mod test { virt: 36, }], timestamp: 0, - input: vec![1, 2, 3, 4], + input: input_values, i: 0, output_address: MemoryAddress { context: 0, @@ -186,8 +211,10 @@ mod test { let stark = S::default(); let row = stark.generate_rows_for_op(op); - assert_eq!(row.is_final, F::ZERO); - assert_eq!(row.w_i, F::from_canonical_u32(33652743)); + + let w_i_bin = to_be_bits(40965); + assert_eq!(row.w_i, w_i_bin.map(F::from_canonical_u32)); + Ok(()) } } \ No newline at end of file From baae0365cf481d66d955f57c440fd3c27df94ab0 Mon Sep 17 00:00:00 2001 From: vanhger Date: Wed, 22 Jan 2025 17:47:27 +0700 Subject: [PATCH 10/25] feat: add constraints for ShaExtend precompile --- prover/src/sha_extend/logic.rs | 121 +++++- prover/src/sha_extend/sha_extend_stark.rs | 373 ++++++++++++++++- prover/src/sha_extend_sponge/logic.rs | 50 +++ prover/src/sha_extend_sponge/mod.rs | 3 +- .../sha_extend_sponge_stark.rs | 383 +++++++++++++++++- 5 files changed, 917 insertions(+), 13 deletions(-) create mode 100644 prover/src/sha_extend_sponge/logic.rs diff --git a/prover/src/sha_extend/logic.rs b/prover/src/sha_extend/logic.rs index e21a274c..371f589f 100644 --- a/prover/src/sha_extend/logic.rs +++ b/prover/src/sha_extend/logic.rs @@ -4,8 +4,8 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; -// these operators are applied in big-endian form +// these operators are applied in big-endian form pub fn rotate_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { let mut result = [F::ZERO; 32]; for i in 0..32 { @@ -14,6 +14,31 @@ pub fn rotate_right, const D: usize>(value: [F; 32] result } +pub(crate) fn rotate_right_packed_constraints( + value: [P; 32], + rotated_value: [P;32], + amount: usize, +) -> Vec

{ + let mut result = Vec::new(); + for i in 0..32 { + result.push(value[i] - rotated_value[(i + 32 - amount) % 32]); + } + result +} + +pub(crate) fn rotate_right_ext_circuit_constraint, const D: usize>( + builder: &mut CircuitBuilder, + value: [ExtensionTarget;32], + rotated_value: [ExtensionTarget; 32], + amount: usize +) -> Vec> { + let mut result = Vec::new(); + for i in 0..32 { + result.push(builder.sub_extension(value[i], rotated_value[(i + 32 - amount) % 32])); + } + result +} + pub fn shift_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { let mut result = [F::ZERO; 32]; if amount < 32 { @@ -24,6 +49,37 @@ pub fn shift_right, const D: usize>(value: [F; 32], result } +pub(crate) fn shift_right_packed_constraints( + value: [P; 32], + shifted_value: [P;32], + amount: usize, +) -> Vec

{ + let mut result = Vec::new(); + for i in 0..32 - amount { + result.push(value[i + amount] - shifted_value[i]); + } + for i in (32 - 3)..32 { + result.push(shifted_value[i]); + } + result +} + +pub(crate) fn shift_right_ext_circuit_constraints, const D: usize>( + builder: &mut CircuitBuilder, + value: [ExtensionTarget;32], + shifted_value: [ExtensionTarget; 32], + amount: usize +) -> Vec> { + let mut result = Vec::new(); + for i in 0..32 - amount { + result.push(builder.sub_extension(value[i + amount], shifted_value[i])); + } + for i in (32 - 3)..32 { + result.push(shifted_value[i]); + } + result +} + pub fn xor3 , const D: usize>(a: [F; 32], b: [F; 32], c: [F; 32]) -> [F; 32] { let mut result = [F::ZERO; 32]; for i in 0..32 { @@ -65,4 +121,67 @@ pub fn from_u32_to_be_bits(value: u32) -> [u32; 32] { result[i] = ((value >> i) & 1) as u32; } result +} + +/// Computes the constraints of wrapping add +pub(crate) fn wrapping_add_packed_constraints( + x: [P; 32], + y: [P; 32], + carry: [P; 32], + out: [P; 32] +) -> Vec

{ + + let mut result = vec![]; + let mut pre_carry = P::ZEROS; + for i in 0..32 { + let sum = x[i] + y[i] + pre_carry; + + let out_constraint = (sum - P::ONES) * (sum - P::ONES - P::ONES - P::ONES) * out[i] + + sum * (sum - P::ONES - P::ONES) * (out[i] - P::ONES); + + let carry_constraint = carry[i] + carry[i] + out[i] - sum; + result.push(out_constraint); + result.push(carry_constraint); + pre_carry = carry[i]; + } + result +} + +pub(crate) fn wrapping_add_ext_circuit_constraints, const D: usize>( + builder: &mut CircuitBuilder, + x: [ExtensionTarget; 32], + y: [ExtensionTarget; 32], + carry: [ExtensionTarget; 32], + out: [ExtensionTarget; 32] +) -> Vec> { + + let mut result = vec![]; + let mut pre_carry= builder.zero_extension(); + let one_ext = builder.one_extension(); + let two_ext = builder.two_extension(); + let three_ext = builder.constant_extension(F::Extension::from_canonical_u8(3)); + for i in 0..32 { + let sum = builder.add_many_extension([x[i], y[i], pre_carry]); + + let inner_1 = builder.sub_extension(sum, one_ext); + let inner_2 = builder.sub_extension(sum, three_ext); + let tmp1 = builder.mul_many_extension( + [inner_1, inner_2, out[i]] + ); + + let inner_1 = builder.sub_extension(sum, two_ext); + let inner_2 = builder.sub_extension(out[i], one_ext); + let tmp2 = builder.mul_many_extension( + [sum, inner_1, inner_2] + ); + result.push(builder.add_extension(tmp1, tmp2)); + + let tmp3 = builder.add_many_extension( + [carry[i], carry[i], out[i]] + ); + result.push(builder.sub_extension(tmp3, sum)); + + pre_carry = carry[i]; + } + result } \ No newline at end of file diff --git a/prover/src/sha_extend/sha_extend_stark.rs b/prover/src/sha_extend/sha_extend_stark.rs index ce38977a..c1176997 100644 --- a/prover/src/sha_extend/sha_extend_stark.rs +++ b/prover/src/sha_extend/sha_extend_stark.rs @@ -8,8 +8,9 @@ use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; +use crate::keccak::logic::{xor3_gen, xor3_gen_circuit}; use crate::sha_extend::columns::{get_input_range, ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS}; -use crate::sha_extend::logic::{rotate_right, shift_right, wrapping_add, xor3}; +use crate::sha_extend::logic::{rotate_right, rotate_right_ext_circuit_constraint, rotate_right_packed_constraints, shift_right, shift_right_ext_circuit_constraints, shift_right_packed_constraints, wrapping_add, wrapping_add_ext_circuit_constraints, wrapping_add_packed_constraints, xor3}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; @@ -115,29 +116,283 @@ impl, const D: usize> Stark for ShaExtendStar FE: FieldExtension, P: PackedField { - todo!() + let local_values: &[P; NUM_SHA_EXTEND_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let local_values: &ShaExtendColumnsView

= local_values.borrow(); + + // check the bit values are zero or one in input + for i in 0..32 { + yield_constr.constraint(local_values.w_i_minus_15[i] * (local_values.w_i_minus_15[i] - P::ONES)); + yield_constr.constraint(local_values.w_i_minus_2[i] * (local_values.w_i_minus_2[i] - P::ONES)); + yield_constr.constraint(local_values.w_i_minus_16[i] * (local_values.w_i_minus_16[i] - P::ONES)); + yield_constr.constraint(local_values.w_i_minus_7[i] * (local_values.w_i_minus_7[i] - P::ONES)); + } + + // check the bit values are zero or one in intermediate values + for i in 0..32 { + yield_constr.constraint(local_values.w_i_inter_0[i] * (local_values.w_i_inter_0[i] - P::ONES)); + yield_constr.constraint(local_values.w_i_inter_1[i] * (local_values.w_i_inter_1[i] - P::ONES)); + yield_constr.constraint(local_values.carry_0[i] * (local_values.carry_0[i] - P::ONES)); + yield_constr.constraint(local_values.carry_1[i] * (local_values.carry_1[i] - P::ONES)); + yield_constr.constraint(local_values.carry_2[i] * (local_values.carry_2[i] - P::ONES)); + } + + // check the bit values are zero or one in output + for i in 0..32 { + yield_constr.constraint(local_values.w_i[i] * (local_values.w_i[i] - P::ONES)); + } + + // check the rotation + rotate_right_packed_constraints( + local_values.w_i_minus_15, + local_values.w_i_minus_15_rr_7, + 7 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + rotate_right_packed_constraints( + local_values.w_i_minus_15, + local_values.w_i_minus_15_rr_18, + 18 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + rotate_right_packed_constraints( + local_values.w_i_minus_2, + local_values.w_i_minus_2_rr_17, + 17 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + rotate_right_packed_constraints( + local_values.w_i_minus_2, + local_values.w_i_minus_2_rr_19, + 19 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + + // check the shift + shift_right_packed_constraints( + local_values.w_i_minus_15, + local_values.w_i_minus_15_rs_3, + 3 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + shift_right_packed_constraints( + local_values.w_i_minus_2, + local_values.w_i_minus_2_rs_10, + 10 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + + // check the computation of s0 and s1 + for i in 0..32 { + let s0 = xor3_gen(local_values.w_i_minus_15_rr_7[i], + local_values.w_i_minus_15_rr_18[i], + local_values.w_i_minus_15_rs_3[i] + ); + yield_constr.constraint(local_values.s_0[i] - s0); + + let s1 = xor3_gen( + local_values.w_i_minus_2_rr_17[i], + local_values.w_i_minus_2_rr_19[i], + local_values.w_i_minus_2_rs_10[i] + ); + yield_constr.constraint(local_values.s_1[i] - s1); + } + + // check the computation of w_i_inter_0 = w[i-7] + s1. + wrapping_add_packed_constraints( + local_values.w_i_minus_7, + local_values.s_1, + local_values.carry_0, + local_values.w_i_inter_0 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + // check the computation of w_i_inter_1 = w_i_inter_0 + s0 + wrapping_add_packed_constraints( + local_values.w_i_inter_0, + local_values.s_0, + local_values.carry_1, + local_values.w_i_inter_1 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + // check the computation of w_i = w_i_inter_1 + w_i_minus_16 + wrapping_add_packed_constraints( + local_values.w_i_inter_1, + local_values.w_i_minus_16, + local_values.carry_2, + local_values.w_i + ).into_iter().for_each(|c| yield_constr.constraint(c)); + } fn eval_ext_circuit( &self, builder: &mut CircuitBuilder, vars: &Self::EvaluationFrameTarget, - yield_constr: &mut RecursiveConstraintConsumer) - { + yield_constr: &mut RecursiveConstraintConsumer) { + + let local_values: &[ExtensionTarget; NUM_SHA_EXTEND_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let local_values: &ShaExtendColumnsView> = local_values.borrow(); - todo!() + // check the bit values are zero or one in input + for i in 0..32 { + let constraint = builder.mul_sub_extension( + local_values.w_i_minus_15[i], local_values.w_i_minus_15[i], local_values.w_i_minus_15[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.w_i_minus_2[i], local_values.w_i_minus_2[i], local_values.w_i_minus_2[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.w_i_minus_16[i], local_values.w_i_minus_16[i], local_values.w_i_minus_16[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.w_i_minus_7[i], local_values.w_i_minus_7[i], local_values.w_i_minus_7[i]); + yield_constr.constraint(builder, constraint); + } + + // check the bit values are zero or one in intermediate values + for i in 0..32 { + let constraint = builder.mul_sub_extension( + local_values.w_i_inter_0[i], local_values.w_i_inter_0[i], local_values.w_i_inter_0[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.w_i_inter_1[i], local_values.w_i_inter_1[i], local_values.w_i_inter_1[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.carry_0[i], local_values.carry_0[i], local_values.carry_0[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.carry_1[i], local_values.carry_1[i], local_values.carry_1[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.carry_2[i], local_values.carry_2[i], local_values.carry_2[i]); + yield_constr.constraint(builder, constraint); + } + + // check the bit values are zero or one in output + for i in 0..32 { + let constraint = builder.mul_sub_extension( + local_values.w_i[i], local_values.w_i[i], local_values.w_i[i]); + yield_constr.constraint(builder, constraint); + } + + // check the rotation + rotate_right_ext_circuit_constraint( + builder, + local_values.w_i_minus_15, + local_values.w_i_minus_15_rr_7, + 7 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + rotate_right_ext_circuit_constraint( + builder, + local_values.w_i_minus_15, + local_values.w_i_minus_15_rr_18, + 18 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + rotate_right_ext_circuit_constraint( + builder, + local_values.w_i_minus_2, + local_values.w_i_minus_2_rr_17, + 17 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + rotate_right_ext_circuit_constraint( + builder, + local_values.w_i_minus_2, + local_values.w_i_minus_2_rr_19, + 19 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + // check the shift + shift_right_ext_circuit_constraints( + builder, + local_values.w_i_minus_15, + local_values.w_i_minus_15_rs_3, + 3 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + shift_right_ext_circuit_constraints( + builder, + local_values.w_i_minus_2, + local_values.w_i_minus_2_rs_10, + 10 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + // check the computation of s0 and s1 + for i in 0..32 { + let s0 = xor3_gen_circuit( + builder, + local_values.w_i_minus_15_rr_7[i], + local_values.w_i_minus_15_rr_18[i], + local_values.w_i_minus_15_rs_3[i] + ); + let constraint = builder.sub_extension(local_values.s_0[i], s0); + yield_constr.constraint(builder, constraint); + + let s1 = xor3_gen_circuit( + builder, + local_values.w_i_minus_2_rr_17[i], + local_values.w_i_minus_2_rr_19[i], + local_values.w_i_minus_2_rs_10[i] + ); + let constraint = builder.sub_extension(local_values.s_1[i], s1); + yield_constr.constraint(builder, constraint); + } + + // check the computation of w_i_inter_0 = w[i-7] + s1. + wrapping_add_ext_circuit_constraints( + builder, + local_values.w_i_minus_7, + local_values.s_1, + local_values.carry_0, + local_values.w_i_inter_0 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + // check the computation of w_i_inter_1 = w_i_inter_0 + s0 + wrapping_add_ext_circuit_constraints( + builder, + local_values.w_i_inter_0, + local_values.s_0, + local_values.carry_1, + local_values.w_i_inter_1 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + // check the computation of w_i = w_i_inter_1 + w_i_minus_16 + wrapping_add_ext_circuit_constraints( + builder, + local_values.w_i_inter_1, + local_values.w_i_minus_16, + local_values.carry_2, + local_values.w_i + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); } fn constraint_degree(&self) -> usize { - todo!() + 3 } } + #[cfg(test)] mod test { + use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::{Field}; + use plonky2::fri::oracle::PolynomialBatch; + use plonky2::iop::challenger::Challenger; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::timed; + use plonky2::util::timing::TimingTree; + use crate::config::StarkConfig; + use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; + use crate::prover::prove_single_table; use crate::sha_extend::sha_extend_stark::ShaExtendStark; + use crate::sha_extend_sponge::columns::NUM_EXTEND_INPUT; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; fn to_be_bits(value: u32) -> [u32; 32] { let mut result = [0; 32]; @@ -147,6 +402,14 @@ mod test { result } + fn get_random_input() -> [u32; NUM_EXTEND_INPUT * 32] { + let mut input_values = vec![]; + let rand = rand::random::(); + input_values.extend((rand..rand + 4).map(|i| to_be_bits(i as u32))); + let input_values = input_values.into_iter().flatten().collect::>(); + input_values.try_into().unwrap() + } + #[test] fn test_correction() -> Result<(), String> { const D: usize = 2; @@ -183,4 +446,102 @@ mod test { Ok(()) } + + #[test] + fn test_stark_degree() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaExtendStark; + + let stark = S { + f: Default::default(), + }; + test_stark_low_degree(stark) + } + + #[test] + fn test_stark_circuit() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaExtendStark; + + let stark = S { + f: Default::default(), + }; + test_stark_circuit_constraints::(stark) + } + + #[test] + fn sha_extend_benchmark() -> anyhow::Result<()> { + const NUM_EXTEND: usize = 48; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaExtendStark; + let stark = S::default(); + let config = StarkConfig::standard_fast_config(); + + init_logger(); + + let input: Vec<([u32; NUM_EXTEND_INPUT * 32], usize)> = + (0..NUM_EXTEND).map(|_| (get_random_input(), 0)).collect(); + + let mut timing = TimingTree::new("prove", log::Level::Debug); + let trace_poly_values = stark.generate_trace(input, 8); + + // TODO: Cloning this isn't great; consider having `from_values` accept a reference, + // or having `compute_permutation_z_polys` read trace values from the `PolynomialBatch`. + let cloned_trace_poly_values = timed!(timing, "clone", trace_poly_values.clone()); + + let trace_commitments = timed!( + timing, + "compute trace commitment", + PolynomialBatch::::from_values( + cloned_trace_poly_values, + config.fri_config.rate_bits, + false, + config.fri_config.cap_height, + &mut timing, + None, + ) + ); + let degree = 1 << trace_commitments.degree_log; + + // Fake CTL data. + let ctl_z_data = CtlZData { + helper_columns: vec![PolynomialValues::zero(degree)], + z: PolynomialValues::zero(degree), + challenge: GrandProductChallenge { + beta: F::ZERO, + gamma: F::ZERO, + }, + columns: vec![], + filter: vec![Some(Filter::new_simple(Column::constant(F::ZERO)))], + }; + let ctl_data = CtlData { + zs_columns: vec![ctl_z_data.clone(); config.num_challenges], + }; + + prove_single_table( + &stark, + &config, + &trace_poly_values, + &trace_commitments, + &ctl_data, + &GrandProductChallengeSet { + challenges: vec![ctl_z_data.challenge; config.num_challenges], + }, + &mut Challenger::new(), + &mut timing, + )?; + + timing.print(); + Ok(()) + } + + fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); + } } \ No newline at end of file diff --git a/prover/src/sha_extend_sponge/logic.rs b/prover/src/sha_extend_sponge/logic.rs new file mode 100644 index 00000000..be31b63f --- /dev/null +++ b/prover/src/sha_extend_sponge/logic.rs @@ -0,0 +1,50 @@ +use plonky2::field::extension::Extendable; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use crate::sha_extend_sponge::sha_extend_sponge_stark::NUM_ROUNDS; + + +// Compute (x - y - diff) * sum_round_flags +pub(crate) fn diff_address_ext_circuit_constraint, const D: usize>( + builder: &mut CircuitBuilder, + sum_round_flags: ExtensionTarget, + x: ExtensionTarget, + y: ExtensionTarget, + diff: usize +) -> ExtensionTarget { + let inter_1 = builder.sub_extension(x, y); + let diff_ext = builder.constant_extension(F::Extension::from_canonical_u32(diff as u32)); + let address_diff = builder.sub_extension(inter_1, diff_ext); + builder.mul_extension(sum_round_flags, address_diff) +} + +// Compute nxt_round - local_round - 1 +pub(crate) fn round_increment_ext_circuit_constraint, const D: usize>( + builder: &mut CircuitBuilder, + local_round: [ExtensionTarget; NUM_ROUNDS], + next_round: [ExtensionTarget; NUM_ROUNDS], +) -> ExtensionTarget { + + let one_ext = builder.one_extension(); + let local_round_indices: Vec<_> = + (0..NUM_ROUNDS).map(|i| { + let index = builder.constant_extension(F::Extension::from_canonical_u32(i as u32)); + builder.mul_extension(local_round[i], index) + }).collect(); + + let local_round_index = builder.add_many_extension(local_round_indices); + + let next_round_indices: Vec<_> = + (0..NUM_ROUNDS).map(|i| { + let index = builder.constant_extension(F::Extension::from_canonical_u32(i as u32)); + builder.mul_extension(next_round[i], index) + }).collect(); + + let next_round_index = builder.add_many_extension(next_round_indices); + + let increment = builder.sub_extension(next_round_index, local_round_index); + builder.sub_extension(increment, one_ext) + +} \ No newline at end of file diff --git a/prover/src/sha_extend_sponge/mod.rs b/prover/src/sha_extend_sponge/mod.rs index 96a740d5..afdca798 100644 --- a/prover/src/sha_extend_sponge/mod.rs +++ b/prover/src/sha_extend_sponge/mod.rs @@ -1,2 +1,3 @@ pub mod columns; -pub mod sha_extend_sponge_stark; \ No newline at end of file +pub mod sha_extend_sponge_stark; +pub mod logic; diff --git a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs index df2be7be..3c40f574 100644 --- a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs +++ b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs @@ -13,7 +13,8 @@ use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; use crate::sha_extend::columns::get_input_range; use crate::sha_extend::logic::{from_be_bits_to_u32, from_u32_to_be_bits}; -use crate::sha_extend_sponge::columns::{ShaExtendSpongeColumnsView, NUM_SHA_EXTEND_SPONGE_COLUMNS}; +use crate::sha_extend_sponge::columns::{ShaExtendSpongeColumnsView, NUM_EXTEND_INPUT, NUM_SHA_EXTEND_SPONGE_COLUMNS}; +use crate::sha_extend_sponge::logic::{diff_address_ext_circuit_constraint, round_increment_ext_circuit_constraint}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryAddress; @@ -136,7 +137,86 @@ impl, const D: usize> Stark for ShaExtendSpon FE: FieldExtension, P: PackedField { - todo!() + + let local_values: &[P; NUM_SHA_EXTEND_SPONGE_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let local_values: &ShaExtendSpongeColumnsView

= local_values.borrow(); + let next_values: &[P; NUM_SHA_EXTEND_SPONGE_COLUMNS] = + vars.get_next_values().try_into().unwrap(); + let next_values: &ShaExtendSpongeColumnsView

= next_values.borrow(); + + // check the binary form + for i in 0..32 { + yield_constr.constraint(local_values.w_i_minus_15[i] * (local_values.w_i_minus_15[i] - P::ONES)); + yield_constr.constraint(local_values.w_i_minus_2[i] * (local_values.w_i_minus_2[i] - P::ONES)); + yield_constr.constraint(local_values.w_i_minus_16[i] * (local_values.w_i_minus_16[i] - P::ONES)); + yield_constr.constraint(local_values.w_i_minus_7[i] * (local_values.w_i_minus_7[i] - P::ONES)); + yield_constr.constraint(local_values.w_i[i] * (local_values.w_i[i] - P::ONES)); + } + + // check the round + for i in 0..NUM_ROUNDS { + yield_constr.constraint(local_values.round[i] * (local_values.round[i] - P::ONES)); + } + + // check the filter + let is_final = local_values.round[NUM_ROUNDS - 1]; + yield_constr.constraint(is_final * (is_final - P::ONES)); + let not_final = P::ONES - is_final; + + let sum_round_flags = (0..NUM_ROUNDS) + .map(|i| local_values.round[i]) + .sum::

(); + + // If this is not the final step or a padding row, + // the local and next timestamps must match. + yield_constr.constraint( + sum_round_flags * not_final * (next_values.timestamp - local_values.timestamp), + ); + + // If this is not the final step or a padding row, + // round index should be increased by one + + let local_round_index = (0..NUM_ROUNDS) + .map(|i| local_values.round[i] * FE::from_canonical_u32(i as u32)) + .sum::

(); + let next_round_index = (0..NUM_ROUNDS) + .map(|i| next_values.round[i] * FE::from_canonical_u32(i as u32)) + .sum::

(); + yield_constr.constraint( + sum_round_flags * not_final * (next_round_index - local_round_index - P::ONES) + ); + + // If this is not the final step or a padding row, + // input and output addresses should be increased by 4 each + (0..NUM_EXTEND_INPUT).for_each(|i| { + yield_constr.constraint( + sum_round_flags * not_final * (next_values.input_virt[i] - local_values.input_virt[i] - FE::from_canonical_u32(4)) + ); + }); + yield_constr.constraint( + sum_round_flags * not_final * (next_values.output_virt - local_values.output_virt - FE::from_canonical_u32(4)) + ); + + // If it's not the padding row, check the virtual addresses + // The list of input addresses are: w[i-15], w[i-2], w[i-16], w[i-7] + + // add_w[i-15] = add_w[i-16] + 4 + yield_constr.constraint( + sum_round_flags * (local_values.input_virt[0] - local_values.input_virt[2] - FE::from_canonical_u32(4)) + ); + // add_w[i-2] = add_w[i-16] + 56 + yield_constr.constraint( + sum_round_flags * (local_values.input_virt[1] - local_values.input_virt[2] - FE::from_canonical_u32(56)) + ); + // add_w[i-7] = add_w[i-16] + 36 + yield_constr.constraint( + sum_round_flags * (local_values.input_virt[3] - local_values.input_virt[2] - FE::from_canonical_u32(36)) + ); + // add_w[i] = add_w[i-16] + 64 + yield_constr.constraint( + sum_round_flags * (local_values.output_virt - local_values.input_virt[2] - FE::from_canonical_u32(64)) + ); } fn eval_ext_circuit( @@ -145,23 +225,165 @@ impl, const D: usize> Stark for ShaExtendSpon vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer ) { - todo!() + + let local_values: &[ExtensionTarget; NUM_SHA_EXTEND_SPONGE_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let local_values: &ShaExtendSpongeColumnsView> = local_values.borrow(); + let next_values: &[ExtensionTarget; NUM_SHA_EXTEND_SPONGE_COLUMNS] = + vars.get_next_values().try_into().unwrap(); + let next_values: &ShaExtendSpongeColumnsView> = next_values.borrow(); + + let one_ext = builder.one_extension(); + let four_ext = builder.constant_extension(F::Extension::from_canonical_u32(4)); + + // check the binary form + for i in 0..32 { + let constraint = builder.mul_sub_extension( + local_values.w_i_minus_15[i], local_values.w_i_minus_15[i], local_values.w_i_minus_15[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.w_i_minus_2[i], local_values.w_i_minus_2[i], local_values.w_i_minus_2[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.w_i_minus_16[i], local_values.w_i_minus_16[i], local_values.w_i_minus_16[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.w_i_minus_7[i], local_values.w_i_minus_7[i], local_values.w_i_minus_7[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.w_i[i], local_values.w_i[i], local_values.w_i[i]); + yield_constr.constraint(builder, constraint); + } + + // check the round + for i in 0..NUM_ROUNDS { + let constraint = builder.mul_sub_extension( + local_values.round[i], local_values.round[i], local_values.round[i] + ); + yield_constr.constraint(builder, constraint); + } + + // check the filter + let is_final = local_values.round[NUM_ROUNDS - 1]; + let constraint = builder.mul_sub_extension(is_final, is_final, is_final); + yield_constr.constraint(builder, constraint); + let not_final = builder.sub_extension(one_ext, is_final); + + let sum_round_flags = + builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values.round[i])); + + // If this is not the final step or a padding row, + // the local and next timestamps must match. + let diff = builder.sub_extension(next_values.timestamp, local_values.timestamp); + let constraint = builder.mul_many_extension([sum_round_flags, not_final, diff]); + yield_constr.constraint(builder, constraint); + + // If this is not the final step or a padding row, + // round index should be increased by one + + let round_increment = round_increment_ext_circuit_constraint( + builder, + local_values.round, + next_values.round + ); + let constraint = builder.mul_many_extension( + [sum_round_flags, not_final, round_increment] + ); + yield_constr.constraint(builder, constraint); + + // If this is not the final step or a padding row, + // input and output addresses should be increased by 4 each + (0..NUM_EXTEND_INPUT).for_each(|i| { + + let increment = builder.sub_extension(next_values.input_virt[i], local_values.input_virt[i]); + let address_increment = builder.sub_extension(increment, four_ext); + let constraint = builder.mul_many_extension( + [sum_round_flags, not_final, address_increment] + ); + yield_constr.constraint(builder, constraint); + }); + + let increment = builder.sub_extension(next_values.output_virt, local_values.output_virt); + let address_increment = builder.sub_extension(increment, four_ext); + let constraint = builder.mul_many_extension( + [sum_round_flags, not_final, address_increment] + ); + yield_constr.constraint(builder, constraint); + + + // If it's not the padding row, check the virtual addresses + // The list of input addresses are: w[i-15], w[i-2], w[i-16], w[i-7] + + // add_w[i-15] = add_w[i-16] + 4 + let constraint = diff_address_ext_circuit_constraint( + builder, + sum_round_flags, + local_values.input_virt[0], + local_values.input_virt[2], + 4 + ); + yield_constr.constraint(builder, constraint); + + // add_w[i-2] = add_w[i-16] + 56 + let constraint = diff_address_ext_circuit_constraint( + builder, + sum_round_flags, + local_values.input_virt[1], + local_values.input_virt[2], + 56 + ); + yield_constr.constraint(builder, constraint); + + // add_w[i-7] = add_w[i-16] + 36 + let constraint = diff_address_ext_circuit_constraint( + builder, + sum_round_flags, + local_values.input_virt[3], + local_values.input_virt[2], + 36 + ); + yield_constr.constraint(builder, constraint); + + // add_w[i] = add_w[i-16] + 64 + let constraint = diff_address_ext_circuit_constraint( + builder, + sum_round_flags, + local_values.output_virt, + local_values.input_virt[2], + 64 + ); + yield_constr.constraint(builder, constraint); } fn constraint_degree(&self) -> usize { - todo!() + 3 } } + #[cfg(test)] mod test { + use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; + use plonky2::fri::oracle::PolynomialBatch; + use plonky2::iop::challenger::Challenger; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::timed; + use plonky2::util::timing::TimingTree; + use crate::config::StarkConfig; + use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; use crate::memory::segments::Segment; + use crate::prover::prove_single_table; use crate::sha_extend_sponge::sha_extend_sponge_stark::{ShaExtendSpongeOp, ShaExtendSpongeStark}; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use crate::witness::memory::MemoryAddress; - fn to_be_bits(value: u32) -> [u32; 32] { let mut result = [0; 32]; for i in 0..32 { @@ -217,4 +439,155 @@ mod test { Ok(()) } + + #[test] + fn test_stark_circuit() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaExtendSpongeStark; + + let stark = S::default(); + test_stark_circuit_constraints::(stark) + } + + #[test] + fn test_stark_degree() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaExtendSpongeStark; + + let stark = S { + f: Default::default(), + }; + test_stark_low_degree(stark) + } + + fn get_random_input() -> Vec { + let mut w = [0u32; 64]; + for i in 0..16 { + w[i] = rand::random::(); + } + for i in 16..64 { + + let w_i_minus_15 = w[i-15]; + let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); + + // Read w[i-2]. + let w_i_minus_2 = w[i-2]; + // Compute `s1`. + let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); + + // Read w[i-16]. + let w_i_minus_16 = w[i-16]; + let w_i_minus_7 = w[i-7]; + + // Compute `w_i`. + w[i] = s1 + .wrapping_add(w_i_minus_16) + .wrapping_add(s0) + .wrapping_add(w_i_minus_7); + } + + let mut addresses = vec![]; + for i in 0..64 { + addresses.push(MemoryAddress{ + context: 0, + segment: Segment::Code as usize, + virt: i * 4 + }); + } + + let mut res = vec![]; + + for i in 16..64 { + let mut input_values = vec![]; + input_values.extend(to_be_bits(w[i - 15])); + input_values.extend(to_be_bits(w[i - 2])); + input_values.extend(to_be_bits(w[i - 16])); + input_values.extend(to_be_bits(w[i - 7])); + + let op = ShaExtendSpongeOp { + base_address: vec![addresses[i - 15], addresses[i - 2], addresses[i - 16], addresses[i - 7]], + timestamp: 0, + input: input_values, + i: i as u32 - 16, + output_address: addresses[i], + }; + + res.push(op); + } + + res + + } + #[test] + fn sha_extend_sponge_benchmark() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaExtendSpongeStark; + let stark = S::default(); + let config = StarkConfig::standard_fast_config(); + + init_logger(); + + let input = get_random_input(); + let mut timing = TimingTree::new("prove", log::Level::Debug); + let trace_poly_values = stark.generate_trace(input, 8); + + // TODO: Cloning this isn't great; consider having `from_values` accept a reference, + // or having `compute_permutation_z_polys` read trace values from the `PolynomialBatch`. + let cloned_trace_poly_values = timed!(timing, "clone", trace_poly_values.clone()); + + let trace_commitments = timed!( + timing, + "compute trace commitment", + PolynomialBatch::::from_values( + cloned_trace_poly_values, + config.fri_config.rate_bits, + false, + config.fri_config.cap_height, + &mut timing, + None, + ) + ); + let degree = 1 << trace_commitments.degree_log; + + // Fake CTL data. + let ctl_z_data = CtlZData { + helper_columns: vec![PolynomialValues::zero(degree)], + z: PolynomialValues::zero(degree), + challenge: GrandProductChallenge { + beta: F::ZERO, + gamma: F::ZERO, + }, + columns: vec![], + filter: vec![Some(Filter::new_simple(Column::constant(F::ZERO)))], + }; + let ctl_data = CtlData { + zs_columns: vec![ctl_z_data.clone(); config.num_challenges], + }; + + prove_single_table( + &stark, + &config, + &trace_poly_values, + &trace_commitments, + &ctl_data, + &GrandProductChallengeSet { + challenges: vec![ctl_z_data.challenge; config.num_challenges], + }, + &mut Challenger::new(), + &mut timing, + )?; + + timing.print(); + Ok(()) + } + + fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); + } } \ No newline at end of file From 6b2dfba3421e28e635388bf1d33f3e40a083047b Mon Sep 17 00:00:00 2001 From: vanhger Date: Thu, 23 Jan 2025 09:39:16 +0700 Subject: [PATCH 11/25] chore: make the logic functions support generic input sizes --- prover/src/sha_extend/logic.rs | 41 ++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/prover/src/sha_extend/logic.rs b/prover/src/sha_extend/logic.rs index 371f589f..2f850a6b 100644 --- a/prover/src/sha_extend/logic.rs +++ b/prover/src/sha_extend/logic.rs @@ -80,20 +80,23 @@ pub(crate) fn shift_right_ext_circuit_constraints, result } -pub fn xor3 , const D: usize>(a: [F; 32], b: [F; 32], c: [F; 32]) -> [F; 32] { - let mut result = [F::ZERO; 32]; - for i in 0..32 { +pub fn xor3 , const D: usize, const N: usize>(a: [F; N], b: [F; N], c: [F; N]) -> [F; N] { + let mut result = [F::ZERO; N]; + for i in 0..N { result[i] = crate::keccak::logic::xor([a[i], b[i], c[i]]); } result } -pub fn wrapping_add, const D: usize>(a: [F; 32], b: [F; 32]) -> ([F; 32], [F; 32]) { - let mut result = [F::ZERO; 32]; - let mut carries = [F::ZERO; 32]; +pub fn wrapping_add, const D: usize, const N: usize>( + a: [F; N], + b: [F; N] +) -> ([F; N], [F; N]) { + let mut result = [F::ZERO; N]; + let mut carries = [F::ZERO; N]; let mut sum = F::ZERO; let mut carry = F::ZERO; - for i in 0..32 { + for i in 0..N { debug_assert!(a[i].is_zero() || a[i].is_one()); debug_assert!(b[i].is_zero() || b[i].is_one()); @@ -124,16 +127,16 @@ pub fn from_u32_to_be_bits(value: u32) -> [u32; 32] { } /// Computes the constraints of wrapping add -pub(crate) fn wrapping_add_packed_constraints( - x: [P; 32], - y: [P; 32], - carry: [P; 32], - out: [P; 32] +pub(crate) fn wrapping_add_packed_constraints( + x: [P; N], + y: [P; N], + carry: [P; N], + out: [P; N] ) -> Vec

{ let mut result = vec![]; let mut pre_carry = P::ZEROS; - for i in 0..32 { + for i in 0..N { let sum = x[i] + y[i] + pre_carry; let out_constraint = (sum - P::ONES) * (sum - P::ONES - P::ONES - P::ONES) * out[i] @@ -147,12 +150,12 @@ pub(crate) fn wrapping_add_packed_constraints( result } -pub(crate) fn wrapping_add_ext_circuit_constraints, const D: usize>( +pub(crate) fn wrapping_add_ext_circuit_constraints, const D: usize, const N: usize>( builder: &mut CircuitBuilder, - x: [ExtensionTarget; 32], - y: [ExtensionTarget; 32], - carry: [ExtensionTarget; 32], - out: [ExtensionTarget; 32] + x: [ExtensionTarget; N], + y: [ExtensionTarget; N], + carry: [ExtensionTarget; N], + out: [ExtensionTarget; N] ) -> Vec> { let mut result = vec![]; @@ -160,7 +163,7 @@ pub(crate) fn wrapping_add_ext_circuit_constraints, let one_ext = builder.one_extension(); let two_ext = builder.two_extension(); let three_ext = builder.constant_extension(F::Extension::from_canonical_u8(3)); - for i in 0..32 { + for i in 0..N { let sum = builder.add_many_extension([x[i], y[i], pre_carry]); let inner_1 = builder.sub_extension(sum, one_ext); From 3971f5bfb4a5f0bc32fae3cd936f3ad0a74c7091 Mon Sep 17 00:00:00 2001 From: vanhger Date: Thu, 23 Jan 2025 10:31:20 +0700 Subject: [PATCH 12/25] chore: change the location of get_input_range function --- prover/src/sha_extend/columns.rs | 4 ---- prover/src/sha_extend/logic.rs | 18 ++++++++++++------ prover/src/sha_extend/sha_extend_stark.rs | 4 ++-- .../sha_extend_sponge_stark.rs | 3 +-- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/prover/src/sha_extend/columns.rs b/prover/src/sha_extend/columns.rs index 3ff2718e..e4674d99 100644 --- a/prover/src/sha_extend/columns.rs +++ b/prover/src/sha_extend/columns.rs @@ -83,7 +83,3 @@ const fn make_col_map() -> ShaExtendColumnsView { pub(crate) const SHA_EXTEND_COL_MAP: ShaExtendColumnsView = make_col_map(); -pub fn get_input_range(i: usize) -> std::ops::Range { - (0 + i * 32)..(32 + i * 32) -} - diff --git a/prover/src/sha_extend/logic.rs b/prover/src/sha_extend/logic.rs index 2f850a6b..cbd24ec0 100644 --- a/prover/src/sha_extend/logic.rs +++ b/prover/src/sha_extend/logic.rs @@ -5,8 +5,14 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; + +pub(crate) fn get_input_range(i: usize) -> std::ops::Range { + (0 + i * 32)..(32 + i * 32) +} + + // these operators are applied in big-endian form -pub fn rotate_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { +pub(crate) fn rotate_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { let mut result = [F::ZERO; 32]; for i in 0..32 { result[i] = value[(i + amount) % 32]; @@ -39,7 +45,7 @@ pub(crate) fn rotate_right_ext_circuit_constraint, result } -pub fn shift_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { +pub(crate) fn shift_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { let mut result = [F::ZERO; 32]; if amount < 32 { for i in 0..32 - amount { @@ -80,7 +86,7 @@ pub(crate) fn shift_right_ext_circuit_constraints, result } -pub fn xor3 , const D: usize, const N: usize>(a: [F; N], b: [F; N], c: [F; N]) -> [F; N] { +pub(crate) fn xor3 , const D: usize, const N: usize>(a: [F; N], b: [F; N], c: [F; N]) -> [F; N] { let mut result = [F::ZERO; N]; for i in 0..N { result[i] = crate::keccak::logic::xor([a[i], b[i], c[i]]); @@ -88,7 +94,7 @@ pub fn xor3 , const D: usize, const N: usize>(a: [F result } -pub fn wrapping_add, const D: usize, const N: usize>( +pub(crate) fn wrapping_add, const D: usize, const N: usize>( a: [F; N], b: [F; N] ) -> ([F; N], [F; N]) { @@ -109,7 +115,7 @@ pub fn wrapping_add, const D: usize, const N: usize (result, carries) } -pub fn from_be_bits_to_u32, const D: usize>(value: [F; 32]) -> u32 { +pub(crate) fn from_be_bits_to_u32, const D: usize>(value: [F; 32]) -> u32 { let mut result = 0; for i in 0..32 { debug_assert!(value[i].is_zero() || value[i].is_one()); @@ -118,7 +124,7 @@ pub fn from_be_bits_to_u32, const D: usize>(value: result } -pub fn from_u32_to_be_bits(value: u32) -> [u32; 32] { +pub(crate) fn from_u32_to_be_bits(value: u32) -> [u32; 32] { let mut result = [0; 32]; for i in 0..32 { result[i] = ((value >> i) & 1) as u32; diff --git a/prover/src/sha_extend/sha_extend_stark.rs b/prover/src/sha_extend/sha_extend_stark.rs index c1176997..13459296 100644 --- a/prover/src/sha_extend/sha_extend_stark.rs +++ b/prover/src/sha_extend/sha_extend_stark.rs @@ -9,8 +9,8 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::keccak::logic::{xor3_gen, xor3_gen_circuit}; -use crate::sha_extend::columns::{get_input_range, ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS}; -use crate::sha_extend::logic::{rotate_right, rotate_right_ext_circuit_constraint, rotate_right_packed_constraints, shift_right, shift_right_ext_circuit_constraints, shift_right_packed_constraints, wrapping_add, wrapping_add_ext_circuit_constraints, wrapping_add_packed_constraints, xor3}; +use crate::sha_extend::columns::{ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS}; +use crate::sha_extend::logic::{get_input_range, rotate_right, rotate_right_ext_circuit_constraint, rotate_right_packed_constraints, shift_right, shift_right_ext_circuit_constraints, shift_right_packed_constraints, wrapping_add, wrapping_add_ext_circuit_constraints, wrapping_add_packed_constraints, xor3}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; diff --git a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs index 3c40f574..1ba2e0fd 100644 --- a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs +++ b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs @@ -11,8 +11,7 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; -use crate::sha_extend::columns::get_input_range; -use crate::sha_extend::logic::{from_be_bits_to_u32, from_u32_to_be_bits}; +use crate::sha_extend::logic::{get_input_range, from_be_bits_to_u32, from_u32_to_be_bits}; use crate::sha_extend_sponge::columns::{ShaExtendSpongeColumnsView, NUM_EXTEND_INPUT, NUM_SHA_EXTEND_SPONGE_COLUMNS}; use crate::sha_extend_sponge::logic::{diff_address_ext_circuit_constraint, round_increment_ext_circuit_constraint}; use crate::stark::Stark; From d4763e784fa05266065ebf7ba85adaa1cbd74c2e Mon Sep 17 00:00:00 2001 From: vanhger Date: Thu, 23 Jan 2025 11:30:33 +0700 Subject: [PATCH 13/25] chore: adjust the type of input --- prover/src/sha_extend/logic.rs | 6 ++--- prover/src/sha_extend/sha_extend_stark.rs | 26 +++++++++---------- .../sha_extend_sponge_stark.rs | 22 ++++++++-------- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/prover/src/sha_extend/logic.rs b/prover/src/sha_extend/logic.rs index cbd24ec0..d3e79350 100644 --- a/prover/src/sha_extend/logic.rs +++ b/prover/src/sha_extend/logic.rs @@ -124,10 +124,10 @@ pub(crate) fn from_be_bits_to_u32, const D: usize>( result } -pub(crate) fn from_u32_to_be_bits(value: u32) -> [u32; 32] { - let mut result = [0; 32]; +pub(crate) fn from_u32_to_be_bits(value: u32) -> [u8; 32] { + let mut result = [0_u8; 32]; for i in 0..32 { - result[i] = ((value >> i) & 1) as u32; + result[i] = ((value >> i) & 1) as u8; } result } diff --git a/prover/src/sha_extend/sha_extend_stark.rs b/prover/src/sha_extend/sha_extend_stark.rs index 13459296..5110e24f 100644 --- a/prover/src/sha_extend/sha_extend_stark.rs +++ b/prover/src/sha_extend/sha_extend_stark.rs @@ -25,7 +25,7 @@ pub struct ShaExtendStark { impl, const D: usize> ShaExtendStark { pub(crate) fn generate_trace( &self, - inputs_and_timestamps: Vec<([u32; NUM_INPUTS], usize)>, + inputs_and_timestamps: Vec<([u8; NUM_INPUTS], usize)>, min_rows: usize, ) -> Vec> { // Generate the witness row-wise @@ -35,7 +35,7 @@ impl, const D: usize> ShaExtendStark { fn generate_trace_rows( &self, - inputs_and_timestamps: Vec<([u32; NUM_INPUTS], usize)>, + inputs_and_timestamps: Vec<([u8; NUM_INPUTS], usize)>, min_rows: usize, ) -> Vec<[F; NUM_SHA_EXTEND_COLUMNS]> { let num_rows = inputs_and_timestamps.len() @@ -57,19 +57,19 @@ impl, const D: usize> ShaExtendStark { fn generate_trace_rows_for_extend( &self, - input_and_timestamp: ([u32; NUM_INPUTS], usize), + input_and_timestamp: ([u8; NUM_INPUTS], usize), ) -> ShaExtendColumnsView { let mut row = ShaExtendColumnsView::default(); row.timestamp = F::from_canonical_usize(input_and_timestamp.1); row.w_i_minus_15 = input_and_timestamp.0[get_input_range(0)] - .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.w_i_minus_2 = input_and_timestamp.0[get_input_range(1)] - .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.w_i_minus_16 = input_and_timestamp.0[get_input_range(2)] - .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.w_i_minus_7 = input_and_timestamp.0[get_input_range(3)] - .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); self.generate_trace_row_for_round(&mut row); row @@ -394,15 +394,15 @@ mod test { use crate::sha_extend_sponge::columns::NUM_EXTEND_INPUT; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; - fn to_be_bits(value: u32) -> [u32; 32] { + fn to_be_bits(value: u32) -> [u8; 32] { let mut result = [0; 32]; for i in 0..32 { - result[i] = ((value >> i) & 1) as u32; + result[i] = ((value >> i) & 1) as u8; } result } - fn get_random_input() -> [u32; NUM_EXTEND_INPUT * 32] { + fn get_random_input() -> [u8; NUM_EXTEND_INPUT * 32] { let mut input_values = vec![]; let rand = rand::random::(); input_values.extend((rand..rand + 4).map(|i| to_be_bits(i as u32))); @@ -419,7 +419,7 @@ mod test { let mut input_values = vec![]; input_values.extend((0..4).map(|i| to_be_bits(i as u32))); let input_values = input_values.into_iter().flatten().collect::>(); - let input_values: [u32; 128] = input_values.try_into().unwrap(); + let input_values: [u8; 128] = input_values.try_into().unwrap(); let input_and_timestamp = (input_values, 0); let stark = S::default(); @@ -442,7 +442,7 @@ mod test { .wrapping_add(w_i_minus_7); let w_i_bin = to_be_bits(w_i); - assert_eq!(row.w_i, w_i_bin.map(F::from_canonical_u32)); + assert_eq!(row.w_i, w_i_bin.map(F::from_canonical_u8)); Ok(()) } @@ -485,7 +485,7 @@ mod test { init_logger(); - let input: Vec<([u32; NUM_EXTEND_INPUT * 32], usize)> = + let input: Vec<([u8; NUM_EXTEND_INPUT * 32], usize)> = (0..NUM_EXTEND).map(|_| (get_random_input(), 0)).collect(); let mut timing = TimingTree::new("prove", log::Level::Debug); diff --git a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs index 1ba2e0fd..8c7be9e6 100644 --- a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs +++ b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs @@ -29,10 +29,10 @@ pub(crate) struct ShaExtendSpongeOp { /// The input that was read. /// Values: w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 in big-endian order. - pub(crate) input: Vec, + pub(crate) input: Vec, /// The index of round - pub(crate) i: u32, + pub(crate) i: usize, /// The base address at which the output is written. pub(crate) output_address: MemoryAddress, @@ -90,13 +90,13 @@ impl, const D: usize> ShaExtendSpongeStark { row.output_virt = F::from_canonical_usize(op.output_address.virt); row.w_i_minus_15 = op.input[get_input_range(0)] - .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.w_i_minus_2 = op.input[get_input_range(1)] - .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.w_i_minus_16 = op.input[get_input_range(2)] - .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.w_i_minus_7 = op.input[get_input_range(3)] - .iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap(); + .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.w_i = self.compute_w_i(&mut row); row @@ -115,7 +115,7 @@ impl, const D: usize> ShaExtendSpongeStark { .wrapping_add(w_i_minus_7); let w_i_bin = from_u32_to_be_bits(w_i_u32); - w_i_bin.iter().map(|&x| F::from_canonical_u32(x)).collect::>().try_into().unwrap() + w_i_bin.iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap() } } @@ -383,10 +383,10 @@ mod test { use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use crate::witness::memory::MemoryAddress; - fn to_be_bits(value: u32) -> [u32; 32] { + fn to_be_bits(value: u32) -> [u8; 32] { let mut result = [0; 32]; for i in 0..32 { - result[i] = ((value >> i) & 1) as u32; + result[i] = ((value >> i) & 1) as u8; } result } @@ -434,7 +434,7 @@ mod test { let row = stark.generate_rows_for_op(op); let w_i_bin = to_be_bits(40965); - assert_eq!(row.w_i, w_i_bin.map(F::from_canonical_u32)); + assert_eq!(row.w_i, w_i_bin.map(F::from_canonical_u8)); Ok(()) } @@ -511,7 +511,7 @@ mod test { base_address: vec![addresses[i - 15], addresses[i - 2], addresses[i - 16], addresses[i - 7]], timestamp: 0, input: input_values, - i: i as u32 - 16, + i: i - 16, output_address: addresses[i], }; From ac022f02f072c3c552940ea76f8a7407a67ae1b7 Mon Sep 17 00:00:00 2001 From: vanhger Date: Fri, 24 Jan 2025 07:53:13 +0700 Subject: [PATCH 14/25] chore: change the columns in SHA compress table. --- prover/src/sha_compress/columns.rs | 98 +++--- prover/src/sha_compress/constants.rs | 10 - prover/src/sha_compress/logic.rs | 53 +++ prover/src/sha_compress/mod.rs | 6 +- prover/src/sha_compress/sha_compress_stark.rs | 312 ++++++++++-------- 5 files changed, 269 insertions(+), 210 deletions(-) delete mode 100644 prover/src/sha_compress/constants.rs create mode 100644 prover/src/sha_compress/logic.rs diff --git a/prover/src/sha_compress/columns.rs b/prover/src/sha_compress/columns.rs index 87da0a42..739404df 100644 --- a/prover/src/sha_compress/columns.rs +++ b/prover/src/sha_compress/columns.rs @@ -3,69 +3,57 @@ use std::intrinsics::transmute; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; #[derive(Clone)] pub(crate) struct ShaCompressColumnsView { - /// The timestamp at which inputs should be read from memory. - pub timestamp: T, - - /// Round number - pub i: T, - /// 8 temp buffer values as input - pub a: T, - pub b: T, - pub c: T, - pub d: T, - pub e: T, - pub f: T, - pub g: T, - pub h: T, - /// w[i] - pub w: [T; 64], - - /// Selector - pub round_i_filter: [T; 64], + /// input state: a,b,c,d,e,f,g,h in binary form + pub input_state: [T; 256], + /// Out + pub output_state: [T; 256], + /// w[i] and key[i] + pub w_i: [T; 32], + pub k_i: [T; 32], /// Intermediate values - pub k_i: T, - pub w_i: T, - pub e_rr_6: T, - pub e_rr_11: T, - pub e_rr_25: T, - pub s_1_inter: T, - pub s_1: T, - pub e_and_f: T, - pub e_not: T, - pub e_not_and_g: T, - pub ch: T, - pub temp1: T, - pub a_rr_2: T, - pub a_rr_13: T, - pub a_rr_22: T, - pub s_0_inter: T, - pub s_0: T, - pub a_and_b: T, - pub a_and_c: T, - pub b_and_c: T, - pub maj_inter: T, - pub maj: T, - pub temp2: T, + pub e_rr_6: [T; 32], + pub e_rr_11: [T; 32], + pub e_rr_25: [T; 32], + pub s_1: [T; 32], + pub e_and_f: [T; 32], + pub not_e_and_g: [T; 32], + pub ch: [T;32], + // h.wrapping_add(s1) + pub inter_1: [T;32], + pub carry_1: [T;32], + // inter_1.wrapping_add(ch) + pub inter_2: [T;32], + pub carry_2: [T;32], + // inter_2.wrapping_add(SHA_COMPRESS_K[i]) + pub inter_3: [T;32], + pub carry_3: [T;32], + // inter_3.wrapping_add(w_i) + pub temp1: [T;32], + pub carry_4: [T;32], + + pub a_rr_2: [T;32], + pub a_rr_13: [T;32], + pub a_rr_22: [T;32], + pub s_0: [T;32], + pub a_and_b: [T;32], + pub a_and_c: [T;32], + pub b_and_c: [T;32], + pub maj: [T;32], + pub temp2: [T;32], + pub carry_5: [T;32], + pub carry_a: [T; 32], + pub carry_e: [T; 32], - /// Out - pub new_a: T, - pub new_b: T, - pub new_c: T, - pub new_d: T, - pub new_e: T, - pub new_f: T, - pub new_g: T, - pub new_h: T, - - /// 1 if this is the final round of the compress phase, 0 otherwise - pub is_final: T, + + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, } -pub const NUM_SHA_COMPRESS_COLUMNS: usize = size_of::>(); // 170 +pub const NUM_SHA_COMPRESS_COLUMNS: usize = size_of::>(); impl From<[T; NUM_SHA_COMPRESS_COLUMNS]> for ShaCompressColumnsView { fn from(value: [T; NUM_SHA_COMPRESS_COLUMNS]) -> Self { diff --git a/prover/src/sha_compress/constants.rs b/prover/src/sha_compress/constants.rs deleted file mode 100644 index b386f60d..00000000 --- a/prover/src/sha_compress/constants.rs +++ /dev/null @@ -1,10 +0,0 @@ -pub const SHA_COMPRESS_K: [u32; 64] = [ - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, - 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, - 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, - 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, - 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, - 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, - 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, -]; \ No newline at end of file diff --git a/prover/src/sha_compress/logic.rs b/prover/src/sha_compress/logic.rs new file mode 100644 index 00000000..feeece1c --- /dev/null +++ b/prover/src/sha_compress/logic.rs @@ -0,0 +1,53 @@ +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use crate::keccak::logic::andn_gen_circuit; + +pub(crate) fn and_op, const D: usize, const N: usize>( + x: [F; N], + y: [F; N] +) -> [F; N] { + let mut result = [F::ZERO; N]; + for i in 0..N { + debug_assert!(x[i].is_zero() || x[i].is_one()); + debug_assert!(y[i].is_zero() || y[i].is_one()); + result[i] = x[i] * y[i]; + } + result +} + +pub(crate) fn andn_op, const D: usize, const N: usize>( + x: [F; N], + y: [F; N] +) -> [F; N] { + let mut result = [F::ZERO; N]; + for i in 0..N { + debug_assert!(x[i].is_zero() || x[i].is_one()); + debug_assert!(y[i].is_zero() || y[i].is_one()); + result[i] = crate::keccak::logic::andn(x[i], y[i]); + } + result +} + +pub(crate) fn xor_op, const D: usize, const N: usize>( + x: [F; N], + y: [F; N] +) -> [F; N] { + let mut result = [F::ZERO; N]; + for i in 0..N { + debug_assert!(x[i].is_zero() || x[i].is_one()); + debug_assert!(y[i].is_zero() || y[i].is_one()); + result[i] = crate::keccak::logic::xor([x[i], y[i]]); + } + result +} + +pub(crate) fn from_be_bits_to_u32( bits: [u8; 32]) -> u32 { + let mut result = 0; + for i in 0..32 { + result |= (bits[i] as u32) << i; + } + result +} \ No newline at end of file diff --git a/prover/src/sha_compress/mod.rs b/prover/src/sha_compress/mod.rs index 910b48c5..f48755ac 100644 --- a/prover/src/sha_compress/mod.rs +++ b/prover/src/sha_compress/mod.rs @@ -1,3 +1,3 @@ -mod columns; -mod sha_compress_stark; -mod constants; \ No newline at end of file +pub mod columns; +pub mod sha_compress_stark; +pub mod logic; \ No newline at end of file diff --git a/prover/src/sha_compress/sha_compress_stark.rs b/prover/src/sha_compress/sha_compress_stark.rs index c2d928ae..425fbcc4 100644 --- a/prover/src/sha_compress/sha_compress_stark.rs +++ b/prover/src/sha_compress/sha_compress_stark.rs @@ -1,4 +1,5 @@ use std::marker::PhantomData; +use std::borrow::Borrow; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; @@ -6,15 +7,17 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::evaluation_frame::StarkFrame; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; +use crate::keccak::logic::{xor3_gen, xor3_gen_circuit, xor_gen, xor_gen_circuit}; use crate::sha_compress::columns::{ShaCompressColumnsView, NUM_SHA_COMPRESS_COLUMNS}; -use crate::sha_compress::constants::SHA_COMPRESS_K; +use crate::sha_compress::logic::{and_op, and_op_ext_circuit_constraints, and_op_packed_constraints, andn_op, andn_op_ext_circuit_constraints, andn_op_packed_constraints, equal_ext_circuit_constraints, equal_packed_constraint, xor_op}; +use crate::sha_extend::logic::{rotate_right, get_input_range, xor3, wrapping_add, rotate_right_packed_constraints, wrapping_add_packed_constraints, rotate_right_ext_circuit_constraint, wrapping_add_ext_circuit_constraints}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; pub const NUM_ROUND_CONSTANTS: usize = 64; -pub const NUM_INPUTS: usize = 72; // 8 + 64 +pub const NUM_INPUTS: usize = 10; // 8 states + w_i + key_i #[derive(Copy, Clone, Default)] pub struct ShaCompressStark { @@ -24,7 +27,7 @@ pub struct ShaCompressStark { impl, const D: usize> ShaCompressStark { pub(crate) fn generate_trace( &self, - inputs: Vec<([u32; NUM_INPUTS], usize)>, + inputs: Vec<([u8; NUM_INPUTS * 32], usize)>, min_rows: usize, ) -> Vec> { // Generate the witness row-wise @@ -34,17 +37,17 @@ impl, const D: usize> ShaCompressStark { fn generate_trace_rows( &self, - inputs_and_timestamps: Vec<([u32; NUM_INPUTS], usize)>, + inputs_and_timestamps: Vec<([u8; NUM_INPUTS * 32], usize)>, min_rows: usize, ) -> Vec<[F; NUM_SHA_COMPRESS_COLUMNS]> { - let num_rows = (inputs_and_timestamps.len() * NUM_ROUND_CONSTANTS) + let num_rows = inputs_and_timestamps.len() .max(min_rows) .next_power_of_two(); let mut rows = Vec::with_capacity(num_rows); for input_and_timestamp in inputs_and_timestamps.iter() { - let rows_for_compress = self.generate_trace_rows_for_compress(*input_and_timestamp); - rows.extend(rows_for_compress); + let row_for_compress = self.generate_trace_rows_for_compress(*input_and_timestamp); + rows.push(row_for_compress); } while rows.len() < num_rows { @@ -55,130 +58,113 @@ impl, const D: usize> ShaCompressStark { fn generate_trace_rows_for_compress( &self, - input_and_timestamp: ([u32; NUM_INPUTS], usize), - ) -> Vec<[F; NUM_SHA_COMPRESS_COLUMNS]> { - - let mut rows = vec![ShaCompressColumnsView::default(); NUM_ROUND_CONSTANTS]; + input_and_timestamp: ([u8; NUM_INPUTS * 32], usize), + ) -> [F; NUM_SHA_COMPRESS_COLUMNS] { let timestamp = input_and_timestamp.1; let inputs = input_and_timestamp.0; - // set the first row - - - for round in 0..NUM_ROUND_CONSTANTS { - rows[round].timestamp = F::from_canonical_usize(timestamp); - rows[round].i = F::from_canonical_usize(round); - rows[round].is_final = F::ZERO; - if round == NUM_ROUND_CONSTANTS - 1 { - rows[round].is_final = F::ONE; - } + let mut row = ShaCompressColumnsView::::default(); + row.timestamp = F::from_canonical_usize(timestamp); + // read inputs + row.input_state = inputs[0..256].iter().map(|x| F::from_canonical_u8(*x)).collect::>().try_into().unwrap(); + row.w_i = inputs[256..288].iter().map(|x| F::from_canonical_u8(*x)).collect::>().try_into().unwrap(); + row.k_i = inputs[288..320].iter().map(|x| F::from_canonical_u8(*x)).collect::>().try_into().unwrap(); + + // compute + row.e_rr_6 = rotate_right(row.input_state[get_input_range(4)].try_into().unwrap(), 6); + row.e_rr_11 = rotate_right(row.input_state[get_input_range(4)].try_into().unwrap(), 11); + row.e_rr_25 = rotate_right(row.input_state[get_input_range(4)].try_into().unwrap(), 25); + row.s_1 = xor3(row.e_rr_6, row.e_rr_11, row.e_rr_25); + + row.e_and_f = and_op( + row.input_state[get_input_range(4)].try_into().unwrap(), + row.input_state[get_input_range(5)].try_into().unwrap(), + ); + + row.not_e_and_g = andn_op( + row.input_state[get_input_range(4)].try_into().unwrap(), + row.input_state[get_input_range(6)].try_into().unwrap(), + ); + + row.ch = xor_op(row.e_and_f, row.not_e_and_g); + + (row.inter_1, row.carry_1) = wrapping_add( + row.input_state[get_input_range(7)].try_into().unwrap(), + row.s_1, + ); + + (row.inter_2, row.carry_2) = wrapping_add( + row.inter_1, + row.ch, + ); + + (row.inter_3, row.carry_3) = wrapping_add( + row.inter_2, + row.k_i, + ); + + (row.temp1, row.carry_4) = wrapping_add( + row.inter_3, + row.w_i, + ); + + row.a_rr_2 = rotate_right(row.input_state[get_input_range(0)].try_into().unwrap(), 2); + row.a_rr_13 = rotate_right(row.input_state[get_input_range(0)].try_into().unwrap(), 13); + row.a_rr_22 = rotate_right(row.input_state[get_input_range(0)].try_into().unwrap(), 22); + row.s_0 = xor3(row.a_rr_2, row.a_rr_13, row.a_rr_22); + + row.b_and_c = and_op( + row.input_state[get_input_range(1)].try_into().unwrap(), + row.input_state[get_input_range(2)].try_into().unwrap(), + ); + + row.a_and_b = and_op( + row.input_state[get_input_range(0)].try_into().unwrap(), + row.input_state[get_input_range(1)].try_into().unwrap(), + ); + + row.a_and_c = and_op( + row.input_state[get_input_range(0)].try_into().unwrap(), + row.input_state[get_input_range(2)].try_into().unwrap(), + ); + + row.maj = xor3(row.a_and_b, row.a_and_c, row.b_and_c); + (row.temp2, row.carry_5) = wrapping_add( + row.s_0, + row.maj, + ); + + + for i in 32..256 { + row.output_state[i] = row.input_state[i - 32]; } - // Populate the round input for the first round. - [rows[0].a, rows[0].b, rows[0].c, rows[0].d, - rows[0].e, rows[0].f, rows[0].g, rows[0].h] = inputs[0..8] - .iter() - .map(|&x| F::from_canonical_u32(x)) - .collect::>() - .try_into() - .unwrap(); - - - rows[0].w = inputs[8..inputs.len()].iter() - .map(|&x| F::from_canonical_u32(x)) - .collect::>() - .try_into() - .unwrap(); - - self.generate_trace_row_for_round(&mut rows[0], 0); - for round in 1..NUM_ROUND_CONSTANTS{ - self.copy_output_to_input(&mut rows, round); - self.generate_trace_row_for_round(&mut rows[round], round); - } + let mut new_e; + let mut new_a; + (new_e, row.carry_e) = wrapping_add( + row.input_state[get_input_range(3)].try_into().unwrap(), + row.temp1, + ); - rows.into_iter().map(|row| row.into()).collect::>() - } + (new_a, row.carry_a) = wrapping_add( + row.temp1, + row.temp2, + ); - fn generate_trace_row_for_round(&self, row: &mut ShaCompressColumnsView, round: usize) { - row.round_i_filter = [F::ZERO; NUM_ROUND_CONSTANTS]; - row.round_i_filter[round] = F::ONE; - - row.k_i = F::from_canonical_u32(SHA_COMPRESS_K[round]); - row.w_i = row.w[round]; - - let e = row.e.to_canonical_u64() as u32; - let g = row.g.to_canonical_u64() as u32; - let e_rr_6 = e.rotate_right(6); - let e_rr_11 = e.rotate_right(11); - let s_1_inter = e_rr_6 ^ e_rr_11; - let e_rr_25 = e.rotate_right(25); - let s_1 = s_1_inter ^ e_rr_25; - - [row.e_rr_6, row.e_rr_11, row.e_rr_25, row.s_1_inter, row.s_1] - = [e_rr_6, e_rr_11, e_rr_25, s_1_inter, s_1].map(F::from_canonical_u32); - - let e_and_f = e & (row.f.to_canonical_u64() as u32); - let e_not = !e; - let e_not_and_g = e_not & g; - let ch = e_and_f ^ e_not_and_g; - let temp1 = (row.h.to_canonical_u64() as u32).wrapping_add(s_1) - .wrapping_add(ch) - .wrapping_add(row.k_i.to_canonical_u64() as u32) - .wrapping_add(row.w_i.to_canonical_u64() as u32); - - [row.e_and_f, row.e_not, row.e_not_and_g, row.ch, row.temp1] - = [e_and_f, e_not, e_not_and_g, ch, temp1].map(F::from_canonical_u32); - - let a = row.a.to_canonical_u64() as u32; - let a_rr_2 = a.rotate_right(2); - let a_rr_13 = a.rotate_right(13); - let a_rr_22 = a.rotate_right(22); - let s_0_inter = a_rr_2 ^ a_rr_13; - let s_0 = s_0_inter ^ a_rr_22; - - [row.a_rr_22, row.a_rr_13, row.a_rr_2, row.s_0_inter, row.s_0] - = [a_rr_22, a_rr_13, a_rr_2, s_0_inter, s_0].map(F::from_canonical_u32); - - let a_and_b = a & (row.b.to_canonical_u64() as u32); - let a_and_c = a & (row.c.to_canonical_u64() as u32); - let b_and_c = (row.b.to_canonical_u64() as u32) & (row.c.to_canonical_u64() as u32); - let maj_inter = a_and_b ^ a_and_c; - let maj = maj_inter ^ b_and_c; - let temp2 = s_0.wrapping_add(maj); - - let new_e = (row.d.to_canonical_u64() as u32).wrapping_add(temp1); - let new_a = temp1.wrapping_add(temp2); - [row.a_and_b, row.a_and_c, row.b_and_c, row.maj_inter, row.maj, row.temp2] - = [a_and_b, a_and_c, b_and_c, maj_inter, maj, temp2].map(F::from_canonical_u32); - - row.new_h = row.g; - row.new_g = row.f; - row.new_f = row.e; - row.new_e = F::from_canonical_u32(new_e); - row.new_d = row.c; - row.new_c = row.b; - row.new_b = row.a; - row.new_a = F::from_canonical_u32(new_a); - } + for i in 0..32 { + row.output_state[i] = new_a[i]; + row.output_state[i + 32 * 4] = new_e[i]; + } - fn copy_output_to_input(&self, rows: &mut Vec>, round: usize) { - rows[round].a = rows[round-1].new_a; - rows[round].b = rows[round-1].new_b; - rows[round].c = rows[round-1].new_c; - rows[round].d = rows[round-1].new_d; - rows[round].e = rows[round-1].new_e; - rows[round].f = rows[round-1].new_f; - rows[round].g = rows[round-1].new_g; - rows[round].h = rows[round-1].new_h; - rows[round].w = rows[round-1].w; + row.into() } } impl, const D: usize> Stark for ShaCompressStark { type EvaluationFrame - = StarkFrame + = StarkFrame where FE: FieldExtension, P: PackedField; @@ -212,9 +198,25 @@ impl, const D: usize> Stark for ShaCompressSt #[cfg(test)] mod test { use plonky2::field::goldilocks_field::GoldilocksField; - use plonky2::field::types::Field; use crate::sha_compress::columns::ShaCompressColumnsView; - use crate::sha_compress::sha_compress_stark::ShaCompressStark; + use crate::sha_compress::sha_compress_stark::{ShaCompressStark, NUM_INPUTS}; + use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range}; + use std::borrow::Borrow; + use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; + use plonky2::field::polynomial::PolynomialValues; + use plonky2::field::types::Field; + use plonky2::fri::oracle::PolynomialBatch; + use plonky2::iop::challenger::Challenger; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::timed; + use plonky2::util::timing::TimingTree; + use crate::config::StarkConfig; + use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; + use crate::prover::prove_single_table; + use crate::sha_compress_sponge::constants::SHA_COMPRESS_K; + use crate::sha_extend::sha_extend_stark::ShaExtendStark; + use crate::sha_extend_sponge::columns::NUM_EXTEND_INPUT; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; const W: [u32; 64] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200, 3020350282, 1447362251, 3118632270, 4004188394, 690615167, @@ -230,6 +232,15 @@ mod test { 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, ]; + fn get_random_input() -> [u8; NUM_INPUTS * 32] { + let mut input = [0u8; NUM_INPUTS * 32]; + for i in 0..NUM_INPUTS * 32 { + input[i] = rand::random::() % 2; + debug_assert!(input[i] == 0 || input[i] == 1); + } + input + } + #[test] fn test_generation() -> Result<(), String>{ @@ -241,32 +252,49 @@ mod test { let h = H256_256; let mut input = vec![]; - input.extend(h);input.extend(w); + for hx in h { + input.extend(from_u32_to_be_bits(hx)); + } + input.extend(from_u32_to_be_bits(w[0])); + input.extend(from_u32_to_be_bits(SHA_COMPRESS_K[0])); let stark = S::default(); - let rows = stark.generate_trace_rows_for_compress((input.try_into().unwrap(), 0)); - - assert_eq!(rows.len(), 64); - - - // check first row - let first_row: ShaCompressColumnsView = rows[0].into(); - assert_eq!(first_row.a, F::from_canonical_u32(0x6a09e667)); - assert_eq!(first_row.new_a, F::from_canonical_u32(4228417613)); - - // output - let last_row: ShaCompressColumnsView = rows[63].into(); - assert_eq!(last_row.is_final, F::ONE); - - assert_eq!(last_row.new_a, F::from_canonical_u32(1813631354)); - assert_eq!(last_row.new_b, F::from_canonical_u32(3315363907)); - assert_eq!(last_row.new_c, F::from_canonical_u32(209435322)); - assert_eq!(last_row.new_d, F::from_canonical_u32(267716009)); - assert_eq!(last_row.new_e, F::from_canonical_u32(646830348)); - assert_eq!(last_row.new_f, F::from_canonical_u32(362222596)); - assert_eq!(last_row.new_g, F::from_canonical_u32(3323089566)); - assert_eq!(last_row.new_h, F::from_canonical_u32(1912443780)); + let row = stark.generate_trace_rows_for_compress((input.try_into().unwrap(), 0)); + let local_values: &ShaCompressColumnsView = row.borrow(); + + assert_eq!( + local_values.output_state[get_input_range(0)], + from_u32_to_be_bits(4228417613).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(1)], + from_u32_to_be_bits(1779033703).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(2)], + from_u32_to_be_bits(3144134277).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(3)], + from_u32_to_be_bits(1013904242).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(4)], + from_u32_to_be_bits(2563236514).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(5)], + from_u32_to_be_bits(1359893119).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(6)], + from_u32_to_be_bits(2600822924).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(7)], + from_u32_to_be_bits(528734635).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); Ok(()) } } \ No newline at end of file From 7ed2e0f6d58b2b0f7381a653af061dcee24597e4 Mon Sep 17 00:00:00 2001 From: vanhger Date: Fri, 24 Jan 2025 07:54:55 +0700 Subject: [PATCH 15/25] feat: add constraints for SHA compress table. --- prover/src/sha_compress/logic.rs | 82 +++ prover/src/sha_compress/sha_compress_stark.rs | 542 +++++++++++++++++- 2 files changed, 621 insertions(+), 3 deletions(-) diff --git a/prover/src/sha_compress/logic.rs b/prover/src/sha_compress/logic.rs index feeece1c..dbd7da1c 100644 --- a/prover/src/sha_compress/logic.rs +++ b/prover/src/sha_compress/logic.rs @@ -18,6 +18,35 @@ pub(crate) fn and_op, const D: usize, const N: usiz result } +pub(crate) fn and_op_packed_constraints( + x: [P; N], + y: [P; N], + out: [P; N] +) -> Vec

{ + let mut result = vec![]; + for i in 0..N { + let out_constraint = x[i] * y[i] - out[i]; + result.push(out_constraint); + } + result +} + +pub(crate) fn and_op_ext_circuit_constraints, const D: usize, const N: usize>( + builder: &mut CircuitBuilder, + x: [ExtensionTarget; N], + y: [ExtensionTarget; N], + out: [ExtensionTarget; N] +) -> Vec> { + let mut result = vec![]; + for i in 0..N { + + let expected_out = builder.mul_extension(x[i], y[i]); + let out_constraint = builder.sub_extension(expected_out, out[i]); + result.push(out_constraint); + } + result +} + pub(crate) fn andn_op, const D: usize, const N: usize>( x: [F; N], y: [F; N] @@ -31,6 +60,35 @@ pub(crate) fn andn_op, const D: usize, const N: usi result } +pub(crate) fn andn_op_packed_constraints( + x: [P; N], + y: [P; N], + out: [P; N] +) -> Vec

{ + let mut result = vec![]; + for i in 0..N { + let out_constraint = crate::keccak::logic::andn_gen(x[i], y[i]) - out[i]; + result.push(out_constraint); + } + result +} + +pub(crate) fn andn_op_ext_circuit_constraints, const D: usize, const N: usize>( + builder: &mut CircuitBuilder, + x: [ExtensionTarget; N], + y: [ExtensionTarget; N], + out: [ExtensionTarget; N] +) -> Vec> { + let mut result = vec![]; + for i in 0..N { + + let expected_out = andn_gen_circuit(builder, x[i], y[i]); + let out_constraint = builder.sub_extension(expected_out, out[i]); + result.push(out_constraint); + } + result +} + pub(crate) fn xor_op, const D: usize, const N: usize>( x: [F; N], y: [F; N] @@ -44,6 +102,30 @@ pub(crate) fn xor_op, const D: usize, const N: usiz result } +pub(crate) fn equal_packed_constraint( + x: [P; N], + y: [P; N], +) -> Vec

{ + let mut result = vec![]; + for i in 0..N { + result.push(x[i] - y[i]); + } + result +} + +pub(crate) fn equal_ext_circuit_constraints, const D: usize, const N: usize>( + builder: &mut CircuitBuilder, + x: [ExtensionTarget; N], + y: [ExtensionTarget; N], +) -> Vec> { + let mut result = vec![]; + for i in 0..N { + let out_constraint = builder.sub_extension(x[i], y[i]); + result.push(out_constraint); + } + result +} + pub(crate) fn from_be_bits_to_u32( bits: [u8; 32]) -> u32 { let mut result = 0; for i in 0..32 { diff --git a/prover/src/sha_compress/sha_compress_stark.rs b/prover/src/sha_compress/sha_compress_stark.rs index 425fbcc4..98bc4515 100644 --- a/prover/src/sha_compress/sha_compress_stark.rs +++ b/prover/src/sha_compress/sha_compress_stark.rs @@ -178,7 +178,205 @@ impl, const D: usize> Stark for ShaCompressSt FE: FieldExtension, P: PackedField { - todo!() + let local_values: &[P; NUM_SHA_COMPRESS_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let local_values: &ShaCompressColumnsView

= local_values.borrow(); + + // check the input binary form + for i in 0..256 { + yield_constr.constraint(local_values.input_state[i] * (local_values.input_state[i] - P::ONES)); + } + for i in 0..32 { + yield_constr.constraint(local_values.w_i[i] * (local_values.w_i[i] - P::ONES)); + yield_constr.constraint(local_values.k_i[i] * (local_values.k_i[i] - P::ONES)); + } + + // check the bit values are zero or one in output + for i in 0..256 { + yield_constr.constraint(local_values.output_state[i] * (local_values.output_state[i] - P::ONES)); + } + + // check the rotation + rotate_right_packed_constraints( + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.e_rr_6, + 6 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + rotate_right_packed_constraints( + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.e_rr_11, + 11 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + rotate_right_packed_constraints( + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.e_rr_25, + 25 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + rotate_right_packed_constraints( + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.a_rr_2, + 2 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + rotate_right_packed_constraints( + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.a_rr_13, + 13 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + rotate_right_packed_constraints( + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.a_rr_22, + 22 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + // check the xor + for i in 0..32 { + let s1 = xor3_gen( + local_values.e_rr_6[i], + local_values.e_rr_11[i], + local_values.e_rr_25[i] + ); + yield_constr.constraint(local_values.s_1[i] - s1); + + let s0 = xor3_gen( + local_values.a_rr_2[i], + local_values.a_rr_13[i], + local_values.a_rr_22[i] + ); + yield_constr.constraint(local_values.s_0[i] - s0); + + let ch = xor_gen( + local_values.e_and_f[i], + local_values.not_e_and_g[i] + ); + yield_constr.constraint(local_values.ch[i] - ch); + + let maj = xor3_gen( + local_values.a_and_b[i], + local_values.a_and_c[i], + local_values.b_and_c[i] + ); + yield_constr.constraint(local_values.maj[i] - maj); + } + + // wrapping add constraints + + wrapping_add_packed_constraints( + local_values.input_state[get_input_range(7)].try_into().unwrap(), + local_values.s_1, + local_values.carry_1, + local_values.inter_1 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + wrapping_add_packed_constraints( + local_values.inter_1, + local_values.ch, + local_values.carry_2, + local_values.inter_2 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + wrapping_add_packed_constraints( + local_values.inter_2, + local_values.k_i, + local_values.carry_3, + local_values.inter_3 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + wrapping_add_packed_constraints( + local_values.inter_3, + local_values.w_i, + local_values.carry_4, + local_values.temp1 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + wrapping_add_packed_constraints( + local_values.s_0, + local_values.maj, + local_values.carry_5, + local_values.temp2 + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + wrapping_add_packed_constraints( + local_values.input_state[get_input_range(3)].try_into().unwrap(), + local_values.temp1, + local_values.carry_e, + local_values.output_state[get_input_range(4)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + wrapping_add_packed_constraints( + local_values.temp1, + local_values.temp2, + local_values.carry_a, + local_values.output_state[get_input_range(0)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + // The op constraints + and_op_packed_constraints( + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(5)].try_into().unwrap(), + local_values.e_and_f + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + and_op_packed_constraints( + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(1)].try_into().unwrap(), + local_values.a_and_b + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + and_op_packed_constraints( + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(2)].try_into().unwrap(), + local_values.a_and_c + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + and_op_packed_constraints( + local_values.input_state[get_input_range(1)].try_into().unwrap(), + local_values.input_state[get_input_range(2)].try_into().unwrap(), + local_values.b_and_c + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + andn_op_packed_constraints( + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(6)].try_into().unwrap(), + local_values.not_e_and_g + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + + // output constraint + equal_packed_constraint::( + local_values.output_state[get_input_range(1)].try_into().unwrap(), + local_values.input_state[get_input_range(0)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + equal_packed_constraint::( + local_values.output_state[get_input_range(2)].try_into().unwrap(), + local_values.input_state[get_input_range(1)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + equal_packed_constraint::( + local_values.output_state[get_input_range(3)].try_into().unwrap(), + local_values.input_state[get_input_range(2)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + // equal_packed_constraint( + // local_values.output_state[get_input_range(4)].try_into().unwrap(), + // local_values.input_state[get_input_range(3)].try_into().unwrap(), + // ).into_iter().for_each(|c| yield_constr.constraint(c)); + + equal_packed_constraint::( + local_values.output_state[get_input_range(5)].try_into().unwrap(), + local_values.input_state[get_input_range(4)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + equal_packed_constraint::( + local_values.output_state[get_input_range(6)].try_into().unwrap(), + local_values.input_state[get_input_range(5)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + equal_packed_constraint::( + local_values.output_state[get_input_range(7)].try_into().unwrap(), + local_values.input_state[get_input_range(6)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(c)); } fn eval_ext_circuit( @@ -187,11 +385,251 @@ impl, const D: usize> Stark for ShaCompressSt vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer ) { - todo!() + let local_values: &[ExtensionTarget; NUM_SHA_COMPRESS_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let local_values: &ShaCompressColumnsView> = local_values.borrow(); + + // check the input binary form + for i in 0..256 { + let constraint = builder.mul_sub_extension( + local_values.input_state[i], local_values.input_state[i], local_values.input_state[i]); + yield_constr.constraint(builder, constraint); + + } + for i in 0..32 { + let constraint = builder.mul_sub_extension( + local_values.w_i[i], local_values.w_i[i], local_values.w_i[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.k_i[i], local_values.k_i[i], local_values.k_i[i]); + yield_constr.constraint(builder, constraint); + } + + // check the bit values are zero or one in output + for i in 0..256 { + let constraint = builder.mul_sub_extension( + local_values.output_state[i], local_values.output_state[i], local_values.output_state[i]); + yield_constr.constraint(builder, constraint); + } + + // check the rotation + rotate_right_ext_circuit_constraint( + builder, + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.e_rr_6, + 6 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + rotate_right_ext_circuit_constraint( + builder, + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.e_rr_11, + 11 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + rotate_right_ext_circuit_constraint( + builder, + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.e_rr_25, + 25 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + rotate_right_ext_circuit_constraint( + builder, + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.a_rr_2, + 2 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + rotate_right_ext_circuit_constraint( + builder, + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.a_rr_13, + 13 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + rotate_right_ext_circuit_constraint( + builder, + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.a_rr_22, + 22 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + // check the xor + for i in 0..32 { + let s1 = xor3_gen_circuit( + builder, + local_values.e_rr_6[i], + local_values.e_rr_11[i], + local_values.e_rr_25[i] + ); + let constraint = builder.sub_extension(local_values.s_1[i], s1); + yield_constr.constraint(builder, constraint); + + let s0 = xor3_gen_circuit( + builder, + local_values.a_rr_2[i], + local_values.a_rr_13[i], + local_values.a_rr_22[i] + ); + let constraint = builder.sub_extension(local_values.s_0[i], s0); + yield_constr.constraint(builder, constraint); + + let ch = xor_gen_circuit( + builder, + local_values.e_and_f[i], + local_values.not_e_and_g[i] + ); + let constraint = builder.sub_extension(local_values.ch[i], ch); + yield_constr.constraint(builder, constraint); + + let maj = xor3_gen_circuit( + builder, + local_values.a_and_b[i], + local_values.a_and_c[i], + local_values.b_and_c[i] + ); + let constraint = builder.sub_extension(local_values.maj[i], maj); + yield_constr.constraint(builder, constraint); + } + + // wrapping add constraints + + wrapping_add_ext_circuit_constraints( + builder, + local_values.input_state[get_input_range(7)].try_into().unwrap(), + local_values.s_1, + local_values.carry_1, + local_values.inter_1 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + wrapping_add_ext_circuit_constraints( + builder, + local_values.inter_1, + local_values.ch, + local_values.carry_2, + local_values.inter_2 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + wrapping_add_ext_circuit_constraints( + builder, + local_values.inter_2, + local_values.k_i, + local_values.carry_3, + local_values.inter_3 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + wrapping_add_ext_circuit_constraints( + builder, + local_values.inter_3, + local_values.w_i, + local_values.carry_4, + local_values.temp1 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + wrapping_add_ext_circuit_constraints( + builder, + local_values.s_0, + local_values.maj, + local_values.carry_5, + local_values.temp2 + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + wrapping_add_ext_circuit_constraints( + builder, + local_values.input_state[get_input_range(3)].try_into().unwrap(), + local_values.temp1, + local_values.carry_e, + local_values.output_state[get_input_range(4)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + wrapping_add_ext_circuit_constraints( + builder, + local_values.temp1, + local_values.temp2, + local_values.carry_a, + local_values.output_state[get_input_range(0)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + // The op constraints + and_op_ext_circuit_constraints( + builder, + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(5)].try_into().unwrap(), + local_values.e_and_f + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + and_op_ext_circuit_constraints( + builder, + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(1)].try_into().unwrap(), + local_values.a_and_b + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + and_op_ext_circuit_constraints( + builder, + local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(2)].try_into().unwrap(), + local_values.a_and_c + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + and_op_ext_circuit_constraints( + builder, + local_values.input_state[get_input_range(1)].try_into().unwrap(), + local_values.input_state[get_input_range(2)].try_into().unwrap(), + local_values.b_and_c + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + andn_op_ext_circuit_constraints( + builder, + local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(6)].try_into().unwrap(), + local_values.not_e_and_g + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + + // output constraint + equal_ext_circuit_constraints::( + builder, + local_values.output_state[get_input_range(1)].try_into().unwrap(), + local_values.input_state[get_input_range(0)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + equal_ext_circuit_constraints::( + builder, + local_values.output_state[get_input_range(2)].try_into().unwrap(), + local_values.input_state[get_input_range(1)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + equal_ext_circuit_constraints::( + builder, + local_values.output_state[get_input_range(3)].try_into().unwrap(), + local_values.input_state[get_input_range(2)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + // equal_packed_constraint( + // local_values.output_state[get_input_range(4)].try_into().unwrap(), + // local_values.input_state[get_input_range(3)].try_into().unwrap(), + // ).into_iter().for_each(|c| yield_constr.constraint(c)); + + equal_ext_circuit_constraints::( + builder, + local_values.output_state[get_input_range(5)].try_into().unwrap(), + local_values.input_state[get_input_range(4)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + equal_ext_circuit_constraints::( + builder, + local_values.output_state[get_input_range(6)].try_into().unwrap(), + local_values.input_state[get_input_range(5)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + equal_ext_circuit_constraints::( + builder, + local_values.output_state[get_input_range(7)].try_into().unwrap(), + local_values.input_state[get_input_range(6)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); } fn constraint_degree(&self) -> usize { - todo!() + 3 } } @@ -297,4 +735,102 @@ mod test { ); Ok(()) } + + #[test] + fn test_stark_degree() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaCompressStark; + + let stark = S { + f: Default::default(), + }; + test_stark_low_degree(stark) + } + + #[test] + fn test_stark_circuit() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaCompressStark; + + let stark = S { + f: Default::default(), + }; + test_stark_circuit_constraints::(stark) + } + + #[test] + fn sha_extend_benchmark() -> anyhow::Result<()> { + const NUM_EXTEND: usize = 64; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaCompressStark; + let stark = S::default(); + let config = StarkConfig::standard_fast_config(); + + init_logger(); + + let input: Vec<([u8; NUM_INPUTS * 32], usize)> = + (0..NUM_EXTEND).map(|_| (get_random_input(), 0)).collect(); + + let mut timing = TimingTree::new("prove", log::Level::Debug); + let trace_poly_values = stark.generate_trace(input, 8); + + // TODO: Cloning this isn't great; consider having `from_values` accept a reference, + // or having `compute_permutation_z_polys` read trace values from the `PolynomialBatch`. + let cloned_trace_poly_values = timed!(timing, "clone", trace_poly_values.clone()); + + let trace_commitments = timed!( + timing, + "compute trace commitment", + PolynomialBatch::::from_values( + cloned_trace_poly_values, + config.fri_config.rate_bits, + false, + config.fri_config.cap_height, + &mut timing, + None, + ) + ); + let degree = 1 << trace_commitments.degree_log; + + // Fake CTL data. + let ctl_z_data = CtlZData { + helper_columns: vec![PolynomialValues::zero(degree)], + z: PolynomialValues::zero(degree), + challenge: GrandProductChallenge { + beta: F::ZERO, + gamma: F::ZERO, + }, + columns: vec![], + filter: vec![Some(Filter::new_simple(Column::constant(F::ZERO)))], + }; + let ctl_data = CtlData { + zs_columns: vec![ctl_z_data.clone(); config.num_challenges], + }; + + prove_single_table( + &stark, + &config, + &trace_poly_values, + &trace_commitments, + &ctl_data, + &GrandProductChallengeSet { + challenges: vec![ctl_z_data.challenge; config.num_challenges], + }, + &mut Challenger::new(), + &mut timing, + )?; + + timing.print(); + Ok(()) + } + + fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); + } } \ No newline at end of file From 693fed41ff5a1ff2417863a2b48c28a6d8bf0863 Mon Sep 17 00:00:00 2001 From: vanhger Date: Fri, 24 Jan 2025 08:19:15 +0700 Subject: [PATCH 16/25] chore: change the columns in SHA compress sponge table. --- prover/src/sha_compress_sponge/columns.rs | 38 +-- prover/src/sha_compress_sponge/constants.rs | 79 ++++++ prover/src/sha_compress_sponge/mod.rs | 5 +- .../sha_compress_sponge_stark.rs | 242 ++++++++++++------ 4 files changed, 260 insertions(+), 104 deletions(-) create mode 100644 prover/src/sha_compress_sponge/constants.rs diff --git a/prover/src/sha_compress_sponge/columns.rs b/prover/src/sha_compress_sponge/columns.rs index 64f922dc..7485a6a9 100644 --- a/prover/src/sha_compress_sponge/columns.rs +++ b/prover/src/sha_compress_sponge/columns.rs @@ -3,40 +3,28 @@ use std::intrinsics::transmute; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; pub(crate) struct ShaCompressSpongeColumnsView { - /// The timestamp at which inputs should be read from memory. - pub timestamp: T, - - /// hx_i - pub hx: [T;8], - /// w[i] - pub w: [T; 64], - - /// a,b...,h values after compressed - pub new_a: T, - pub new_b: T, - pub new_c: T, - pub new_d: T, - pub new_e: T, - pub new_f: T, - pub new_g: T, - pub new_h: T, + pub hx: [T;256], + pub input_state: [T; 256], + pub output_state: [T; 256], + pub output_hx: [T; 256], + pub carry: [T; 256], + pub round: [T; 64], + pub w_i: [T; 32], + pub k_i: [T; 32], + pub hx_virt: [T; 8], + pub w_virt: T, - /// output - pub final_hx: [T;8], + /// The timestamp at which inputs should be read from memory. + pub timestamp: T, /// The base address at which we will read the input block. pub context: T, pub segment: T, - /// Hx addresses - pub hx_virt: [T; 8], - - /// W_i addresses - pub w_virt: [T;64], } -pub const NUM_SHA_COMPRESS_SPONGE_COLUMNS: usize = size_of::>(); +pub const NUM_SHA_COMPRESS_SPONGE_COLUMNS: usize = size_of::>(); //1420 impl From<[T; NUM_SHA_COMPRESS_SPONGE_COLUMNS]> for ShaCompressSpongeColumnsView { fn from(value: [T; NUM_SHA_COMPRESS_SPONGE_COLUMNS]) -> Self { diff --git a/prover/src/sha_compress_sponge/constants.rs b/prover/src/sha_compress_sponge/constants.rs new file mode 100644 index 00000000..f656ddc5 --- /dev/null +++ b/prover/src/sha_compress_sponge/constants.rs @@ -0,0 +1,79 @@ +pub const SHA_COMPRESS_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + + +// big-endian form +pub const SHA_COMPRESS_K_BINARY: [[u8; 32]; 64] = [ + [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0], + [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0], + [1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1], + [1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1], + [1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0], + [1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0], + [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1], + [1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1], + [0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1], + [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0], + [0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], + [1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], + [0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1], + [1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1], + [0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1], + [0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], + [1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0], + [0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0], + [0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0], + [0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0], + [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1], + [1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], + [0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1], + [1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1], + [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1], + [1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1], + [1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], + [1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0], + [1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0], + [0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0], + [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0], + [0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1], + [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1], + [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1], + [1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], + [0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1], + [1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1], + [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1], + [1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0], + [1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0], + [1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0], + [0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0], + [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0], + [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0], + [1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1], + [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1], + [1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1], + [1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1], + [0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1], +]; \ No newline at end of file diff --git a/prover/src/sha_compress_sponge/mod.rs b/prover/src/sha_compress_sponge/mod.rs index 9f1ac542..1c6919dc 100644 --- a/prover/src/sha_compress_sponge/mod.rs +++ b/prover/src/sha_compress_sponge/mod.rs @@ -1,2 +1,3 @@ -mod columns; -mod sha_compress_sponge_stark; \ No newline at end of file +pub mod columns; +mod sha_compress_sponge_stark; +pub mod constants; \ No newline at end of file diff --git a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs index cadbd1c1..369e8999 100644 --- a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs +++ b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs @@ -9,7 +9,10 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::evaluation_frame::StarkFrame; use crate::memory::segments::Segment; +use crate::sha_compress::logic::from_be_bits_to_u32; use crate::sha_compress_sponge::columns::{ShaCompressSpongeColumnsView, NUM_SHA_COMPRESS_SPONGE_COLUMNS}; +use crate::sha_compress_sponge::constants::SHA_COMPRESS_K_BINARY; +use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range, wrapping_add}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryAddress; @@ -18,13 +21,18 @@ use crate::witness::operation::SHA_COMPRESS_K; #[derive(Clone, Debug)] pub(crate) struct ShaCompressSpongeOp { /// The base address at which inputs are read. + /// h[0],...,h[7], w[i]. pub(crate) base_address: Vec, /// The timestamp at which inputs are read. pub(crate) timestamp: usize, + /// The index of round + pub(crate) i: usize, + /// The input that was read. - pub(crate) input: Vec, + /// Values: h[0],..., h[7], w[i] in big-endian order. + pub(crate) input: Vec, } #[derive(Copy, Clone, Default)] @@ -70,75 +78,85 @@ impl, const D: usize> ShaCompressSpongeStark let mut row = ShaCompressSpongeColumnsView::default(); row.timestamp = F::from_canonical_usize(op.timestamp); - - let new_buffer = self.compress(&op.input); - - row.hx = op.input[0..8] - .iter() - .map(|&x| F::from_canonical_u32(x)) - .collect::>() - .try_into() - .unwrap(); - - row.w = op.input[8..op.input.len()] - .iter() - .map(|&x| F::from_canonical_u32(x)) - .collect::>() - .try_into() - .unwrap(); - row.context = F::from_canonical_usize(op.base_address[0].context); row.segment = F::from_canonical_usize(op.base_address[Segment::Code as usize].segment); - [row.new_a, row.new_b, row.new_c, row.new_d, row.new_e, row.new_f, row.new_g, row.new_h] - = new_buffer.iter() - .map(|&x| F::from_canonical_u32(x)) - .collect::>() - .try_into() - .unwrap(); - - row.final_hx = new_buffer.iter().zip(row.hx.iter()) - .map(|(&x, &hx)| F::from_canonical_u32(x.wrapping_add(hx.to_canonical_u64() as u32))) - .collect::>() - .try_into() - .unwrap(); - let hx_virt = (0..8) .map(|i| op.base_address[i].virt) .collect_vec(); let hx_virt: [usize; 8] = hx_virt.try_into().unwrap(); row.hx_virt = hx_virt.map(F::from_canonical_usize); - let w_virt = (8..op.input.len()) - .map(|i| op.base_address[i].virt) - .collect_vec(); - let w_virt: [usize; 64] = w_virt.try_into().unwrap(); - row.w_virt = w_virt.map(F::from_canonical_usize); + let w_virt = op.base_address[8].virt; + row.w_virt = F::from_canonical_usize(w_virt); + + row.round = [F::ZEROS; 64]; + row.round[op.i] = F::ONE; + row.k_i = SHA_COMPRESS_K_BINARY[op.i].map(|k| F::from_canonical_u8(k)); + row.w_i = op.input[256..288].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + if op.i == 0 { + row.hx = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + row.input_state = row.hx; + } else if op.i != 63 { + row.input_state = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + } else { + row.input_state = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + row.hx = op.input[288..].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + } - row - } + let output = self.compress(&op.input[..288], op.i); + row.output_state = output.map(F::from_canonical_u8); - fn compress(&self, input: &[u32]) -> [u32; 8] { - let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h]: [u32; 8] = input[0..8].try_into().unwrap(); - let mut t1: u32; - let mut t2: u32; - - for i in 0..64 { - t1 = h.wrapping_add(e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25)) - .wrapping_add((e & f) ^ ((!e) & g)).wrapping_add(SHA_COMPRESS_K[i]).wrapping_add(input[8 + i]); - t2 = (a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22)) - .wrapping_add((a & b) ^ (a & c) ^ (b & c)); - h = g; - g = f; - f = e; - e = d.wrapping_add(t1); - d = c; - c = b; - b = a; - a = t1.wrapping_add(t2); + if op.i == 63 { + for i in 0..8 { + + let (output_hx, carry) = wrapping_add::( + row.hx[get_input_range(i)].try_into().unwrap(), + row.output_state[get_input_range(i)].try_into().unwrap() + ); + + row.output_hx[get_input_range(i)].copy_from_slice(&output_hx[0..]); + row.carry[get_input_range(i)].copy_from_slice(&carry[0..]); + } + + } else { + row.output_hx = row.hx; + row.carry = [F::ZEROS; 256]; } - [a, b, c, d, e, f, g, h] + row + } + + fn compress(&self, input: &[u8], round: usize) -> [u8; 256] { + let values: Vec<[u8; 32]> = input.chunks(32).map(|chunk| chunk.try_into().unwrap()).collect(); + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h, w_i] = values.into_iter().map( + |x| from_be_bits_to_u32(x) + ).collect::>().try_into().unwrap(); + + let t1 = h.wrapping_add(e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25)) + .wrapping_add((e & f) ^ ((!e) & g)).wrapping_add(SHA_COMPRESS_K[round]).wrapping_add(w_i); + let t2 = (a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22)) + .wrapping_add((a & b) ^ (a & c) ^ (b & c)); + h = g; + g = f; + f = e; + e = d.wrapping_add(t1); + d = c; + c = b; + b = a; + a = t1.wrapping_add(t2); + + let mut result = vec![]; + result.extend(from_u32_to_be_bits(a)); + result.extend(from_u32_to_be_bits(b)); + result.extend(from_u32_to_be_bits(c)); + result.extend(from_u32_to_be_bits(d)); + result.extend(from_u32_to_be_bits(e)); + result.extend(from_u32_to_be_bits(f)); + result.extend(from_u32_to_be_bits(g)); + result.extend(from_u32_to_be_bits(h)); + + result.try_into().unwrap() } } @@ -180,8 +198,11 @@ impl, const D: usize> Stark for ShaCompressSp #[cfg(test)] mod test { use plonky2::field::goldilocks_field::GoldilocksField; - use plonky2::field::types::Field; + use plonky2::field::types::{Field}; + use std::borrow::Borrow; + use crate::sha_compress_sponge::columns::ShaCompressSpongeColumnsView; use crate::sha_compress_sponge::sha_compress_sponge_stark::{ShaCompressSpongeOp, ShaCompressSpongeStark}; + use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range}; use crate::witness::memory::MemoryAddress; @@ -221,30 +242,97 @@ mod test { virt: i, } }).collect(); - + let mut input = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + input.extend(from_u32_to_be_bits(W[0])); let op = ShaCompressSpongeOp { - base_address: hx_addresses.iter().chain(w_addresses.iter()).cloned().collect(), + base_address: hx_addresses.iter().chain([w_addresses[0]].iter()).cloned().collect(), + i: 0, timestamp: 0, - input: H256_256.iter().chain(W.iter()).cloned().collect(), + input: input, }; let row = stark.generate_rows_for_op(op); + let local_values: &ShaCompressSpongeColumnsView = row.borrow(); + + assert_eq!( + local_values.output_state[get_input_range(0)], + from_u32_to_be_bits(4228417613).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(1)], + from_u32_to_be_bits(1779033703).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(2)], + from_u32_to_be_bits(3144134277).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(3)], + from_u32_to_be_bits(1013904242).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(4)], + from_u32_to_be_bits(2563236514).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(5)], + from_u32_to_be_bits(1359893119).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(6)], + from_u32_to_be_bits(2600822924).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_state[get_input_range(7)], + from_u32_to_be_bits(528734635).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + + let mut input = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + input.extend(from_u32_to_be_bits(W[63])); + input.extend(H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>()); - assert_eq!(row.new_a, F::from_canonical_u32(1813631354)); - assert_eq!(row.new_b, F::from_canonical_u32(3315363907)); - assert_eq!(row.new_c, F::from_canonical_u32(209435322)); - assert_eq!(row.new_d, F::from_canonical_u32(267716009)); - assert_eq!(row.new_e, F::from_canonical_u32(646830348)); - assert_eq!(row.new_f, F::from_canonical_u32(362222596)); - assert_eq!(row.new_g, F::from_canonical_u32(3323089566)); - assert_eq!(row.new_h, F::from_canonical_u32(1912443780)); - - let expected_values: [F; 8] = [3592665057_u32, 2164530888, 1223339564, 3041196771, 2006723467, - 2963045520, 3851824201, 3453903005].into_iter().map(F::from_canonical_u32) - .collect::>().try_into().unwrap(); - - - assert_eq!(row.final_hx, expected_values); + let op = ShaCompressSpongeOp { + base_address: hx_addresses.iter().chain([w_addresses[0]].iter()).cloned().collect(), + i: 63, + timestamp: 0, + input: input, + }; + let row = stark.generate_rows_for_op(op); + let local_values: &ShaCompressSpongeColumnsView = row.borrow(); + + + assert_eq!( + local_values.output_hx[get_input_range(0)], + from_u32_to_be_bits(H256_256[0].wrapping_add(2781379838 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_hx[get_input_range(1)], + from_u32_to_be_bits(H256_256[1].wrapping_add(1779033703 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_hx[get_input_range(2)], + from_u32_to_be_bits(H256_256[2].wrapping_add(3144134277 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_hx[get_input_range(3)], + from_u32_to_be_bits(H256_256[3].wrapping_add(1013904242 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_hx[get_input_range(4)], + from_u32_to_be_bits(H256_256[4].wrapping_add(1116198739 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_hx[get_input_range(5)], + from_u32_to_be_bits(H256_256[5].wrapping_add(1359893119 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_hx[get_input_range(6)], + from_u32_to_be_bits(H256_256[6].wrapping_add(2600822924 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); + assert_eq!( + local_values.output_hx[get_input_range(7)], + from_u32_to_be_bits(H256_256[7].wrapping_add(528734635 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + ); Ok(()) - } + } \ No newline at end of file From 38811e5783afd6878239f551b6edd70019a697f0 Mon Sep 17 00:00:00 2001 From: vanhger Date: Fri, 24 Jan 2025 10:39:44 +0700 Subject: [PATCH 17/25] chore: adjust SHA compress sponge operation. --- .../sha_compress_sponge_stark.rs | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs index 369e8999..5a01458d 100644 --- a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs +++ b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs @@ -27,6 +27,9 @@ pub(crate) struct ShaCompressSpongeOp { /// The timestamp at which inputs are read. pub(crate) timestamp: usize, + /// The input state + pub(crate) input_state: Vec, + /// The index of round pub(crate) i: usize, @@ -94,36 +97,27 @@ impl, const D: usize> ShaCompressSpongeStark row.round[op.i] = F::ONE; row.k_i = SHA_COMPRESS_K_BINARY[op.i].map(|k| F::from_canonical_u8(k)); row.w_i = op.input[256..288].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - if op.i == 0 { - row.hx = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - row.input_state = row.hx; - } else if op.i != 63 { - row.input_state = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - } else { - row.input_state = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - row.hx = op.input[288..].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - } + row.hx = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + row.input_state = op.input_state.iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); let output = self.compress(&op.input[..288], op.i); row.output_state = output.map(F::from_canonical_u8); - if op.i == 63 { - for i in 0..8 { + // We use the result if only we are at the final round. + // The computation in other rounds are ensure the constraint degree + // not to be exceeded 3. + for i in 0..8 { - let (output_hx, carry) = wrapping_add::( - row.hx[get_input_range(i)].try_into().unwrap(), - row.output_state[get_input_range(i)].try_into().unwrap() - ); - - row.output_hx[get_input_range(i)].copy_from_slice(&output_hx[0..]); - row.carry[get_input_range(i)].copy_from_slice(&carry[0..]); - } + let (output_hx, carry) = wrapping_add::( + row.hx[get_input_range(i)].try_into().unwrap(), + row.output_state[get_input_range(i)].try_into().unwrap() + ); - } else { - row.output_hx = row.hx; - row.carry = [F::ZEROS; 256]; + row.output_hx[get_input_range(i)].copy_from_slice(&output_hx[0..]); + row.carry[get_input_range(i)].copy_from_slice(&carry[0..]); } + row } @@ -177,7 +171,7 @@ impl, const D: usize> Stark for ShaCompressSp FE: FieldExtension, P: PackedField { - todo!() + // } fn eval_ext_circuit( @@ -244,11 +238,13 @@ mod test { }).collect(); let mut input = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); input.extend(from_u32_to_be_bits(W[0])); + let input_state = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); let op = ShaCompressSpongeOp { base_address: hx_addresses.iter().chain([w_addresses[0]].iter()).cloned().collect(), i: 0, timestamp: 0, - input: input, + input_state, + input, }; let row = stark.generate_rows_for_op(op); let local_values: &ShaCompressSpongeColumnsView = row.borrow(); @@ -288,13 +284,13 @@ mod test { let mut input = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); input.extend(from_u32_to_be_bits(W[63])); - input.extend(H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>()); - + let input_state = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); let op = ShaCompressSpongeOp { base_address: hx_addresses.iter().chain([w_addresses[0]].iter()).cloned().collect(), i: 63, timestamp: 0, - input: input, + input_state, + input, }; let row = stark.generate_rows_for_op(op); let local_values: &ShaCompressSpongeColumnsView = row.borrow(); From ff8c5bb0dceb1a9f2868a09d5812430968721c6f Mon Sep 17 00:00:00 2001 From: vanhger Date: Fri, 24 Jan 2025 15:26:01 +0700 Subject: [PATCH 18/25] feat: add constraints for SHA compress sponge table. --- prover/src/sha_compress_sponge/constants.rs | 1 + .../sha_compress_sponge_stark.rs | 413 +++++++++++++++++- 2 files changed, 403 insertions(+), 11 deletions(-) diff --git a/prover/src/sha_compress_sponge/constants.rs b/prover/src/sha_compress_sponge/constants.rs index f656ddc5..d50bde1f 100644 --- a/prover/src/sha_compress_sponge/constants.rs +++ b/prover/src/sha_compress_sponge/constants.rs @@ -9,6 +9,7 @@ pub const SHA_COMPRESS_K: [u32; 64] = [ 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, ]; +pub const NUM_COMPRESS_ROWS: usize = 64; // big-endian form pub const SHA_COMPRESS_K_BINARY: [[u8; 32]; 64] = [ diff --git a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs index 5a01458d..74b47bae 100644 --- a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs +++ b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs @@ -1,18 +1,21 @@ use std::marker::PhantomData; +use std::borrow::Borrow; use itertools::Itertools; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::evaluation_frame::StarkFrame; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; use crate::sha_compress::logic::from_be_bits_to_u32; use crate::sha_compress_sponge::columns::{ShaCompressSpongeColumnsView, NUM_SHA_COMPRESS_SPONGE_COLUMNS}; -use crate::sha_compress_sponge::constants::SHA_COMPRESS_K_BINARY; -use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range, wrapping_add}; +use crate::sha_compress_sponge::constants::{NUM_COMPRESS_ROWS, SHA_COMPRESS_K_BINARY}; +use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range, wrapping_add, wrapping_add_ext_circuit_constraints, wrapping_add_packed_constraints}; +use crate::sha_extend_sponge::sha_extend_sponge_stark::NUM_ROUNDS; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryAddress; @@ -100,7 +103,7 @@ impl, const D: usize> ShaCompressSpongeStark row.hx = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.input_state = op.input_state.iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - let output = self.compress(&op.input[..288], op.i); + let output = self.compress(&op.input_state, &op.input[256..288], op.i); row.output_state = output.map(F::from_canonical_u8); // We use the result if only we are at the final round. @@ -117,15 +120,15 @@ impl, const D: usize> ShaCompressSpongeStark row.carry[get_input_range(i)].copy_from_slice(&carry[0..]); } - row } - fn compress(&self, input: &[u8], round: usize) -> [u8; 256] { - let values: Vec<[u8; 32]> = input.chunks(32).map(|chunk| chunk.try_into().unwrap()).collect(); - let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h, w_i] = values.into_iter().map( + fn compress(&self, input_state: &[u8], w_i: &[u8], round: usize) -> [u8; 256] { + let values: Vec<[u8; 32]> = input_state.chunks(32).map(|chunk| chunk.try_into().unwrap()).collect(); + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = values.into_iter().map( |x| from_be_bits_to_u32(x) ).collect::>().try_into().unwrap(); + let w_i = from_be_bits_to_u32(w_i.try_into().unwrap()); let t1 = h.wrapping_add(e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25)) .wrapping_add((e & f) ^ ((!e) & g)).wrapping_add(SHA_COMPRESS_K[round]).wrapping_add(w_i); @@ -171,7 +174,106 @@ impl, const D: usize> Stark for ShaCompressSp FE: FieldExtension, P: PackedField { - // + + let local_values: &[P; NUM_SHA_COMPRESS_SPONGE_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let local_values: &ShaCompressSpongeColumnsView

= local_values.borrow(); + + let next_values: &[P; NUM_SHA_COMPRESS_SPONGE_COLUMNS] = + vars.get_next_values().try_into().unwrap(); + let next_values: &ShaCompressSpongeColumnsView

= next_values.borrow(); + + // check the bit values are zero or one in input + for i in 0..256 { + yield_constr.constraint(local_values.hx[i] * (local_values.hx[i] - P::ONES)); + yield_constr.constraint(local_values.input_state[i] * (local_values.input_state[i] - P::ONES)); + } + for i in 0..32 { + yield_constr.constraint(local_values.w_i[i] * (local_values.w_i[i] - P::ONES)); + yield_constr.constraint(local_values.k_i[i] * (local_values.k_i[i] - P::ONES)); + } + + // check the bit values are zero or one in output + for i in 0..256 { + yield_constr.constraint(local_values.output_state[i] * (local_values.output_state[i] - P::ONES)); + yield_constr.constraint(local_values.output_hx[i] * (local_values.output_hx[i] - P::ONES)); + yield_constr.constraint(local_values.carry[i] * (local_values.carry[i] - P::ONES)); + } + + // // check the round + for i in 0..NUM_ROUNDS { + yield_constr.constraint(local_values.round[i] * (local_values.round[i] - P::ONES)); + } + + // check the filter + let is_final = local_values.round[NUM_COMPRESS_ROWS - 1]; + yield_constr.constraint(is_final * (is_final - P::ONES)); + let not_final = P::ONES - is_final; + + let sum_round_flags = (0..NUM_COMPRESS_ROWS) + .map(|i| local_values.round[i]) + .sum::

(); + yield_constr.constraint(sum_round_flags * (sum_round_flags - P::ONES)); + + + // If this is not the final step or a padding row: + + // the local and next timestamps must match. + yield_constr.constraint( + sum_round_flags * not_final * (next_values.timestamp - local_values.timestamp), + ); + + // the local and next context hx_virt must match + for i in 0..8 { + yield_constr.constraint( + sum_round_flags * not_final * (next_values.hx_virt[i] - local_values.hx_virt[i]), + ); + } + + // the output state of local row must be the input state of next row + for i in 0..256 { + yield_constr.constraint( + sum_round_flags * not_final * (next_values.input_state[i] - local_values.output_state[i]) + ); + } + + // the address of w_i must be increased by 4 + yield_constr.constraint( + sum_round_flags * not_final * (next_values.w_virt - local_values.w_virt - FE::from_canonical_u8(4)), + ); + + + // if not the padding row, the hx address must be a sequence of numbers spaced 4 units apart + + for i in 0..7 { + yield_constr.constraint( + sum_round_flags * (local_values.hx_virt[i + 1] - local_values.hx_virt[i] - FE::from_canonical_u8(4)), + ); + } + + // check the validation of key[i] + + for i in 0..32 { + let mut bit_i = P::ZEROS; + for j in 0..64 { + bit_i = bit_i + local_values.round[j] * FE::from_canonical_u8(SHA_COMPRESS_K_BINARY[j][i]); + } + yield_constr.constraint(local_values.k_i[i] - bit_i); + } + + // wrapping add constraints + + for i in 0..8 { + + wrapping_add_packed_constraints::( + local_values.hx[get_input_range(i)].try_into().unwrap(), + local_values.output_state[get_input_range(i)].try_into().unwrap(), + local_values.carry[get_input_range(i)].try_into().unwrap(), + local_values.output_hx[get_input_range(i)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(c)); + + } + } fn eval_ext_circuit( @@ -180,11 +282,147 @@ impl, const D: usize> Stark for ShaCompressSp vars: &Self::EvaluationFrameTarget, yield_constr: &mut RecursiveConstraintConsumer ) { - todo!() + let local_values: &[ExtensionTarget; NUM_SHA_COMPRESS_SPONGE_COLUMNS] = + vars.get_local_values().try_into().unwrap(); + let local_values: &ShaCompressSpongeColumnsView> = local_values.borrow(); + + let next_values: &[ExtensionTarget; NUM_SHA_COMPRESS_SPONGE_COLUMNS] = + vars.get_next_values().try_into().unwrap(); + let next_values: &ShaCompressSpongeColumnsView> = next_values.borrow(); + + let one_ext = builder.one_extension(); + let four_ext = builder.constant_extension(F::Extension::from_canonical_u8(4)); + + // check the bit values are zero or one in input + for i in 0..256 { + let constraint = builder.mul_sub_extension( + local_values.hx[i], local_values.hx[i], local_values.hx[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.input_state[i], local_values.input_state[i], local_values.input_state[i]); + yield_constr.constraint(builder, constraint); + } + for i in 0..32 { + let constraint = builder.mul_sub_extension( + local_values.w_i[i], local_values.w_i[i], local_values.w_i[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.k_i[i], local_values.k_i[i], local_values.k_i[i]); + yield_constr.constraint(builder, constraint); + + } + + // check the bit values are zero or one in output + for i in 0..256 { + + let constraint = builder.mul_sub_extension( + local_values.output_state[i], local_values.output_state[i], local_values.output_state[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.output_hx[i], local_values.output_hx[i], local_values.output_hx[i]); + yield_constr.constraint(builder, constraint); + + let constraint = builder.mul_sub_extension( + local_values.carry[i], local_values.carry[i], local_values.carry[i]); + yield_constr.constraint(builder, constraint); + } + + // check the round + for i in 0..NUM_ROUNDS { + let constraint = builder.mul_sub_extension( + local_values.round[i], local_values.round[i], local_values.round[i]); + yield_constr.constraint(builder, constraint); + } + + // check the filter + let is_final = local_values.round[NUM_COMPRESS_ROWS - 1]; + let constraint = builder.mul_sub_extension(is_final, is_final, is_final); + yield_constr.constraint(builder, constraint); + let not_final = builder.sub_extension(one_ext, is_final); + + let sum_round_flags = + builder.add_many_extension((0..NUM_COMPRESS_ROWS).map(|i| local_values.round[i])); + + let constraint = builder.mul_sub_extension( + sum_round_flags, sum_round_flags, sum_round_flags + ); + yield_constr.constraint(builder, constraint); + + + // If this is not the final step or a padding row: + + // the local and next timestamps must match. + + let diff = builder.sub_extension(next_values.timestamp, local_values.timestamp); + let constraint = builder.mul_many_extension([sum_round_flags, not_final, diff]); + yield_constr.constraint(builder, constraint); + + // the local and next context hx_virt must match + for i in 0..8 { + let diff = builder.sub_extension(next_values.hx_virt[i], local_values.hx_virt[i]); + let constraint = builder.mul_many_extension([sum_round_flags, not_final, diff]); + yield_constr.constraint(builder, constraint); + } + + // the output state of local row must be the input state of next row + for i in 0..256 { + let diff = builder.sub_extension(next_values.input_state[i], local_values.output_state[i]); + let constraint = builder.mul_many_extension([sum_round_flags, not_final, diff]); + yield_constr.constraint(builder, constraint); + } + + // the address of w_i must be increased by 4 + let increment = builder.sub_extension(next_values.w_virt, local_values.w_virt); + let address_increment = builder.sub_extension(increment, four_ext); + let constraint = builder.mul_many_extension( + [sum_round_flags, not_final, address_increment] + ); + yield_constr.constraint(builder, constraint); + + + // if not the padding row, the hx address must be a sequence of numbers spaced 4 units apart + + for i in 0..7 { + let increment = builder.sub_extension(local_values.hx_virt[i + 1], local_values.hx_virt[i]); + let address_increment = builder.sub_extension(increment, four_ext); + let constraint = builder.mul_extension( + sum_round_flags, address_increment + ); + yield_constr.constraint(builder, constraint); + } + + // check the validation of key[i] + + for i in 0..32 { + + let bit_i_comp: Vec<_> = (0..64).map(|j| { + let k_j_i = builder.constant_extension(F::Extension::from_canonical_u8(SHA_COMPRESS_K_BINARY[j][i])); + builder.mul_extension(local_values.round[j], k_j_i) + }).collect(); + let bit_i = builder.add_many_extension(bit_i_comp); + let constraint = builder.sub_extension(local_values.k_i[i], bit_i); + yield_constr.constraint(builder, constraint); + } + + // wrapping add constraints + + for i in 0..8 { + wrapping_add_ext_circuit_constraints::( + builder, + local_values.hx[get_input_range(i)].try_into().unwrap(), + local_values.output_state[get_input_range(i)].try_into().unwrap(), + local_values.carry[get_input_range(i)].try_into().unwrap(), + local_values.output_hx[get_input_range(i)].try_into().unwrap(), + ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + + } } fn constraint_degree(&self) -> usize { - todo!() + 3 } } @@ -194,9 +432,20 @@ mod test { use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::types::{Field}; use std::borrow::Borrow; + use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; + use plonky2::field::polynomial::PolynomialValues; + use plonky2::fri::oracle::PolynomialBatch; + use plonky2::iop::challenger::Challenger; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use plonky2::timed; + use plonky2::util::timing::TimingTree; + use crate::config::StarkConfig; + use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; + use crate::prover::prove_single_table; use crate::sha_compress_sponge::columns::ShaCompressSpongeColumnsView; use crate::sha_compress_sponge::sha_compress_sponge_stark::{ShaCompressSpongeOp, ShaCompressSpongeStark}; use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range}; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use crate::witness::memory::MemoryAddress; @@ -331,4 +580,146 @@ mod test { Ok(()) } + + #[test] + fn test_stark_circuit() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaCompressSpongeStark; + + let stark = S::default(); + test_stark_circuit_constraints::(stark) + } + + #[test] + fn test_stark_degree() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaCompressSpongeStark; + + let stark = S { + f: Default::default(), + }; + test_stark_low_degree(stark) + } + + + fn get_random_input() -> Vec { + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaCompressSpongeStark; + let stark = S::default(); + + let hx_addresses: Vec = (0..32).step_by(4).map(|i| { + MemoryAddress { + context: 0, + segment: 0, + virt: i, + } + }).collect(); + + let w_addresses: Vec = (32..288).step_by(4).map(|i| { + MemoryAddress { + context: 0, + segment: 0, + virt: i, + } + }).collect(); + + let mut res = vec![]; + let mut output_state = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + for i in 0..64 { + + let mut input = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + input.extend(from_u32_to_be_bits(W[i])); + let input_state = output_state.clone(); + + output_state = stark.compress(&input_state, &from_u32_to_be_bits(W[i]), i).to_vec(); + let op = ShaCompressSpongeOp { + base_address: hx_addresses.iter().chain([w_addresses[i]].iter()).cloned().collect(), + i: i, + timestamp: 0, + input_state, + input, + }; + + res.push(op); + } + + res + + } + #[test] + fn sha_extend_sponge_benchmark() -> anyhow::Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type S = ShaCompressSpongeStark; + let stark = S::default(); + let config = StarkConfig::standard_fast_config(); + + init_logger(); + + let input = get_random_input(); + let mut timing = TimingTree::new("prove", log::Level::Debug); + let trace_poly_values = stark.generate_trace(input, 8); + + // TODO: Cloning this isn't great; consider having `from_values` accept a reference, + // or having `compute_permutation_z_polys` read trace values from the `PolynomialBatch`. + let cloned_trace_poly_values = timed!(timing, "clone", trace_poly_values.clone()); + + let trace_commitments = timed!( + timing, + "compute trace commitment", + PolynomialBatch::::from_values( + cloned_trace_poly_values, + config.fri_config.rate_bits, + false, + config.fri_config.cap_height, + &mut timing, + None, + ) + ); + let degree = 1 << trace_commitments.degree_log; + + // Fake CTL data. + let ctl_z_data = CtlZData { + helper_columns: vec![PolynomialValues::zero(degree)], + z: PolynomialValues::zero(degree), + challenge: GrandProductChallenge { + beta: F::ZERO, + gamma: F::ZERO, + }, + columns: vec![], + filter: vec![Some(Filter::new_simple(Column::constant(F::ZERO)))], + }; + let ctl_data = CtlData { + zs_columns: vec![ctl_z_data.clone(); config.num_challenges], + }; + + prove_single_table( + &stark, + &config, + &trace_poly_values, + &trace_commitments, + &ctl_data, + &GrandProductChallengeSet { + challenges: vec![ctl_z_data.challenge; config.num_challenges], + }, + &mut Challenger::new(), + &mut timing, + )?; + + timing.print(); + Ok(()) + } + + fn init_logger() { + let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); + } + } \ No newline at end of file From fcd9dbb2a93e0f8863a72ee503f6782fa1b3551f Mon Sep 17 00:00:00 2001 From: vanhger Date: Tue, 28 Jan 2025 00:01:57 +0700 Subject: [PATCH 19/25] chore: adjust the sha extend tables --- prover/src/sha_extend/columns.rs | 3 ++- prover/src/sha_extend/sha_extend_stark.rs | 9 +++---- prover/src/sha_extend_sponge/columns.rs | 3 ++- .../sha_extend_sponge_stark.rs | 24 ++++++++++++------- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/prover/src/sha_extend/columns.rs b/prover/src/sha_extend/columns.rs index e4674d99..4d3982f6 100644 --- a/prover/src/sha_extend/columns.rs +++ b/prover/src/sha_extend/columns.rs @@ -28,9 +28,10 @@ pub(crate) struct ShaExtendColumnsView { pub w_i: [T; 32], // w_i_inter_1 + w_i_minus_16 /// The timestamp at which inputs should be read from memory. pub timestamp: T, + pub is_normal_round: T, } -pub const NUM_SHA_EXTEND_COLUMNS: usize = size_of::>(); +pub const NUM_SHA_EXTEND_COLUMNS: usize = size_of::>(); //577 impl From<[T; NUM_SHA_EXTEND_COLUMNS]> for ShaExtendColumnsView { fn from(value: [T; NUM_SHA_EXTEND_COLUMNS]) -> Self { diff --git a/prover/src/sha_extend/sha_extend_stark.rs b/prover/src/sha_extend/sha_extend_stark.rs index 5110e24f..69cc131e 100644 --- a/prover/src/sha_extend/sha_extend_stark.rs +++ b/prover/src/sha_extend/sha_extend_stark.rs @@ -3,25 +3,26 @@ use std::marker::PhantomData; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cross_table_lookup::{Column, Filter}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::keccak::logic::{xor3_gen, xor3_gen_circuit}; -use crate::sha_extend::columns::{ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS}; +use crate::sha_extend::columns::{ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS, SHA_EXTEND_COL_MAP}; use crate::sha_extend::logic::{get_input_range, rotate_right, rotate_right_ext_circuit_constraint, rotate_right_packed_constraints, shift_right, shift_right_ext_circuit_constraints, shift_right_packed_constraints, wrapping_add, wrapping_add_ext_circuit_constraints, wrapping_add_packed_constraints, xor3}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; -const NUM_INPUTS: usize = 4 * 32; // w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 +pub const NUM_INPUTS: usize = 4 * 32; // w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 #[derive(Copy, Clone, Default)] pub struct ShaExtendStark { pub(crate) f: PhantomData, } - impl, const D: usize> ShaExtendStark { pub(crate) fn generate_trace( &self, @@ -70,7 +71,7 @@ impl, const D: usize> ShaExtendStark { .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.w_i_minus_7 = input_and_timestamp.0[get_input_range(3)] .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - + row.is_normal_round = F::ONE; self.generate_trace_row_for_round(&mut row); row } diff --git a/prover/src/sha_extend_sponge/columns.rs b/prover/src/sha_extend_sponge/columns.rs index e39e0b4f..9b849386 100644 --- a/prover/src/sha_extend_sponge/columns.rs +++ b/prover/src/sha_extend_sponge/columns.rs @@ -3,6 +3,7 @@ use std::intrinsics::transmute; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; pub(crate) const NUM_EXTEND_INPUT: usize = 4; +pub(crate) const SHA_EXTEND_SPONGE_READ_BITS: usize = NUM_EXTEND_INPUT * 32; pub(crate) struct ShaExtendSpongeColumnsView { /// Input @@ -30,7 +31,7 @@ pub(crate) struct ShaExtendSpongeColumnsView { pub timestamp: T, } -pub const NUM_SHA_EXTEND_SPONGE_COLUMNS: usize = size_of::>(); //170 +pub const NUM_SHA_EXTEND_SPONGE_COLUMNS: usize = size_of::>(); //216 impl From<[T; NUM_SHA_EXTEND_SPONGE_COLUMNS]> for ShaExtendSpongeColumnsView { fn from(value: [T; NUM_SHA_EXTEND_SPONGE_COLUMNS]) -> Self { diff --git a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs index 8c7be9e6..56713dc9 100644 --- a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs +++ b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs @@ -9,10 +9,12 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cpu::membus::NUM_CHANNELS; +use crate::cross_table_lookup::{Column, Filter}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; -use crate::sha_extend::logic::{get_input_range, from_be_bits_to_u32, from_u32_to_be_bits}; -use crate::sha_extend_sponge::columns::{ShaExtendSpongeColumnsView, NUM_EXTEND_INPUT, NUM_SHA_EXTEND_SPONGE_COLUMNS}; +use crate::sha_extend::logic::{get_input_range, from_u32_to_be_bits, from_be_bits_to_u32}; +use crate::sha_extend_sponge::columns::{ShaExtendSpongeColumnsView, NUM_EXTEND_INPUT, NUM_SHA_EXTEND_SPONGE_COLUMNS, SHA_EXTEND_SPONGE_COL_MAP}; use crate::sha_extend_sponge::logic::{diff_address_ext_circuit_constraint, round_increment_ext_circuit_constraint}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; @@ -20,11 +22,12 @@ use crate::witness::memory::MemoryAddress; pub const NUM_ROUNDS: usize = 48; +#[derive(Clone, Debug)] pub(crate) struct ShaExtendSpongeOp { /// The base address at which inputs are read pub(crate) base_address: Vec, - /// The timestamp at which inputs are read and output are written (same for both). + /// The timestamp at which inputs are read pub(crate) timestamp: usize, /// The input that was read. @@ -168,9 +171,10 @@ impl, const D: usize> Stark for ShaExtendSpon .sum::

(); // If this is not the final step or a padding row, - // the local and next timestamps must match. + // the timestamp must be increased by 2 * NUM_CHANNELS. yield_constr.constraint( - sum_round_flags * not_final * (next_values.timestamp - local_values.timestamp), + sum_round_flags * not_final * + (next_values.timestamp - local_values.timestamp - FE::from_canonical_usize(2 * NUM_CHANNELS)), ); // If this is not the final step or a padding row, @@ -234,7 +238,7 @@ impl, const D: usize> Stark for ShaExtendSpon let one_ext = builder.one_extension(); let four_ext = builder.constant_extension(F::Extension::from_canonical_u32(4)); - + let num_channel = builder.constant_extension(F::Extension::from_canonical_usize(2 * NUM_CHANNELS)); // check the binary form for i in 0..32 { let constraint = builder.mul_sub_extension( @@ -276,8 +280,9 @@ impl, const D: usize> Stark for ShaExtendSpon builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values.round[i])); // If this is not the final step or a padding row, - // the local and next timestamps must match. + // the timestamp must be increased by 2 * NUM_CHANNELS. let diff = builder.sub_extension(next_values.timestamp, local_values.timestamp); + let diff = builder.sub_extension(diff, num_channel); let constraint = builder.mul_many_extension([sum_round_flags, not_final, diff]); yield_constr.constraint(builder, constraint); @@ -377,6 +382,7 @@ mod test { use plonky2::util::timing::TimingTree; use crate::config::StarkConfig; use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; + use crate::memory::NUM_CHANNELS; use crate::memory::segments::Segment; use crate::prover::prove_single_table; use crate::sha_extend_sponge::sha_extend_sponge_stark::{ShaExtendSpongeOp, ShaExtendSpongeStark}; @@ -500,6 +506,7 @@ mod test { let mut res = vec![]; + let mut time = 0; for i in 16..64 { let mut input_values = vec![]; input_values.extend(to_be_bits(w[i - 15])); @@ -509,13 +516,14 @@ mod test { let op = ShaExtendSpongeOp { base_address: vec![addresses[i - 15], addresses[i - 2], addresses[i - 16], addresses[i - 7]], - timestamp: 0, + timestamp: time, input: input_values, i: i - 16, output_address: addresses[i], }; res.push(op); + time += 2 * NUM_CHANNELS; } res From bb5171b94fe627d6fba19667aa6272987275fffb Mon Sep 17 00:00:00 2001 From: vanhger Date: Tue, 28 Jan 2025 00:07:57 +0700 Subject: [PATCH 20/25] feat: add CTL for sha extend tables. --- prover/examples/utils/src/utils.rs | 2 +- prover/src/all_stark.rs | 77 ++++++++++++++++- prover/src/cpu/columns/general.rs | 16 ++++ prover/src/cpu/columns/mod.rs | 1 + prover/src/cpu/cpu_stark.rs | 22 +++++ prover/src/fixed_recursive_verifier.rs | 19 +++++ prover/src/prover.rs | 32 ++++++++ prover/src/sha_extend/sha_extend_stark.rs | 35 ++++++++ .../sha_extend_sponge_stark.rs | 82 +++++++++++++++++++ prover/src/verifier.rs | 25 ++++++ prover/src/witness/operation.rs | 30 +++++++ prover/src/witness/traces.rs | 37 ++++++++- prover/src/witness/util.rs | 47 +++++++++++ 13 files changed, 420 insertions(+), 5 deletions(-) diff --git a/prover/examples/utils/src/utils.rs b/prover/examples/utils/src/utils.rs index af5dc87f..e70f6a8e 100644 --- a/prover/examples/utils/src/utils.rs +++ b/prover/examples/utils/src/utils.rs @@ -16,7 +16,7 @@ use zkm_prover::cpu::kernel::assembler::segment_kernel; use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; use zkm_prover::generation::state::{AssumptionReceipts, Receipt}; -const DEGREE_BITS_RANGE: [Range; 8] = [10..21, 12..22, 11..21, 8..21, 6..21, 6..21, 6..21, 13..23]; +const DEGREE_BITS_RANGE: [Range; 10] = [10..21, 12..22, 11..21, 8..21, 6..21, 6..21, 6..21, 6..21, 6..21, 13..23]; const D: usize = 2; type C = PoseidonGoldilocksConfig; diff --git a/prover/src/all_stark.rs b/prover/src/all_stark.rs index 14191577..5ad6d11f 100644 --- a/prover/src/all_stark.rs +++ b/prover/src/all_stark.rs @@ -23,6 +23,11 @@ use crate::poseidon::poseidon_stark::PoseidonStark; use crate::poseidon_sponge::columns::POSEIDON_RATE_BYTES; use crate::poseidon_sponge::poseidon_sponge_stark; use crate::poseidon_sponge::poseidon_sponge_stark::PoseidonSpongeStark; +use crate::sha_extend::sha_extend_stark; +use crate::sha_extend::sha_extend_stark::ShaExtendStark; +use crate::sha_extend_sponge::columns::SHA_EXTEND_SPONGE_READ_BITS; +use crate::sha_extend_sponge::sha_extend_sponge_stark; +use crate::sha_extend_sponge::sha_extend_sponge_stark::ShaExtendSpongeStark; use crate::stark::Stark; #[derive(Clone)] @@ -33,6 +38,8 @@ pub struct AllStark, const D: usize> { pub poseidon_sponge_stark: PoseidonSpongeStark, pub keccak_stark: KeccakStark, pub keccak_sponge_stark: KeccakSpongeStark, + pub sha_extend_stark: ShaExtendStark, + pub sha_extend_sponge_stark: ShaExtendSpongeStark, pub logic_stark: LogicStark, pub memory_stark: MemoryStark, pub cross_table_lookups: Vec>, @@ -47,6 +54,8 @@ impl, const D: usize> Default for AllStark { poseidon_sponge_stark: PoseidonSpongeStark::default(), keccak_stark: KeccakStark::default(), keccak_sponge_stark: KeccakSpongeStark::default(), + sha_extend_stark: ShaExtendStark::default(), + sha_extend_sponge_stark: ShaExtendSpongeStark::default(), logic_stark: LogicStark::default(), memory_stark: MemoryStark::default(), cross_table_lookups: all_cross_table_lookups(), @@ -63,6 +72,8 @@ impl, const D: usize> AllStark { self.poseidon_sponge_stark.num_lookup_helper_columns(config), self.keccak_stark.num_lookup_helper_columns(config), self.keccak_sponge_stark.num_lookup_helper_columns(config), + self.sha_extend_stark.num_lookup_helper_columns(config), + self.sha_extend_sponge_stark.num_lookup_helper_columns(config), self.logic_stark.num_lookup_helper_columns(config), self.memory_stark.num_lookup_helper_columns(config), ] @@ -77,8 +88,10 @@ pub enum Table { PoseidonSponge = 3, Keccak = 4, KeccakSponge = 5, - Logic = 6, - Memory = 7, + ShaExtend = 6, + ShaExtendSponge = 7, + Logic = 8, + Memory = 9, } pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; @@ -95,6 +108,8 @@ impl Table { Self::PoseidonSponge, Self::Keccak, Self::KeccakSponge, + Self::ShaExtend, + Self::ShaExtendSponge, Self::Logic, Self::Memory, ] @@ -110,6 +125,9 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { ctl_keccak_sponge(), ctl_keccak_inputs(), ctl_keccak_outputs(), + ctl_sha_extend_sponge(), + ctl_sha_extend_inputs(), + ctl_sha_extend_outputs(), ctl_logic(), ctl_memory(), ] @@ -215,6 +233,48 @@ fn ctl_keccak_sponge() -> CrossTableLookup { CrossTableLookup::new(vec![cpu_looking], keccak_sponge_looked) } +fn ctl_sha_extend_inputs() -> CrossTableLookup { + let sha_extend_sponge_looking = TableWithColumns::new( + Table::ShaExtendSponge, + sha_extend_sponge_stark::ctl_looking_sha_extend_inputs(), + Some(sha_extend_sponge_stark::ctl_looking_sha_extend_filter()), + ); + let sha_extend_looked = TableWithColumns::new( + Table::ShaExtend, + sha_extend_stark::ctl_data_inputs(), + Some(sha_extend_stark::ctl_filter_inputs()), + ); + CrossTableLookup::new(vec![sha_extend_sponge_looking], sha_extend_looked) +} + +fn ctl_sha_extend_outputs() -> CrossTableLookup { + let sha_extend_sponge_looking = TableWithColumns::new( + Table::ShaExtendSponge, + sha_extend_sponge_stark::ctl_looking_sha_extend_outputs(), + Some(sha_extend_sponge_stark::ctl_looking_sha_extend_filter()), + ); + let sha_extend_looked = TableWithColumns::new( + Table::ShaExtend, + sha_extend_stark::ctl_data_outputs(), + Some(sha_extend_stark::ctl_filter_outputs()), + ); + CrossTableLookup::new(vec![sha_extend_sponge_looking], sha_extend_looked) +} + +fn ctl_sha_extend_sponge() -> CrossTableLookup { + let cpu_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_sha_extend_sponge(), + Some(cpu_stark::ctl_filter_sha_extend_sponge()), + ); + let sha_extend_sponge_looked = TableWithColumns::new( + Table::ShaExtendSponge, + sha_extend_sponge_stark::ctl_looked_data(), + Some(sha_extend_sponge_stark::ctl_looking_sha_extend_filter()), + ); + CrossTableLookup::new(vec![cpu_looking], sha_extend_sponge_looked) +} + pub(crate) fn ctl_logic() -> CrossTableLookup { let cpu_looking = TableWithColumns::new( Table::Cpu, @@ -261,11 +321,22 @@ fn ctl_memory() -> CrossTableLookup { Some(keccak_sponge_stark::ctl_looking_memory_filter(i)), ) }); + + let sha_extend_sponge_reads = (0..SHA_EXTEND_SPONGE_READ_BITS).map(|i| { + TableWithColumns::new( + Table::ShaExtendSponge, + sha_extend_sponge_stark::ctl_looking_memory(i), + Some(sha_extend_sponge_stark::ctl_looking_sha_extend_filter()), + ) + }); + let all_lookers = [] .into_iter() .chain(cpu_memory_gp_ops) - .chain(poseidon_sponge_reads) .chain(keccak_sponge_reads) + .chain(poseidon_sponge_reads) + .chain(sha_extend_sponge_reads) + .collect(); let memory_looked = TableWithColumns::new( Table::Memory, diff --git a/prover/src/cpu/columns/general.rs b/prover/src/cpu/columns/general.rs index 90ffc2b5..d2d1e2b6 100644 --- a/prover/src/cpu/columns/general.rs +++ b/prover/src/cpu/columns/general.rs @@ -12,6 +12,7 @@ pub(crate) union CpuGeneralColumnsView { io: CpuIOAuxView, hash: CpuHashView, khash: CpuKHashView, + element: CpuElementView, misc: CpuMiscView, } @@ -36,6 +37,16 @@ impl CpuGeneralColumnsView { unsafe { &mut self.khash } } + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn element(&self) -> &CpuElementView { + unsafe { &self.element } + } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn element_mut(&mut self) -> &mut CpuElementView { + unsafe { &mut self.element } + } + // SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn syscall(&self) -> &CpuSyscallView { unsafe { &self.syscall } @@ -168,5 +179,10 @@ pub(crate) struct CpuKHashView { pub(crate) value: [T; 8], } +#[derive(Copy, Clone)] +pub(crate) struct CpuElementView { + pub(crate) value: T, +} + // `u8` is guaranteed to have a `size_of` of 1. pub const NUM_SHARED_COLUMNS: usize = size_of::>(); diff --git a/prover/src/cpu/columns/mod.rs b/prover/src/cpu/columns/mod.rs index f92fd311..0909cb0e 100644 --- a/prover/src/cpu/columns/mod.rs +++ b/prover/src/cpu/columns/mod.rs @@ -103,6 +103,7 @@ pub struct CpuColumnsView { /// Filter. 1 iff a Poseidon sponge lookup is performed on this row. pub is_poseidon_sponge: T, pub is_keccak_sponge: T, + pub is_sha_extend_sponge: T, pub(crate) general: CpuGeneralColumnsView, diff --git a/prover/src/cpu/cpu_stark.rs b/prover/src/cpu/cpu_stark.rs index de688cb3..00bdebb6 100644 --- a/prover/src/cpu/cpu_stark.rs +++ b/prover/src/cpu/cpu_stark.rs @@ -43,10 +43,32 @@ pub fn ctl_data_keccak_sponge() -> Vec> { cols } +pub fn ctl_data_sha_extend_sponge() -> Vec> { + // When executing KECCAK_GENERAL, the GP memory channels are used as follows: + // GP channel 0: stack[-1] = context + // GP channel 1: stack[-2] = segment + // GP channel 2: stack[-3] = virt + // GP channel 3: pushed = outputs + let context = Column::single(COL_MAP.mem_channels[0].value); + let segment = Column::single(COL_MAP.mem_channels[1].value); + let virt = Column::single(COL_MAP.mem_channels[2].value); + + let num_channels = F::from_canonical_usize(NUM_CHANNELS); + let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); + + let mut cols = vec![context, segment, virt, timestamp]; + cols.push(Column::single(COL_MAP.general.element().value)); + cols +} + pub fn ctl_filter_keccak_sponge() -> Filter { Filter::new_simple(Column::single(COL_MAP.is_keccak_sponge)) } +pub fn ctl_filter_sha_extend_sponge() -> Filter { + Filter::new_simple(Column::single(COL_MAP.is_sha_extend_sponge)) +} + pub fn ctl_data_poseidon_sponge() -> Vec> { // When executing POSEIDON_GENERAL, the GP memory channels are used as follows: // GP channel 0: stack[-1] = context diff --git a/prover/src/fixed_recursive_verifier.rs b/prover/src/fixed_recursive_verifier.rs index 5cf9550a..e19101ed 100644 --- a/prover/src/fixed_recursive_verifier.rs +++ b/prover/src/fixed_recursive_verifier.rs @@ -406,6 +406,23 @@ where &all_stark.cross_table_lookups, stark_config, ); + + let sha_extend = RecursiveCircuitsForTable::new( + Table::ShaExtend, + &all_stark.sha_extend_stark, + degree_bits_ranges[Table::ShaExtend as usize].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + + let sha_extend_sponge = RecursiveCircuitsForTable::new( + Table::ShaExtendSponge, + &all_stark.sha_extend_sponge_stark, + degree_bits_ranges[Table::ShaExtendSponge as usize].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let logic = RecursiveCircuitsForTable::new( Table::Logic, &all_stark.logic_stark, @@ -428,6 +445,8 @@ where poseidon_sponge, keccak, keccak_sponge, + sha_extend, + sha_extend_sponge, logic, memory, ]; diff --git a/prover/src/prover.rs b/prover/src/prover.rs index c9681c24..e3576a6d 100644 --- a/prover/src/prover.rs +++ b/prover/src/prover.rs @@ -332,6 +332,36 @@ where )? ); + let sha_extend_proof = timed!( + timing, + "prove SHA Extend STARK", + prove_single_table( + &all_stark.sha_extend_stark, + config, + &trace_poly_values[Table::ShaExtend as usize], + &trace_commitments[Table::ShaExtend as usize], + &ctl_data_per_table[Table::ShaExtend as usize], + ctl_challenges, + challenger, + timing, + )? + ); + + let sha_extend_sponge_proof = timed!( + timing, + "prove SHA Extend sponge STARK", + prove_single_table( + &all_stark.sha_extend_sponge_stark, + config, + &trace_poly_values[Table::ShaExtendSponge as usize], + &trace_commitments[Table::ShaExtendSponge as usize], + &ctl_data_per_table[Table::ShaExtendSponge as usize], + ctl_challenges, + challenger, + timing, + )? + ); + let logic_proof = timed!( timing, "prove Logic STARK", @@ -368,6 +398,8 @@ where poseidon_sponge_proof, keccak_proof, keccak_sponge_proof, + sha_extend_proof, + sha_extend_sponge_proof, logic_proof, memory_proof, ]) diff --git a/prover/src/sha_extend/sha_extend_stark.rs b/prover/src/sha_extend/sha_extend_stark.rs index 69cc131e..2165ee7b 100644 --- a/prover/src/sha_extend/sha_extend_stark.rs +++ b/prover/src/sha_extend/sha_extend_stark.rs @@ -18,6 +18,41 @@ use crate::util::trace_rows_to_poly_values; pub const NUM_INPUTS: usize = 4 * 32; // w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 +pub fn ctl_data_inputs() -> Vec> { + let cols = SHA_EXTEND_COL_MAP; + let mut res: Vec<_> = Column::singles( + [ + cols.w_i_minus_15.as_slice(), + cols.w_i_minus_2.as_slice(), + cols.w_i_minus_16.as_slice(), + cols.w_i_minus_7.as_slice(), + ] + .concat(), + ) + .collect(); + res.push(Column::single(cols.timestamp)); + res +} + +pub fn ctl_data_outputs() -> Vec> { + let cols = SHA_EXTEND_COL_MAP; + let mut res: Vec<_> = Column::singles(&cols.w_i).collect(); + res.push(Column::single(cols.timestamp)); + res +} + +pub fn ctl_filter_inputs() -> Filter { + let cols = SHA_EXTEND_COL_MAP; + // not the padding rows. + Filter::new_simple(Column::single(cols.is_normal_round)) +} +pub fn ctl_filter_outputs() -> Filter { + let cols = SHA_EXTEND_COL_MAP; + // not the padding rows. + Filter::new_simple(Column::single(cols.is_normal_round)) +} + + #[derive(Copy, Clone, Default)] pub struct ShaExtendStark { pub(crate) f: PhantomData, diff --git a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs index 56713dc9..1abbbead 100644 --- a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs +++ b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs @@ -22,6 +22,88 @@ use crate::witness::memory::MemoryAddress; pub const NUM_ROUNDS: usize = 48; +pub(crate) fn ctl_looking_sha_extend_inputs() -> Vec> { + let cols = SHA_EXTEND_SPONGE_COL_MAP; + let mut res: Vec<_> = Column::singles( + [ + cols.w_i_minus_15.as_slice(), + cols.w_i_minus_2.as_slice(), + cols.w_i_minus_16.as_slice(), + cols.w_i_minus_7.as_slice(), + ] + .concat(), + ) + .collect(); + res.push(Column::single(cols.timestamp)); + res +} + +pub(crate) fn ctl_looking_sha_extend_outputs() -> Vec> { + let cols = SHA_EXTEND_SPONGE_COL_MAP; + + let mut res = vec![]; + res.extend(Column::singles(&cols.w_i)); + res.push(Column::single(cols.timestamp)); + res +} + +pub(crate) fn ctl_looked_data() -> Vec> { + let cols = SHA_EXTEND_SPONGE_COL_MAP; + let w_i_usize = Column::linear_combination( + cols.w_i.iter() + .enumerate() + .map(|(i, &b)| (b, F::from_canonical_usize(1 << i))), + ); + + Column::singles([ + cols.context, + cols.segment, + cols.output_virt, + cols.timestamp, + ]).chain([w_i_usize]).collect() +} + +pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { + let cols = SHA_EXTEND_SPONGE_COL_MAP; + + let mut res = vec![Column::constant(F::ONE)]; // is_read + + res.extend(Column::singles([cols.context, cols.segment])); + res.push(Column::single(cols.input_virt[i / 32])); + + // The u32 of i'th input bit being read. + let start = i / 32; + let mut le_bit; + if start == 0 { + le_bit = cols.w_i_minus_15; + } else if start == 1 { + le_bit = cols.w_i_minus_2; + } else if start == 2 { + le_bit = cols.w_i_minus_16; + } else { + le_bit = cols.w_i_minus_7; + } + // le_bit.reverse(); + let u32_value: Column = Column::le_bits(&le_bit); + res.push(u32_value); + + res.push(Column::single(cols.timestamp)); + + assert_eq!( + res.len(), + crate::memory::memory_stark::ctl_data::().len() + ); + res +} + +pub(crate) fn ctl_looking_sha_extend_filter() -> Filter { + let cols = SHA_EXTEND_SPONGE_COL_MAP; + // not the padding rows. + Filter::new_simple(Column::sum( + &cols.round, + )) +} + #[derive(Clone, Debug)] pub(crate) struct ShaExtendSpongeOp { /// The base address at which inputs are read diff --git a/prover/src/verifier.rs b/prover/src/verifier.rs index 3554169f..e4534be9 100644 --- a/prover/src/verifier.rs +++ b/prover/src/verifier.rs @@ -47,6 +47,8 @@ where poseidon_sponge_stark, keccak_stark, keccak_sponge_stark, + sha_extend_stark, + sha_extend_sponge_stark, logic_stark, memory_stark, cross_table_lookups, @@ -116,6 +118,27 @@ where config, )?; + verify_stark_proof_with_challenges( + sha_extend_stark, + &all_proof.stark_proofs[Table::ShaExtend as usize].proof, + &stark_challenges[Table::ShaExtend as usize], + &ctl_vars_per_table[Table::ShaExtend as usize], + &ctl_challenges, + config, + )?; + log::info!("ShaExtend Stark proof verified"); + + verify_stark_proof_with_challenges( + sha_extend_sponge_stark, + &all_proof.stark_proofs[Table::ShaExtendSponge as usize].proof, + &stark_challenges[Table::ShaExtendSponge as usize], + &ctl_vars_per_table[Table::ShaExtendSponge as usize], + &ctl_challenges, + config, + )?; + log::info!("ShaExtendSponge Stark proof verified"); + + verify_stark_proof_with_challenges( logic_stark, &all_proof.stark_proofs[Table::Logic as usize].proof, @@ -124,6 +147,7 @@ where &ctl_challenges, config, )?; + log::info!("Logic Stark proof verified"); verify_stark_proof_with_challenges( memory_stark, &all_proof.stark_proofs[Table::Memory as usize].proof, @@ -132,6 +156,7 @@ where &ctl_challenges, config, )?; + log::info!("Memory Stark proof verified"); verify_cross_table_lookups::( cross_table_lookups, diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index be448823..966f6ebf 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -20,6 +20,7 @@ use plonky2::field::extension::Extendable; use plonky2::hash::hash_types::RichField; use plonky2::plonk::config::GenericConfig; use std::fs; +use crate::sha_extend::logic::from_u32_to_be_bits; pub const WORD_SIZE: usize = core::mem::size_of::(); @@ -1194,16 +1195,23 @@ pub(crate) fn generate_sha_extend< for i in 16..64 { let mut cpu_row = CpuColumnsView::default(); cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + let mut input_addresses = vec![]; + let mut input_value_bit_be = vec![]; let addr = MemoryAddress::new(0, Segment::Code, w_ptr + (i - 15) * 4); let (w_i_minus_15, mem_op) = mem_read_gp_with_log_and_fill(0, addr, state, &mut cpu_row); state.traces.push_memory(mem_op); let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); + input_addresses.push(addr); + input_value_bit_be.push(from_u32_to_be_bits(w_i_minus_15)); // Read w[i-2]. let addr = MemoryAddress::new(0, Segment::Code, w_ptr + (i - 2) * 4); let (w_i_minus_2, mem_op) = mem_read_gp_with_log_and_fill(1, addr, state, &mut cpu_row); state.traces.push_memory(mem_op); + + input_addresses.push(addr); + input_value_bit_be.push(from_u32_to_be_bits(w_i_minus_2)); // Compute `s1`. let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); @@ -1211,11 +1219,16 @@ pub(crate) fn generate_sha_extend< let addr = MemoryAddress::new(0, Segment::Code, w_ptr + (i - 16) * 4); let (w_i_minus_16, mem_op) = mem_read_gp_with_log_and_fill(2, addr, state, &mut cpu_row); state.traces.push_memory(mem_op); + input_addresses.push(addr); + input_value_bit_be.push(from_u32_to_be_bits(w_i_minus_16)); + // Read w[i-7]. let addr = MemoryAddress::new(0, Segment::Code, w_ptr + (i - 7) * 4); let (w_i_minus_7, mem_op) = mem_read_gp_with_log_and_fill(3, addr, state, &mut cpu_row); state.traces.push_memory(mem_op); + input_addresses.push(addr); + input_value_bit_be.push(from_u32_to_be_bits(w_i_minus_7)); // Compute `w_i`. let w_i = s1 @@ -1228,11 +1241,28 @@ pub(crate) fn generate_sha_extend< let addr = MemoryAddress::new(0, Segment::Code, w_ptr + i * 4); log::debug!("extend write {:X} {:X}", w_ptr + i * 4, w_i); let mem_op = mem_write_gp_log_and_fill(4, addr, state, &mut cpu_row, w_i); + state.traces.push_memory(mem_op); state.traces.push_cpu(cpu_row); state.memory.apply_ops(&state.traces.memory_ops); + + cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + cpu_row.is_sha_extend_sponge = F::ONE; + + // The SHA extend sponge CTL uses memory value columns for its inputs and outputs. + cpu_row.mem_channels[0].value = F::ZERO; // context + cpu_row.mem_channels[1].value = F::from_canonical_usize(Segment::Code as usize); + cpu_row.mem_channels[2].value = F::from_canonical_usize(addr.virt); + cpu_row.general.element_mut().value = F::from_canonical_u32(w_i); + sha_extend_sponge_log(state, input_addresses, input_value_bit_be, addr, i - 16); + state.traces.push_cpu(cpu_row); + + } + + Ok(()) } diff --git a/prover/src/witness/traces.rs b/prover/src/witness/traces.rs index 5cf0c8c1..4b6245df 100644 --- a/prover/src/witness/traces.rs +++ b/prover/src/witness/traces.rs @@ -23,6 +23,8 @@ use crate::util::join; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryOp; use crate::{arithmetic, logic}; +use crate::sha_extend::sha_extend_stark; +use crate::sha_extend_sponge::sha_extend_sponge_stark::ShaExtendSpongeOp; #[derive(Clone, Copy, Debug)] pub struct TraceCheckpoint { @@ -32,6 +34,8 @@ pub struct TraceCheckpoint { pub(self) poseidon_sponge_len: usize, pub(self) keccak_len: usize, pub(self) keccak_sponge_len: usize, + pub(self) sha_extend_len: usize, + pub(self) sha_extend_sponge_len: usize, pub(self) logic_len: usize, pub(self) memory_len: usize, } @@ -46,6 +50,8 @@ pub(crate) struct Traces { pub(crate) poseidon_sponge_ops: Vec, pub(crate) keccak_inputs: Vec<([u64; keccak_stark::NUM_INPUTS], usize)>, pub(crate) keccak_sponge_ops: Vec, + pub(crate) sha_extend_inputs: Vec<([u8; sha_extend_stark::NUM_INPUTS], usize)>, + pub(crate) sha_extend_sponge_ops: Vec, } impl Traces { @@ -59,6 +65,8 @@ impl Traces { poseidon_sponge_ops: vec![], keccak_inputs: vec![], keccak_sponge_ops: vec![], + sha_extend_inputs: vec![], + sha_extend_sponge_ops: vec![], } } @@ -89,6 +97,9 @@ impl Traces { .iter() .map(|op| op.input.len() / keccak_sponge::columns::KECCAK_RATE_BYTES + 1) .sum(), + sha_extend_len: self.sha_extend_inputs.len(), + sha_extend_sponge_len: self + .sha_extend_sponge_ops.len(), logic_len: self.logic_ops.len(), // This is technically a lower-bound, as we may fill gaps, // but this gives a relatively good estimate. @@ -105,6 +116,8 @@ impl Traces { poseidon_sponge_len: self.poseidon_sponge_ops.len(), keccak_len: self.keccak_inputs.len(), keccak_sponge_len: self.keccak_sponge_ops.len(), + sha_extend_len: self.sha_extend_inputs.len(), + sha_extend_sponge_len: self.sha_extend_sponge_ops.len(), logic_len: self.logic_ops.len(), memory_len: self.memory_ops.len(), } @@ -119,6 +132,9 @@ impl Traces { self.keccak_inputs.truncate(checkpoint.keccak_len); self.keccak_sponge_ops .truncate(checkpoint.keccak_sponge_len); + self.sha_extend_inputs.truncate(checkpoint.sha_extend_len); + self.sha_extend_sponge_ops + .truncate(checkpoint.sha_extend_sponge_len); self.logic_ops.truncate(checkpoint.logic_len); self.memory_ops.truncate(checkpoint.memory_len); } @@ -169,6 +185,14 @@ impl Traces { self.keccak_sponge_ops.push(op); } + pub fn push_sha_extend(&mut self, input: [u8; sha_extend_stark::NUM_INPUTS], clock: usize) { + self.sha_extend_inputs.push((input, clock)); + } + + pub fn push_sha_extend_sponge(&mut self, op: ShaExtendSpongeOp) { + self.sha_extend_sponge_ops.push(op); + } + pub fn clock(&self) -> usize { self.cpu.len() } @@ -193,6 +217,8 @@ impl Traces { poseidon_sponge_ops, keccak_inputs, keccak_sponge_ops, + sha_extend_inputs, + sha_extend_sponge_ops, } = self; let mut memory_trace = vec![]; @@ -203,7 +229,8 @@ impl Traces { let mut keccak_trace = vec![]; let mut keccak_sponge_trace = vec![]; let mut logic_trace = vec![]; - + let mut sha_extend_trace = vec![]; + let mut sha_extend_sponge_trace = vec![]; timed!( timing, "convert trace to table parallelly", @@ -224,6 +251,12 @@ impl Traces { || keccak_sponge_trace = all_stark .keccak_sponge_stark .generate_trace(keccak_sponge_ops, min_rows), + || sha_extend_trace = all_stark + .sha_extend_stark + .generate_trace(sha_extend_inputs, min_rows), + || sha_extend_sponge_trace = all_stark + .sha_extend_sponge_stark + .generate_trace(sha_extend_sponge_ops, min_rows), || logic_trace = all_stark.logic_stark.generate_trace(logic_ops, min_rows), ) ); @@ -235,6 +268,8 @@ impl Traces { poseidon_sponge_trace, keccak_trace, keccak_sponge_trace, + sha_extend_trace, + sha_extend_sponge_trace, logic_trace, memory_trace, ] diff --git a/prover/src/witness/util.rs b/prover/src/witness/util.rs index 379f7d36..8b924b3f 100644 --- a/prover/src/witness/util.rs +++ b/prover/src/witness/util.rs @@ -23,6 +23,8 @@ use crate::witness::errors::ProgramError; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; use plonky2::field::extension::Extendable; use plonky2::plonk::config::GenericConfig; +use crate::sha_compress::logic::from_be_bits_to_u32; +use crate::sha_extend_sponge::sha_extend_sponge_stark::ShaExtendSpongeOp; fn to_byte_checked(n: u32) -> u8 { let res: u8 = n.to_le_bytes()[0]; @@ -553,6 +555,51 @@ pub(crate) fn keccak_sponge_log< }); } +pub(crate) fn sha_extend_sponge_log< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> ( + state: &mut GenerationState, + base_address: Vec, + inputs: Vec<[u8; 32]>, // BE bits + output_address: MemoryAddress, + round: usize, +) { + // Since the Sha extend reads bit by bit, and the memory unit is of 4-byte, we just need to read + // the same memory for 32 sha-extend ops + + let clock = state.traces.clock(); + let mut n_gp = 0; + let mut addr_idx = 0; + let extend_input: Vec = inputs.iter().flatten().cloned().collect(); + + for input in inputs { + let val = from_be_bits_to_u32(input); + for _ in 0..32 { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::GeneralPurpose(n_gp), + clock, + base_address[addr_idx], + MemoryOpKind::Read, + val, + )); + n_gp += 1; + n_gp %= NUM_GP_CHANNELS - 1; + } + addr_idx += 1; + } + state.traces.push_sha_extend(extend_input.clone().try_into().unwrap(), clock * NUM_CHANNELS); + + state.traces.push_sha_extend_sponge(ShaExtendSpongeOp { + base_address, + timestamp: clock * NUM_CHANNELS, + input: extend_input, + i: round, + output_address + }); +} + fn xor_into_sponge, C: GenericConfig, const D: usize>( state: &mut GenerationState, sponge_state: &mut [u8; KECCAK_WIDTH_BYTES], From 021b7642e1082001b78870801b571f55d5367028 Mon Sep 17 00:00:00 2001 From: vanhger Date: Sat, 1 Feb 2025 21:58:40 +0700 Subject: [PATCH 21/25] chore: rename function. --- prover/src/sha_extend/logic.rs | 4 ++-- .../src/sha_extend_sponge/sha_extend_sponge_stark.rs | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/prover/src/sha_extend/logic.rs b/prover/src/sha_extend/logic.rs index d3e79350..eab8f703 100644 --- a/prover/src/sha_extend/logic.rs +++ b/prover/src/sha_extend/logic.rs @@ -100,7 +100,7 @@ pub(crate) fn wrapping_add, const D: usize, const N ) -> ([F; N], [F; N]) { let mut result = [F::ZERO; N]; let mut carries = [F::ZERO; N]; - let mut sum = F::ZERO; + let mut sum; let mut carry = F::ZERO; for i in 0..N { debug_assert!(a[i].is_zero() || a[i].is_one()); @@ -115,7 +115,7 @@ pub(crate) fn wrapping_add, const D: usize, const N (result, carries) } -pub(crate) fn from_be_bits_to_u32, const D: usize>(value: [F; 32]) -> u32 { +pub(crate) fn from_be_fbits_to_u32, const D: usize>(value: [F; 32]) -> u32 { let mut result = 0; for i in 0..32 { debug_assert!(value[i].is_zero() || value[i].is_one()); diff --git a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs index 1abbbead..a64cd310 100644 --- a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs +++ b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs @@ -13,7 +13,7 @@ use crate::cpu::membus::NUM_CHANNELS; use crate::cross_table_lookup::{Column, Filter}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; -use crate::sha_extend::logic::{get_input_range, from_u32_to_be_bits, from_be_bits_to_u32}; +use crate::sha_extend::logic::{get_input_range, from_u32_to_be_bits, from_be_fbits_to_u32}; use crate::sha_extend_sponge::columns::{ShaExtendSpongeColumnsView, NUM_EXTEND_INPUT, NUM_SHA_EXTEND_SPONGE_COLUMNS, SHA_EXTEND_SPONGE_COL_MAP}; use crate::sha_extend_sponge::logic::{diff_address_ext_circuit_constraint, round_increment_ext_circuit_constraint}; use crate::stark::Stark; @@ -73,7 +73,7 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { // The u32 of i'th input bit being read. let start = i / 32; - let mut le_bit; + let le_bit; if start == 0 { le_bit = cols.w_i_minus_15; } else if start == 1 { @@ -188,10 +188,10 @@ impl, const D: usize> ShaExtendSpongeStark { } fn compute_w_i(&self, row: &mut ShaExtendSpongeColumnsView) -> [F; 32] { - let w_i_minus_15 = from_be_bits_to_u32(row.w_i_minus_15); - let w_i_minus_2 = from_be_bits_to_u32(row.w_i_minus_2); - let w_i_minus_16 = from_be_bits_to_u32(row.w_i_minus_16); - let w_i_minus_7 = from_be_bits_to_u32(row.w_i_minus_7); + let w_i_minus_15 = from_be_fbits_to_u32(row.w_i_minus_15); + let w_i_minus_2 = from_be_fbits_to_u32(row.w_i_minus_2); + let w_i_minus_16 = from_be_fbits_to_u32(row.w_i_minus_16); + let w_i_minus_7 = from_be_fbits_to_u32(row.w_i_minus_7); let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); let w_i_u32 = s1 From a5549399be94f4f31371f69ef29ac96fbb6d5d5b Mon Sep 17 00:00:00 2001 From: vanhger Date: Sat, 1 Feb 2025 22:06:15 +0700 Subject: [PATCH 22/25] feat: add columns for sha compress --- prover/src/sha_compress/columns.rs | 2 +- prover/src/sha_compress/sha_compress_stark.rs | 27 ++++++++++--------- prover/src/sha_compress_sponge/mod.rs | 2 +- .../sha_compress_sponge_stark.rs | 20 +++++++------- 4 files changed, 27 insertions(+), 24 deletions(-) diff --git a/prover/src/sha_compress/columns.rs b/prover/src/sha_compress/columns.rs index 739404df..fc0eeb24 100644 --- a/prover/src/sha_compress/columns.rs +++ b/prover/src/sha_compress/columns.rs @@ -47,9 +47,9 @@ pub(crate) struct ShaCompressColumnsView { pub carry_a: [T; 32], pub carry_e: [T; 32], - /// The timestamp at which inputs should be read from memory. pub timestamp: T, + pub is_normal_round: T, } diff --git a/prover/src/sha_compress/sha_compress_stark.rs b/prover/src/sha_compress/sha_compress_stark.rs index 98bc4515..ad69b6aa 100644 --- a/prover/src/sha_compress/sha_compress_stark.rs +++ b/prover/src/sha_compress/sha_compress_stark.rs @@ -3,13 +3,15 @@ use std::borrow::Borrow; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cross_table_lookup::{Column, Filter}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::keccak::logic::{xor3_gen, xor3_gen_circuit, xor_gen, xor_gen_circuit}; -use crate::sha_compress::columns::{ShaCompressColumnsView, NUM_SHA_COMPRESS_COLUMNS}; +use crate::sha_compress::columns::{ShaCompressColumnsView, NUM_SHA_COMPRESS_COLUMNS, SHA_COMPRESS_COL_MAP}; use crate::sha_compress::logic::{and_op, and_op_ext_circuit_constraints, and_op_packed_constraints, andn_op, andn_op_ext_circuit_constraints, andn_op_packed_constraints, equal_ext_circuit_constraints, equal_packed_constraint, xor_op}; use crate::sha_extend::logic::{rotate_right, get_input_range, xor3, wrapping_add, rotate_right_packed_constraints, wrapping_add_packed_constraints, rotate_right_ext_circuit_constraint, wrapping_add_ext_circuit_constraints}; use crate::stark::Stark; @@ -17,7 +19,7 @@ use crate::util::trace_rows_to_poly_values; pub const NUM_ROUND_CONSTANTS: usize = 64; -pub const NUM_INPUTS: usize = 10; // 8 states + w_i + key_i +pub const NUM_INPUTS: usize = 10 * 32; // 8 states (a, b, ..., h) + w_i + key_i #[derive(Copy, Clone, Default)] pub struct ShaCompressStark { @@ -27,7 +29,7 @@ pub struct ShaCompressStark { impl, const D: usize> ShaCompressStark { pub(crate) fn generate_trace( &self, - inputs: Vec<([u8; NUM_INPUTS * 32], usize)>, + inputs: Vec<([u8; NUM_INPUTS], usize)>, min_rows: usize, ) -> Vec> { // Generate the witness row-wise @@ -37,7 +39,7 @@ impl, const D: usize> ShaCompressStark { fn generate_trace_rows( &self, - inputs_and_timestamps: Vec<([u8; NUM_INPUTS * 32], usize)>, + inputs_and_timestamps: Vec<([u8; NUM_INPUTS], usize)>, min_rows: usize, ) -> Vec<[F; NUM_SHA_COMPRESS_COLUMNS]> { let num_rows = inputs_and_timestamps.len() @@ -58,7 +60,7 @@ impl, const D: usize> ShaCompressStark { fn generate_trace_rows_for_compress( &self, - input_and_timestamp: ([u8; NUM_INPUTS * 32], usize), + input_and_timestamp: ([u8; NUM_INPUTS], usize), ) -> [F; NUM_SHA_COMPRESS_COLUMNS] { let timestamp = input_and_timestamp.1; @@ -66,6 +68,7 @@ impl, const D: usize> ShaCompressStark { let mut row = ShaCompressColumnsView::::default(); row.timestamp = F::from_canonical_usize(timestamp); + row.is_normal_round = F::ONE; // read inputs row.input_state = inputs[0..256].iter().map(|x| F::from_canonical_u8(*x)).collect::>().try_into().unwrap(); row.w_i = inputs[256..288].iter().map(|x| F::from_canonical_u8(*x)).collect::>().try_into().unwrap(); @@ -140,8 +143,8 @@ impl, const D: usize> ShaCompressStark { row.output_state[i] = row.input_state[i - 32]; } - let mut new_e; - let mut new_a; + let new_e; + let new_a; (new_e, row.carry_e) = wrapping_add( row.input_state[get_input_range(3)].try_into().unwrap(), @@ -652,8 +655,6 @@ mod test { use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; use crate::prover::prove_single_table; use crate::sha_compress_sponge::constants::SHA_COMPRESS_K; - use crate::sha_extend::sha_extend_stark::ShaExtendStark; - use crate::sha_extend_sponge::columns::NUM_EXTEND_INPUT; use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; const W: [u32; 64] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, @@ -670,9 +671,9 @@ mod test { 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, ]; - fn get_random_input() -> [u8; NUM_INPUTS * 32] { - let mut input = [0u8; NUM_INPUTS * 32]; - for i in 0..NUM_INPUTS * 32 { + fn get_random_input() -> [u8; NUM_INPUTS] { + let mut input = [0u8; NUM_INPUTS]; + for i in 0..NUM_INPUTS { input[i] = rand::random::() % 2; debug_assert!(input[i] == 0 || input[i] == 1); } @@ -774,7 +775,7 @@ mod test { init_logger(); - let input: Vec<([u8; NUM_INPUTS * 32], usize)> = + let input: Vec<([u8; NUM_INPUTS], usize)> = (0..NUM_EXTEND).map(|_| (get_random_input(), 0)).collect(); let mut timing = TimingTree::new("prove", log::Level::Debug); diff --git a/prover/src/sha_compress_sponge/mod.rs b/prover/src/sha_compress_sponge/mod.rs index 1c6919dc..c47afc31 100644 --- a/prover/src/sha_compress_sponge/mod.rs +++ b/prover/src/sha_compress_sponge/mod.rs @@ -1,3 +1,3 @@ pub mod columns; -mod sha_compress_sponge_stark; +pub mod sha_compress_sponge_stark; pub mod constants; \ No newline at end of file diff --git a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs index 74b47bae..7e1e5037 100644 --- a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs +++ b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs @@ -9,18 +9,20 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cross_table_lookup::{Column, Filter}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; use crate::sha_compress::logic::from_be_bits_to_u32; -use crate::sha_compress_sponge::columns::{ShaCompressSpongeColumnsView, NUM_SHA_COMPRESS_SPONGE_COLUMNS}; +use crate::sha_compress_sponge::columns::{ShaCompressSpongeColumnsView, NUM_SHA_COMPRESS_SPONGE_COLUMNS, SHA_COMPRESS_SPONGE_COL_MAP}; use crate::sha_compress_sponge::constants::{NUM_COMPRESS_ROWS, SHA_COMPRESS_K_BINARY}; use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range, wrapping_add, wrapping_add_ext_circuit_constraints, wrapping_add_packed_constraints}; -use crate::sha_extend_sponge::sha_extend_sponge_stark::NUM_ROUNDS; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryAddress; use crate::witness::operation::SHA_COMPRESS_K; +pub(crate) const NUM_ROUNDS: usize = 64; + #[derive(Clone, Debug)] pub(crate) struct ShaCompressSpongeOp { /// The base address at which inputs are read. @@ -30,8 +32,8 @@ pub(crate) struct ShaCompressSpongeOp { /// The timestamp at which inputs are read. pub(crate) timestamp: usize, - /// The input state - pub(crate) input_state: Vec, + /// The input state: a, b, c, d, e, f, g, h. + pub(crate) input_states: Vec, /// The index of round pub(crate) i: usize, @@ -101,9 +103,9 @@ impl, const D: usize> ShaCompressSpongeStark row.k_i = SHA_COMPRESS_K_BINARY[op.i].map(|k| F::from_canonical_u8(k)); row.w_i = op.input[256..288].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); row.hx = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - row.input_state = op.input_state.iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + row.input_state = op.input_states.iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - let output = self.compress(&op.input_state, &op.input[256..288], op.i); + let output = self.compress(&op.input_states, &op.input[256..288], op.i); row.output_state = output.map(F::from_canonical_u8); // We use the result if only we are at the final round. @@ -492,7 +494,7 @@ mod test { base_address: hx_addresses.iter().chain([w_addresses[0]].iter()).cloned().collect(), i: 0, timestamp: 0, - input_state, + input_states: input_state, input, }; let row = stark.generate_rows_for_op(op); @@ -538,7 +540,7 @@ mod test { base_address: hx_addresses.iter().chain([w_addresses[0]].iter()).cloned().collect(), i: 63, timestamp: 0, - input_state, + input_states: input_state, input, }; let row = stark.generate_rows_for_op(op); @@ -643,7 +645,7 @@ mod test { base_address: hx_addresses.iter().chain([w_addresses[i]].iter()).cloned().collect(), i: i, timestamp: 0, - input_state, + input_states: input_state, input, }; From 4a41b74c50a3fb9cc3e6650e7cf598f49ef59087 Mon Sep 17 00:00:00 2001 From: vanhger Date: Sat, 1 Feb 2025 22:29:14 +0700 Subject: [PATCH 23/25] feat: add CTL for Sha compress --- prover/examples/utils/src/utils.rs | 3 +- prover/src/all_stark.rs | 74 ++++++++++++- prover/src/cpu/columns/general.rs | 16 +++ prover/src/cpu/columns/mod.rs | 1 + prover/src/cpu/cpu_stark.rs | 25 ++++- prover/src/fixed_recursive_verifier.rs | 19 +++- prover/src/prover.rs | 32 ++++++ prover/src/sha_compress/sha_compress_stark.rs | 35 ++++++ .../sha_compress_sponge_stark.rs | 101 ++++++++++++++++++ prover/src/verifier.rs | 24 ++++- prover/src/witness/operation.rs | 44 ++++++-- prover/src/witness/traces.rs | 36 +++++++ prover/src/witness/util.rs | 72 +++++++++++++ 13 files changed, 466 insertions(+), 16 deletions(-) diff --git a/prover/examples/utils/src/utils.rs b/prover/examples/utils/src/utils.rs index e70f6a8e..befb291e 100644 --- a/prover/examples/utils/src/utils.rs +++ b/prover/examples/utils/src/utils.rs @@ -16,7 +16,8 @@ use zkm_prover::cpu::kernel::assembler::segment_kernel; use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; use zkm_prover::generation::state::{AssumptionReceipts, Receipt}; -const DEGREE_BITS_RANGE: [Range; 10] = [10..21, 12..22, 11..21, 8..21, 6..21, 6..21, 6..21, 6..21, 6..21, 13..23]; +const DEGREE_BITS_RANGE: [Range; 12] = [10..21, 12..22, 11..21, 8..21, + 6..21, 6..21, 6..13, 6..13, 6..13, 6..13, 6..21, 13..23]; const D: usize = 2; type C = PoseidonGoldilocksConfig; diff --git a/prover/src/all_stark.rs b/prover/src/all_stark.rs index 5ad6d11f..b15fd893 100644 --- a/prover/src/all_stark.rs +++ b/prover/src/all_stark.rs @@ -23,6 +23,10 @@ use crate::poseidon::poseidon_stark::PoseidonStark; use crate::poseidon_sponge::columns::POSEIDON_RATE_BYTES; use crate::poseidon_sponge::poseidon_sponge_stark; use crate::poseidon_sponge::poseidon_sponge_stark::PoseidonSpongeStark; +use crate::sha_compress::sha_compress_stark; +use crate::sha_compress::sha_compress_stark::ShaCompressStark; +use crate::sha_compress_sponge::sha_compress_sponge_stark; +use crate::sha_compress_sponge::sha_compress_sponge_stark::{ShaCompressSpongeStark, SHA_COMPRESS_SPONGE_READ_BITS}; use crate::sha_extend::sha_extend_stark; use crate::sha_extend::sha_extend_stark::ShaExtendStark; use crate::sha_extend_sponge::columns::SHA_EXTEND_SPONGE_READ_BITS; @@ -40,6 +44,8 @@ pub struct AllStark, const D: usize> { pub keccak_sponge_stark: KeccakSpongeStark, pub sha_extend_stark: ShaExtendStark, pub sha_extend_sponge_stark: ShaExtendSpongeStark, + pub sha_compress_stark: ShaCompressStark, + pub sha_compress_sponge_stark: ShaCompressSpongeStark, pub logic_stark: LogicStark, pub memory_stark: MemoryStark, pub cross_table_lookups: Vec>, @@ -56,6 +62,8 @@ impl, const D: usize> Default for AllStark { keccak_sponge_stark: KeccakSpongeStark::default(), sha_extend_stark: ShaExtendStark::default(), sha_extend_sponge_stark: ShaExtendSpongeStark::default(), + sha_compress_stark: ShaCompressStark::default(), + sha_compress_sponge_stark: ShaCompressSpongeStark::default(), logic_stark: LogicStark::default(), memory_stark: MemoryStark::default(), cross_table_lookups: all_cross_table_lookups(), @@ -74,6 +82,8 @@ impl, const D: usize> AllStark { self.keccak_sponge_stark.num_lookup_helper_columns(config), self.sha_extend_stark.num_lookup_helper_columns(config), self.sha_extend_sponge_stark.num_lookup_helper_columns(config), + self.sha_compress_stark.num_lookup_helper_columns(config), + self.sha_compress_sponge_stark.num_lookup_helper_columns(config), self.logic_stark.num_lookup_helper_columns(config), self.memory_stark.num_lookup_helper_columns(config), ] @@ -90,8 +100,10 @@ pub enum Table { KeccakSponge = 5, ShaExtend = 6, ShaExtendSponge = 7, - Logic = 8, - Memory = 9, + ShaCompress = 8, + ShaCompressSponge = 9, + Logic = 10, + Memory = 11, } pub(crate) const NUM_TABLES: usize = Table::Memory as usize + 1; @@ -110,6 +122,8 @@ impl Table { Self::KeccakSponge, Self::ShaExtend, Self::ShaExtendSponge, + Self::ShaCompress, + Self::ShaCompressSponge, Self::Logic, Self::Memory, ] @@ -128,6 +142,9 @@ pub(crate) fn all_cross_table_lookups() -> Vec> { ctl_sha_extend_sponge(), ctl_sha_extend_inputs(), ctl_sha_extend_outputs(), + ctl_sha_compress_sponge(), + ctl_sha_compress_inputs(), + ctl_sha_compress_outputs(), ctl_logic(), ctl_memory(), ] @@ -275,6 +292,49 @@ fn ctl_sha_extend_sponge() -> CrossTableLookup { CrossTableLookup::new(vec![cpu_looking], sha_extend_sponge_looked) } + +fn ctl_sha_compress_inputs() -> CrossTableLookup { + let sha_compress_sponge_looking = TableWithColumns::new( + Table::ShaCompressSponge, + sha_compress_sponge_stark::ctl_looking_sha_compress_inputs(), + Some(sha_compress_sponge_stark::ctl_looking_sha_compress_filter()), + ); + let sha_compress_looked = TableWithColumns::new( + Table::ShaCompress, + sha_compress_stark::ctl_data_inputs(), + Some(sha_compress_stark::ctl_filter_inputs()), + ); + CrossTableLookup::new(vec![sha_compress_sponge_looking], sha_compress_looked) +} + +fn ctl_sha_compress_outputs() -> CrossTableLookup { + let sha_compress_sponge_looking = TableWithColumns::new( + Table::ShaCompressSponge, + sha_compress_sponge_stark::ctl_looking_sha_compress_outputs(), + Some(sha_compress_sponge_stark::ctl_looking_sha_compress_filter()), + ); + let sha_compress_looked = TableWithColumns::new( + Table::ShaCompress, + sha_compress_stark::ctl_data_outputs(), + Some(sha_compress_stark::ctl_filter_outputs()), + ); + CrossTableLookup::new(vec![sha_compress_sponge_looking], sha_compress_looked) +} + +fn ctl_sha_compress_sponge() -> CrossTableLookup { + let cpu_looking = TableWithColumns::new( + Table::Cpu, + cpu_stark::ctl_data_sha_compress_sponge(), + Some(cpu_stark::ctl_filter_sha_compress_sponge()), + ); + let sha_compress_sponge_looked = TableWithColumns::new( + Table::ShaCompressSponge, + sha_compress_sponge_stark::ctl_looked_data(), + Some(sha_compress_sponge_stark::ctl_looked_filter()), + ); + CrossTableLookup::new(vec![cpu_looking], sha_compress_sponge_looked) +} + pub(crate) fn ctl_logic() -> CrossTableLookup { let cpu_looking = TableWithColumns::new( Table::Cpu, @@ -330,13 +390,21 @@ fn ctl_memory() -> CrossTableLookup { ) }); + let sha_compress_sponge_reads = (0..SHA_COMPRESS_SPONGE_READ_BITS).map(|i| { + TableWithColumns::new( + Table::ShaCompressSponge, + sha_compress_sponge_stark::ctl_looking_memory(i), + Some(sha_compress_sponge_stark::ctl_looking_sha_compress_filter()), + ) + }); + let all_lookers = [] .into_iter() .chain(cpu_memory_gp_ops) .chain(keccak_sponge_reads) .chain(poseidon_sponge_reads) .chain(sha_extend_sponge_reads) - + .chain(sha_compress_sponge_reads) .collect(); let memory_looked = TableWithColumns::new( Table::Memory, diff --git a/prover/src/cpu/columns/general.rs b/prover/src/cpu/columns/general.rs index d2d1e2b6..84e8faf3 100644 --- a/prover/src/cpu/columns/general.rs +++ b/prover/src/cpu/columns/general.rs @@ -12,6 +12,7 @@ pub(crate) union CpuGeneralColumnsView { io: CpuIOAuxView, hash: CpuHashView, khash: CpuKHashView, + shash: CpuSHashView, element: CpuElementView, misc: CpuMiscView, } @@ -47,6 +48,16 @@ impl CpuGeneralColumnsView { unsafe { &mut self.element } } + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn shash(&self) -> &CpuSHashView { + unsafe { &self.shash } + } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn shash_mut(&mut self) -> &mut CpuSHashView { + unsafe { &mut self.shash } + } + // SAFETY: Each view is a valid interpretation of the underlying array. pub(crate) fn syscall(&self) -> &CpuSyscallView { unsafe { &self.syscall } @@ -179,6 +190,11 @@ pub(crate) struct CpuKHashView { pub(crate) value: [T; 8], } +#[derive(Copy, Clone)] +pub(crate) struct CpuSHashView { + pub(crate) value: [T; 8], +} + #[derive(Copy, Clone)] pub(crate) struct CpuElementView { pub(crate) value: T, diff --git a/prover/src/cpu/columns/mod.rs b/prover/src/cpu/columns/mod.rs index 0909cb0e..45c5a8b4 100644 --- a/prover/src/cpu/columns/mod.rs +++ b/prover/src/cpu/columns/mod.rs @@ -104,6 +104,7 @@ pub struct CpuColumnsView { pub is_poseidon_sponge: T, pub is_keccak_sponge: T, pub is_sha_extend_sponge: T, + pub is_sha_compress_sponge: T, pub(crate) general: CpuGeneralColumnsView, diff --git a/prover/src/cpu/cpu_stark.rs b/prover/src/cpu/cpu_stark.rs index 00bdebb6..1440bdbf 100644 --- a/prover/src/cpu/cpu_stark.rs +++ b/prover/src/cpu/cpu_stark.rs @@ -44,7 +44,7 @@ pub fn ctl_data_keccak_sponge() -> Vec> { } pub fn ctl_data_sha_extend_sponge() -> Vec> { - // When executing KECCAK_GENERAL, the GP memory channels are used as follows: + // When executing SHA_EXTEND_GENERAL, the GP memory channels are used as follows: // GP channel 0: stack[-1] = context // GP channel 1: stack[-2] = segment // GP channel 2: stack[-3] = virt @@ -61,6 +61,25 @@ pub fn ctl_data_sha_extend_sponge() -> Vec> { cols } +pub fn ctl_data_sha_compress_sponge() -> Vec> { + // When executing SHA_COMPRESS_GENERAL, the GP memory channels are used as follows: + // GP channel 0: stack[-1] = context + // GP channel 1: stack[-2] = segment + // GP channel 2: stack[-3] = start virt + // GP channel 3: pushed = outputs + let context = Column::single(COL_MAP.mem_channels[0].value); + let segment = Column::single(COL_MAP.mem_channels[1].value); + let virt = Column::single(COL_MAP.mem_channels[2].value); + + let num_channels = F::from_canonical_usize(NUM_CHANNELS); + let timestamp = Column::linear_combination([(COL_MAP.clock, num_channels)]); + + let mut cols = vec![context, segment, virt, timestamp]; + // let mut cols = vec![context, segment, virt]; + cols.extend(COL_MAP.general.shash().value.map(Column::single)); + cols +} + pub fn ctl_filter_keccak_sponge() -> Filter { Filter::new_simple(Column::single(COL_MAP.is_keccak_sponge)) } @@ -69,6 +88,10 @@ pub fn ctl_filter_sha_extend_sponge() -> Filter { Filter::new_simple(Column::single(COL_MAP.is_sha_extend_sponge)) } +pub fn ctl_filter_sha_compress_sponge() -> Filter { + Filter::new_simple(Column::single(COL_MAP.is_sha_compress_sponge)) +} + pub fn ctl_data_poseidon_sponge() -> Vec> { // When executing POSEIDON_GENERAL, the GP memory channels are used as follows: // GP channel 0: stack[-1] = context diff --git a/prover/src/fixed_recursive_verifier.rs b/prover/src/fixed_recursive_verifier.rs index e19101ed..7182acc4 100644 --- a/prover/src/fixed_recursive_verifier.rs +++ b/prover/src/fixed_recursive_verifier.rs @@ -423,6 +423,22 @@ where stark_config, ); + let sha_compress = RecursiveCircuitsForTable::new( + Table::ShaCompress, + &all_stark.sha_compress_stark, + degree_bits_ranges[Table::ShaCompress as usize].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + + let sha_compress_sponge = RecursiveCircuitsForTable::new( + Table::ShaCompressSponge, + &all_stark.sha_compress_sponge_stark, + degree_bits_ranges[Table::ShaCompressSponge as usize].clone(), + &all_stark.cross_table_lookups, + stark_config, + ); + let logic = RecursiveCircuitsForTable::new( Table::Logic, &all_stark.logic_stark, @@ -447,6 +463,8 @@ where keccak_sponge, sha_extend, sha_extend_sponge, + sha_compress, + sha_compress_sponge, logic, memory, ]; @@ -743,7 +761,6 @@ where let (all_proof, output) = prove_with_outputs::(all_stark, kernel, config, timing)?; verify_proof(all_stark, all_proof.clone(), config).unwrap(); let mut root_inputs = PartialWitness::new(); - for table in 0..NUM_TABLES { let stark_proof = &all_proof.stark_proofs[table]; let original_degree_bits = stark_proof.proof.recover_degree_bits(config); diff --git a/prover/src/prover.rs b/prover/src/prover.rs index e3576a6d..fd03149e 100644 --- a/prover/src/prover.rs +++ b/prover/src/prover.rs @@ -362,6 +362,36 @@ where )? ); + let sha_compress_proof = timed!( + timing, + "prove SHA Compress STARK", + prove_single_table( + &all_stark.sha_compress_stark, + config, + &trace_poly_values[Table::ShaCompress as usize], + &trace_commitments[Table::ShaCompress as usize], + &ctl_data_per_table[Table::ShaCompress as usize], + ctl_challenges, + challenger, + timing, + )? + ); + + let sha_compress_sponge_proof = timed!( + timing, + "prove SHA Compress sponge STARK", + prove_single_table( + &all_stark.sha_compress_sponge_stark, + config, + &trace_poly_values[Table::ShaCompressSponge as usize], + &trace_commitments[Table::ShaCompressSponge as usize], + &ctl_data_per_table[Table::ShaCompressSponge as usize], + ctl_challenges, + challenger, + timing, + )? + ); + let logic_proof = timed!( timing, "prove Logic STARK", @@ -400,6 +430,8 @@ where keccak_sponge_proof, sha_extend_proof, sha_extend_sponge_proof, + sha_compress_proof, + sha_compress_sponge_proof, logic_proof, memory_proof, ]) diff --git a/prover/src/sha_compress/sha_compress_stark.rs b/prover/src/sha_compress/sha_compress_stark.rs index ad69b6aa..f0078e91 100644 --- a/prover/src/sha_compress/sha_compress_stark.rs +++ b/prover/src/sha_compress/sha_compress_stark.rs @@ -21,6 +21,41 @@ pub const NUM_ROUND_CONSTANTS: usize = 64; pub const NUM_INPUTS: usize = 10 * 32; // 8 states (a, b, ..., h) + w_i + key_i + +pub fn ctl_data_inputs() -> Vec> { + let cols = SHA_COMPRESS_COL_MAP; + let mut res: Vec<_> = Column::singles( + [ + cols.input_state.as_slice(), + cols.w_i.as_slice(), + cols.k_i.as_slice(), + ] + .concat(), + ) + .collect(); + res.push(Column::single(cols.timestamp)); + res +} + +pub fn ctl_data_outputs() -> Vec> { + let cols = SHA_COMPRESS_COL_MAP; + let mut res: Vec<_> = Column::singles(&cols.output_state).collect(); + res.push(Column::single(cols.timestamp)); + res +} + +pub fn ctl_filter_inputs() -> Filter { + let cols = SHA_COMPRESS_COL_MAP; + // not the padding rows. + Filter::new_simple(Column::single(cols.is_normal_round)) +} + +pub fn ctl_filter_outputs() -> Filter { + let cols = SHA_COMPRESS_COL_MAP; + // not the padding rows. + Filter::new_simple(Column::single(cols.is_normal_round)) +} + #[derive(Copy, Clone, Default)] pub struct ShaCompressStark { pub(crate) f: PhantomData, diff --git a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs index 7e1e5037..54465df7 100644 --- a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs +++ b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs @@ -23,6 +23,107 @@ use crate::witness::operation::SHA_COMPRESS_K; pub(crate) const NUM_ROUNDS: usize = 64; +pub(crate) const SHA_COMPRESS_SPONGE_READ_BITS: usize = 9 * 32; // h[0],...,h[7], w[i]. +pub(crate) fn ctl_looking_sha_compress_inputs() -> Vec> { + let cols = SHA_COMPRESS_SPONGE_COL_MAP; + let mut res: Vec<_> = Column::singles( + [ + cols.input_state.as_slice(), + cols.w_i.as_slice(), + cols.k_i.as_slice(), + ] + .concat(), + ) + .collect(); + res.push(Column::single(cols.timestamp)); + res +} + +pub(crate) fn ctl_looking_sha_compress_outputs() -> Vec> { + let cols = SHA_COMPRESS_SPONGE_COL_MAP; + + let mut res = vec![]; + res.extend(Column::singles(&cols.output_state)); + res.push(Column::single(cols.timestamp)); + res +} + +pub(crate) fn ctl_looked_data() -> Vec> { + let cols = SHA_COMPRESS_SPONGE_COL_MAP; + let mut outputs = Vec::with_capacity(8); + + for i in 0..8 { + let cur_col = Column::linear_combination( + cols.output_hx[get_input_range(i)] + .iter() + .enumerate() + .map(|(j, &c)| (c, F::from_canonical_u64(1 << (j)))), + ); + outputs.push(cur_col); + } + + Column::singles([ + cols.context, + cols.segment, + cols.hx_virt[0], + cols.timestamp, + ]) + .chain(outputs) + .collect() +} + +pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { + let cols = SHA_COMPRESS_SPONGE_COL_MAP; + + let mut res = vec![Column::constant(F::ONE)]; // is_read + + res.extend(Column::singles([cols.context, cols.segment])); + if i >= 256 { + res.push(Column::single(cols.w_virt)); + } else { + res.push(Column::single(cols.hx_virt[i / 32])); + } + + // The u32 of i'th input bit being read. + let start = i / 32; + let le_bit; + if start < 8 { + le_bit = cols.hx[get_input_range(start)].try_into().unwrap(); + } else { + le_bit = cols.w_i; + } + // le_bit.reverse(); + let u32_value: Column = Column::le_bits(&le_bit); + res.push(u32_value); + + res.push(Column::single(cols.timestamp)); + + assert_eq!( + res.len(), + crate::memory::memory_stark::ctl_data::().len() + ); + res +} + + +pub(crate) fn ctl_looking_sha_compress_filter() -> Filter { + let cols = SHA_COMPRESS_SPONGE_COL_MAP; + // not the padding rows. + Filter::new_simple(Column::sum( + &cols.round, + )) +} + +pub(crate) fn ctl_looked_filter() -> Filter { + // The CPU table is only interested in our final rows, since those contain the final + // compress sponge output. + let cols = SHA_COMPRESS_SPONGE_COL_MAP; + // the final row only. + Filter::new_simple(Column::single( + cols.round[63], + )) +} + #[derive(Clone, Debug)] pub(crate) struct ShaCompressSpongeOp { /// The base address at which inputs are read. diff --git a/prover/src/verifier.rs b/prover/src/verifier.rs index e4534be9..a34e8029 100644 --- a/prover/src/verifier.rs +++ b/prover/src/verifier.rs @@ -49,6 +49,8 @@ where keccak_sponge_stark, sha_extend_stark, sha_extend_sponge_stark, + sha_compress_stark, + sha_compress_sponge_stark, logic_stark, memory_stark, cross_table_lookups, @@ -126,7 +128,6 @@ where &ctl_challenges, config, )?; - log::info!("ShaExtend Stark proof verified"); verify_stark_proof_with_challenges( sha_extend_sponge_stark, @@ -136,8 +137,24 @@ where &ctl_challenges, config, )?; - log::info!("ShaExtendSponge Stark proof verified"); + verify_stark_proof_with_challenges( + sha_compress_stark, + &all_proof.stark_proofs[Table::ShaCompress as usize].proof, + &stark_challenges[Table::ShaCompress as usize], + &ctl_vars_per_table[Table::ShaCompress as usize], + &ctl_challenges, + config, + )?; + + verify_stark_proof_with_challenges( + sha_compress_sponge_stark, + &all_proof.stark_proofs[Table::ShaCompressSponge as usize].proof, + &stark_challenges[Table::ShaCompressSponge as usize], + &ctl_vars_per_table[Table::ShaCompressSponge as usize], + &ctl_challenges, + config, + )?; verify_stark_proof_with_challenges( logic_stark, @@ -147,7 +164,7 @@ where &ctl_challenges, config, )?; - log::info!("Logic Stark proof verified"); + verify_stark_proof_with_challenges( memory_stark, &all_proof.stark_proofs[Table::Memory as usize].proof, @@ -156,7 +173,6 @@ where &ctl_challenges, config, )?; - log::info!("Memory Stark proof verified"); verify_cross_table_lookups::( cross_table_lookups, diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index 966f6ebf..df4c2fb7 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -1257,12 +1257,8 @@ pub(crate) fn generate_sha_extend< cpu_row.general.element_mut().value = F::from_canonical_u32(w_i); sha_extend_sponge_log(state, input_addresses, input_value_bit_be, addr, i - 16); state.traces.push_cpu(cpu_row); - - } - - Ok(()) } @@ -1289,14 +1285,22 @@ pub(crate) fn generate_sha_compress< let mut hx = [0u32; 8]; let mut cpu_row = CpuColumnsView::default(); cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + + let mut hx_addresses = vec![]; + let mut hx_value_bit_be = vec![]; + let mut w_i_value_bit_be = vec![]; + let mut w_i_addresses = vec![]; + let mut state_values = vec![]; + for i in 0..8 { let addr = MemoryAddress::new(0, Segment::Code, h_ptr + i * 4); let (value, mem_op) = mem_read_gp_with_log_and_fill(i, addr, state, &mut cpu_row); state.traces.push_memory(mem_op); hx[i] = value; + hx_addresses.push(addr); + hx_value_bit_be.push(from_u32_to_be_bits(value)); } state.traces.push_cpu(cpu_row); - let mut original_w = Vec::new(); // Execute the "compress" phase. let mut a = hx[0]; let mut b = hx[1]; @@ -1310,6 +1314,9 @@ pub(crate) fn generate_sha_compress< cpu_row = CpuColumnsView::default(); cpu_row.clock = F::from_canonical_usize(state.traces.clock()); for i in 0..64 { + let input_state = [a, b, c, d, e, f, g, h].iter().map(|x| from_u32_to_be_bits(*x)).collect_vec(); + state_values.push(input_state); + let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); let ch = (e & f) ^ (!e & g); if j == 8 { @@ -1322,7 +1329,9 @@ pub(crate) fn generate_sha_compress< let (w_i, mem_op) = mem_read_gp_with_log_and_fill(j, addr, state, &mut cpu_row); state.traces.push_memory(mem_op); j += 1; - original_w.push(w_i); + w_i_value_bit_be.push(from_u32_to_be_bits(w_i)); + w_i_addresses.push(addr); + let temp1 = h .wrapping_add(s1) .wrapping_add(ch) @@ -1343,6 +1352,29 @@ pub(crate) fn generate_sha_compress< } state.traces.push_cpu(cpu_row); // Execute the "finalize" phase. + + let mut cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + cpu_row.is_sha_compress_sponge = F::ONE; + + cpu_row.mem_channels[0].value = F::ZERO; // context + cpu_row.mem_channels[1].value = F::from_canonical_usize(Segment::Code as usize); + cpu_row.mem_channels[2].value = F::from_canonical_usize(hx_addresses[0].virt); // start address of hx + + let u32_result: Vec = [a, b, c, d, e, f, g, h].iter().enumerate().map(|(i, x)| hx[i].wrapping_add(*x)).collect_vec(); + + cpu_row.general.shash_mut().value = u32_result.into_iter().map(F::from_canonical_u32).collect_vec().try_into().unwrap(); + // cpu_row.general.shash_mut().value.reverse(); + sha_compress_sponge_log( + state, + hx_value_bit_be, + hx_addresses, + w_i_value_bit_be, + w_i_addresses, + state_values + ); + state.traces.push_cpu(cpu_row); + let v = [a, b, c, d, e, f, g, h]; let mut cpu_row = CpuColumnsView::default(); cpu_row.clock = F::from_canonical_usize(state.traces.clock()); diff --git a/prover/src/witness/traces.rs b/prover/src/witness/traces.rs index 4b6245df..3cdab506 100644 --- a/prover/src/witness/traces.rs +++ b/prover/src/witness/traces.rs @@ -23,6 +23,8 @@ use crate::util::join; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryOp; use crate::{arithmetic, logic}; +use crate::sha_compress::sha_compress_stark; +use crate::sha_compress_sponge::sha_compress_sponge_stark::ShaCompressSpongeOp; use crate::sha_extend::sha_extend_stark; use crate::sha_extend_sponge::sha_extend_sponge_stark::ShaExtendSpongeOp; @@ -36,6 +38,8 @@ pub struct TraceCheckpoint { pub(self) keccak_sponge_len: usize, pub(self) sha_extend_len: usize, pub(self) sha_extend_sponge_len: usize, + pub(self) sha_compress_len: usize, + pub(self) sha_compress_sponge_len: usize, pub(self) logic_len: usize, pub(self) memory_len: usize, } @@ -52,6 +56,8 @@ pub(crate) struct Traces { pub(crate) keccak_sponge_ops: Vec, pub(crate) sha_extend_inputs: Vec<([u8; sha_extend_stark::NUM_INPUTS], usize)>, pub(crate) sha_extend_sponge_ops: Vec, + pub(crate) sha_compress_inputs: Vec<([u8; sha_compress_stark::NUM_INPUTS], usize)>, + pub(crate) sha_compress_sponge_ops: Vec, } impl Traces { @@ -67,6 +73,8 @@ impl Traces { keccak_sponge_ops: vec![], sha_extend_inputs: vec![], sha_extend_sponge_ops: vec![], + sha_compress_inputs: vec![], + sha_compress_sponge_ops: vec![], } } @@ -100,6 +108,9 @@ impl Traces { sha_extend_len: self.sha_extend_inputs.len(), sha_extend_sponge_len: self .sha_extend_sponge_ops.len(), + sha_compress_len: self.sha_compress_inputs.len(), + sha_compress_sponge_len: self + .sha_compress_sponge_ops.len(), logic_len: self.logic_ops.len(), // This is technically a lower-bound, as we may fill gaps, // but this gives a relatively good estimate. @@ -118,6 +129,8 @@ impl Traces { keccak_sponge_len: self.keccak_sponge_ops.len(), sha_extend_len: self.sha_extend_inputs.len(), sha_extend_sponge_len: self.sha_extend_sponge_ops.len(), + sha_compress_len: self.sha_compress_inputs.len(), + sha_compress_sponge_len: self.sha_compress_sponge_ops.len(), logic_len: self.logic_ops.len(), memory_len: self.memory_ops.len(), } @@ -135,6 +148,9 @@ impl Traces { self.sha_extend_inputs.truncate(checkpoint.sha_extend_len); self.sha_extend_sponge_ops .truncate(checkpoint.sha_extend_sponge_len); + self.sha_compress_inputs.truncate(checkpoint.sha_compress_len); + self.sha_compress_sponge_ops + .truncate(checkpoint.sha_compress_sponge_len); self.logic_ops.truncate(checkpoint.logic_len); self.memory_ops.truncate(checkpoint.memory_len); } @@ -193,6 +209,14 @@ impl Traces { self.sha_extend_sponge_ops.push(op); } + pub fn push_sha_compress(&mut self, input: [u8; sha_compress_stark::NUM_INPUTS], clock: usize) { + self.sha_compress_inputs.push((input, clock)); + } + + pub fn push_sha_compress_sponge(&mut self, op: ShaCompressSpongeOp) { + self.sha_compress_sponge_ops.push(op); + } + pub fn clock(&self) -> usize { self.cpu.len() } @@ -219,6 +243,8 @@ impl Traces { keccak_sponge_ops, sha_extend_inputs, sha_extend_sponge_ops, + sha_compress_inputs, + sha_compress_sponge_ops, } = self; let mut memory_trace = vec![]; @@ -231,6 +257,8 @@ impl Traces { let mut logic_trace = vec![]; let mut sha_extend_trace = vec![]; let mut sha_extend_sponge_trace = vec![]; + let mut sha_compress_trace = vec![]; + let mut sha_compress_sponge_trace = vec![]; timed!( timing, "convert trace to table parallelly", @@ -257,6 +285,12 @@ impl Traces { || sha_extend_sponge_trace = all_stark .sha_extend_sponge_stark .generate_trace(sha_extend_sponge_ops, min_rows), + || sha_compress_trace = all_stark + .sha_compress_stark + .generate_trace(sha_compress_inputs, min_rows), + || sha_compress_sponge_trace = all_stark + .sha_compress_sponge_stark + .generate_trace(sha_compress_sponge_ops, min_rows), || logic_trace = all_stark.logic_stark.generate_trace(logic_ops, min_rows), ) ); @@ -270,6 +304,8 @@ impl Traces { keccak_sponge_trace, sha_extend_trace, sha_extend_sponge_trace, + sha_compress_trace, + sha_compress_sponge_trace, logic_trace, memory_trace, ] diff --git a/prover/src/witness/util.rs b/prover/src/witness/util.rs index 8b924b3f..de4c3041 100644 --- a/prover/src/witness/util.rs +++ b/prover/src/witness/util.rs @@ -24,6 +24,8 @@ use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKin use plonky2::field::extension::Extendable; use plonky2::plonk::config::GenericConfig; use crate::sha_compress::logic::from_be_bits_to_u32; +use crate::sha_compress_sponge::constants::SHA_COMPRESS_K_BINARY; +use crate::sha_compress_sponge::sha_compress_sponge_stark::ShaCompressSpongeOp; use crate::sha_extend_sponge::sha_extend_sponge_stark::ShaExtendSpongeOp; fn to_byte_checked(n: u32) -> u8 { @@ -600,6 +602,76 @@ pub(crate) fn sha_extend_sponge_log< }); } +pub(crate) fn sha_compress_sponge_log < + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> ( + state: &mut GenerationState, + hx_values: Vec<[u8; 32]>, // BE bits + hx_addresses: Vec, + w_i_values: Vec<[u8; 32]>, // BE bits + w_i_addresses: Vec, + input_state_list: Vec>, // BE bits +) { + // Since the Sha compress reads bit by bit, and the memory unit is of 4-byte, we just need to read + // the same memory for 32 sha-extend ops + + let clock = state.traces.clock(); + let mut n_gp = 0; + + for i in 0..64 { + + // read hx as input + for (j, hx) in hx_values.iter().enumerate() { + let val = from_be_bits_to_u32(*hx); + for _ in 0..32 { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::GeneralPurpose(n_gp), + clock, + hx_addresses[j], + MemoryOpKind::Read, + val, + )); + n_gp += 1; + n_gp %= NUM_GP_CHANNELS - 1; + } + } + // read w_i as input + let w_i_u32 = from_be_bits_to_u32(w_i_values[i]); + for _ in 0..32 { + state.traces.push_memory(MemoryOp::new( + MemoryChannel::GeneralPurpose(n_gp), + clock, + w_i_addresses[i], + MemoryOpKind::Read, + w_i_u32, + )); + n_gp += 1; + n_gp %= NUM_GP_CHANNELS - 1; + } + + + let w_i = w_i_values[i]; + let k_i = SHA_COMPRESS_K_BINARY[i]; + let base_address = hx_addresses.clone().into_iter().chain([w_i_addresses[i]]).collect_vec(); + let compress_sponge_input: Vec = hx_values.iter().chain(&[w_i]).flatten().cloned().collect(); + let compress_input: Vec = input_state_list[i].iter().chain(&[w_i, k_i]).flatten().cloned().collect(); + let input_states: Vec = input_state_list[i].clone().iter().flatten().cloned().collect(); + + state.traces.push_sha_compress(compress_input.try_into().unwrap(), clock * NUM_CHANNELS); + + state.traces.push_sha_compress_sponge(ShaCompressSpongeOp { + base_address, + timestamp: clock * NUM_CHANNELS, + input_states, + i, + input: compress_sponge_input, + }); + + } +} + fn xor_into_sponge, C: GenericConfig, const D: usize>( state: &mut GenerationState, sponge_state: &mut [u8; KECCAK_WIDTH_BYTES], From c241bc51ed75f1e82285df66d60d0029bcae8814 Mon Sep 17 00:00:00 2001 From: vanhger Date: Mon, 3 Feb 2025 13:04:24 +0700 Subject: [PATCH 24/25] chore: refactor code --- emulator/src/state.rs | 62 +- prover/src/all_stark.rs | 11 +- prover/src/cpu/cpu_stark.rs | 5 +- prover/src/lib.rs | 8 +- prover/src/sha_compress/columns.rs | 47 +- prover/src/sha_compress/logic.rs | 40 +- prover/src/sha_compress/mod.rs | 2 +- prover/src/sha_compress/sha_compress_stark.rs | 751 ++++++++++++------ prover/src/sha_compress_sponge/columns.rs | 12 +- prover/src/sha_compress_sponge/constants.rs | 322 ++++++-- prover/src/sha_compress_sponge/mod.rs | 2 +- .../sha_compress_sponge_stark.rs | 497 +++++++----- prover/src/sha_extend/columns.rs | 6 +- prover/src/sha_extend/logic.rs | 70 +- prover/src/sha_extend/mod.rs | 2 +- prover/src/sha_extend/sha_extend_stark.rs | 319 +++++--- prover/src/sha_extend_sponge/columns.rs | 12 +- prover/src/sha_extend_sponge/logic.rs | 28 +- prover/src/sha_extend_sponge/mod.rs | 2 +- .../sha_extend_sponge_stark.rs | 318 +++++--- prover/src/witness/operation.rs | 39 +- prover/src/witness/traces.rs | 17 +- prover/src/witness/util.rs | 55 +- runtime/entrypoint/src/syscalls/mod.rs | 4 +- runtime/precompiles/src/io.rs | 18 +- 25 files changed, 1739 insertions(+), 910 deletions(-) diff --git a/emulator/src/state.rs b/emulator/src/state.rs index a5b5fbea..e155e1ac 100644 --- a/emulator/src/state.rs +++ b/emulator/src/state.rs @@ -541,44 +541,59 @@ impl InstrumentedState { log::debug!("syscall {} {} {} {}", syscall_num, a0, a1, a2); match syscall_num { - 0x300105 => { // SHA_EXTEND + 0x300105 => { + // SHA_EXTEND let w_ptr = a0; assert!(a1 == 0, "arg2 must be 0"); for i in 16..64 { // Read w[i-15]. - let w_i_minus_15 = self.state.memory.get_memory(w_ptr + (i - 15) * 4); + let w_i_minus_15 = self.state.memory.get_memory(w_ptr + (i - 15) * 4); // Compute `s0`. - let s0 = - w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); + let s0 = w_i_minus_15.rotate_right(7) + ^ w_i_minus_15.rotate_right(18) + ^ (w_i_minus_15 >> 3); // Read w[i-2]. - let w_i_minus_2 = self.state.memory.get_memory(w_ptr + (i - 2) * 4); + let w_i_minus_2 = self.state.memory.get_memory(w_ptr + (i - 2) * 4); // Compute `s1`. - let s1 = - w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); + let s1 = w_i_minus_2.rotate_right(17) + ^ w_i_minus_2.rotate_right(19) + ^ (w_i_minus_2 >> 10); // Read w[i-16]. - let w_i_minus_16 = self.state.memory.get_memory(w_ptr + (i - 16) * 4); + let w_i_minus_16 = self.state.memory.get_memory(w_ptr + (i - 16) * 4); // Read w[i-7]. - let w_i_minus_7 = self.state.memory.get_memory(w_ptr + (i - 7) * 4); + let w_i_minus_7 = self.state.memory.get_memory(w_ptr + (i - 7) * 4); // Compute `w_i`. - let w_i = s1.wrapping_add(w_i_minus_16).wrapping_add(s0).wrapping_add(w_i_minus_7); + let w_i = s1 + .wrapping_add(w_i_minus_16) + .wrapping_add(s0) + .wrapping_add(w_i_minus_7); // Write w[i]. - log::debug!("{:X}, {:X}, {:X} {:X} {:X} {:X}", s1, s0, w_i_minus_16, w_i_minus_7, w_i_minus_15, w_i_minus_2); + log::debug!( + "{:X}, {:X}, {:X} {:X} {:X} {:X}", + s1, + s0, + w_i_minus_16, + w_i_minus_7, + w_i_minus_15, + w_i_minus_2 + ); self.state.memory.set_memory(w_ptr + i * 4, w_i); log::debug!("extend write {:X} {:X}", w_ptr + i * 4, w_i); } - }, - 0x010106 => { // SHA_COMPRESS + } + 0x010106 => { + // SHA_COMPRESS let w_ptr = a0; let h_ptr = a1; let mut hx = [0u32; 8]; - for i in 0..8 { - hx[i] = self.state.memory.get_memory(h_ptr + i as u32 * 4); + for (i, hx_item) in hx.iter_mut().enumerate() { + *hx_item = self.state.memory.get_memory(h_ptr + i as u32 * 4); } let mut original_w = Vec::new(); @@ -617,11 +632,18 @@ impl InstrumentedState { // Execute the "finalize" phase. let v = [a, b, c, d, e, f, g, h]; for i in 0..8 { - self.state.memory.set_memory(h_ptr + i as u32 * 4, hx[i].wrapping_add(v[i])); - log::debug!("write {:X} {:X}", h_ptr + i as u32 * 4, hx[i].wrapping_add(v[i])); + self.state + .memory + .set_memory(h_ptr + i as u32 * 4, hx[i].wrapping_add(v[i])); + log::debug!( + "write {:X} {:X}", + h_ptr + i as u32 * 4, + hx[i].wrapping_add(v[i]) + ); } - }, - 0x010109 => { //keccak + } + 0x010109 => { + //keccak assert!((a0 & 3) == 0); assert!((a2 & 3) == 0); let bytes = (0..a1) @@ -664,7 +686,7 @@ impl InstrumentedState { log::debug!("input: {:?}", vec); assert_eq!(a0 % 4, 0, "hint read address not aligned to 4 bytes"); if a1 >= 1 { - self.state.cycle += (a1 as u64 + 31) / 32; + self.state.cycle += (a1 as u64).div_ceil(32); } for i in (0..a1).step_by(4) { // Get each byte in the chunk diff --git a/prover/src/all_stark.rs b/prover/src/all_stark.rs index b15fd893..dc867624 100644 --- a/prover/src/all_stark.rs +++ b/prover/src/all_stark.rs @@ -26,7 +26,9 @@ use crate::poseidon_sponge::poseidon_sponge_stark::PoseidonSpongeStark; use crate::sha_compress::sha_compress_stark; use crate::sha_compress::sha_compress_stark::ShaCompressStark; use crate::sha_compress_sponge::sha_compress_sponge_stark; -use crate::sha_compress_sponge::sha_compress_sponge_stark::{ShaCompressSpongeStark, SHA_COMPRESS_SPONGE_READ_BITS}; +use crate::sha_compress_sponge::sha_compress_sponge_stark::{ + ShaCompressSpongeStark, SHA_COMPRESS_SPONGE_READ_BITS, +}; use crate::sha_extend::sha_extend_stark; use crate::sha_extend::sha_extend_stark::ShaExtendStark; use crate::sha_extend_sponge::columns::SHA_EXTEND_SPONGE_READ_BITS; @@ -81,9 +83,11 @@ impl, const D: usize> AllStark { self.keccak_stark.num_lookup_helper_columns(config), self.keccak_sponge_stark.num_lookup_helper_columns(config), self.sha_extend_stark.num_lookup_helper_columns(config), - self.sha_extend_sponge_stark.num_lookup_helper_columns(config), + self.sha_extend_sponge_stark + .num_lookup_helper_columns(config), self.sha_compress_stark.num_lookup_helper_columns(config), - self.sha_compress_sponge_stark.num_lookup_helper_columns(config), + self.sha_compress_sponge_stark + .num_lookup_helper_columns(config), self.logic_stark.num_lookup_helper_columns(config), self.memory_stark.num_lookup_helper_columns(config), ] @@ -292,7 +296,6 @@ fn ctl_sha_extend_sponge() -> CrossTableLookup { CrossTableLookup::new(vec![cpu_looking], sha_extend_sponge_looked) } - fn ctl_sha_compress_inputs() -> CrossTableLookup { let sha_compress_sponge_looking = TableWithColumns::new( Table::ShaCompressSponge, diff --git a/prover/src/cpu/cpu_stark.rs b/prover/src/cpu/cpu_stark.rs index 1440bdbf..ab366982 100644 --- a/prover/src/cpu/cpu_stark.rs +++ b/prover/src/cpu/cpu_stark.rs @@ -1,5 +1,5 @@ use std::borrow::Borrow; -use std::iter::repeat; +use std::iter::repeat_n; use std::marker::PhantomData; use itertools::Itertools; @@ -13,7 +13,6 @@ use super::columns::CpuColumnsView; use crate::all_stark::Table; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{COL_MAP, NUM_CPU_COLUMNS}; -//use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cpu::{ bits, bootstrap_kernel, count, decode, jumps, membus, memio, misc, shift, syscall, }; @@ -212,7 +211,7 @@ pub fn ctl_data_code_memory() -> Vec> { cols.push(Column::le_bits(base)); // High limbs of the value are all zero. - cols.extend(repeat(Column::constant(F::ZERO)).take(VALUE_LIMBS - 1)); + cols.extend(repeat_n(Column::constant(F::ZERO), VALUE_LIMBS - 1)); cols.push(mem_time_and_channel(MEM_CODE_CHANNEL_IDX)); diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 9eea1ca6..868f1829 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -28,13 +28,13 @@ pub mod poseidon_sponge; pub mod proof; pub mod prover; pub mod recursive_verifier; +pub mod sha_compress; +pub mod sha_compress_sponge; +pub mod sha_extend; +pub mod sha_extend_sponge; pub mod stark; pub mod stark_testing; pub mod util; pub mod vanishing_poly; pub mod verifier; pub mod witness; -pub mod sha_extend; -pub mod sha_extend_sponge; -pub mod sha_compress; -pub mod sha_compress_sponge; diff --git a/prover/src/sha_compress/columns.rs b/prover/src/sha_compress/columns.rs index fc0eeb24..cdf114da 100644 --- a/prover/src/sha_compress/columns.rs +++ b/prover/src/sha_compress/columns.rs @@ -1,10 +1,8 @@ -use std::borrow::{Borrow, BorrowMut}; -use std::intrinsics::transmute; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; +use std::borrow::{Borrow, BorrowMut}; +use std::mem::transmute; #[derive(Clone)] pub(crate) struct ShaCompressColumnsView { - - /// input state: a,b,c,d,e,f,g,h in binary form pub input_state: [T; 256], /// Out @@ -20,37 +18,36 @@ pub(crate) struct ShaCompressColumnsView { pub s_1: [T; 32], pub e_and_f: [T; 32], pub not_e_and_g: [T; 32], - pub ch: [T;32], + pub ch: [T; 32], // h.wrapping_add(s1) - pub inter_1: [T;32], - pub carry_1: [T;32], + pub inter_1: [T; 32], + pub carry_1: [T; 32], // inter_1.wrapping_add(ch) - pub inter_2: [T;32], - pub carry_2: [T;32], + pub inter_2: [T; 32], + pub carry_2: [T; 32], // inter_2.wrapping_add(SHA_COMPRESS_K[i]) - pub inter_3: [T;32], - pub carry_3: [T;32], + pub inter_3: [T; 32], + pub carry_3: [T; 32], // inter_3.wrapping_add(w_i) - pub temp1: [T;32], - pub carry_4: [T;32], - - pub a_rr_2: [T;32], - pub a_rr_13: [T;32], - pub a_rr_22: [T;32], - pub s_0: [T;32], - pub a_and_b: [T;32], - pub a_and_c: [T;32], - pub b_and_c: [T;32], - pub maj: [T;32], - pub temp2: [T;32], - pub carry_5: [T;32], + pub temp1: [T; 32], + pub carry_4: [T; 32], + + pub a_rr_2: [T; 32], + pub a_rr_13: [T; 32], + pub a_rr_22: [T; 32], + pub s_0: [T; 32], + pub a_and_b: [T; 32], + pub a_and_c: [T; 32], + pub b_and_c: [T; 32], + pub maj: [T; 32], + pub temp2: [T; 32], + pub carry_5: [T; 32], pub carry_a: [T; 32], pub carry_e: [T; 32], /// The timestamp at which inputs should be read from memory. pub timestamp: T, pub is_normal_round: T, - } pub const NUM_SHA_COMPRESS_COLUMNS: usize = size_of::>(); diff --git a/prover/src/sha_compress/logic.rs b/prover/src/sha_compress/logic.rs index dbd7da1c..bccd962e 100644 --- a/prover/src/sha_compress/logic.rs +++ b/prover/src/sha_compress/logic.rs @@ -1,13 +1,13 @@ +use crate::keccak::logic::andn_gen_circuit; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; -use crate::keccak::logic::andn_gen_circuit; pub(crate) fn and_op, const D: usize, const N: usize>( x: [F; N], - y: [F; N] + y: [F; N], ) -> [F; N] { let mut result = [F::ZERO; N]; for i in 0..N { @@ -21,7 +21,7 @@ pub(crate) fn and_op, const D: usize, const N: usiz pub(crate) fn and_op_packed_constraints( x: [P; N], y: [P; N], - out: [P; N] + out: [P; N], ) -> Vec

{ let mut result = vec![]; for i in 0..N { @@ -31,15 +31,18 @@ pub(crate) fn and_op_packed_constraints( result } -pub(crate) fn and_op_ext_circuit_constraints, const D: usize, const N: usize>( +pub(crate) fn and_op_ext_circuit_constraints< + F: RichField + Extendable, + const D: usize, + const N: usize, +>( builder: &mut CircuitBuilder, x: [ExtensionTarget; N], y: [ExtensionTarget; N], - out: [ExtensionTarget; N] + out: [ExtensionTarget; N], ) -> Vec> { let mut result = vec![]; for i in 0..N { - let expected_out = builder.mul_extension(x[i], y[i]); let out_constraint = builder.sub_extension(expected_out, out[i]); result.push(out_constraint); @@ -49,7 +52,7 @@ pub(crate) fn and_op_ext_circuit_constraints, const pub(crate) fn andn_op, const D: usize, const N: usize>( x: [F; N], - y: [F; N] + y: [F; N], ) -> [F; N] { let mut result = [F::ZERO; N]; for i in 0..N { @@ -63,7 +66,7 @@ pub(crate) fn andn_op, const D: usize, const N: usi pub(crate) fn andn_op_packed_constraints( x: [P; N], y: [P; N], - out: [P; N] + out: [P; N], ) -> Vec

{ let mut result = vec![]; for i in 0..N { @@ -73,15 +76,18 @@ pub(crate) fn andn_op_packed_constraints( result } -pub(crate) fn andn_op_ext_circuit_constraints, const D: usize, const N: usize>( +pub(crate) fn andn_op_ext_circuit_constraints< + F: RichField + Extendable, + const D: usize, + const N: usize, +>( builder: &mut CircuitBuilder, x: [ExtensionTarget; N], y: [ExtensionTarget; N], - out: [ExtensionTarget; N] + out: [ExtensionTarget; N], ) -> Vec> { let mut result = vec![]; for i in 0..N { - let expected_out = andn_gen_circuit(builder, x[i], y[i]); let out_constraint = builder.sub_extension(expected_out, out[i]); result.push(out_constraint); @@ -91,7 +97,7 @@ pub(crate) fn andn_op_ext_circuit_constraints, cons pub(crate) fn xor_op, const D: usize, const N: usize>( x: [F; N], - y: [F; N] + y: [F; N], ) -> [F; N] { let mut result = [F::ZERO; N]; for i in 0..N { @@ -113,7 +119,11 @@ pub(crate) fn equal_packed_constraint( result } -pub(crate) fn equal_ext_circuit_constraints, const D: usize, const N: usize>( +pub(crate) fn equal_ext_circuit_constraints< + F: RichField + Extendable, + const D: usize, + const N: usize, +>( builder: &mut CircuitBuilder, x: [ExtensionTarget; N], y: [ExtensionTarget; N], @@ -126,10 +136,10 @@ pub(crate) fn equal_ext_circuit_constraints, const result } -pub(crate) fn from_be_bits_to_u32( bits: [u8; 32]) -> u32 { +pub(crate) fn from_be_bits_to_u32(bits: [u8; 32]) -> u32 { let mut result = 0; for i in 0..32 { result |= (bits[i] as u32) << i; } result -} \ No newline at end of file +} diff --git a/prover/src/sha_compress/mod.rs b/prover/src/sha_compress/mod.rs index f48755ac..e8b4ff9c 100644 --- a/prover/src/sha_compress/mod.rs +++ b/prover/src/sha_compress/mod.rs @@ -1,3 +1,3 @@ pub mod columns; +pub mod logic; pub mod sha_compress_stark; -pub mod logic; \ No newline at end of file diff --git a/prover/src/sha_compress/sha_compress_stark.rs b/prover/src/sha_compress/sha_compress_stark.rs index f0078e91..04bc968c 100644 --- a/prover/src/sha_compress/sha_compress_stark.rs +++ b/prover/src/sha_compress/sha_compress_stark.rs @@ -1,5 +1,22 @@ -use std::marker::PhantomData; -use std::borrow::Borrow; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cross_table_lookup::{Column, Filter}; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; +use crate::keccak::logic::{xor3_gen, xor3_gen_circuit, xor_gen, xor_gen_circuit}; +use crate::sha_compress::columns::{ + ShaCompressColumnsView, NUM_SHA_COMPRESS_COLUMNS, SHA_COMPRESS_COL_MAP, +}; +use crate::sha_compress::logic::{ + and_op, and_op_ext_circuit_constraints, and_op_packed_constraints, andn_op, + andn_op_ext_circuit_constraints, andn_op_packed_constraints, equal_ext_circuit_constraints, + equal_packed_constraint, xor_op, +}; +use crate::sha_extend::logic::{ + get_input_range, rotate_right, rotate_right_ext_circuit_constraint, + rotate_right_packed_constraints, wrapping_add, wrapping_add_ext_circuit_constraints, + wrapping_add_packed_constraints, xor3, +}; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; @@ -7,21 +24,13 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; -use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cross_table_lookup::{Column, Filter}; -use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; -use crate::keccak::logic::{xor3_gen, xor3_gen_circuit, xor_gen, xor_gen_circuit}; -use crate::sha_compress::columns::{ShaCompressColumnsView, NUM_SHA_COMPRESS_COLUMNS, SHA_COMPRESS_COL_MAP}; -use crate::sha_compress::logic::{and_op, and_op_ext_circuit_constraints, and_op_packed_constraints, andn_op, andn_op_ext_circuit_constraints, andn_op_packed_constraints, equal_ext_circuit_constraints, equal_packed_constraint, xor_op}; -use crate::sha_extend::logic::{rotate_right, get_input_range, xor3, wrapping_add, rotate_right_packed_constraints, wrapping_add_packed_constraints, rotate_right_ext_circuit_constraint, wrapping_add_ext_circuit_constraints}; -use crate::stark::Stark; -use crate::util::trace_rows_to_poly_values; +use std::borrow::Borrow; +use std::marker::PhantomData; pub const NUM_ROUND_CONSTANTS: usize = 64; pub const NUM_INPUTS: usize = 10 * 32; // 8 states (a, b, ..., h) + w_i + key_i - pub fn ctl_data_inputs() -> Vec> { let cols = SHA_COMPRESS_COL_MAP; let mut res: Vec<_> = Column::singles( @@ -30,9 +39,9 @@ pub fn ctl_data_inputs() -> Vec> { cols.w_i.as_slice(), cols.k_i.as_slice(), ] - .concat(), + .concat(), ) - .collect(); + .collect(); res.push(Column::single(cols.timestamp)); res } @@ -77,7 +86,8 @@ impl, const D: usize> ShaCompressStark { inputs_and_timestamps: Vec<([u8; NUM_INPUTS], usize)>, min_rows: usize, ) -> Vec<[F; NUM_SHA_COMPRESS_COLUMNS]> { - let num_rows = inputs_and_timestamps.len() + let num_rows = inputs_and_timestamps + .len() .max(min_rows) .next_power_of_two(); @@ -97,7 +107,6 @@ impl, const D: usize> ShaCompressStark { &self, input_and_timestamp: ([u8; NUM_INPUTS], usize), ) -> [F; NUM_SHA_COMPRESS_COLUMNS] { - let timestamp = input_and_timestamp.1; let inputs = input_and_timestamp.0; @@ -105,9 +114,24 @@ impl, const D: usize> ShaCompressStark { row.timestamp = F::from_canonical_usize(timestamp); row.is_normal_round = F::ONE; // read inputs - row.input_state = inputs[0..256].iter().map(|x| F::from_canonical_u8(*x)).collect::>().try_into().unwrap(); - row.w_i = inputs[256..288].iter().map(|x| F::from_canonical_u8(*x)).collect::>().try_into().unwrap(); - row.k_i = inputs[288..320].iter().map(|x| F::from_canonical_u8(*x)).collect::>().try_into().unwrap(); + row.input_state = inputs[0..256] + .iter() + .map(|x| F::from_canonical_u8(*x)) + .collect::>() + .try_into() + .unwrap(); + row.w_i = inputs[256..288] + .iter() + .map(|x| F::from_canonical_u8(*x)) + .collect::>() + .try_into() + .unwrap(); + row.k_i = inputs[288..320] + .iter() + .map(|x| F::from_canonical_u8(*x)) + .collect::>() + .try_into() + .unwrap(); // compute row.e_rr_6 = rotate_right(row.input_state[get_input_range(4)].try_into().unwrap(), 6); @@ -132,20 +156,11 @@ impl, const D: usize> ShaCompressStark { row.s_1, ); - (row.inter_2, row.carry_2) = wrapping_add( - row.inter_1, - row.ch, - ); + (row.inter_2, row.carry_2) = wrapping_add(row.inter_1, row.ch); - (row.inter_3, row.carry_3) = wrapping_add( - row.inter_2, - row.k_i, - ); + (row.inter_3, row.carry_3) = wrapping_add(row.inter_2, row.k_i); - (row.temp1, row.carry_4) = wrapping_add( - row.inter_3, - row.w_i, - ); + (row.temp1, row.carry_4) = wrapping_add(row.inter_3, row.w_i); row.a_rr_2 = rotate_right(row.input_state[get_input_range(0)].try_into().unwrap(), 2); row.a_rr_13 = rotate_right(row.input_state[get_input_range(0)].try_into().unwrap(), 13); @@ -168,11 +183,7 @@ impl, const D: usize> ShaCompressStark { ); row.maj = xor3(row.a_and_b, row.a_and_c, row.b_and_c); - (row.temp2, row.carry_5) = wrapping_add( - row.s_0, - row.maj, - ); - + (row.temp2, row.carry_5) = wrapping_add(row.s_0, row.maj); for i in 32..256 { row.output_state[i] = row.input_state[i - 32]; @@ -186,10 +197,7 @@ impl, const D: usize> ShaCompressStark { row.temp1, ); - (new_a, row.carry_a) = wrapping_add( - row.temp1, - row.temp2, - ); + (new_a, row.carry_a) = wrapping_add(row.temp1, row.temp2); for i in 0..32 { row.output_state[i] = new_a[i]; @@ -202,19 +210,19 @@ impl, const D: usize> ShaCompressStark { impl, const D: usize> Stark for ShaCompressStark { type EvaluationFrame - = StarkFrame + = StarkFrame where - FE: FieldExtension, - P: PackedField; + FE: FieldExtension, + P: PackedField; type EvaluationFrameTarget = StarkFrame, NUM_SHA_COMPRESS_COLUMNS>; fn eval_packed_generic( &self, vars: &Self::EvaluationFrame, - yield_constr: &mut ConstraintConsumer

+ yield_constr: &mut ConstraintConsumer

, ) where - FE: FieldExtension, - P: PackedField + FE: FieldExtension, + P: PackedField, { let local_values: &[P; NUM_SHA_COMPRESS_COLUMNS] = vars.get_local_values().try_into().unwrap(); @@ -222,7 +230,8 @@ impl, const D: usize> Stark for ShaCompressSt // check the input binary form for i in 0..256 { - yield_constr.constraint(local_values.input_state[i] * (local_values.input_state[i] - P::ONES)); + yield_constr + .constraint(local_values.input_state[i] * (local_values.input_state[i] - P::ONES)); } for i in 0..32 { yield_constr.constraint(local_values.w_i[i] * (local_values.w_i[i] - P::ONES)); @@ -231,68 +240,91 @@ impl, const D: usize> Stark for ShaCompressSt // check the bit values are zero or one in output for i in 0..256 { - yield_constr.constraint(local_values.output_state[i] * (local_values.output_state[i] - P::ONES)); + yield_constr.constraint( + local_values.output_state[i] * (local_values.output_state[i] - P::ONES), + ); } // check the rotation rotate_right_packed_constraints( - local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), local_values.e_rr_6, - 6 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 6, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); rotate_right_packed_constraints( - local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), local_values.e_rr_11, - 11 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 11, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); rotate_right_packed_constraints( - local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), local_values.e_rr_25, - 25 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 25, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); rotate_right_packed_constraints( - local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), local_values.a_rr_2, - 2 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 2, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); rotate_right_packed_constraints( - local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), local_values.a_rr_13, - 13 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 13, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); rotate_right_packed_constraints( - local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), local_values.a_rr_22, - 22 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 22, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); // check the xor for i in 0..32 { let s1 = xor3_gen( local_values.e_rr_6[i], local_values.e_rr_11[i], - local_values.e_rr_25[i] + local_values.e_rr_25[i], ); yield_constr.constraint(local_values.s_1[i] - s1); let s0 = xor3_gen( local_values.a_rr_2[i], local_values.a_rr_13[i], - local_values.a_rr_22[i] + local_values.a_rr_22[i], ); yield_constr.constraint(local_values.s_0[i] - s0); - let ch = xor_gen( - local_values.e_and_f[i], - local_values.not_e_and_g[i] - ); + let ch = xor_gen(local_values.e_and_f[i], local_values.not_e_and_g[i]); yield_constr.constraint(local_values.ch[i] - ch); let maj = xor3_gen( local_values.a_and_b[i], local_values.a_and_c[i], - local_values.b_and_c[i] + local_values.b_and_c[i], ); yield_constr.constraint(local_values.maj[i] - maj); } @@ -300,101 +332,170 @@ impl, const D: usize> Stark for ShaCompressSt // wrapping add constraints wrapping_add_packed_constraints( - local_values.input_state[get_input_range(7)].try_into().unwrap(), + local_values.input_state[get_input_range(7)] + .try_into() + .unwrap(), local_values.s_1, local_values.carry_1, - local_values.inter_1 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.inter_1, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); wrapping_add_packed_constraints( local_values.inter_1, local_values.ch, local_values.carry_2, - local_values.inter_2 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.inter_2, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); wrapping_add_packed_constraints( local_values.inter_2, local_values.k_i, local_values.carry_3, - local_values.inter_3 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.inter_3, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); wrapping_add_packed_constraints( local_values.inter_3, local_values.w_i, local_values.carry_4, - local_values.temp1 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.temp1, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); wrapping_add_packed_constraints( local_values.s_0, local_values.maj, local_values.carry_5, - local_values.temp2 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.temp2, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); wrapping_add_packed_constraints( - local_values.input_state[get_input_range(3)].try_into().unwrap(), + local_values.input_state[get_input_range(3)] + .try_into() + .unwrap(), local_values.temp1, local_values.carry_e, - local_values.output_state[get_input_range(4)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.output_state[get_input_range(4)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); wrapping_add_packed_constraints( local_values.temp1, local_values.temp2, local_values.carry_a, - local_values.output_state[get_input_range(0)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.output_state[get_input_range(0)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); // The op constraints and_op_packed_constraints( - local_values.input_state[get_input_range(4)].try_into().unwrap(), - local_values.input_state[get_input_range(5)].try_into().unwrap(), - local_values.e_and_f - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(5)] + .try_into() + .unwrap(), + local_values.e_and_f, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); and_op_packed_constraints( - local_values.input_state[get_input_range(0)].try_into().unwrap(), - local_values.input_state[get_input_range(1)].try_into().unwrap(), - local_values.a_and_b - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(1)] + .try_into() + .unwrap(), + local_values.a_and_b, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); and_op_packed_constraints( - local_values.input_state[get_input_range(0)].try_into().unwrap(), - local_values.input_state[get_input_range(2)].try_into().unwrap(), - local_values.a_and_c - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(2)] + .try_into() + .unwrap(), + local_values.a_and_c, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); and_op_packed_constraints( - local_values.input_state[get_input_range(1)].try_into().unwrap(), - local_values.input_state[get_input_range(2)].try_into().unwrap(), - local_values.b_and_c - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.input_state[get_input_range(1)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(2)] + .try_into() + .unwrap(), + local_values.b_and_c, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); andn_op_packed_constraints( - local_values.input_state[get_input_range(4)].try_into().unwrap(), - local_values.input_state[get_input_range(6)].try_into().unwrap(), - local_values.not_e_and_g - ).into_iter().for_each(|c| yield_constr.constraint(c)); - + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(6)] + .try_into() + .unwrap(), + local_values.not_e_and_g, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); // output constraint equal_packed_constraint::( - local_values.output_state[get_input_range(1)].try_into().unwrap(), - local_values.input_state[get_input_range(0)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.output_state[get_input_range(1)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); equal_packed_constraint::( - local_values.output_state[get_input_range(2)].try_into().unwrap(), - local_values.input_state[get_input_range(1)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.output_state[get_input_range(2)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(1)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); equal_packed_constraint::( - local_values.output_state[get_input_range(3)].try_into().unwrap(), - local_values.input_state[get_input_range(2)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.output_state[get_input_range(3)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(2)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); // equal_packed_constraint( // local_values.output_state[get_input_range(4)].try_into().unwrap(), @@ -402,26 +503,44 @@ impl, const D: usize> Stark for ShaCompressSt // ).into_iter().for_each(|c| yield_constr.constraint(c)); equal_packed_constraint::( - local_values.output_state[get_input_range(5)].try_into().unwrap(), - local_values.input_state[get_input_range(4)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.output_state[get_input_range(5)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); equal_packed_constraint::( - local_values.output_state[get_input_range(6)].try_into().unwrap(), - local_values.input_state[get_input_range(5)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.output_state[get_input_range(6)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(5)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); equal_packed_constraint::( - local_values.output_state[get_input_range(7)].try_into().unwrap(), - local_values.input_state[get_input_range(6)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.output_state[get_input_range(7)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(6)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); } fn eval_ext_circuit( &self, builder: &mut CircuitBuilder, vars: &Self::EvaluationFrameTarget, - yield_constr: &mut RecursiveConstraintConsumer + yield_constr: &mut RecursiveConstraintConsumer, ) { let local_values: &[ExtensionTarget; NUM_SHA_COMPRESS_COLUMNS] = vars.get_local_values().try_into().unwrap(); @@ -430,65 +549,100 @@ impl, const D: usize> Stark for ShaCompressSt // check the input binary form for i in 0..256 { let constraint = builder.mul_sub_extension( - local_values.input_state[i], local_values.input_state[i], local_values.input_state[i]); + local_values.input_state[i], + local_values.input_state[i], + local_values.input_state[i], + ); yield_constr.constraint(builder, constraint); - } for i in 0..32 { let constraint = builder.mul_sub_extension( - local_values.w_i[i], local_values.w_i[i], local_values.w_i[i]); + local_values.w_i[i], + local_values.w_i[i], + local_values.w_i[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.k_i[i], local_values.k_i[i], local_values.k_i[i]); + local_values.k_i[i], + local_values.k_i[i], + local_values.k_i[i], + ); yield_constr.constraint(builder, constraint); } // check the bit values are zero or one in output for i in 0..256 { let constraint = builder.mul_sub_extension( - local_values.output_state[i], local_values.output_state[i], local_values.output_state[i]); + local_values.output_state[i], + local_values.output_state[i], + local_values.output_state[i], + ); yield_constr.constraint(builder, constraint); } // check the rotation rotate_right_ext_circuit_constraint( builder, - local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), local_values.e_rr_6, - 6 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 6, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); rotate_right_ext_circuit_constraint( builder, - local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), local_values.e_rr_11, - 11 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 11, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); rotate_right_ext_circuit_constraint( builder, - local_values.input_state[get_input_range(4)].try_into().unwrap(), + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), local_values.e_rr_25, - 25 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 25, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); rotate_right_ext_circuit_constraint( builder, - local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), local_values.a_rr_2, - 2 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 2, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); rotate_right_ext_circuit_constraint( builder, - local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), local_values.a_rr_13, - 13 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 13, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); rotate_right_ext_circuit_constraint( builder, - local_values.input_state[get_input_range(0)].try_into().unwrap(), + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), local_values.a_rr_22, - 22 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 22, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); // check the xor for i in 0..32 { @@ -496,7 +650,7 @@ impl, const D: usize> Stark for ShaCompressSt builder, local_values.e_rr_6[i], local_values.e_rr_11[i], - local_values.e_rr_25[i] + local_values.e_rr_25[i], ); let constraint = builder.sub_extension(local_values.s_1[i], s1); yield_constr.constraint(builder, constraint); @@ -505,7 +659,7 @@ impl, const D: usize> Stark for ShaCompressSt builder, local_values.a_rr_2[i], local_values.a_rr_13[i], - local_values.a_rr_22[i] + local_values.a_rr_22[i], ); let constraint = builder.sub_extension(local_values.s_0[i], s0); yield_constr.constraint(builder, constraint); @@ -513,7 +667,7 @@ impl, const D: usize> Stark for ShaCompressSt let ch = xor_gen_circuit( builder, local_values.e_and_f[i], - local_values.not_e_and_g[i] + local_values.not_e_and_g[i], ); let constraint = builder.sub_extension(local_values.ch[i], ch); yield_constr.constraint(builder, constraint); @@ -522,7 +676,7 @@ impl, const D: usize> Stark for ShaCompressSt builder, local_values.a_and_b[i], local_values.a_and_c[i], - local_values.b_and_c[i] + local_values.b_and_c[i], ); let constraint = builder.sub_extension(local_values.maj[i], maj); yield_constr.constraint(builder, constraint); @@ -532,115 +686,184 @@ impl, const D: usize> Stark for ShaCompressSt wrapping_add_ext_circuit_constraints( builder, - local_values.input_state[get_input_range(7)].try_into().unwrap(), + local_values.input_state[get_input_range(7)] + .try_into() + .unwrap(), local_values.s_1, local_values.carry_1, - local_values.inter_1 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.inter_1, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); wrapping_add_ext_circuit_constraints( builder, local_values.inter_1, local_values.ch, local_values.carry_2, - local_values.inter_2 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.inter_2, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); wrapping_add_ext_circuit_constraints( builder, local_values.inter_2, local_values.k_i, local_values.carry_3, - local_values.inter_3 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.inter_3, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); wrapping_add_ext_circuit_constraints( builder, local_values.inter_3, local_values.w_i, local_values.carry_4, - local_values.temp1 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.temp1, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); wrapping_add_ext_circuit_constraints( builder, local_values.s_0, local_values.maj, local_values.carry_5, - local_values.temp2 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.temp2, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); wrapping_add_ext_circuit_constraints( builder, - local_values.input_state[get_input_range(3)].try_into().unwrap(), + local_values.input_state[get_input_range(3)] + .try_into() + .unwrap(), local_values.temp1, local_values.carry_e, - local_values.output_state[get_input_range(4)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.output_state[get_input_range(4)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); wrapping_add_ext_circuit_constraints( builder, local_values.temp1, local_values.temp2, local_values.carry_a, - local_values.output_state[get_input_range(0)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.output_state[get_input_range(0)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); // The op constraints and_op_ext_circuit_constraints( builder, - local_values.input_state[get_input_range(4)].try_into().unwrap(), - local_values.input_state[get_input_range(5)].try_into().unwrap(), - local_values.e_and_f - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(5)] + .try_into() + .unwrap(), + local_values.e_and_f, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); and_op_ext_circuit_constraints( builder, - local_values.input_state[get_input_range(0)].try_into().unwrap(), - local_values.input_state[get_input_range(1)].try_into().unwrap(), - local_values.a_and_b - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(1)] + .try_into() + .unwrap(), + local_values.a_and_b, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); and_op_ext_circuit_constraints( builder, - local_values.input_state[get_input_range(0)].try_into().unwrap(), - local_values.input_state[get_input_range(2)].try_into().unwrap(), - local_values.a_and_c - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(2)] + .try_into() + .unwrap(), + local_values.a_and_c, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); and_op_ext_circuit_constraints( builder, - local_values.input_state[get_input_range(1)].try_into().unwrap(), - local_values.input_state[get_input_range(2)].try_into().unwrap(), - local_values.b_and_c - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.input_state[get_input_range(1)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(2)] + .try_into() + .unwrap(), + local_values.b_and_c, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); andn_op_ext_circuit_constraints( builder, - local_values.input_state[get_input_range(4)].try_into().unwrap(), - local_values.input_state[get_input_range(6)].try_into().unwrap(), - local_values.not_e_and_g - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); - + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(6)] + .try_into() + .unwrap(), + local_values.not_e_and_g, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); // output constraint equal_ext_circuit_constraints::( builder, - local_values.output_state[get_input_range(1)].try_into().unwrap(), - local_values.input_state[get_input_range(0)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.output_state[get_input_range(1)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(0)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); equal_ext_circuit_constraints::( builder, - local_values.output_state[get_input_range(2)].try_into().unwrap(), - local_values.input_state[get_input_range(1)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.output_state[get_input_range(2)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(1)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); equal_ext_circuit_constraints::( builder, - local_values.output_state[get_input_range(3)].try_into().unwrap(), - local_values.input_state[get_input_range(2)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.output_state[get_input_range(3)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(2)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); // equal_packed_constraint( // local_values.output_state[get_input_range(4)].try_into().unwrap(), @@ -649,21 +872,39 @@ impl, const D: usize> Stark for ShaCompressSt equal_ext_circuit_constraints::( builder, - local_values.output_state[get_input_range(5)].try_into().unwrap(), - local_values.input_state[get_input_range(4)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.output_state[get_input_range(5)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(4)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); equal_ext_circuit_constraints::( builder, - local_values.output_state[get_input_range(6)].try_into().unwrap(), - local_values.input_state[get_input_range(5)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.output_state[get_input_range(6)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(5)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); equal_ext_circuit_constraints::( builder, - local_values.output_state[get_input_range(7)].try_into().unwrap(), - local_values.input_state[get_input_range(6)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.output_state[get_input_range(7)] + .try_into() + .unwrap(), + local_values.input_state[get_input_range(6)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); } fn constraint_degree(&self) -> usize { @@ -673,12 +914,18 @@ impl, const D: usize> Stark for ShaCompressSt #[cfg(test)] mod test { - use plonky2::field::goldilocks_field::GoldilocksField; + use crate::config::StarkConfig; + use crate::cross_table_lookup::{ + Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet, + }; + use crate::prover::prove_single_table; use crate::sha_compress::columns::ShaCompressColumnsView; use crate::sha_compress::sha_compress_stark::{ShaCompressStark, NUM_INPUTS}; + use crate::sha_compress_sponge::constants::SHA_COMPRESS_K; use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range}; - use std::borrow::Borrow; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; + use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::fri::oracle::PolynomialBatch; @@ -686,24 +933,22 @@ mod test { use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::timed; use plonky2::util::timing::TimingTree; - use crate::config::StarkConfig; - use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; - use crate::prover::prove_single_table; - use crate::sha_compress_sponge::constants::SHA_COMPRESS_K; - use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + use std::borrow::Borrow; + + const W: [u32; 64] = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200, + 3020350282, 1447362251, 3118632270, 4004188394, 690615167, 6070360, 1105370215, 2385558114, + 2348232513, 507799627, 2098764358, 5845374, 823657968, 2969863067, 3903496557, 4274682881, + 2059629362, 1849247231, 2656047431, 835162919, 2096647516, 2259195856, 1779072524, + 3152121987, 4210324067, 1557957044, 376930560, 982142628, 3926566666, 4164334963, + 789545383, 1028256580, 2867933222, 3843938318, 1135234440, 390334875, 2025924737, + 3318322046, 3436065867, 652746999, 4261492214, 2543173532, 3334668051, 3166416553, + 634956631, + ]; - const W: [u32; 64] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, - 67559435, 1711661200, 3020350282, 1447362251, 3118632270, 4004188394, 690615167, - 6070360, 1105370215, 2385558114, 2348232513, 507799627, 2098764358, 5845374, 823657968, - 2969863067, 3903496557, 4274682881, 2059629362, 1849247231, 2656047431, 835162919, - 2096647516, 2259195856, 1779072524, 3152121987, 4210324067, 1557957044, 376930560, - 982142628, 3926566666, 4164334963, 789545383, 1028256580, 2867933222, 3843938318, 1135234440, - 390334875, 2025924737, 3318322046, 3436065867, 652746999, 4261492214, 2543173532, 3334668051, - 3166416553, 634956631]; - - pub const H256_256: [u32;8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, - 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, + pub const H256_256: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, + 0x5be0cd19, ]; fn get_random_input() -> [u8; NUM_INPUTS] { @@ -716,8 +961,7 @@ mod test { } #[test] - fn test_generation() -> Result<(), String>{ - + fn test_generation() -> Result<(), String> { const D: usize = 2; type F = GoldilocksField; type S = ShaCompressStark; @@ -732,42 +976,65 @@ mod test { input.extend(from_u32_to_be_bits(w[0])); input.extend(from_u32_to_be_bits(SHA_COMPRESS_K[0])); - let stark = S::default(); let row = stark.generate_trace_rows_for_compress((input.try_into().unwrap(), 0)); let local_values: &ShaCompressColumnsView = row.borrow(); assert_eq!( local_values.output_state[get_input_range(0)], - from_u32_to_be_bits(4228417613).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(4228417613) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(1)], - from_u32_to_be_bits(1779033703).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(1779033703) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(2)], - from_u32_to_be_bits(3144134277).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(3144134277) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(3)], - from_u32_to_be_bits(1013904242).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(1013904242) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(4)], - from_u32_to_be_bits(2563236514).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(2563236514) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(5)], - from_u32_to_be_bits(1359893119).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(1359893119) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(6)], - from_u32_to_be_bits(2600822924).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(2600822924) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(7)], - from_u32_to_be_bits(528734635).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(528734635) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); Ok(()) } @@ -869,4 +1136,4 @@ mod test { fn init_logger() { let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); } -} \ No newline at end of file +} diff --git a/prover/src/sha_compress_sponge/columns.rs b/prover/src/sha_compress_sponge/columns.rs index 7485a6a9..539248fc 100644 --- a/prover/src/sha_compress_sponge/columns.rs +++ b/prover/src/sha_compress_sponge/columns.rs @@ -1,10 +1,9 @@ -use std::borrow::{Borrow, BorrowMut}; -use std::intrinsics::transmute; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; +use std::borrow::{Borrow, BorrowMut}; +use std::mem::transmute; pub(crate) struct ShaCompressSpongeColumnsView { - - pub hx: [T;256], + pub hx: [T; 256], pub input_state: [T; 256], pub output_state: [T; 256], pub output_hx: [T; 256], @@ -23,7 +22,6 @@ pub(crate) struct ShaCompressSpongeColumnsView { pub segment: T, } - pub const NUM_SHA_COMPRESS_SPONGE_COLUMNS: usize = size_of::>(); //1420 impl From<[T; NUM_SHA_COMPRESS_SPONGE_COLUMNS]> for ShaCompressSpongeColumnsView { @@ -71,7 +69,9 @@ impl Default for ShaCompressSpongeColumnsView { const fn make_col_map() -> ShaCompressSpongeColumnsView { let indices_arr = indices_arr::(); unsafe { - transmute::<[usize; NUM_SHA_COMPRESS_SPONGE_COLUMNS], ShaCompressSpongeColumnsView>(indices_arr) + transmute::<[usize; NUM_SHA_COMPRESS_SPONGE_COLUMNS], ShaCompressSpongeColumnsView>( + indices_arr, + ) } } diff --git a/prover/src/sha_compress_sponge/constants.rs b/prover/src/sha_compress_sponge/constants.rs index d50bde1f..74c9c1d9 100644 --- a/prover/src/sha_compress_sponge/constants.rs +++ b/prover/src/sha_compress_sponge/constants.rs @@ -13,68 +13,260 @@ pub const NUM_COMPRESS_ROWS: usize = 64; // big-endian form pub const SHA_COMPRESS_K_BINARY: [[u8; 32]; 64] = [ - [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0], - [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0], - [1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1], - [1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1], - [1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0], - [1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0], - [0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1], - [1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1], - [0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0], - [0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], - [1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], - [0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0], - [0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1], - [1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1], - [0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1], - [1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1], - [0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1], - [0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], - [0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], - [1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0], - [0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0], - [0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0], - [0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0], - [0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1], - [1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], - [0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1], - [1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1], - [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1], - [1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1], - [1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0], - [1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], - [1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0], - [0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0], - [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0], - [1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0], - [0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0], - [1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0], - [0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1], - [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1], - [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1], - [1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], - [0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1], - [1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1], - [1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1], - [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1], - [1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1], - [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0], - [0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0], - [0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0], - [1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0], - [1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0], - [0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0], - [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0], - [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0], - [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0], - [1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0], - [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1], - [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1], - [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1], - [1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1], - [1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1], - [0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1], -]; \ No newline at end of file + [ + 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, + 1, 0, + ], + [ + 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, + 1, 0, + ], + [ + 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, + 0, 1, + ], + [ + 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, + 1, 1, + ], + [ + 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, + 0, 0, + ], + [ + 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, + 1, 0, + ], + [ + 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, + 0, 1, + ], + [ + 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, + 0, 1, + ], + [ + 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, + 1, 1, + ], + [ + 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, + 0, 0, + ], + [ + 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, + 0, 0, + ], + [ + 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, + 1, 0, + ], + [ + 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, + 1, 0, + ], + [ + 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, + 0, 1, + ], + [ + 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, + 0, 1, + ], + [ + 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, + 1, 1, + ], + [ + 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, + 1, 1, + ], + [ + 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, + 1, 1, + ], + [ + 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, + 0, 0, + ], + [ + 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, + 0, 0, + ], + [ + 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, + 0, 0, + ], + [ + 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, + 1, 0, + ], + [ + 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, + 1, 0, + ], + [ + 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, + 1, 0, + ], + [ + 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, + ], + [ + 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, + 0, 1, + ], + [ + 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 0, 1, + ], + [ + 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, + 0, 1, + ], + [ + 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, + 1, 1, + ], + [ + 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, + 1, 1, + ], + [ + 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, + 0, 0, + ], + [ + 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, + 0, 0, + ], + [ + 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, + 0, 0, + ], + [ + 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, + 0, 0, + ], + [ + 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, + 1, 0, + ], + [ + 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, + 1, 0, + ], + [ + 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, + 1, 0, + ], + [ + 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, + 1, 0, + ], + [ + 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 1, + ], + [ + 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, + 0, 1, + ], + [ + 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, + 0, 1, + ], + [ + 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, + 0, 1, + ], + [ + 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, + 1, 1, + ], + [ + 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, + 1, 1, + ], + [ + 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, + 1, 1, + ], + [ + 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, + 1, 1, + ], + [ + 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, + 1, 1, + ], + [ + 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, + 0, 0, + ], + [ + 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, + 0, 0, + ], + [ + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, + 0, 0, + ], + [ + 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, + 0, 0, + ], + [ + 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, + 0, 0, + ], + [ + 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, + 0, 0, + ], + [ + 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, + 1, 0, + ], + [ + 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, + 1, 0, + ], + [ + 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, + 1, 0, + ], + [ + 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, + 1, 0, + ], + [ + 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, + 1, 0, + ], + [ + 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, + 0, 1, + ], + [ + 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, + 0, 1, + ], + [ + 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, + 0, 1, + ], + [ + 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, + 0, 1, + ], + [ + 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, + 0, 1, + ], + [ + 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, + 1, 1, + ], +]; diff --git a/prover/src/sha_compress_sponge/mod.rs b/prover/src/sha_compress_sponge/mod.rs index c47afc31..d1fb8826 100644 --- a/prover/src/sha_compress_sponge/mod.rs +++ b/prover/src/sha_compress_sponge/mod.rs @@ -1,3 +1,3 @@ pub mod columns; +pub mod constants; pub mod sha_compress_sponge_stark; -pub mod constants; \ No newline at end of file diff --git a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs index 54465df7..467b9941 100644 --- a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs +++ b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs @@ -1,29 +1,34 @@ -use std::marker::PhantomData; -use std::borrow::Borrow; -use itertools::Itertools; -use plonky2::field::extension::{Extendable, FieldExtension}; -use plonky2::field::packed::PackedField; -use plonky2::field::polynomial::PolynomialValues; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cross_table_lookup::{Column, Filter}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; use crate::sha_compress::logic::from_be_bits_to_u32; -use crate::sha_compress_sponge::columns::{ShaCompressSpongeColumnsView, NUM_SHA_COMPRESS_SPONGE_COLUMNS, SHA_COMPRESS_SPONGE_COL_MAP}; +use crate::sha_compress_sponge::columns::{ + ShaCompressSpongeColumnsView, NUM_SHA_COMPRESS_SPONGE_COLUMNS, SHA_COMPRESS_SPONGE_COL_MAP, +}; use crate::sha_compress_sponge::constants::{NUM_COMPRESS_ROWS, SHA_COMPRESS_K_BINARY}; -use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range, wrapping_add, wrapping_add_ext_circuit_constraints, wrapping_add_packed_constraints}; +use crate::sha_extend::logic::{ + from_u32_to_be_bits, get_input_range, wrapping_add, wrapping_add_ext_circuit_constraints, + wrapping_add_packed_constraints, +}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryAddress; use crate::witness::operation::SHA_COMPRESS_K; +use itertools::Itertools; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use std::borrow::Borrow; +use std::marker::PhantomData; pub(crate) const NUM_ROUNDS: usize = 64; -pub(crate) const SHA_COMPRESS_SPONGE_READ_BITS: usize = 9 * 32; // h[0],...,h[7], w[i]. +pub(crate) const SHA_COMPRESS_SPONGE_READ_BITS: usize = 9 * 32; // h[0],...,h[7], w[i]. pub(crate) fn ctl_looking_sha_compress_inputs() -> Vec> { let cols = SHA_COMPRESS_SPONGE_COL_MAP; let mut res: Vec<_> = Column::singles( @@ -32,9 +37,9 @@ pub(crate) fn ctl_looking_sha_compress_inputs() -> Vec> { cols.w_i.as_slice(), cols.k_i.as_slice(), ] - .concat(), + .concat(), ) - .collect(); + .collect(); res.push(Column::single(cols.timestamp)); res } @@ -62,12 +67,7 @@ pub(crate) fn ctl_looked_data() -> Vec> { outputs.push(cur_col); } - Column::singles([ - cols.context, - cols.segment, - cols.hx_virt[0], - cols.timestamp, - ]) + Column::singles([cols.context, cols.segment, cols.hx_virt[0], cols.timestamp]) .chain(outputs) .collect() } @@ -86,14 +86,13 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { // The u32 of i'th input bit being read. let start = i / 32; - let le_bit; - if start < 8 { - le_bit = cols.hx[get_input_range(start)].try_into().unwrap(); + let le_bit: [usize; 32] = if start < 8 { + cols.hx[get_input_range(start)].try_into().unwrap() } else { - le_bit = cols.w_i; - } + cols.w_i + }; // le_bit.reverse(); - let u32_value: Column = Column::le_bits(&le_bit); + let u32_value: Column = Column::le_bits(le_bit); res.push(u32_value); res.push(Column::single(cols.timestamp)); @@ -105,13 +104,10 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { res } - pub(crate) fn ctl_looking_sha_compress_filter() -> Filter { let cols = SHA_COMPRESS_SPONGE_COL_MAP; // not the padding rows. - Filter::new_simple(Column::sum( - &cols.round, - )) + Filter::new_simple(Column::sum(cols.round)) } pub(crate) fn ctl_looked_filter() -> Filter { @@ -119,9 +115,7 @@ pub(crate) fn ctl_looked_filter() -> Filter { // compress sponge output. let cols = SHA_COMPRESS_SPONGE_COL_MAP; // the final row only. - Filter::new_simple(Column::single( - cols.round[63], - )) + Filter::new_simple(Column::single(cols.round[63])) } #[derive(Clone, Debug)] @@ -180,31 +174,42 @@ impl, const D: usize> ShaCompressSpongeStark rows } - fn generate_rows_for_op( - &self, - op: ShaCompressSpongeOp, - ) -> ShaCompressSpongeColumnsView { + fn generate_rows_for_op(&self, op: ShaCompressSpongeOp) -> ShaCompressSpongeColumnsView { let mut row = ShaCompressSpongeColumnsView::default(); row.timestamp = F::from_canonical_usize(op.timestamp); row.context = F::from_canonical_usize(op.base_address[0].context); row.segment = F::from_canonical_usize(op.base_address[Segment::Code as usize].segment); - let hx_virt = (0..8) - .map(|i| op.base_address[i].virt) - .collect_vec(); + let hx_virt = (0..8).map(|i| op.base_address[i].virt).collect_vec(); let hx_virt: [usize; 8] = hx_virt.try_into().unwrap(); row.hx_virt = hx_virt.map(F::from_canonical_usize); - let w_virt = op.base_address[8].virt; + let w_virt = op.base_address[8].virt; row.w_virt = F::from_canonical_usize(w_virt); row.round = [F::ZEROS; 64]; row.round[op.i] = F::ONE; row.k_i = SHA_COMPRESS_K_BINARY[op.i].map(|k| F::from_canonical_u8(k)); - row.w_i = op.input[256..288].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - row.hx = op.input[..256].iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); - row.input_state = op.input_states.iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + row.w_i = op.input[256..288] + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); + row.hx = op.input[..256] + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); + row.input_state = op + .input_states + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); let output = self.compress(&op.input_states, &op.input[256..288], op.i); row.output_state = output.map(F::from_canonical_u8); @@ -213,10 +218,9 @@ impl, const D: usize> ShaCompressSpongeStark // The computation in other rounds are ensure the constraint degree // not to be exceeded 3. for i in 0..8 { - let (output_hx, carry) = wrapping_add::( row.hx[get_input_range(i)].try_into().unwrap(), - row.output_state[get_input_range(i)].try_into().unwrap() + row.output_state[get_input_range(i)].try_into().unwrap(), ); row.output_hx[get_input_range(i)].copy_from_slice(&output_hx[0..]); @@ -227,14 +231,23 @@ impl, const D: usize> ShaCompressSpongeStark } fn compress(&self, input_state: &[u8], w_i: &[u8], round: usize) -> [u8; 256] { - let values: Vec<[u8; 32]> = input_state.chunks(32).map(|chunk| chunk.try_into().unwrap()).collect(); - let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = values.into_iter().map( - |x| from_be_bits_to_u32(x) - ).collect::>().try_into().unwrap(); + let values: Vec<[u8; 32]> = input_state + .chunks(32) + .map(|chunk| chunk.try_into().unwrap()) + .collect(); + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = values + .into_iter() + .map(from_be_bits_to_u32) + .collect::>() + .try_into() + .unwrap(); let w_i = from_be_bits_to_u32(w_i.try_into().unwrap()); - let t1 = h.wrapping_add(e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25)) - .wrapping_add((e & f) ^ ((!e) & g)).wrapping_add(SHA_COMPRESS_K[round]).wrapping_add(w_i); + let t1 = h + .wrapping_add(e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25)) + .wrapping_add((e & f) ^ ((!e) & g)) + .wrapping_add(SHA_COMPRESS_K[round]) + .wrapping_add(w_i); let t2 = (a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22)) .wrapping_add((a & b) ^ (a & c) ^ (b & c)); h = g; @@ -262,22 +275,21 @@ impl, const D: usize> ShaCompressSpongeStark impl, const D: usize> Stark for ShaCompressSpongeStark { type EvaluationFrame - = StarkFrame + = StarkFrame where - FE: FieldExtension, - P: PackedField; + FE: FieldExtension, + P: PackedField; type EvaluationFrameTarget = StarkFrame, NUM_SHA_COMPRESS_SPONGE_COLUMNS>; fn eval_packed_generic( &self, vars: &Self::EvaluationFrame, - yield_constr: &mut ConstraintConsumer

+ yield_constr: &mut ConstraintConsumer

, ) where - FE: FieldExtension, - P: PackedField + FE: FieldExtension, + P: PackedField, { - let local_values: &[P; NUM_SHA_COMPRESS_SPONGE_COLUMNS] = vars.get_local_values().try_into().unwrap(); let local_values: &ShaCompressSpongeColumnsView

= local_values.borrow(); @@ -289,7 +301,8 @@ impl, const D: usize> Stark for ShaCompressSp // check the bit values are zero or one in input for i in 0..256 { yield_constr.constraint(local_values.hx[i] * (local_values.hx[i] - P::ONES)); - yield_constr.constraint(local_values.input_state[i] * (local_values.input_state[i] - P::ONES)); + yield_constr + .constraint(local_values.input_state[i] * (local_values.input_state[i] - P::ONES)); } for i in 0..32 { yield_constr.constraint(local_values.w_i[i] * (local_values.w_i[i] - P::ONES)); @@ -298,8 +311,11 @@ impl, const D: usize> Stark for ShaCompressSp // check the bit values are zero or one in output for i in 0..256 { - yield_constr.constraint(local_values.output_state[i] * (local_values.output_state[i] - P::ONES)); - yield_constr.constraint(local_values.output_hx[i] * (local_values.output_hx[i] - P::ONES)); + yield_constr.constraint( + local_values.output_state[i] * (local_values.output_state[i] - P::ONES), + ); + yield_constr + .constraint(local_values.output_hx[i] * (local_values.output_hx[i] - P::ONES)); yield_constr.constraint(local_values.carry[i] * (local_values.carry[i] - P::ONES)); } @@ -318,7 +334,6 @@ impl, const D: usize> Stark for ShaCompressSp .sum::

(); yield_constr.constraint(sum_round_flags * (sum_round_flags - P::ONES)); - // If this is not the final step or a padding row: // the local and next timestamps must match. @@ -336,21 +351,27 @@ impl, const D: usize> Stark for ShaCompressSp // the output state of local row must be the input state of next row for i in 0..256 { yield_constr.constraint( - sum_round_flags * not_final * (next_values.input_state[i] - local_values.output_state[i]) + sum_round_flags + * not_final + * (next_values.input_state[i] - local_values.output_state[i]), ); } // the address of w_i must be increased by 4 yield_constr.constraint( - sum_round_flags * not_final * (next_values.w_virt - local_values.w_virt - FE::from_canonical_u8(4)), + sum_round_flags + * not_final + * (next_values.w_virt - local_values.w_virt - FE::from_canonical_u8(4)), ); - // if not the padding row, the hx address must be a sequence of numbers spaced 4 units apart for i in 0..7 { yield_constr.constraint( - sum_round_flags * (local_values.hx_virt[i + 1] - local_values.hx_virt[i] - FE::from_canonical_u8(4)), + sum_round_flags + * (local_values.hx_virt[i + 1] + - local_values.hx_virt[i] + - FE::from_canonical_u8(4)), ); } @@ -359,7 +380,7 @@ impl, const D: usize> Stark for ShaCompressSp for i in 0..32 { let mut bit_i = P::ZEROS; for j in 0..64 { - bit_i = bit_i + local_values.round[j] * FE::from_canonical_u8(SHA_COMPRESS_K_BINARY[j][i]); + bit_i += local_values.round[j] * FE::from_canonical_u8(SHA_COMPRESS_K_BINARY[j][i]) } yield_constr.constraint(local_values.k_i[i] - bit_i); } @@ -367,23 +388,26 @@ impl, const D: usize> Stark for ShaCompressSp // wrapping add constraints for i in 0..8 { - wrapping_add_packed_constraints::( local_values.hx[get_input_range(i)].try_into().unwrap(), - local_values.output_state[get_input_range(i)].try_into().unwrap(), + local_values.output_state[get_input_range(i)] + .try_into() + .unwrap(), local_values.carry[get_input_range(i)].try_into().unwrap(), - local_values.output_hx[get_input_range(i)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(c)); - + local_values.output_hx[get_input_range(i)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); } - } fn eval_ext_circuit( &self, builder: &mut CircuitBuilder, vars: &Self::EvaluationFrameTarget, - yield_constr: &mut RecursiveConstraintConsumer + yield_constr: &mut RecursiveConstraintConsumer, ) { let local_values: &[ExtensionTarget; NUM_SHA_COMPRESS_SPONGE_COLUMNS] = vars.get_local_values().try_into().unwrap(); @@ -399,44 +423,66 @@ impl, const D: usize> Stark for ShaCompressSp // check the bit values are zero or one in input for i in 0..256 { let constraint = builder.mul_sub_extension( - local_values.hx[i], local_values.hx[i], local_values.hx[i]); + local_values.hx[i], + local_values.hx[i], + local_values.hx[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.input_state[i], local_values.input_state[i], local_values.input_state[i]); + local_values.input_state[i], + local_values.input_state[i], + local_values.input_state[i], + ); yield_constr.constraint(builder, constraint); } for i in 0..32 { let constraint = builder.mul_sub_extension( - local_values.w_i[i], local_values.w_i[i], local_values.w_i[i]); + local_values.w_i[i], + local_values.w_i[i], + local_values.w_i[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.k_i[i], local_values.k_i[i], local_values.k_i[i]); + local_values.k_i[i], + local_values.k_i[i], + local_values.k_i[i], + ); yield_constr.constraint(builder, constraint); - } // check the bit values are zero or one in output for i in 0..256 { - let constraint = builder.mul_sub_extension( - local_values.output_state[i], local_values.output_state[i], local_values.output_state[i]); + local_values.output_state[i], + local_values.output_state[i], + local_values.output_state[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.output_hx[i], local_values.output_hx[i], local_values.output_hx[i]); + local_values.output_hx[i], + local_values.output_hx[i], + local_values.output_hx[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.carry[i], local_values.carry[i], local_values.carry[i]); + local_values.carry[i], + local_values.carry[i], + local_values.carry[i], + ); yield_constr.constraint(builder, constraint); } // check the round for i in 0..NUM_ROUNDS { let constraint = builder.mul_sub_extension( - local_values.round[i], local_values.round[i], local_values.round[i]); + local_values.round[i], + local_values.round[i], + local_values.round[i], + ); yield_constr.constraint(builder, constraint); } @@ -449,12 +495,10 @@ impl, const D: usize> Stark for ShaCompressSp let sum_round_flags = builder.add_many_extension((0..NUM_COMPRESS_ROWS).map(|i| local_values.round[i])); - let constraint = builder.mul_sub_extension( - sum_round_flags, sum_round_flags, sum_round_flags - ); + let constraint = + builder.mul_sub_extension(sum_round_flags, sum_round_flags, sum_round_flags); yield_constr.constraint(builder, constraint); - // If this is not the final step or a padding row: // the local and next timestamps must match. @@ -472,7 +516,8 @@ impl, const D: usize> Stark for ShaCompressSp // the output state of local row must be the input state of next row for i in 0..256 { - let diff = builder.sub_extension(next_values.input_state[i], local_values.output_state[i]); + let diff = + builder.sub_extension(next_values.input_state[i], local_values.output_state[i]); let constraint = builder.mul_many_extension([sum_round_flags, not_final, diff]); yield_constr.constraint(builder, constraint); } @@ -480,31 +525,31 @@ impl, const D: usize> Stark for ShaCompressSp // the address of w_i must be increased by 4 let increment = builder.sub_extension(next_values.w_virt, local_values.w_virt); let address_increment = builder.sub_extension(increment, four_ext); - let constraint = builder.mul_many_extension( - [sum_round_flags, not_final, address_increment] - ); + let constraint = + builder.mul_many_extension([sum_round_flags, not_final, address_increment]); yield_constr.constraint(builder, constraint); - // if not the padding row, the hx address must be a sequence of numbers spaced 4 units apart for i in 0..7 { - let increment = builder.sub_extension(local_values.hx_virt[i + 1], local_values.hx_virt[i]); + let increment = + builder.sub_extension(local_values.hx_virt[i + 1], local_values.hx_virt[i]); let address_increment = builder.sub_extension(increment, four_ext); - let constraint = builder.mul_extension( - sum_round_flags, address_increment - ); + let constraint = builder.mul_extension(sum_round_flags, address_increment); yield_constr.constraint(builder, constraint); } // check the validation of key[i] for i in 0..32 { - - let bit_i_comp: Vec<_> = (0..64).map(|j| { - let k_j_i = builder.constant_extension(F::Extension::from_canonical_u8(SHA_COMPRESS_K_BINARY[j][i])); - builder.mul_extension(local_values.round[j], k_j_i) - }).collect(); + let bit_i_comp: Vec<_> = (0..64) + .map(|j| { + let k_j_i = builder.constant_extension(F::Extension::from_canonical_u8( + SHA_COMPRESS_K_BINARY[j][i], + )); + builder.mul_extension(local_values.round[j], k_j_i) + }) + .collect(); let bit_i = builder.add_many_extension(bit_i_comp); let constraint = builder.sub_extension(local_values.k_i[i], bit_i); yield_constr.constraint(builder, constraint); @@ -516,11 +561,16 @@ impl, const D: usize> Stark for ShaCompressSp wrapping_add_ext_circuit_constraints::( builder, local_values.hx[get_input_range(i)].try_into().unwrap(), - local_values.output_state[get_input_range(i)].try_into().unwrap(), + local_values.output_state[get_input_range(i)] + .try_into() + .unwrap(), local_values.carry[get_input_range(i)].try_into().unwrap(), - local_values.output_hx[get_input_range(i)].try_into().unwrap(), - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); - + local_values.output_hx[get_input_range(i)] + .try_into() + .unwrap(), + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); } } @@ -529,41 +579,45 @@ impl, const D: usize> Stark for ShaCompressSp } } - #[cfg(test)] mod test { - use plonky2::field::goldilocks_field::GoldilocksField; - use plonky2::field::types::{Field}; - use std::borrow::Borrow; + use crate::config::StarkConfig; + use crate::cross_table_lookup::{ + Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet, + }; + use crate::prover::prove_single_table; + use crate::sha_compress_sponge::columns::ShaCompressSpongeColumnsView; + use crate::sha_compress_sponge::sha_compress_sponge_stark::{ + ShaCompressSpongeOp, ShaCompressSpongeStark, + }; + use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range}; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + use crate::witness::memory::MemoryAddress; use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; + use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::polynomial::PolynomialValues; + use plonky2::field::types::Field; use plonky2::fri::oracle::PolynomialBatch; use plonky2::iop::challenger::Challenger; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::timed; use plonky2::util::timing::TimingTree; - use crate::config::StarkConfig; - use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; - use crate::prover::prove_single_table; - use crate::sha_compress_sponge::columns::ShaCompressSpongeColumnsView; - use crate::sha_compress_sponge::sha_compress_sponge_stark::{ShaCompressSpongeOp, ShaCompressSpongeStark}; - use crate::sha_extend::logic::{from_u32_to_be_bits, get_input_range}; - use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; - use crate::witness::memory::MemoryAddress; - + use std::borrow::Borrow; - const W: [u32; 64] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, - 67559435, 1711661200, 3020350282, 1447362251, 3118632270, 4004188394, 690615167, - 6070360, 1105370215, 2385558114, 2348232513, 507799627, 2098764358, 5845374, 823657968, - 2969863067, 3903496557, 4274682881, 2059629362, 1849247231, 2656047431, 835162919, - 2096647516, 2259195856, 1779072524, 3152121987, 4210324067, 1557957044, 376930560, - 982142628, 3926566666, 4164334963, 789545383, 1028256580, 2867933222, 3843938318, 1135234440, - 390334875, 2025924737, 3318322046, 3436065867, 652746999, 4261492214, 2543173532, 3334668051, - 3166416553, 634956631]; + const W: [u32; 64] = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200, + 3020350282, 1447362251, 3118632270, 4004188394, 690615167, 6070360, 1105370215, 2385558114, + 2348232513, 507799627, 2098764358, 5845374, 823657968, 2969863067, 3903496557, 4274682881, + 2059629362, 1849247231, 2656047431, 835162919, 2096647516, 2259195856, 1779072524, + 3152121987, 4210324067, 1557957044, 376930560, 982142628, 3926566666, 4164334963, + 789545383, 1028256580, 2867933222, 3843938318, 1135234440, 390334875, 2025924737, + 3318322046, 3436065867, 652746999, 4261492214, 2543173532, 3334668051, 3166416553, + 634956631, + ]; - pub const H256_256: [u32;8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, - 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, + pub const H256_256: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, + 0x5be0cd19, ]; #[test] fn test_generation() -> Result<(), String> { @@ -573,26 +627,38 @@ mod test { type S = ShaCompressSpongeStark; let stark = S::default(); - let hx_addresses: Vec = (0..32).step_by(4).map(|i| { - MemoryAddress { + let hx_addresses: Vec = (0..32) + .step_by(4) + .map(|i| MemoryAddress { context: 0, segment: 0, virt: i, - } - }).collect(); + }) + .collect(); - let w_addresses: Vec = (32..288).step_by(4).map(|i| { - MemoryAddress { + let w_addresses: Vec = (32..288) + .step_by(4) + .map(|i| MemoryAddress { context: 0, segment: 0, virt: i, - } - }).collect(); - let mut input = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + }) + .collect(); + let mut input = H256_256 + .iter() + .flat_map(|x| from_u32_to_be_bits(*x)) + .collect::>(); input.extend(from_u32_to_be_bits(W[0])); - let input_state = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + let input_state = H256_256 + .iter() + .flat_map(|x| from_u32_to_be_bits(*x)) + .collect::>(); let op = ShaCompressSpongeOp { - base_address: hx_addresses.iter().chain([w_addresses[0]].iter()).cloned().collect(), + base_address: hx_addresses + .iter() + .chain([w_addresses[0]].iter()) + .cloned() + .collect(), i: 0, timestamp: 0, input_states: input_state, @@ -603,42 +669,76 @@ mod test { assert_eq!( local_values.output_state[get_input_range(0)], - from_u32_to_be_bits(4228417613).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(4228417613) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(1)], - from_u32_to_be_bits(1779033703).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(1779033703) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(2)], - from_u32_to_be_bits(3144134277).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(3144134277) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(3)], - from_u32_to_be_bits(1013904242).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(1013904242) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(4)], - from_u32_to_be_bits(2563236514).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(2563236514) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(5)], - from_u32_to_be_bits(1359893119).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(1359893119) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(6)], - from_u32_to_be_bits(2600822924).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(2600822924) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_state[get_input_range(7)], - from_u32_to_be_bits(528734635).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(528734635) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); - let mut input = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + let mut input = H256_256 + .iter() + .flat_map(|x| from_u32_to_be_bits(*x)) + .collect::>(); input.extend(from_u32_to_be_bits(W[63])); - let input_state = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + let input_state = H256_256 + .iter() + .flat_map(|x| from_u32_to_be_bits(*x)) + .collect::>(); let op = ShaCompressSpongeOp { - base_address: hx_addresses.iter().chain([w_addresses[0]].iter()).cloned().collect(), + base_address: hx_addresses + .iter() + .chain([w_addresses[0]].iter()) + .cloned() + .collect(), i: 63, timestamp: 0, input_states: input_state, @@ -647,43 +747,65 @@ mod test { let row = stark.generate_rows_for_op(op); let local_values: &ShaCompressSpongeColumnsView = row.borrow(); - assert_eq!( local_values.output_hx[get_input_range(0)], - from_u32_to_be_bits(H256_256[0].wrapping_add(2781379838 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(H256_256[0].wrapping_add(2781379838_u32)) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_hx[get_input_range(1)], - from_u32_to_be_bits(H256_256[1].wrapping_add(1779033703 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(H256_256[1].wrapping_add(1779033703_u32)) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_hx[get_input_range(2)], - from_u32_to_be_bits(H256_256[2].wrapping_add(3144134277 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(H256_256[2].wrapping_add(3144134277_u32)) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_hx[get_input_range(3)], - from_u32_to_be_bits(H256_256[3].wrapping_add(1013904242 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(H256_256[3].wrapping_add(1013904242_u32)) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_hx[get_input_range(4)], - from_u32_to_be_bits(H256_256[4].wrapping_add(1116198739 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(H256_256[4].wrapping_add(1116198739_u32)) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_hx[get_input_range(5)], - from_u32_to_be_bits(H256_256[5].wrapping_add(1359893119 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(H256_256[5].wrapping_add(1359893119_u32)) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_hx[get_input_range(6)], - from_u32_to_be_bits(H256_256[6].wrapping_add(2600822924 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(H256_256[6].wrapping_add(2600822924_u32)) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); assert_eq!( local_values.output_hx[get_input_range(7)], - from_u32_to_be_bits(H256_256[7].wrapping_add(528734635 as u32)).iter().map(|&x| F::from_canonical_u8(x)).collect::>() + from_u32_to_be_bits(H256_256[7].wrapping_add(528734635_u32)) + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() ); Ok(()) } - #[test] fn test_stark_circuit() -> anyhow::Result<()> { const D: usize = 2; @@ -708,43 +830,54 @@ mod test { test_stark_low_degree(stark) } - fn get_random_input() -> Vec { - const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; type S = ShaCompressSpongeStark; let stark = S::default(); - let hx_addresses: Vec = (0..32).step_by(4).map(|i| { - MemoryAddress { + let hx_addresses: Vec = (0..32) + .step_by(4) + .map(|i| MemoryAddress { context: 0, segment: 0, virt: i, - } - }).collect(); + }) + .collect(); - let w_addresses: Vec = (32..288).step_by(4).map(|i| { - MemoryAddress { + let w_addresses: Vec = (32..288) + .step_by(4) + .map(|i| MemoryAddress { context: 0, segment: 0, virt: i, - } - }).collect(); + }) + .collect(); let mut res = vec![]; - let mut output_state = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + let mut output_state = H256_256 + .iter() + .flat_map(|x| from_u32_to_be_bits(*x)) + .collect::>(); for i in 0..64 { - - let mut input = H256_256.iter().map(|x| from_u32_to_be_bits(*x)).flatten().collect::>(); + let mut input = H256_256 + .iter() + .flat_map(|x| from_u32_to_be_bits(*x)) + .collect::>(); input.extend(from_u32_to_be_bits(W[i])); let input_state = output_state.clone(); - output_state = stark.compress(&input_state, &from_u32_to_be_bits(W[i]), i).to_vec(); + output_state = stark + .compress(&input_state, &from_u32_to_be_bits(W[i]), i) + .to_vec(); let op = ShaCompressSpongeOp { - base_address: hx_addresses.iter().chain([w_addresses[i]].iter()).cloned().collect(), - i: i, + base_address: hx_addresses + .iter() + .chain([w_addresses[i]].iter()) + .cloned() + .collect(), + i, timestamp: 0, input_states: input_state, input, @@ -754,7 +887,6 @@ mod test { } res - } #[test] fn sha_extend_sponge_benchmark() -> anyhow::Result<()> { @@ -824,5 +956,4 @@ mod test { fn init_logger() { let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); } - -} \ No newline at end of file +} diff --git a/prover/src/sha_extend/columns.rs b/prover/src/sha_extend/columns.rs index 4d3982f6..4dd8e5e4 100644 --- a/prover/src/sha_extend/columns.rs +++ b/prover/src/sha_extend/columns.rs @@ -1,9 +1,8 @@ -use std::borrow::{Borrow, BorrowMut}; -use std::intrinsics::transmute; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; +use std::borrow::{Borrow, BorrowMut}; +use std::mem::transmute; pub(crate) struct ShaExtendColumnsView { - /// Input in big-endian order pub w_i_minus_15: [T; 32], pub w_i_minus_2: [T; 32], @@ -83,4 +82,3 @@ const fn make_col_map() -> ShaExtendColumnsView { } pub(crate) const SHA_EXTEND_COL_MAP: ShaExtendColumnsView = make_col_map(); - diff --git a/prover/src/sha_extend/logic.rs b/prover/src/sha_extend/logic.rs index eab8f703..2932cf4a 100644 --- a/prover/src/sha_extend/logic.rs +++ b/prover/src/sha_extend/logic.rs @@ -5,14 +5,15 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; - pub(crate) fn get_input_range(i: usize) -> std::ops::Range { - (0 + i * 32)..(32 + i * 32) + (i * 32)..(32 + i * 32) } - // these operators are applied in big-endian form -pub(crate) fn rotate_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { +pub(crate) fn rotate_right, const D: usize>( + value: [F; 32], + amount: usize, +) -> [F; 32] { let mut result = [F::ZERO; 32]; for i in 0..32 { result[i] = value[(i + amount) % 32]; @@ -22,7 +23,7 @@ pub(crate) fn rotate_right, const D: usize>(value: pub(crate) fn rotate_right_packed_constraints( value: [P; 32], - rotated_value: [P;32], + rotated_value: [P; 32], amount: usize, ) -> Vec

{ let mut result = Vec::new(); @@ -34,9 +35,9 @@ pub(crate) fn rotate_right_packed_constraints( pub(crate) fn rotate_right_ext_circuit_constraint, const D: usize>( builder: &mut CircuitBuilder, - value: [ExtensionTarget;32], + value: [ExtensionTarget; 32], rotated_value: [ExtensionTarget; 32], - amount: usize + amount: usize, ) -> Vec> { let mut result = Vec::new(); for i in 0..32 { @@ -45,19 +46,20 @@ pub(crate) fn rotate_right_ext_circuit_constraint, result } -pub(crate) fn shift_right, const D: usize>(value: [F; 32], amount: usize) -> [F; 32] { +pub(crate) fn shift_right, const D: usize>( + value: [F; 32], + amount: usize, +) -> [F; 32] { let mut result = [F::ZERO; 32]; if amount < 32 { - for i in 0..32 - amount { - result[i] = value[i + amount]; - } + result[..(32 - amount)].copy_from_slice(&value[amount..((32 - amount) + amount)]); } result } pub(crate) fn shift_right_packed_constraints( value: [P; 32], - shifted_value: [P;32], + shifted_value: [P; 32], amount: usize, ) -> Vec

{ let mut result = Vec::new(); @@ -72,9 +74,9 @@ pub(crate) fn shift_right_packed_constraints( pub(crate) fn shift_right_ext_circuit_constraints, const D: usize>( builder: &mut CircuitBuilder, - value: [ExtensionTarget;32], + value: [ExtensionTarget; 32], shifted_value: [ExtensionTarget; 32], - amount: usize + amount: usize, ) -> Vec> { let mut result = Vec::new(); for i in 0..32 - amount { @@ -86,7 +88,11 @@ pub(crate) fn shift_right_ext_circuit_constraints, result } -pub(crate) fn xor3 , const D: usize, const N: usize>(a: [F; N], b: [F; N], c: [F; N]) -> [F; N] { +pub(crate) fn xor3, const D: usize, const N: usize>( + a: [F; N], + b: [F; N], + c: [F; N], +) -> [F; N] { let mut result = [F::ZERO; N]; for i in 0..N { result[i] = crate::keccak::logic::xor([a[i], b[i], c[i]]); @@ -96,7 +102,7 @@ pub(crate) fn xor3 , const D: usize, const N: usize pub(crate) fn wrapping_add, const D: usize, const N: usize>( a: [F; N], - b: [F; N] + b: [F; N], ) -> ([F; N], [F; N]) { let mut result = [F::ZERO; N]; let mut carries = [F::ZERO; N]; @@ -115,7 +121,9 @@ pub(crate) fn wrapping_add, const D: usize, const N (result, carries) } -pub(crate) fn from_be_fbits_to_u32, const D: usize>(value: [F; 32]) -> u32 { +pub(crate) fn from_be_fbits_to_u32, const D: usize>( + value: [F; 32], +) -> u32 { let mut result = 0; for i in 0..32 { debug_assert!(value[i].is_zero() || value[i].is_one()); @@ -137,9 +145,8 @@ pub(crate) fn wrapping_add_packed_constraints( x: [P; N], y: [P; N], carry: [P; N], - out: [P; N] + out: [P; N], ) -> Vec

{ - let mut result = vec![]; let mut pre_carry = P::ZEROS; for i in 0..N { @@ -156,16 +163,19 @@ pub(crate) fn wrapping_add_packed_constraints( result } -pub(crate) fn wrapping_add_ext_circuit_constraints, const D: usize, const N: usize>( +pub(crate) fn wrapping_add_ext_circuit_constraints< + F: RichField + Extendable, + const D: usize, + const N: usize, +>( builder: &mut CircuitBuilder, x: [ExtensionTarget; N], y: [ExtensionTarget; N], carry: [ExtensionTarget; N], - out: [ExtensionTarget; N] + out: [ExtensionTarget; N], ) -> Vec> { - let mut result = vec![]; - let mut pre_carry= builder.zero_extension(); + let mut pre_carry = builder.zero_extension(); let one_ext = builder.one_extension(); let two_ext = builder.two_extension(); let three_ext = builder.constant_extension(F::Extension::from_canonical_u8(3)); @@ -174,23 +184,17 @@ pub(crate) fn wrapping_add_ext_circuit_constraints, let inner_1 = builder.sub_extension(sum, one_ext); let inner_2 = builder.sub_extension(sum, three_ext); - let tmp1 = builder.mul_many_extension( - [inner_1, inner_2, out[i]] - ); + let tmp1 = builder.mul_many_extension([inner_1, inner_2, out[i]]); let inner_1 = builder.sub_extension(sum, two_ext); let inner_2 = builder.sub_extension(out[i], one_ext); - let tmp2 = builder.mul_many_extension( - [sum, inner_1, inner_2] - ); + let tmp2 = builder.mul_many_extension([sum, inner_1, inner_2]); result.push(builder.add_extension(tmp1, tmp2)); - let tmp3 = builder.add_many_extension( - [carry[i], carry[i], out[i]] - ); + let tmp3 = builder.add_many_extension([carry[i], carry[i], out[i]]); result.push(builder.sub_extension(tmp3, sum)); pre_carry = carry[i]; } result -} \ No newline at end of file +} diff --git a/prover/src/sha_extend/mod.rs b/prover/src/sha_extend/mod.rs index 8fb26d6f..a7cb89b8 100644 --- a/prover/src/sha_extend/mod.rs +++ b/prover/src/sha_extend/mod.rs @@ -1,3 +1,3 @@ pub mod columns; +pub mod logic; pub mod sha_extend_stark; -pub mod logic; \ No newline at end of file diff --git a/prover/src/sha_extend/sha_extend_stark.rs b/prover/src/sha_extend/sha_extend_stark.rs index 2165ee7b..2d25599c 100644 --- a/prover/src/sha_extend/sha_extend_stark.rs +++ b/prover/src/sha_extend/sha_extend_stark.rs @@ -1,5 +1,18 @@ -use std::borrow::Borrow; -use std::marker::PhantomData; +use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; +use crate::cross_table_lookup::{Column, Filter}; +use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; +use crate::keccak::logic::{xor3_gen, xor3_gen_circuit}; +use crate::sha_extend::columns::{ + ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS, SHA_EXTEND_COL_MAP, +}; +use crate::sha_extend::logic::{ + get_input_range, rotate_right, rotate_right_ext_circuit_constraint, + rotate_right_packed_constraints, shift_right, shift_right_ext_circuit_constraints, + shift_right_packed_constraints, wrapping_add, wrapping_add_ext_circuit_constraints, + wrapping_add_packed_constraints, xor3, +}; +use crate::stark::Stark; +use crate::util::trace_rows_to_poly_values; use plonky2::field::extension::{Extendable, FieldExtension}; use plonky2::field::packed::PackedField; use plonky2::field::polynomial::PolynomialValues; @@ -7,14 +20,8 @@ use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; -use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; -use crate::cross_table_lookup::{Column, Filter}; -use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; -use crate::keccak::logic::{xor3_gen, xor3_gen_circuit}; -use crate::sha_extend::columns::{ShaExtendColumnsView, NUM_SHA_EXTEND_COLUMNS, SHA_EXTEND_COL_MAP}; -use crate::sha_extend::logic::{get_input_range, rotate_right, rotate_right_ext_circuit_constraint, rotate_right_packed_constraints, shift_right, shift_right_ext_circuit_constraints, shift_right_packed_constraints, wrapping_add, wrapping_add_ext_circuit_constraints, wrapping_add_packed_constraints, xor3}; -use crate::stark::Stark; -use crate::util::trace_rows_to_poly_values; +use std::borrow::Borrow; +use std::marker::PhantomData; pub const NUM_INPUTS: usize = 4 * 32; // w_i_minus_15, w_i_minus_2, w_i_minus_16, w_i_minus_7 @@ -27,9 +34,9 @@ pub fn ctl_data_inputs() -> Vec> { cols.w_i_minus_16.as_slice(), cols.w_i_minus_7.as_slice(), ] - .concat(), + .concat(), ) - .collect(); + .collect(); res.push(Column::single(cols.timestamp)); res } @@ -52,7 +59,6 @@ pub fn ctl_filter_outputs() -> Filter { Filter::new_simple(Column::single(cols.is_normal_round)) } - #[derive(Copy, Clone, Default)] pub struct ShaExtendStark { pub(crate) f: PhantomData, @@ -74,8 +80,10 @@ impl, const D: usize> ShaExtendStark { inputs_and_timestamps: Vec<([u8; NUM_INPUTS], usize)>, min_rows: usize, ) -> Vec<[F; NUM_SHA_EXTEND_COLUMNS]> { - let num_rows = inputs_and_timestamps.len() - .max(min_rows).next_power_of_two(); + let num_rows = inputs_and_timestamps + .len() + .max(min_rows) + .next_power_of_two(); let mut rows = Vec::with_capacity(num_rows); for input_and_timestamp in inputs_and_timestamps.iter() { @@ -99,13 +107,29 @@ impl, const D: usize> ShaExtendStark { row.timestamp = F::from_canonical_usize(input_and_timestamp.1); row.w_i_minus_15 = input_and_timestamp.0[get_input_range(0)] - .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); row.w_i_minus_2 = input_and_timestamp.0[get_input_range(1)] - .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); row.w_i_minus_16 = input_and_timestamp.0[get_input_range(2)] - .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); row.w_i_minus_7 = input_and_timestamp.0[get_input_range(3)] - .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); row.is_normal_round = F::ONE; self.generate_trace_row_for_round(&mut row); row @@ -117,14 +141,22 @@ impl, const D: usize> ShaExtendStark { row.w_i_minus_15_rs_3 = shift_right(row.w_i_minus_15, 3); // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) - row.s_0 = xor3(row.w_i_minus_15_rr_7, row.w_i_minus_15_rr_18, row.w_i_minus_15_rs_3); + row.s_0 = xor3( + row.w_i_minus_15_rr_7, + row.w_i_minus_15_rr_18, + row.w_i_minus_15_rs_3, + ); row.w_i_minus_2_rr_17 = rotate_right(row.w_i_minus_2, 17); row.w_i_minus_2_rr_19 = rotate_right(row.w_i_minus_2, 19); row.w_i_minus_2_rs_10 = shift_right(row.w_i_minus_2, 10); // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) - row.s_1 = xor3(row.w_i_minus_2_rr_17, row.w_i_minus_2_rr_19, row.w_i_minus_2_rs_10); + row.s_1 = xor3( + row.w_i_minus_2_rr_17, + row.w_i_minus_2_rr_19, + row.w_i_minus_2_rs_10, + ); // (w_i_inter_0, carry) = w[i-7] + s1. (row.w_i_inter_0, row.carry_0) = wrapping_add(row.w_i_minus_7, row.s_1); @@ -134,23 +166,22 @@ impl, const D: usize> ShaExtendStark { } } - impl, const D: usize> Stark for ShaExtendStark { type EvaluationFrame - = StarkFrame + = StarkFrame where - FE: FieldExtension, - P: PackedField; + FE: FieldExtension, + P: PackedField; - type EvaluationFrameTarget = StarkFrame, NUM_SHA_EXTEND_COLUMNS>; + type EvaluationFrameTarget = StarkFrame, NUM_SHA_EXTEND_COLUMNS>; fn eval_packed_generic( &self, vars: &Self::EvaluationFrame, - yield_constr: &mut ConstraintConsumer

+ yield_constr: &mut ConstraintConsumer

, ) where - FE: FieldExtension, - P: PackedField + FE: FieldExtension, + P: PackedField, { let local_values: &[P; NUM_SHA_EXTEND_COLUMNS] = vars.get_local_values().try_into().unwrap(); @@ -158,16 +189,24 @@ impl, const D: usize> Stark for ShaExtendStar // check the bit values are zero or one in input for i in 0..32 { - yield_constr.constraint(local_values.w_i_minus_15[i] * (local_values.w_i_minus_15[i] - P::ONES)); - yield_constr.constraint(local_values.w_i_minus_2[i] * (local_values.w_i_minus_2[i] - P::ONES)); - yield_constr.constraint(local_values.w_i_minus_16[i] * (local_values.w_i_minus_16[i] - P::ONES)); - yield_constr.constraint(local_values.w_i_minus_7[i] * (local_values.w_i_minus_7[i] - P::ONES)); + yield_constr.constraint( + local_values.w_i_minus_15[i] * (local_values.w_i_minus_15[i] - P::ONES), + ); + yield_constr + .constraint(local_values.w_i_minus_2[i] * (local_values.w_i_minus_2[i] - P::ONES)); + yield_constr.constraint( + local_values.w_i_minus_16[i] * (local_values.w_i_minus_16[i] - P::ONES), + ); + yield_constr + .constraint(local_values.w_i_minus_7[i] * (local_values.w_i_minus_7[i] - P::ONES)); } // check the bit values are zero or one in intermediate values for i in 0..32 { - yield_constr.constraint(local_values.w_i_inter_0[i] * (local_values.w_i_inter_0[i] - P::ONES)); - yield_constr.constraint(local_values.w_i_inter_1[i] * (local_values.w_i_inter_1[i] - P::ONES)); + yield_constr + .constraint(local_values.w_i_inter_0[i] * (local_values.w_i_inter_0[i] - P::ONES)); + yield_constr + .constraint(local_values.w_i_inter_1[i] * (local_values.w_i_inter_1[i] - P::ONES)); yield_constr.constraint(local_values.carry_0[i] * (local_values.carry_0[i] - P::ONES)); yield_constr.constraint(local_values.carry_1[i] * (local_values.carry_1[i] - P::ONES)); yield_constr.constraint(local_values.carry_2[i] * (local_values.carry_2[i] - P::ONES)); @@ -182,51 +221,62 @@ impl, const D: usize> Stark for ShaExtendStar rotate_right_packed_constraints( local_values.w_i_minus_15, local_values.w_i_minus_15_rr_7, - 7 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 7, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); rotate_right_packed_constraints( local_values.w_i_minus_15, local_values.w_i_minus_15_rr_18, - 18 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 18, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); rotate_right_packed_constraints( local_values.w_i_minus_2, local_values.w_i_minus_2_rr_17, - 17 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 17, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); rotate_right_packed_constraints( local_values.w_i_minus_2, local_values.w_i_minus_2_rr_19, - 19 - ).into_iter().for_each(|c| yield_constr.constraint(c)); - + 19, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); // check the shift shift_right_packed_constraints( local_values.w_i_minus_15, local_values.w_i_minus_15_rs_3, - 3 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + 3, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); shift_right_packed_constraints( local_values.w_i_minus_2, local_values.w_i_minus_2_rs_10, - 10 - ).into_iter().for_each(|c| yield_constr.constraint(c)); - + 10, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); // check the computation of s0 and s1 for i in 0..32 { - let s0 = xor3_gen(local_values.w_i_minus_15_rr_7[i], - local_values.w_i_minus_15_rr_18[i], - local_values.w_i_minus_15_rs_3[i] + let s0 = xor3_gen( + local_values.w_i_minus_15_rr_7[i], + local_values.w_i_minus_15_rr_18[i], + local_values.w_i_minus_15_rs_3[i], ); yield_constr.constraint(local_values.s_0[i] - s0); let s1 = xor3_gen( local_values.w_i_minus_2_rr_17[i], local_values.w_i_minus_2_rr_19[i], - local_values.w_i_minus_2_rs_10[i] + local_values.w_i_minus_2_rs_10[i], ); yield_constr.constraint(local_values.s_1[i] - s1); } @@ -236,33 +286,38 @@ impl, const D: usize> Stark for ShaExtendStar local_values.w_i_minus_7, local_values.s_1, local_values.carry_0, - local_values.w_i_inter_0 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.w_i_inter_0, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); // check the computation of w_i_inter_1 = w_i_inter_0 + s0 wrapping_add_packed_constraints( local_values.w_i_inter_0, local_values.s_0, local_values.carry_1, - local_values.w_i_inter_1 - ).into_iter().for_each(|c| yield_constr.constraint(c)); + local_values.w_i_inter_1, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); // check the computation of w_i = w_i_inter_1 + w_i_minus_16 wrapping_add_packed_constraints( local_values.w_i_inter_1, local_values.w_i_minus_16, local_values.carry_2, - local_values.w_i - ).into_iter().for_each(|c| yield_constr.constraint(c)); - + local_values.w_i, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(c)); } fn eval_ext_circuit( &self, builder: &mut CircuitBuilder, vars: &Self::EvaluationFrameTarget, - yield_constr: &mut RecursiveConstraintConsumer) { - + yield_constr: &mut RecursiveConstraintConsumer, + ) { let local_values: &[ExtensionTarget; NUM_SHA_EXTEND_COLUMNS] = vars.get_local_values().try_into().unwrap(); let local_values: &ShaExtendColumnsView> = local_values.borrow(); @@ -270,49 +325,79 @@ impl, const D: usize> Stark for ShaExtendStar // check the bit values are zero or one in input for i in 0..32 { let constraint = builder.mul_sub_extension( - local_values.w_i_minus_15[i], local_values.w_i_minus_15[i], local_values.w_i_minus_15[i]); + local_values.w_i_minus_15[i], + local_values.w_i_minus_15[i], + local_values.w_i_minus_15[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.w_i_minus_2[i], local_values.w_i_minus_2[i], local_values.w_i_minus_2[i]); + local_values.w_i_minus_2[i], + local_values.w_i_minus_2[i], + local_values.w_i_minus_2[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.w_i_minus_16[i], local_values.w_i_minus_16[i], local_values.w_i_minus_16[i]); + local_values.w_i_minus_16[i], + local_values.w_i_minus_16[i], + local_values.w_i_minus_16[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.w_i_minus_7[i], local_values.w_i_minus_7[i], local_values.w_i_minus_7[i]); + local_values.w_i_minus_7[i], + local_values.w_i_minus_7[i], + local_values.w_i_minus_7[i], + ); yield_constr.constraint(builder, constraint); } // check the bit values are zero or one in intermediate values for i in 0..32 { let constraint = builder.mul_sub_extension( - local_values.w_i_inter_0[i], local_values.w_i_inter_0[i], local_values.w_i_inter_0[i]); + local_values.w_i_inter_0[i], + local_values.w_i_inter_0[i], + local_values.w_i_inter_0[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.w_i_inter_1[i], local_values.w_i_inter_1[i], local_values.w_i_inter_1[i]); + local_values.w_i_inter_1[i], + local_values.w_i_inter_1[i], + local_values.w_i_inter_1[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.carry_0[i], local_values.carry_0[i], local_values.carry_0[i]); + local_values.carry_0[i], + local_values.carry_0[i], + local_values.carry_0[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.carry_1[i], local_values.carry_1[i], local_values.carry_1[i]); + local_values.carry_1[i], + local_values.carry_1[i], + local_values.carry_1[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.carry_2[i], local_values.carry_2[i], local_values.carry_2[i]); + local_values.carry_2[i], + local_values.carry_2[i], + local_values.carry_2[i], + ); yield_constr.constraint(builder, constraint); } // check the bit values are zero or one in output for i in 0..32 { let constraint = builder.mul_sub_extension( - local_values.w_i[i], local_values.w_i[i], local_values.w_i[i]); + local_values.w_i[i], + local_values.w_i[i], + local_values.w_i[i], + ); yield_constr.constraint(builder, constraint); } @@ -321,41 +406,53 @@ impl, const D: usize> Stark for ShaExtendStar builder, local_values.w_i_minus_15, local_values.w_i_minus_15_rr_7, - 7 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 7, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); rotate_right_ext_circuit_constraint( builder, local_values.w_i_minus_15, local_values.w_i_minus_15_rr_18, - 18 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 18, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); rotate_right_ext_circuit_constraint( builder, local_values.w_i_minus_2, local_values.w_i_minus_2_rr_17, - 17 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 17, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); rotate_right_ext_circuit_constraint( builder, local_values.w_i_minus_2, local_values.w_i_minus_2_rr_19, - 19 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 19, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); // check the shift shift_right_ext_circuit_constraints( builder, local_values.w_i_minus_15, local_values.w_i_minus_15_rs_3, - 3 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 3, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); shift_right_ext_circuit_constraints( builder, local_values.w_i_minus_2, local_values.w_i_minus_2_rs_10, - 10 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + 10, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); // check the computation of s0 and s1 for i in 0..32 { @@ -363,7 +460,7 @@ impl, const D: usize> Stark for ShaExtendStar builder, local_values.w_i_minus_15_rr_7[i], local_values.w_i_minus_15_rr_18[i], - local_values.w_i_minus_15_rs_3[i] + local_values.w_i_minus_15_rs_3[i], ); let constraint = builder.sub_extension(local_values.s_0[i], s0); yield_constr.constraint(builder, constraint); @@ -372,7 +469,7 @@ impl, const D: usize> Stark for ShaExtendStar builder, local_values.w_i_minus_2_rr_17[i], local_values.w_i_minus_2_rr_19[i], - local_values.w_i_minus_2_rs_10[i] + local_values.w_i_minus_2_rs_10[i], ); let constraint = builder.sub_extension(local_values.s_1[i], s1); yield_constr.constraint(builder, constraint); @@ -384,8 +481,10 @@ impl, const D: usize> Stark for ShaExtendStar local_values.w_i_minus_7, local_values.s_1, local_values.carry_0, - local_values.w_i_inter_0 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.w_i_inter_0, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); // check the computation of w_i_inter_1 = w_i_inter_0 + s0 wrapping_add_ext_circuit_constraints( @@ -393,8 +492,10 @@ impl, const D: usize> Stark for ShaExtendStar local_values.w_i_inter_0, local_values.s_0, local_values.carry_1, - local_values.w_i_inter_1 - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.w_i_inter_1, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); // check the computation of w_i = w_i_inter_1 + w_i_minus_16 wrapping_add_ext_circuit_constraints( @@ -402,8 +503,10 @@ impl, const D: usize> Stark for ShaExtendStar local_values.w_i_inter_1, local_values.w_i_minus_16, local_values.carry_2, - local_values.w_i - ).into_iter().for_each(|c| yield_constr.constraint(builder, c)); + local_values.w_i, + ) + .into_iter() + .for_each(|c| yield_constr.constraint(builder, c)); } fn constraint_degree(&self) -> usize { @@ -411,24 +514,25 @@ impl, const D: usize> Stark for ShaExtendStar } } - #[cfg(test)] mod test { + use crate::config::StarkConfig; + use crate::cross_table_lookup::{ + Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet, + }; + use crate::prover::prove_single_table; + use crate::sha_extend::sha_extend_stark::ShaExtendStark; + use crate::sha_extend_sponge::columns::NUM_EXTEND_INPUT; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::polynomial::PolynomialValues; - use plonky2::field::types::{Field}; + use plonky2::field::types::Field; use plonky2::fri::oracle::PolynomialBatch; use plonky2::iop::challenger::Challenger; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::timed; use plonky2::util::timing::TimingTree; - use crate::config::StarkConfig; - use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; - use crate::prover::prove_single_table; - use crate::sha_extend::sha_extend_stark::ShaExtendStark; - use crate::sha_extend_sponge::columns::NUM_EXTEND_INPUT; - use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; fn to_be_bits(value: u32) -> [u8; 32] { let mut result = [0; 32]; @@ -441,7 +545,7 @@ mod test { fn get_random_input() -> [u8; NUM_EXTEND_INPUT * 32] { let mut input_values = vec![]; let rand = rand::random::(); - input_values.extend((rand..rand + 4).map(|i| to_be_bits(i as u32))); + input_values.extend((rand..rand + 4).map(to_be_bits)); let input_values = input_values.into_iter().flatten().collect::>(); input_values.try_into().unwrap() } @@ -459,18 +563,17 @@ mod test { let input_and_timestamp = (input_values, 0); let stark = S::default(); - let row = stark.generate_trace_rows_for_extend(input_and_timestamp.try_into().unwrap()); - + let row = stark.generate_trace_rows_for_extend(input_and_timestamp); // extend phase - let w_i_minus_15 = 0 as u32; + let w_i_minus_15 = 0_u32; let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); - let w_i_minus_2 = 1 as u32; + let w_i_minus_2 = 1_u32; // Compute `s1`. let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); - let w_i_minus_16 = 2 as u32; - let w_i_minus_7 = 3 as u32; + let w_i_minus_16 = 2_u32; + let w_i_minus_7 = 3_u32; // Compute `w_i`. let w_i = s1 .wrapping_add(w_i_minus_16) @@ -580,4 +683,4 @@ mod test { fn init_logger() { let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); } -} \ No newline at end of file +} diff --git a/prover/src/sha_extend_sponge/columns.rs b/prover/src/sha_extend_sponge/columns.rs index 9b849386..2286a957 100644 --- a/prover/src/sha_extend_sponge/columns.rs +++ b/prover/src/sha_extend_sponge/columns.rs @@ -1,11 +1,10 @@ -use std::borrow::{Borrow, BorrowMut}; -use std::intrinsics::transmute; use crate::util::{indices_arr, transmute_no_compile_time_size_checks}; +use std::borrow::{Borrow, BorrowMut}; +use std::mem::transmute; pub(crate) const NUM_EXTEND_INPUT: usize = 4; -pub(crate) const SHA_EXTEND_SPONGE_READ_BITS: usize = NUM_EXTEND_INPUT * 32; +pub(crate) const SHA_EXTEND_SPONGE_READ_BITS: usize = NUM_EXTEND_INPUT * 32; pub(crate) struct ShaExtendSpongeColumnsView { - /// Input pub w_i_minus_15: [T; 32], pub w_i_minus_2: [T; 32], @@ -78,9 +77,10 @@ impl Default for ShaExtendSpongeColumnsView { const fn make_col_map() -> ShaExtendSpongeColumnsView { let indices_arr = indices_arr::(); unsafe { - transmute::<[usize; NUM_SHA_EXTEND_SPONGE_COLUMNS], ShaExtendSpongeColumnsView>(indices_arr) + transmute::<[usize; NUM_SHA_EXTEND_SPONGE_COLUMNS], ShaExtendSpongeColumnsView>( + indices_arr, + ) } } pub(crate) const SHA_EXTEND_SPONGE_COL_MAP: ShaExtendSpongeColumnsView = make_col_map(); - diff --git a/prover/src/sha_extend_sponge/logic.rs b/prover/src/sha_extend_sponge/logic.rs index be31b63f..67d74516 100644 --- a/prover/src/sha_extend_sponge/logic.rs +++ b/prover/src/sha_extend_sponge/logic.rs @@ -1,10 +1,9 @@ +use crate::sha_extend_sponge::sha_extend_sponge_stark::NUM_ROUNDS; use plonky2::field::extension::Extendable; use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::plonk::circuit_builder::CircuitBuilder; -use crate::sha_extend_sponge::sha_extend_sponge_stark::NUM_ROUNDS; - // Compute (x - y - diff) * sum_round_flags pub(crate) fn diff_address_ext_circuit_constraint, const D: usize>( @@ -12,7 +11,7 @@ pub(crate) fn diff_address_ext_circuit_constraint, sum_round_flags: ExtensionTarget, x: ExtensionTarget, y: ExtensionTarget, - diff: usize + diff: usize, ) -> ExtensionTarget { let inter_1 = builder.sub_extension(x, y); let diff_ext = builder.constant_extension(F::Extension::from_canonical_u32(diff as u32)); @@ -21,30 +20,33 @@ pub(crate) fn diff_address_ext_circuit_constraint, } // Compute nxt_round - local_round - 1 -pub(crate) fn round_increment_ext_circuit_constraint, const D: usize>( +pub(crate) fn round_increment_ext_circuit_constraint< + F: RichField + Extendable, + const D: usize, +>( builder: &mut CircuitBuilder, local_round: [ExtensionTarget; NUM_ROUNDS], next_round: [ExtensionTarget; NUM_ROUNDS], ) -> ExtensionTarget { - let one_ext = builder.one_extension(); - let local_round_indices: Vec<_> = - (0..NUM_ROUNDS).map(|i| { + let local_round_indices: Vec<_> = (0..NUM_ROUNDS) + .map(|i| { let index = builder.constant_extension(F::Extension::from_canonical_u32(i as u32)); builder.mul_extension(local_round[i], index) - }).collect(); + }) + .collect(); let local_round_index = builder.add_many_extension(local_round_indices); - let next_round_indices: Vec<_> = - (0..NUM_ROUNDS).map(|i| { + let next_round_indices: Vec<_> = (0..NUM_ROUNDS) + .map(|i| { let index = builder.constant_extension(F::Extension::from_canonical_u32(i as u32)); builder.mul_extension(next_round[i], index) - }).collect(); + }) + .collect(); let next_round_index = builder.add_many_extension(next_round_indices); let increment = builder.sub_extension(next_round_index, local_round_index); builder.sub_extension(increment, one_ext) - -} \ No newline at end of file +} diff --git a/prover/src/sha_extend_sponge/mod.rs b/prover/src/sha_extend_sponge/mod.rs index afdca798..c08a6d5a 100644 --- a/prover/src/sha_extend_sponge/mod.rs +++ b/prover/src/sha_extend_sponge/mod.rs @@ -1,3 +1,3 @@ pub mod columns; -pub mod sha_extend_sponge_stark; pub mod logic; +pub mod sha_extend_sponge_stark; diff --git a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs index a64cd310..201af8d0 100644 --- a/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs +++ b/prover/src/sha_extend_sponge/sha_extend_sponge_stark.rs @@ -1,24 +1,29 @@ -use std::marker::PhantomData; -use std::borrow::Borrow; -use itertools::Itertools; -use plonky2::field::extension::{Extendable, FieldExtension}; -use plonky2::field::packed::PackedField; -use plonky2::field::polynomial::PolynomialValues; -use plonky2::field::types::Field; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::plonk::circuit_builder::CircuitBuilder; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::membus::NUM_CHANNELS; use crate::cross_table_lookup::{Column, Filter}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; -use crate::sha_extend::logic::{get_input_range, from_u32_to_be_bits, from_be_fbits_to_u32}; -use crate::sha_extend_sponge::columns::{ShaExtendSpongeColumnsView, NUM_EXTEND_INPUT, NUM_SHA_EXTEND_SPONGE_COLUMNS, SHA_EXTEND_SPONGE_COL_MAP}; -use crate::sha_extend_sponge::logic::{diff_address_ext_circuit_constraint, round_increment_ext_circuit_constraint}; +use crate::sha_extend::logic::{from_be_fbits_to_u32, from_u32_to_be_bits, get_input_range}; +use crate::sha_extend_sponge::columns::{ + ShaExtendSpongeColumnsView, NUM_EXTEND_INPUT, NUM_SHA_EXTEND_SPONGE_COLUMNS, + SHA_EXTEND_SPONGE_COL_MAP, +}; +use crate::sha_extend_sponge::logic::{ + diff_address_ext_circuit_constraint, round_increment_ext_circuit_constraint, +}; use crate::stark::Stark; use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryAddress; +use itertools::Itertools; +use plonky2::field::extension::{Extendable, FieldExtension}; +use plonky2::field::packed::PackedField; +use plonky2::field::polynomial::PolynomialValues; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use std::borrow::Borrow; +use std::marker::PhantomData; pub const NUM_ROUNDS: usize = 48; @@ -31,9 +36,9 @@ pub(crate) fn ctl_looking_sha_extend_inputs() -> Vec> { cols.w_i_minus_16.as_slice(), cols.w_i_minus_7.as_slice(), ] - .concat(), + .concat(), ) - .collect(); + .collect(); res.push(Column::single(cols.timestamp)); res } @@ -50,17 +55,15 @@ pub(crate) fn ctl_looking_sha_extend_outputs() -> Vec> { pub(crate) fn ctl_looked_data() -> Vec> { let cols = SHA_EXTEND_SPONGE_COL_MAP; let w_i_usize = Column::linear_combination( - cols.w_i.iter() + cols.w_i + .iter() .enumerate() .map(|(i, &b)| (b, F::from_canonical_usize(1 << i))), ); - Column::singles([ - cols.context, - cols.segment, - cols.output_virt, - cols.timestamp, - ]).chain([w_i_usize]).collect() + Column::singles([cols.context, cols.segment, cols.output_virt, cols.timestamp]) + .chain([w_i_usize]) + .collect() } pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { @@ -84,7 +87,7 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { le_bit = cols.w_i_minus_7; } // le_bit.reverse(); - let u32_value: Column = Column::le_bits(&le_bit); + let u32_value: Column = Column::le_bits(le_bit); res.push(u32_value); res.push(Column::single(cols.timestamp)); @@ -99,13 +102,11 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { pub(crate) fn ctl_looking_sha_extend_filter() -> Filter { let cols = SHA_EXTEND_SPONGE_COL_MAP; // not the padding rows. - Filter::new_simple(Column::sum( - &cols.round, - )) + Filter::new_simple(Column::sum(cols.round)) } #[derive(Clone, Debug)] -pub(crate) struct ShaExtendSpongeOp { +pub(crate) struct ShaExtendSpongeOp { /// The base address at which inputs are read pub(crate) base_address: Vec, @@ -159,11 +160,11 @@ impl, const D: usize> ShaExtendSpongeStark { rows } - fn generate_rows_for_op(&self, op: ShaExtendSpongeOp) -> ShaExtendSpongeColumnsView{ + fn generate_rows_for_op(&self, op: ShaExtendSpongeOp) -> ShaExtendSpongeColumnsView { let mut row = ShaExtendSpongeColumnsView::default(); row.timestamp = F::from_canonical_usize(op.timestamp); row.round = [F::ZEROS; 48]; - row.round[op.i as usize] = F::ONE; + row.round[op.i] = F::ONE; row.context = F::from_canonical_usize(op.base_address[0].context); row.segment = F::from_canonical_usize(op.base_address[Segment::Code as usize].segment); @@ -175,13 +176,29 @@ impl, const D: usize> ShaExtendSpongeStark { row.output_virt = F::from_canonical_usize(op.output_address.virt); row.w_i_minus_15 = op.input[get_input_range(0)] - .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); row.w_i_minus_2 = op.input[get_input_range(1)] - .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); row.w_i_minus_16 = op.input[get_input_range(2)] - .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); row.w_i_minus_7 = op.input[get_input_range(3)] - .iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap(); + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap(); row.w_i = self.compute_w_i(&mut row); row @@ -200,28 +217,32 @@ impl, const D: usize> ShaExtendSpongeStark { .wrapping_add(w_i_minus_7); let w_i_bin = from_u32_to_be_bits(w_i_u32); - w_i_bin.iter().map(|&x| F::from_canonical_u8(x)).collect::>().try_into().unwrap() + w_i_bin + .iter() + .map(|&x| F::from_canonical_u8(x)) + .collect::>() + .try_into() + .unwrap() } } impl, const D: usize> Stark for ShaExtendSpongeStark { type EvaluationFrame - = StarkFrame + = StarkFrame where - FE: FieldExtension, - P: PackedField; + FE: FieldExtension, + P: PackedField; type EvaluationFrameTarget = StarkFrame, NUM_SHA_EXTEND_SPONGE_COLUMNS>; fn eval_packed_generic( &self, vars: &Self::EvaluationFrame, - yield_constr: &mut ConstraintConsumer

+ yield_constr: &mut ConstraintConsumer

, ) where - FE: FieldExtension, - P: PackedField + FE: FieldExtension, + P: PackedField, { - let local_values: &[P; NUM_SHA_EXTEND_SPONGE_COLUMNS] = vars.get_local_values().try_into().unwrap(); let local_values: &ShaExtendSpongeColumnsView

= local_values.borrow(); @@ -231,10 +252,16 @@ impl, const D: usize> Stark for ShaExtendSpon // check the binary form for i in 0..32 { - yield_constr.constraint(local_values.w_i_minus_15[i] * (local_values.w_i_minus_15[i] - P::ONES)); - yield_constr.constraint(local_values.w_i_minus_2[i] * (local_values.w_i_minus_2[i] - P::ONES)); - yield_constr.constraint(local_values.w_i_minus_16[i] * (local_values.w_i_minus_16[i] - P::ONES)); - yield_constr.constraint(local_values.w_i_minus_7[i] * (local_values.w_i_minus_7[i] - P::ONES)); + yield_constr.constraint( + local_values.w_i_minus_15[i] * (local_values.w_i_minus_15[i] - P::ONES), + ); + yield_constr + .constraint(local_values.w_i_minus_2[i] * (local_values.w_i_minus_2[i] - P::ONES)); + yield_constr.constraint( + local_values.w_i_minus_16[i] * (local_values.w_i_minus_16[i] - P::ONES), + ); + yield_constr + .constraint(local_values.w_i_minus_7[i] * (local_values.w_i_minus_7[i] - P::ONES)); yield_constr.constraint(local_values.w_i[i] * (local_values.w_i[i] - P::ONES)); } @@ -248,15 +275,16 @@ impl, const D: usize> Stark for ShaExtendSpon yield_constr.constraint(is_final * (is_final - P::ONES)); let not_final = P::ONES - is_final; - let sum_round_flags = (0..NUM_ROUNDS) - .map(|i| local_values.round[i]) - .sum::

(); + let sum_round_flags = (0..NUM_ROUNDS).map(|i| local_values.round[i]).sum::

(); // If this is not the final step or a padding row, // the timestamp must be increased by 2 * NUM_CHANNELS. yield_constr.constraint( - sum_round_flags * not_final * - (next_values.timestamp - local_values.timestamp - FE::from_canonical_usize(2 * NUM_CHANNELS)), + sum_round_flags + * not_final + * (next_values.timestamp + - local_values.timestamp + - FE::from_canonical_usize(2 * NUM_CHANNELS)), ); // If this is not the final step or a padding row, @@ -269,18 +297,24 @@ impl, const D: usize> Stark for ShaExtendSpon .map(|i| next_values.round[i] * FE::from_canonical_u32(i as u32)) .sum::

(); yield_constr.constraint( - sum_round_flags * not_final * (next_round_index - local_round_index - P::ONES) + sum_round_flags * not_final * (next_round_index - local_round_index - P::ONES), ); // If this is not the final step or a padding row, // input and output addresses should be increased by 4 each (0..NUM_EXTEND_INPUT).for_each(|i| { yield_constr.constraint( - sum_round_flags * not_final * (next_values.input_virt[i] - local_values.input_virt[i] - FE::from_canonical_u32(4)) + sum_round_flags + * not_final + * (next_values.input_virt[i] + - local_values.input_virt[i] + - FE::from_canonical_u32(4)), ); }); yield_constr.constraint( - sum_round_flags * not_final * (next_values.output_virt - local_values.output_virt - FE::from_canonical_u32(4)) + sum_round_flags + * not_final + * (next_values.output_virt - local_values.output_virt - FE::from_canonical_u32(4)), ); // If it's not the padding row, check the virtual addresses @@ -288,19 +322,31 @@ impl, const D: usize> Stark for ShaExtendSpon // add_w[i-15] = add_w[i-16] + 4 yield_constr.constraint( - sum_round_flags * (local_values.input_virt[0] - local_values.input_virt[2] - FE::from_canonical_u32(4)) + sum_round_flags + * (local_values.input_virt[0] + - local_values.input_virt[2] + - FE::from_canonical_u32(4)), ); // add_w[i-2] = add_w[i-16] + 56 yield_constr.constraint( - sum_round_flags * (local_values.input_virt[1] - local_values.input_virt[2] - FE::from_canonical_u32(56)) + sum_round_flags + * (local_values.input_virt[1] + - local_values.input_virt[2] + - FE::from_canonical_u32(56)), ); // add_w[i-7] = add_w[i-16] + 36 yield_constr.constraint( - sum_round_flags * (local_values.input_virt[3] - local_values.input_virt[2] - FE::from_canonical_u32(36)) + sum_round_flags + * (local_values.input_virt[3] + - local_values.input_virt[2] + - FE::from_canonical_u32(36)), ); // add_w[i] = add_w[i-16] + 64 yield_constr.constraint( - sum_round_flags * (local_values.output_virt - local_values.input_virt[2] - FE::from_canonical_u32(64)) + sum_round_flags + * (local_values.output_virt + - local_values.input_virt[2] + - FE::from_canonical_u32(64)), ); } @@ -308,9 +354,8 @@ impl, const D: usize> Stark for ShaExtendSpon &self, builder: &mut CircuitBuilder, vars: &Self::EvaluationFrameTarget, - yield_constr: &mut RecursiveConstraintConsumer + yield_constr: &mut RecursiveConstraintConsumer, ) { - let local_values: &[ExtensionTarget; NUM_SHA_EXTEND_SPONGE_COLUMNS] = vars.get_local_values().try_into().unwrap(); let local_values: &ShaExtendSpongeColumnsView> = local_values.borrow(); @@ -320,34 +365,52 @@ impl, const D: usize> Stark for ShaExtendSpon let one_ext = builder.one_extension(); let four_ext = builder.constant_extension(F::Extension::from_canonical_u32(4)); - let num_channel = builder.constant_extension(F::Extension::from_canonical_usize(2 * NUM_CHANNELS)); + let num_channel = + builder.constant_extension(F::Extension::from_canonical_usize(2 * NUM_CHANNELS)); // check the binary form for i in 0..32 { let constraint = builder.mul_sub_extension( - local_values.w_i_minus_15[i], local_values.w_i_minus_15[i], local_values.w_i_minus_15[i]); + local_values.w_i_minus_15[i], + local_values.w_i_minus_15[i], + local_values.w_i_minus_15[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.w_i_minus_2[i], local_values.w_i_minus_2[i], local_values.w_i_minus_2[i]); + local_values.w_i_minus_2[i], + local_values.w_i_minus_2[i], + local_values.w_i_minus_2[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.w_i_minus_16[i], local_values.w_i_minus_16[i], local_values.w_i_minus_16[i]); + local_values.w_i_minus_16[i], + local_values.w_i_minus_16[i], + local_values.w_i_minus_16[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.w_i_minus_7[i], local_values.w_i_minus_7[i], local_values.w_i_minus_7[i]); + local_values.w_i_minus_7[i], + local_values.w_i_minus_7[i], + local_values.w_i_minus_7[i], + ); yield_constr.constraint(builder, constraint); let constraint = builder.mul_sub_extension( - local_values.w_i[i], local_values.w_i[i], local_values.w_i[i]); + local_values.w_i[i], + local_values.w_i[i], + local_values.w_i[i], + ); yield_constr.constraint(builder, constraint); } // check the round for i in 0..NUM_ROUNDS { let constraint = builder.mul_sub_extension( - local_values.round[i], local_values.round[i], local_values.round[i] + local_values.round[i], + local_values.round[i], + local_values.round[i], ); yield_constr.constraint(builder, constraint); } @@ -371,36 +434,28 @@ impl, const D: usize> Stark for ShaExtendSpon // If this is not the final step or a padding row, // round index should be increased by one - let round_increment = round_increment_ext_circuit_constraint( - builder, - local_values.round, - next_values.round - ); - let constraint = builder.mul_many_extension( - [sum_round_flags, not_final, round_increment] - ); + let round_increment = + round_increment_ext_circuit_constraint(builder, local_values.round, next_values.round); + let constraint = builder.mul_many_extension([sum_round_flags, not_final, round_increment]); yield_constr.constraint(builder, constraint); // If this is not the final step or a padding row, // input and output addresses should be increased by 4 each (0..NUM_EXTEND_INPUT).for_each(|i| { - - let increment = builder.sub_extension(next_values.input_virt[i], local_values.input_virt[i]); + let increment = + builder.sub_extension(next_values.input_virt[i], local_values.input_virt[i]); let address_increment = builder.sub_extension(increment, four_ext); - let constraint = builder.mul_many_extension( - [sum_round_flags, not_final, address_increment] - ); + let constraint = + builder.mul_many_extension([sum_round_flags, not_final, address_increment]); yield_constr.constraint(builder, constraint); }); let increment = builder.sub_extension(next_values.output_virt, local_values.output_virt); let address_increment = builder.sub_extension(increment, four_ext); - let constraint = builder.mul_many_extension( - [sum_round_flags, not_final, address_increment] - ); + let constraint = + builder.mul_many_extension([sum_round_flags, not_final, address_increment]); yield_constr.constraint(builder, constraint); - // If it's not the padding row, check the virtual addresses // The list of input addresses are: w[i-15], w[i-2], w[i-16], w[i-7] @@ -410,7 +465,7 @@ impl, const D: usize> Stark for ShaExtendSpon sum_round_flags, local_values.input_virt[0], local_values.input_virt[2], - 4 + 4, ); yield_constr.constraint(builder, constraint); @@ -420,7 +475,7 @@ impl, const D: usize> Stark for ShaExtendSpon sum_round_flags, local_values.input_virt[1], local_values.input_virt[2], - 56 + 56, ); yield_constr.constraint(builder, constraint); @@ -430,7 +485,7 @@ impl, const D: usize> Stark for ShaExtendSpon sum_round_flags, local_values.input_virt[3], local_values.input_virt[2], - 36 + 36, ); yield_constr.constraint(builder, constraint); @@ -440,7 +495,7 @@ impl, const D: usize> Stark for ShaExtendSpon sum_round_flags, local_values.output_virt, local_values.input_virt[2], - 64 + 64, ); yield_constr.constraint(builder, constraint); } @@ -450,9 +505,20 @@ impl, const D: usize> Stark for ShaExtendSpon } } - #[cfg(test)] mod test { + use crate::config::StarkConfig; + use crate::cross_table_lookup::{ + Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet, + }; + use crate::memory::segments::Segment; + use crate::memory::NUM_CHANNELS; + use crate::prover::prove_single_table; + use crate::sha_extend_sponge::sha_extend_sponge_stark::{ + ShaExtendSpongeOp, ShaExtendSpongeStark, + }; + use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; + use crate::witness::memory::MemoryAddress; use env_logger::{try_init_from_env, Env, DEFAULT_FILTER_ENV}; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::polynomial::PolynomialValues; @@ -462,14 +528,6 @@ mod test { use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::timed; use plonky2::util::timing::TimingTree; - use crate::config::StarkConfig; - use crate::cross_table_lookup::{Column, CtlData, CtlZData, Filter, GrandProductChallenge, GrandProductChallengeSet}; - use crate::memory::NUM_CHANNELS; - use crate::memory::segments::Segment; - use crate::prover::prove_single_table; - use crate::sha_extend_sponge::sha_extend_sponge_stark::{ShaExtendSpongeOp, ShaExtendSpongeStark}; - use crate::stark_testing::{test_stark_circuit_constraints, test_stark_low_degree}; - use crate::witness::memory::MemoryAddress; fn to_be_bits(value: u32) -> [u8; 32] { let mut result = [0; 32]; @@ -491,23 +549,28 @@ mod test { let input_values = input_values.into_iter().flatten().collect::>(); let op = ShaExtendSpongeOp { - base_address: vec![MemoryAddress { - context: 0, - segment: Segment::Code as usize, - virt: 4, - }, MemoryAddress { - context: 0, - segment: Segment::Code as usize, - virt: 56, - }, MemoryAddress { - context: 0, - segment: Segment::Code as usize, - virt: 0, - }, MemoryAddress { - context: 0, - segment: Segment::Code as usize, - virt: 36, - }], + base_address: vec![ + MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 4, + }, + MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 56, + }, + MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 0, + }, + MemoryAddress { + context: 0, + segment: Segment::Code as usize, + virt: 36, + }, + ], timestamp: 0, input: input_values, i: 0, @@ -557,18 +620,19 @@ mod test { w[i] = rand::random::(); } for i in 16..64 { - - let w_i_minus_15 = w[i-15]; - let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); + let w_i_minus_15 = w[i - 15]; + let s0 = + w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); // Read w[i-2]. - let w_i_minus_2 = w[i-2]; + let w_i_minus_2 = w[i - 2]; // Compute `s1`. - let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); + let s1 = + w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); // Read w[i-16]. - let w_i_minus_16 = w[i-16]; - let w_i_minus_7 = w[i-7]; + let w_i_minus_16 = w[i - 16]; + let w_i_minus_7 = w[i - 7]; // Compute `w_i`. w[i] = s1 @@ -579,10 +643,10 @@ mod test { let mut addresses = vec![]; for i in 0..64 { - addresses.push(MemoryAddress{ + addresses.push(MemoryAddress { context: 0, segment: Segment::Code as usize, - virt: i * 4 + virt: i * 4, }); } @@ -597,7 +661,12 @@ mod test { input_values.extend(to_be_bits(w[i - 7])); let op = ShaExtendSpongeOp { - base_address: vec![addresses[i - 15], addresses[i - 2], addresses[i - 16], addresses[i - 7]], + base_address: vec![ + addresses[i - 15], + addresses[i - 2], + addresses[i - 16], + addresses[i - 7], + ], timestamp: time, input: input_values, i: i - 16, @@ -609,7 +678,6 @@ mod test { } res - } #[test] fn sha_extend_sponge_benchmark() -> anyhow::Result<()> { @@ -679,4 +747,4 @@ mod test { fn init_logger() { let _ = try_init_from_env(Env::default().filter_or(DEFAULT_FILTER_ENV, "debug")); } -} \ No newline at end of file +} diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index df4c2fb7..b6377acd 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -14,13 +14,13 @@ use plonky2::field::types::Field; use super::util::keccak_sponge_log; use crate::keccak_sponge::columns::{KECCAK_RATE_BYTES, KECCAK_RATE_U32S}; use crate::poseidon_sponge::columns::POSEIDON_RATE_BYTES; +use crate::sha_extend::logic::from_u32_to_be_bits; use itertools::Itertools; use keccak_hash::keccak; use plonky2::field::extension::Extendable; use plonky2::hash::hash_types::RichField; use plonky2::plonk::config::GenericConfig; use std::fs; -use crate::sha_extend::logic::from_u32_to_be_bits; pub const WORD_SIZE: usize = core::mem::size_of::(); @@ -1222,7 +1222,6 @@ pub(crate) fn generate_sha_extend< input_addresses.push(addr); input_value_bit_be.push(from_u32_to_be_bits(w_i_minus_16)); - // Read w[i-7]. let addr = MemoryAddress::new(0, Segment::Code, w_ptr + (i - 7) * 4); let (w_i_minus_7, mem_op) = mem_read_gp_with_log_and_fill(3, addr, state, &mut cpu_row); @@ -1237,9 +1236,17 @@ pub(crate) fn generate_sha_extend< .wrapping_add(w_i_minus_7); // Write w[i]. - log::debug!("{:X}, {:X}, {:X} {:X} {:X} {:X}", s1, s0, w_i_minus_16, w_i_minus_7, w_i_minus_15, w_i_minus_2); + log::debug!( + "{:X}, {:X}, {:X} {:X} {:X} {:X}", + s1, + s0, + w_i_minus_16, + w_i_minus_7, + w_i_minus_15, + w_i_minus_2 + ); let addr = MemoryAddress::new(0, Segment::Code, w_ptr + i * 4); - log::debug!("extend write {:X} {:X}", w_ptr + i * 4, w_i); + log::debug!("extend write {:X} {:X}", w_ptr + i * 4, w_i); let mem_op = mem_write_gp_log_and_fill(4, addr, state, &mut cpu_row, w_i); state.traces.push_memory(mem_op); @@ -1314,7 +1321,10 @@ pub(crate) fn generate_sha_compress< cpu_row = CpuColumnsView::default(); cpu_row.clock = F::from_canonical_usize(state.traces.clock()); for i in 0..64 { - let input_state = [a, b, c, d, e, f, g, h].iter().map(|x| from_u32_to_be_bits(*x)).collect_vec(); + let input_state = [a, b, c, d, e, f, g, h] + .iter() + .map(|x| from_u32_to_be_bits(*x)) + .collect_vec(); state_values.push(input_state); let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25); @@ -1335,7 +1345,7 @@ pub(crate) fn generate_sha_compress< let temp1 = h .wrapping_add(s1) .wrapping_add(ch) - .wrapping_add(SHA_COMPRESS_K[i as usize]) + .wrapping_add(SHA_COMPRESS_K[i]) .wrapping_add(w_i); let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22); let maj = (a & b) ^ (a & c) ^ (b & c); @@ -1361,9 +1371,18 @@ pub(crate) fn generate_sha_compress< cpu_row.mem_channels[1].value = F::from_canonical_usize(Segment::Code as usize); cpu_row.mem_channels[2].value = F::from_canonical_usize(hx_addresses[0].virt); // start address of hx - let u32_result: Vec = [a, b, c, d, e, f, g, h].iter().enumerate().map(|(i, x)| hx[i].wrapping_add(*x)).collect_vec(); + let u32_result: Vec = [a, b, c, d, e, f, g, h] + .iter() + .enumerate() + .map(|(i, x)| hx[i].wrapping_add(*x)) + .collect_vec(); - cpu_row.general.shash_mut().value = u32_result.into_iter().map(F::from_canonical_u32).collect_vec().try_into().unwrap(); + cpu_row.general.shash_mut().value = u32_result + .into_iter() + .map(F::from_canonical_u32) + .collect_vec() + .try_into() + .unwrap(); // cpu_row.general.shash_mut().value.reverse(); sha_compress_sponge_log( state, @@ -1371,7 +1390,7 @@ pub(crate) fn generate_sha_compress< hx_addresses, w_i_value_bit_be, w_i_addresses, - state_values + state_values, ); state.traces.push_cpu(cpu_row); @@ -1383,7 +1402,7 @@ pub(crate) fn generate_sha_compress< let mem_op = mem_write_gp_log_and_fill(i, addr, state, &mut cpu_row, hx[i].wrapping_add(v[i])); state.traces.push_memory(mem_op); - log::debug!("write {:X} {:X}", h_ptr + i * 4, hx[i].wrapping_add(v[i])); + log::debug!("write {:X} {:X}", h_ptr + i * 4, hx[i].wrapping_add(v[i])); } state.traces.push_cpu(cpu_row); Ok(()) diff --git a/prover/src/witness/traces.rs b/prover/src/witness/traces.rs index 3cdab506..f775fc42 100644 --- a/prover/src/witness/traces.rs +++ b/prover/src/witness/traces.rs @@ -19,14 +19,14 @@ use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; use crate::poseidon::constants::SPONGE_WIDTH; use crate::poseidon_sponge::columns::POSEIDON_RATE_BYTES; use crate::poseidon_sponge::poseidon_sponge_stark::PoseidonSpongeOp; -use crate::util::join; -use crate::util::trace_rows_to_poly_values; -use crate::witness::memory::MemoryOp; -use crate::{arithmetic, logic}; use crate::sha_compress::sha_compress_stark; use crate::sha_compress_sponge::sha_compress_sponge_stark::ShaCompressSpongeOp; use crate::sha_extend::sha_extend_stark; use crate::sha_extend_sponge::sha_extend_sponge_stark::ShaExtendSpongeOp; +use crate::util::join; +use crate::util::trace_rows_to_poly_values; +use crate::witness::memory::MemoryOp; +use crate::{arithmetic, logic}; #[derive(Clone, Copy, Debug)] pub struct TraceCheckpoint { @@ -106,11 +106,9 @@ impl Traces { .map(|op| op.input.len() / keccak_sponge::columns::KECCAK_RATE_BYTES + 1) .sum(), sha_extend_len: self.sha_extend_inputs.len(), - sha_extend_sponge_len: self - .sha_extend_sponge_ops.len(), + sha_extend_sponge_len: self.sha_extend_sponge_ops.len(), sha_compress_len: self.sha_compress_inputs.len(), - sha_compress_sponge_len: self - .sha_compress_sponge_ops.len(), + sha_compress_sponge_len: self.sha_compress_sponge_ops.len(), logic_len: self.logic_ops.len(), // This is technically a lower-bound, as we may fill gaps, // but this gives a relatively good estimate. @@ -148,7 +146,8 @@ impl Traces { self.sha_extend_inputs.truncate(checkpoint.sha_extend_len); self.sha_extend_sponge_ops .truncate(checkpoint.sha_extend_sponge_len); - self.sha_compress_inputs.truncate(checkpoint.sha_compress_len); + self.sha_compress_inputs + .truncate(checkpoint.sha_compress_len); self.sha_compress_sponge_ops .truncate(checkpoint.sha_compress_sponge_len); self.logic_ops.truncate(checkpoint.logic_len); diff --git a/prover/src/witness/util.rs b/prover/src/witness/util.rs index de4c3041..7699f6ab 100644 --- a/prover/src/witness/util.rs +++ b/prover/src/witness/util.rs @@ -19,14 +19,14 @@ use crate::poseidon::constants::{SPONGE_RATE, SPONGE_WIDTH}; use crate::poseidon::poseidon_stark::poseidon_with_witness; use crate::poseidon_sponge::columns::POSEIDON_RATE_BYTES; use crate::poseidon_sponge::poseidon_sponge_stark::PoseidonSpongeOp; -use crate::witness::errors::ProgramError; -use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; -use plonky2::field::extension::Extendable; -use plonky2::plonk::config::GenericConfig; use crate::sha_compress::logic::from_be_bits_to_u32; use crate::sha_compress_sponge::constants::SHA_COMPRESS_K_BINARY; use crate::sha_compress_sponge::sha_compress_sponge_stark::ShaCompressSpongeOp; use crate::sha_extend_sponge::sha_extend_sponge_stark::ShaExtendSpongeOp; +use crate::witness::errors::ProgramError; +use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; +use plonky2::field::extension::Extendable; +use plonky2::plonk::config::GenericConfig; fn to_byte_checked(n: u32) -> u8 { let res: u8 = n.to_le_bytes()[0]; @@ -561,7 +561,7 @@ pub(crate) fn sha_extend_sponge_log< F: RichField + Extendable, C: GenericConfig, const D: usize, -> ( +>( state: &mut GenerationState, base_address: Vec, inputs: Vec<[u8; 32]>, // BE bits @@ -573,10 +573,9 @@ pub(crate) fn sha_extend_sponge_log< let clock = state.traces.clock(); let mut n_gp = 0; - let mut addr_idx = 0; let extend_input: Vec = inputs.iter().flatten().cloned().collect(); - for input in inputs { + for (addr_idx, input) in inputs.into_iter().enumerate() { let val = from_be_bits_to_u32(input); for _ in 0..32 { state.traces.push_memory(MemoryOp::new( @@ -589,24 +588,26 @@ pub(crate) fn sha_extend_sponge_log< n_gp += 1; n_gp %= NUM_GP_CHANNELS - 1; } - addr_idx += 1; } - state.traces.push_sha_extend(extend_input.clone().try_into().unwrap(), clock * NUM_CHANNELS); + state.traces.push_sha_extend( + extend_input.clone().try_into().unwrap(), + clock * NUM_CHANNELS, + ); state.traces.push_sha_extend_sponge(ShaExtendSpongeOp { base_address, timestamp: clock * NUM_CHANNELS, input: extend_input, i: round, - output_address + output_address, }); } -pub(crate) fn sha_compress_sponge_log < +pub(crate) fn sha_compress_sponge_log< F: RichField + Extendable, C: GenericConfig, const D: usize, -> ( +>( state: &mut GenerationState, hx_values: Vec<[u8; 32]>, // BE bits hx_addresses: Vec, @@ -621,7 +622,6 @@ pub(crate) fn sha_compress_sponge_log < let mut n_gp = 0; for i in 0..64 { - // read hx as input for (j, hx) in hx_values.iter().enumerate() { let val = from_be_bits_to_u32(*hx); @@ -651,15 +651,31 @@ pub(crate) fn sha_compress_sponge_log < n_gp %= NUM_GP_CHANNELS - 1; } - let w_i = w_i_values[i]; let k_i = SHA_COMPRESS_K_BINARY[i]; - let base_address = hx_addresses.clone().into_iter().chain([w_i_addresses[i]]).collect_vec(); - let compress_sponge_input: Vec = hx_values.iter().chain(&[w_i]).flatten().cloned().collect(); - let compress_input: Vec = input_state_list[i].iter().chain(&[w_i, k_i]).flatten().cloned().collect(); - let input_states: Vec = input_state_list[i].clone().iter().flatten().cloned().collect(); + let base_address = hx_addresses + .clone() + .into_iter() + .chain([w_i_addresses[i]]) + .collect_vec(); + let compress_sponge_input: Vec = + hx_values.iter().chain(&[w_i]).flatten().cloned().collect(); + let compress_input: Vec = input_state_list[i] + .iter() + .chain(&[w_i, k_i]) + .flatten() + .cloned() + .collect(); + let input_states: Vec = input_state_list[i] + .clone() + .iter() + .flatten() + .cloned() + .collect(); - state.traces.push_sha_compress(compress_input.try_into().unwrap(), clock * NUM_CHANNELS); + state + .traces + .push_sha_compress(compress_input.try_into().unwrap(), clock * NUM_CHANNELS); state.traces.push_sha_compress_sponge(ShaCompressSpongeOp { base_address, @@ -668,7 +684,6 @@ pub(crate) fn sha_compress_sponge_log < i, input: compress_sponge_input, }); - } } diff --git a/runtime/entrypoint/src/syscalls/mod.rs b/runtime/entrypoint/src/syscalls/mod.rs index ca3c6dbf..cdf0fc2c 100644 --- a/runtime/entrypoint/src/syscalls/mod.rs +++ b/runtime/entrypoint/src/syscalls/mod.rs @@ -3,15 +3,15 @@ mod halt; mod io; mod keccak; -mod sha256; mod memory; +mod sha256; mod sys; pub use halt::*; pub use io::*; pub use keccak::*; -pub use sha256::*; pub use memory::*; +pub use sha256::*; pub use sys::*; /// These codes MUST match the codes in `core/src/runtime/syscall.rs`. There is a derived test diff --git a/runtime/precompiles/src/io.rs b/runtime/precompiles/src/io.rs index 346786b0..1e7e7cc6 100644 --- a/runtime/precompiles/src/io.rs +++ b/runtime/precompiles/src/io.rs @@ -5,7 +5,7 @@ use crate::syscall_keccak; use crate::syscall_verify; use crate::syscall_write; use crate::{syscall_hint_len, syscall_hint_read}; -use crate::{syscall_sha256_extend, syscall_sha256_compress}; +use crate::{syscall_sha256_compress, syscall_sha256_extend}; use serde::de::DeserializeOwned; use serde::Serialize; use sha2::{Digest, Sha256}; @@ -39,7 +39,7 @@ impl std::io::Write for SyscallWriter { pub fn read_vec() -> Vec { let len = unsafe { syscall_hint_len() }; // Round up to the nearest multiple of 4 so that the memory allocated is in whole words - let capacity = (len + 3) / 4 * 4; + let capacity = (len + 3).div_ceil(4) * 4; // Allocate a buffer of the required length that is 4 byte aligned let layout = Layout::from_size_align(capacity, 4).expect("vec is too large"); @@ -154,14 +154,14 @@ pub fn keccak(data: &[u8]) -> [u8; 32] { pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { unsafe { - for i in 0..blocks.len() { + for block in blocks { let mut w = [0u32; 64]; - for j in 0..16 { - w[j] = u32::from_be_bytes([ - blocks[i][j * 4], - blocks[i][j * 4 + 1], - blocks[i][j * 4 + 2], - blocks[i][j * 4 + 3], + for (j, item) in w.iter_mut().enumerate().take(16) { + *item = u32::from_be_bytes([ + block[j * 4], + block[j * 4 + 1], + block[j * 4 + 2], + block[j * 4 + 3], ]); } syscall_sha256_extend(w.as_mut_ptr()); From 1fe77344fe42441008ae53411ab1394781c229fa Mon Sep 17 00:00:00 2001 From: vanhger Date: Tue, 4 Feb 2025 16:26:31 +0700 Subject: [PATCH 25/25] feat: wrap SHA Ch, Ma, Sigma0, Sigma1 to functions --- prover/src/sha_compress_sponge/logic.rs | 15 +++++++++++++++ prover/src/sha_compress_sponge/mod.rs | 1 + .../sha_compress_sponge_stark.rs | 12 ++++++++---- 3 files changed, 24 insertions(+), 4 deletions(-) create mode 100644 prover/src/sha_compress_sponge/logic.rs diff --git a/prover/src/sha_compress_sponge/logic.rs b/prover/src/sha_compress_sponge/logic.rs new file mode 100644 index 00000000..1b53c407 --- /dev/null +++ b/prover/src/sha_compress_sponge/logic.rs @@ -0,0 +1,15 @@ +pub(crate) fn sha_ch(a: u32, b: u32, c: u32) -> u32 { + (a & b) ^ (!a & c) +} + +pub(crate) fn sha_ma(a: u32, b: u32, c: u32) -> u32 { + (a & b) ^ (a & c) ^ (b & c) +} + +pub(crate) fn sha_sigma0(x: u32) -> u32 { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) +} + +pub(crate) fn sha_sigma1(x: u32) -> u32 { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) +} diff --git a/prover/src/sha_compress_sponge/mod.rs b/prover/src/sha_compress_sponge/mod.rs index d1fb8826..9dc3cd79 100644 --- a/prover/src/sha_compress_sponge/mod.rs +++ b/prover/src/sha_compress_sponge/mod.rs @@ -1,3 +1,4 @@ pub mod columns; pub mod constants; +pub mod logic; pub mod sha_compress_sponge_stark; diff --git a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs index 467b9941..fe4f659b 100644 --- a/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs +++ b/prover/src/sha_compress_sponge/sha_compress_sponge_stark.rs @@ -7,6 +7,7 @@ use crate::sha_compress_sponge::columns::{ ShaCompressSpongeColumnsView, NUM_SHA_COMPRESS_SPONGE_COLUMNS, SHA_COMPRESS_SPONGE_COL_MAP, }; use crate::sha_compress_sponge::constants::{NUM_COMPRESS_ROWS, SHA_COMPRESS_K_BINARY}; +use crate::sha_compress_sponge::logic::{sha_ch, sha_ma, sha_sigma0, sha_sigma1}; use crate::sha_extend::logic::{ from_u32_to_be_bits, get_input_range, wrapping_add, wrapping_add_ext_circuit_constraints, wrapping_add_packed_constraints, @@ -243,13 +244,16 @@ impl, const D: usize> ShaCompressSpongeStark .unwrap(); let w_i = from_be_bits_to_u32(w_i.try_into().unwrap()); + let ch_efg = sha_ch(e, f, g); + let sigma1_e = sha_sigma1(e); + let sigma0_a = sha_sigma0(a); + let major = sha_ma(a, b, c); let t1 = h - .wrapping_add(e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25)) - .wrapping_add((e & f) ^ ((!e) & g)) + .wrapping_add(sigma1_e) + .wrapping_add(ch_efg) .wrapping_add(SHA_COMPRESS_K[round]) .wrapping_add(w_i); - let t2 = (a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22)) - .wrapping_add((a & b) ^ (a & c) ^ (b & c)); + let t2 = sigma0_a.wrapping_add(major); h = g; g = f; f = e;