diff --git a/Cargo.toml b/Cargo.toml index 2742a8a2..fb088e6d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,9 @@ members = [ ] resolver = "2" +[workspace.package] +version = "0.2.0" +edition = "2021" [profile.release] opt-level = 3 diff --git a/build/Cargo.toml b/build/Cargo.toml index 50d0926f..7e728cdf 100644 --- a/build/Cargo.toml +++ b/build/Cargo.toml @@ -2,8 +2,8 @@ name = "zkm-build" description = "Build an ZKM program." readme = "README.md" -version = "0.1.0" -edition = "2021" +version = { workspace = true } +edition = { workspace = true } [dependencies] cargo_metadata = "0.18.1" diff --git a/emulator/Cargo.toml b/emulator/Cargo.toml index 153c5895..89215a19 100644 --- a/emulator/Cargo.toml +++ b/emulator/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "zkm-emulator" -version = "0.1.0" -edition = "2021" +version = { workspace = true } +edition = { workspace = true } [dependencies] plonky2 = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } diff --git a/prover/Cargo.toml b/prover/Cargo.toml index e78f7dd8..1477262c 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "zkm-prover" -version = "0.1.0" -edition = "2021" +version = { workspace = true } +edition = { workspace = true } # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -39,7 +39,6 @@ lazy_static = "1.4.0" elf = { version = "0.7", default-features = false } sha2 = { version = "0.10.8", default-features = false } - [dev-dependencies] env_logger = "0.10.0" keccak-hash = "0.10.0" diff --git a/prover/examples/prove-seg/Cargo.toml b/prover/examples/prove-seg/Cargo.toml index ad0c4570..2a6bdaa7 100644 --- a/prover/examples/prove-seg/Cargo.toml +++ b/prover/examples/prove-seg/Cargo.toml @@ -7,10 +7,7 @@ publish = false [dependencies] zkm-prover = { workspace = true } zkm-emulator = { workspace = true } -plonky2 = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_util = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_maybe_rayon = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2x = { git = "https://github.com/zkMIPS/succinctx.git", package = "plonky2x", branch = "zkm" } +zkm-utils = { path = "../utils" } log = { version = "0.4.14", default-features = false } serde = { version = "1.0.144", features = ["derive"] } diff --git a/prover/examples/prove-seg/src/main.rs b/prover/examples/prove-seg/src/main.rs index ab106cfa..cfea2045 100644 --- a/prover/examples/prove-seg/src/main.rs +++ b/prover/examples/prove-seg/src/main.rs @@ -1,217 +1,5 @@ use std::env; -use std::fs::File; -use std::io::BufReader; -use std::ops::Range; -use std::time::Duration; - -use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; -use plonky2::util::timing::TimingTree; -use plonky2x::backend::circuit::Groth16WrapperParameters; -use plonky2x::backend::wrapper::wrap::WrappedCircuit; -use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; -use plonky2x::prelude::DefaultParameters; - -use zkm_prover::all_stark::AllStark; -use zkm_prover::config::StarkConfig; -use zkm_prover::cpu::kernel::assembler::segment_kernel; -use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; -use zkm_prover::proof; -use zkm_prover::proof::PublicValues; -use zkm_prover::prover::prove; -use zkm_prover::verifier::verify_proof; - -const D: usize = 2; -type C = PoseidonGoldilocksConfig; -type F = >::F; - -const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; - -fn prove_single_seg_common(seg_file: &str, basedir: &str, block: &str, file: &str) { - let seg_reader = BufReader::new(File::open(seg_file).unwrap()); - let kernel = segment_kernel(basedir, block, file, seg_reader); - - let allstark: AllStark = AllStark::default(); - let config = StarkConfig::standard_fast_config(); - let mut timing = TimingTree::new("prove", log::Level::Info); - let allproof: proof::AllProof = - prove(&allstark, &kernel, &config, &mut timing).unwrap(); - let mut count_bytes = 0; - for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() { - let proof_str = serde_json::to_string(&proof.proof).unwrap(); - log::info!("row:{} proof bytes:{}", row, proof_str.len()); - count_bytes += proof_str.len(); - } - timing.filter(Duration::from_millis(100)).print(); - log::info!("total proof bytes:{}KB", count_bytes / 1024); - verify_proof(&allstark, allproof, &config).unwrap(); - log::info!("Prove done"); -} - -fn prove_multi_seg_common( - seg_dir: &str, - basedir: &str, - block: &str, - file: &str, - seg_file_number: usize, - seg_start_id: usize, -) -> anyhow::Result<()> { - type InnerParameters = DefaultParameters; - type OuterParameters = Groth16WrapperParameters; - - if seg_file_number < 2 { - panic!("seg file number must >= 2\n"); - } - - let total_timing = TimingTree::new("prove total time", log::Level::Info); - let all_stark = AllStark::::default(); - let config = StarkConfig::standard_fast_config(); - // Preprocess all circuits. - let all_circuits = - AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); - - let seg_file = format!("{}/{}", seg_dir, seg_start_id); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (mut agg_proof, mut updated_agg_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(agg_proof.clone())?; - - let mut base_seg = seg_start_id + 1; - let mut seg_num = seg_file_number - 1; - let mut is_agg = false; - - if seg_file_number % 2 == 0 { - let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader); - timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &agg_proof, - false, - &root_proof, - agg_public_values.clone(), - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&agg_proof)?; - - is_agg = true; - base_seg = seg_start_id + 2; - seg_num -= 1; - } - - for i in 0..seg_num / 2 { - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (root_proof_first, first_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(root_proof_first.clone())?; - - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let new_agg_public_values = PublicValues { - roots_before: first_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &root_proof_first, - false, - &root_proof, - new_agg_public_values, - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&new_agg_proof)?; - - // Update public values for the nested aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: new_updated_agg_public_values.roots_after, - userdata: new_updated_agg_public_values.userdata, - }; - timing = TimingTree::new("prove nested aggression", log::Level::Info); - - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - is_agg, - &agg_proof, - true, - &new_agg_proof, - agg_public_values.clone(), - )?; - is_agg = true; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_aggregation(&agg_proof)?; - } - - let (block_proof, _block_public_values) = - all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?; - - log::info!( - "proof size: {:?}", - serde_json::to_string(&block_proof.proof).unwrap().len() - ); - let result = all_circuits.verify_block(&block_proof); - - let build_path = "../verifier/data".to_string(); - let path = format!("{}/test_circuit/", build_path); - let builder = WrapperBuilder::::new(); - let mut circuit = builder.build(); - circuit.set_data(all_circuits.block.circuit); - let mut bit_size = vec![32usize; 16]; - bit_size.extend(vec![8; 32]); - bit_size.extend(vec![64; 68]); - let wrapped_circuit = WrappedCircuit::::build( - circuit, - Some((vec![], bit_size)), - ); - log::info!("build finish"); - - let wrapped_proof = wrapped_circuit.prove(&block_proof).unwrap(); - wrapped_proof.save(path).unwrap(); - - total_timing.filter(Duration::from_millis(100)).print(); - result -} +use zkm_prover::utils; fn prove_segments() { let basedir = env::var("BASEDIR").unwrap_or("/tmp/cannon".to_string()); @@ -223,15 +11,18 @@ fn prove_segments() { let seg_start_id = env::var("SEG_START_ID").unwrap_or("0".to_string()); let seg_start_id = seg_start_id.parse::<_>().unwrap_or(0usize); - if seg_num == 1 { - let seg_file = format!("{seg_dir}/{}", seg_start_id); - prove_single_seg_common(&seg_file, &basedir, &block, &file) - } else { - prove_multi_seg_common(&seg_dir, &basedir, &block, &file, seg_num, seg_start_id).unwrap() - } + let _ = utils::prove_segments( + &seg_dir, + &basedir, + &block, + &file, + seg_num, + seg_start_id, + vec![], + ); } fn main() { env_logger::try_init().unwrap_or_default(); prove_segments(); -} \ No newline at end of file +} diff --git a/prover/examples/revme/guest/Cargo.lock b/prover/examples/revme/guest/Cargo.lock index 4d8318dc..983fc2dc 100644 --- a/prover/examples/revme/guest/Cargo.lock +++ b/prover/examples/revme/guest/Cargo.lock @@ -993,7 +993,7 @@ dependencies = [ [[package]] name = "zkm-precompiles" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bincode", "cfg-if", @@ -1003,7 +1003,7 @@ dependencies = [ [[package]] name = "zkm-runtime" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bincode", "cfg-if", diff --git a/prover/examples/revme/host/Cargo.toml b/prover/examples/revme/host/Cargo.toml index b9c10523..726a3b95 100644 --- a/prover/examples/revme/host/Cargo.toml +++ b/prover/examples/revme/host/Cargo.toml @@ -7,10 +7,7 @@ publish = false [dependencies] zkm-prover = { workspace = true } zkm-emulator = { workspace = true } -plonky2 = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_util = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_maybe_rayon = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2x = { git = "https://github.com/zkMIPS/succinctx.git", package = "plonky2x", branch = "zkm" } +zkm-utils = { path = "../../utils" } log = { version = "0.4.14", default-features = false } serde = { version = "1.0.144", features = ["derive"] } diff --git a/prover/examples/revme/host/src/main.rs b/prover/examples/revme/host/src/main.rs index 6e29d717..cb82163a 100644 --- a/prover/examples/revme/host/src/main.rs +++ b/prover/examples/revme/host/src/main.rs @@ -1,219 +1,9 @@ use std::env; -use std::io::Read; use std::fs::File; -use std::io::BufReader; -use std::ops::Range; -use std::time::Duration; - -use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; -use plonky2::util::timing::TimingTree; -use plonky2x::backend::circuit::Groth16WrapperParameters; -use plonky2x::backend::wrapper::wrap::WrappedCircuit; -use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; -use plonky2x::prelude::DefaultParameters; +use std::io::Read; use zkm_emulator::utils::{load_elf_with_patch, split_prog_into_segs}; -use zkm_prover::all_stark::AllStark; -use zkm_prover::config::StarkConfig; -use zkm_prover::cpu::kernel::assembler::segment_kernel; -use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; -use zkm_prover::proof; -use zkm_prover::proof::PublicValues; -use zkm_prover::prover::prove; -use zkm_prover::verifier::verify_proof; - -const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; - -const D: usize = 2; -type C = PoseidonGoldilocksConfig; -type F = >::F; - -fn prove_single_seg_common(seg_file: &str, basedir: &str, block: &str, file: &str) { - let seg_reader = BufReader::new(File::open(seg_file).unwrap()); - let kernel = segment_kernel(basedir, block, file, seg_reader); - - let allstark: AllStark = AllStark::default(); - let config = StarkConfig::standard_fast_config(); - let mut timing = TimingTree::new("prove", log::Level::Info); - let allproof: proof::AllProof = - prove(&allstark, &kernel, &config, &mut timing).unwrap(); - let mut count_bytes = 0; - for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() { - let proof_str = serde_json::to_string(&proof.proof).unwrap(); - log::info!("row:{} proof bytes:{}", row, proof_str.len()); - count_bytes += proof_str.len(); - } - timing.filter(Duration::from_millis(100)).print(); - log::info!("total proof bytes:{}KB", count_bytes / 1024); - verify_proof(&allstark, allproof, &config).unwrap(); - log::info!("Prove done"); -} - -fn prove_multi_seg_common( - seg_dir: &str, - basedir: &str, - block: &str, - file: &str, - seg_file_number: usize, - seg_start_id: usize, -) -> anyhow::Result<()> { - type InnerParameters = DefaultParameters; - type OuterParameters = Groth16WrapperParameters; - - if seg_file_number < 2 { - panic!("seg file number must >= 2\n"); - } - - let total_timing = TimingTree::new("prove total time", log::Level::Info); - let all_stark = AllStark::::default(); - let config = StarkConfig::standard_fast_config(); - // Preprocess all circuits. - let all_circuits = - AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); - - let seg_file = format!("{}/{}", seg_dir, seg_start_id); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (mut agg_proof, mut updated_agg_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(agg_proof.clone())?; - - let mut base_seg = seg_start_id + 1; - let mut seg_num = seg_file_number - 1; - let mut is_agg = false; - - if seg_file_number % 2 == 0 { - let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader); - timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &agg_proof, - false, - &root_proof, - agg_public_values.clone(), - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&agg_proof)?; - - is_agg = true; - base_seg = seg_start_id + 2; - seg_num -= 1; - } - - for i in 0..seg_num / 2 { - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (root_proof_first, first_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(root_proof_first.clone())?; - - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let new_agg_public_values = PublicValues { - roots_before: first_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &root_proof_first, - false, - &root_proof, - new_agg_public_values, - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&new_agg_proof)?; - - // Update public values for the nested aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: new_updated_agg_public_values.roots_after, - userdata: new_updated_agg_public_values.userdata, - }; - timing = TimingTree::new("prove nested aggression", log::Level::Info); - - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - is_agg, - &agg_proof, - true, - &new_agg_proof, - agg_public_values.clone(), - )?; - is_agg = true; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_aggregation(&agg_proof)?; - } - - let (block_proof, _block_public_values) = - all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?; - - log::info!( - "proof size: {:?}", - serde_json::to_string(&block_proof.proof).unwrap().len() - ); - let result = all_circuits.verify_block(&block_proof); - - let build_path = "../verifier/data".to_string(); - let path = format!("{}/test_circuit/", build_path); - let builder = WrapperBuilder::::new(); - let mut circuit = builder.build(); - circuit.set_data(all_circuits.block.circuit); - let mut bit_size = vec![32usize; 16]; - bit_size.extend(vec![8; 32]); - bit_size.extend(vec![64; 68]); - let wrapped_circuit = WrappedCircuit::::build( - circuit, - Some((vec![], bit_size)), - ); - log::info!("build finish"); - - let wrapped_proof = wrapped_circuit.prove(&block_proof).unwrap(); - wrapped_proof.save(path).unwrap(); - - total_timing.filter(Duration::from_millis(100)).print(); - result -} +use zkm_utils::utils::prove_segments; const ELF_PATH: &str = "../guest/elf/mips-zkm-zkvm-elf"; @@ -233,12 +23,7 @@ fn prove_revm() { let (_total_steps, seg_num, mut _state) = split_prog_into_segs(state, &seg_path, "", seg_size); - if seg_num == 1 { - let seg_file = format!("{seg_path}/{}", 0); - prove_single_seg_common(&seg_file, "", "", "") - } else { - prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap() - } + let _ = prove_segments(&seg_path, "", "", "", seg_num, 0, vec![]); } fn main() { diff --git a/prover/examples/sha2-go/host/Cargo.toml b/prover/examples/sha2-go/host/Cargo.toml index 1abaee86..c83ac0c2 100644 --- a/prover/examples/sha2-go/host/Cargo.toml +++ b/prover/examples/sha2-go/host/Cargo.toml @@ -7,10 +7,7 @@ publish = false [dependencies] zkm-prover = { workspace = true } zkm-emulator = { workspace = true } -plonky2 = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_util = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_maybe_rayon = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2x = { git = "https://github.com/zkMIPS/succinctx.git", package = "plonky2x", branch = "zkm" } +zkm-utils = { path = "../../utils" } log = { version = "0.4.14", default-features = false } serde = { version = "1.0.144", features = ["derive"] } diff --git a/prover/examples/sha2-go/host/src/main.rs b/prover/examples/sha2-go/host/src/main.rs index cbc3aa89..6e775336 100644 --- a/prover/examples/sha2-go/host/src/main.rs +++ b/prover/examples/sha2-go/host/src/main.rs @@ -1,219 +1,8 @@ use serde::{Deserialize, Serialize}; use std::env; -use std::fs::File; -use std::io::BufReader; -use std::ops::Range; -use std::time::Duration; - -use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; -use plonky2::util::timing::TimingTree; -use plonky2x::backend::circuit::Groth16WrapperParameters; -use plonky2x::backend::wrapper::wrap::WrappedCircuit; -use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; -use plonky2x::prelude::DefaultParameters; use zkm_emulator::utils::{load_elf_with_patch, split_prog_into_segs}; -use zkm_prover::all_stark::AllStark; -use zkm_prover::config::StarkConfig; -use zkm_prover::cpu::kernel::assembler::segment_kernel; -use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; -use zkm_prover::proof; -use zkm_prover::proof::PublicValues; -use zkm_prover::prover::prove; -use zkm_prover::verifier::verify_proof; - -const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; - -const D: usize = 2; -type C = PoseidonGoldilocksConfig; -type F = >::F; - -fn prove_single_seg_common(seg_file: &str, basedir: &str, block: &str, file: &str) { - let seg_reader = BufReader::new(File::open(seg_file).unwrap()); - let kernel = segment_kernel(basedir, block, file, seg_reader); - - let allstark: AllStark = AllStark::default(); - let config = StarkConfig::standard_fast_config(); - let mut timing = TimingTree::new("prove", log::Level::Info); - let allproof: proof::AllProof = - prove(&allstark, &kernel, &config, &mut timing).unwrap(); - let mut count_bytes = 0; - for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() { - let proof_str = serde_json::to_string(&proof.proof).unwrap(); - log::info!("row:{} proof bytes:{}", row, proof_str.len()); - count_bytes += proof_str.len(); - } - timing.filter(Duration::from_millis(100)).print(); - log::info!("total proof bytes:{}KB", count_bytes / 1024); - verify_proof(&allstark, allproof, &config).unwrap(); - log::info!("Prove done"); -} - -fn prove_multi_seg_common( - seg_dir: &str, - basedir: &str, - block: &str, - file: &str, - seg_file_number: usize, - seg_start_id: usize, -) -> anyhow::Result<()> { - type InnerParameters = DefaultParameters; - type OuterParameters = Groth16WrapperParameters; - - if seg_file_number < 2 { - panic!("seg file number must >= 2\n"); - } - - let total_timing = TimingTree::new("prove total time", log::Level::Info); - let all_stark = AllStark::::default(); - let config = StarkConfig::standard_fast_config(); - // Preprocess all circuits. - let all_circuits = - AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); - - let seg_file = format!("{}/{}", seg_dir, seg_start_id); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (mut agg_proof, mut updated_agg_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(agg_proof.clone())?; - - let mut base_seg = seg_start_id + 1; - let mut seg_num = seg_file_number - 1; - let mut is_agg = false; - - if seg_file_number % 2 == 0 { - let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader); - timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &agg_proof, - false, - &root_proof, - agg_public_values.clone(), - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&agg_proof)?; - - is_agg = true; - base_seg = seg_start_id + 2; - seg_num -= 1; - } - - for i in 0..seg_num / 2 { - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (root_proof_first, first_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(root_proof_first.clone())?; - - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let new_agg_public_values = PublicValues { - roots_before: first_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &root_proof_first, - false, - &root_proof, - new_agg_public_values, - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&new_agg_proof)?; - - // Update public values for the nested aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: new_updated_agg_public_values.roots_after, - userdata: new_updated_agg_public_values.userdata, - }; - timing = TimingTree::new("prove nested aggression", log::Level::Info); - - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - is_agg, - &agg_proof, - true, - &new_agg_proof, - agg_public_values.clone(), - )?; - is_agg = true; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_aggregation(&agg_proof)?; - } - - let (block_proof, _block_public_values) = - all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?; - - log::info!( - "proof size: {:?}", - serde_json::to_string(&block_proof.proof).unwrap().len() - ); - let result = all_circuits.verify_block(&block_proof); - - let build_path = "../verifier/data".to_string(); - let path = format!("{}/test_circuit/", build_path); - let builder = WrapperBuilder::::new(); - let mut circuit = builder.build(); - circuit.set_data(all_circuits.block.circuit); - let mut bit_size = vec![32usize; 16]; - bit_size.extend(vec![8; 32]); - bit_size.extend(vec![64; 68]); - let wrapped_circuit = WrappedCircuit::::build( - circuit, - Some((vec![], bit_size)), - ); - log::info!("build finish"); - - let wrapped_proof = wrapped_circuit.prove(&block_proof).unwrap(); - wrapped_proof.save(path).unwrap(); - - total_timing.filter(Duration::from_millis(100)).print(); - result -} +use zkm_utils::utils::prove_segments; #[derive(Debug, Clone, Deserialize, Serialize)] pub enum DataId { @@ -299,12 +88,7 @@ fn prove_sha2_go() { let value = state.read_public_values::(); log::info!("public value: {:X?}", value); - if seg_num == 1 { - let seg_file = format!("{seg_path}/{}", 0); - prove_single_seg_common(&seg_file, "", "", "") - } else { - prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap() - } + let _ = prove_segments(&seg_path, "", "", "", seg_num, 0, vec![]); } fn main() { diff --git a/prover/examples/sha2-precompile/guest/Cargo.lock b/prover/examples/sha2-precompile/guest/Cargo.lock index cbbae727..00f2f1e6 100644 --- a/prover/examples/sha2-precompile/guest/Cargo.lock +++ b/prover/examples/sha2-precompile/guest/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "bincode" @@ -254,7 +254,7 @@ dependencies = [ [[package]] name = "zkm-precompiles" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bincode", "cfg-if", @@ -264,7 +264,7 @@ dependencies = [ [[package]] name = "zkm-runtime" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bincode", "cfg-if", diff --git a/prover/examples/sha2-precompile/guest/src/main.rs b/prover/examples/sha2-precompile/guest/src/main.rs index e713a400..35cfff0d 100644 --- a/prover/examples/sha2-precompile/guest/src/main.rs +++ b/prover/examples/sha2-precompile/guest/src/main.rs @@ -9,12 +9,12 @@ zkm_runtime::entrypoint!(main); pub fn main() { let public_input: Vec = zkm_runtime::io::read(); - let input: Vec = zkm_runtime::io::read(); + let input: [u8; 32] = zkm_runtime::io::read(); let elf_id: Vec = zkm_runtime::io::read(); zkm_runtime::io::verify(elf_id, &input); let mut hasher = Sha256::new(); - hasher.update(input); + hasher.update(input.to_vec()); let result = hasher.finalize(); let output: [u8; 32] = result.into(); diff --git a/prover/examples/sha2-precompile/host/Cargo.toml b/prover/examples/sha2-precompile/host/Cargo.toml index 98d3925f..e5e78bd2 100644 --- a/prover/examples/sha2-precompile/host/Cargo.toml +++ b/prover/examples/sha2-precompile/host/Cargo.toml @@ -7,9 +7,8 @@ publish = false [dependencies] zkm-prover = { workspace = true } zkm-emulator = { workspace = true } +zkm-utils = { path = "../../utils" } plonky2 = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_util = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_maybe_rayon = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } log = { version = "0.4.14", default-features = false } serde = { version = "1.0.144", features = ["derive"] } diff --git a/prover/examples/sha2-precompile/host/src/main.rs b/prover/examples/sha2-precompile/host/src/main.rs index dd9adc59..6c2db0ee 100644 --- a/prover/examples/sha2-precompile/host/src/main.rs +++ b/prover/examples/sha2-precompile/host/src/main.rs @@ -1,39 +1,16 @@ use std::env; -use std::fs::File; -use std::io::BufReader; -use std::ops::Range; -use std::time::Duration; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; -use plonky2::util::timing::TimingTree; use zkm_emulator::utils::{load_elf_with_patch, split_prog_into_segs}; -use zkm_prover::all_stark::AllStark; -use zkm_prover::config::StarkConfig; -use zkm_prover::cpu::kernel::assembler::segment_kernel; -use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; -use zkm_prover::generation::state::{AssumptionReceipt, AssumptionReceipts, Receipt}; - -const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; +use zkm_prover::generation::state::{AssumptionReceipts, Receipt}; +use zkm_utils::utils::prove_segments; const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; -const ELF_PATH: &str = "../guest/elf/mips-zkm-zkvm-elf"; - -fn u32_array_to_u8_vec(u32_array: &[u32; 8]) -> Vec { - let mut u8_vec = Vec::with_capacity(u32_array.len() * 4); - for &item in u32_array { - u8_vec.extend_from_slice(&item.to_le_bytes()); - } - u8_vec -} - -fn prove_sha_5_precompile( - elf_path: &str, - seg_path: &str, -) -> Receipt<>::F, C, D> { +fn prove_sha_5_precompile(elf_path: &str, seg_path: &str) -> Receipt { let mut state = load_elf_with_patch(elf_path, vec![]); let n: u32 = 5; let public_input: [u8; 32] = [ @@ -50,31 +27,11 @@ fn prove_sha_5_precompile( assert!(seg_num == 1); - let all_stark = AllStark::::default(); - let config = StarkConfig::standard_fast_config(); - // Preprocess all circuits. - let all_circuits = - AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); - - let seg_file: String = format!("{}/{}", seg_path, 0); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file).unwrap()); - let input_first = segment_kernel("", "", "", seg_reader); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (agg_proof, updated_agg_public_values) = all_circuits - .prove_root(&all_stark, &input_first, &config, &mut timing) - .unwrap(); - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(agg_proof.clone()).unwrap(); - - Receipt:: { - proof: agg_proof, - root_before: u32_array_to_u8_vec(&updated_agg_public_values.roots_before.root), - userdata: updated_agg_public_values.userdata.clone(), - } + prove_segments(seg_path, "", "", "", 1, 0, vec![]).unwrap() } +const ELF_PATH: &str = "../guest/elf/mips-unknown-linux-musl"; + fn prove_sha2_precompile() { // 1. split ELF into segs let precompile_path = env::var("PRECOMPILE_PATH").expect("PRECOMPILE ELF file is missing"); @@ -84,11 +41,11 @@ fn prove_sha2_precompile() { log::info!( "elf_id: {:?}, data: {:?}", - receipt.root_before, - receipt.userdata + receipt.claim().elf_id, + receipt.claim().commit, ); - let image_id = receipt.root_before.clone(); + let image_id = receipt.claim().elf_id; receipts.push(receipt.into()); let mut state = load_elf_with_patch(&ELF_PATH, vec![]); @@ -105,7 +62,7 @@ fn prove_sha2_precompile() { 233, 189, 123, 198, 181, 39, 175, 7, 129, 62, 199, 185, 16, ]; log::info!("private input value: {:?}", private_input); - state.add_input_stream(&private_input.to_vec()); + state.add_input_stream(&private_input); state.add_input_stream(&image_id); @@ -115,39 +72,7 @@ fn prove_sha2_precompile() { log::info!("public value: {:X?}", value); log::info!("public value: {} in hex", hex::encode(value)); - let all_stark = AllStark::::default(); - let config = StarkConfig::standard_fast_config(); - // Preprocess all circuits. - let all_circuits = - AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); - - let seg_file: String = format!("{}/{}", seg_path, 0); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file).unwrap()); - let kernel = segment_kernel("", "", "", seg_reader); - - let mut timing = TimingTree::new("prove", log::Level::Info); - let (agg_proof, _updated_agg_public_values, receipts_used) = all_circuits - .prove_root_with_assumption(&all_stark, &kernel, &config, &mut timing, receipts) - .unwrap(); - - log::info!("Process assumptions"); - timing = TimingTree::new("prove aggression", log::Level::Info); - - for assumption in receipts_used.borrow_mut().iter_mut() { - let receipt = assumption.1.clone(); - match receipt { - AssumptionReceipt::Proven(receipt) => { - all_circuits.verify_root(receipt.proof.clone()).unwrap(); - } - AssumptionReceipt::Unresolved(assumpt) => { - log::error!("unresolved assumption: {:X?}", assumpt); - } - } - } - log::info!("verify"); - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(agg_proof.clone()).unwrap(); + let _ = prove_segments(&seg_path, "", "", "", 1, 0, receipts); } fn main() { diff --git a/prover/examples/sha2-rust/guest/Cargo.lock b/prover/examples/sha2-rust/guest/Cargo.lock index a4c316af..1bef5f4e 100644 --- a/prover/examples/sha2-rust/guest/Cargo.lock +++ b/prover/examples/sha2-rust/guest/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "bincode" @@ -254,7 +254,7 @@ dependencies = [ [[package]] name = "zkm-precompiles" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bincode", "cfg-if", @@ -264,7 +264,7 @@ dependencies = [ [[package]] name = "zkm-runtime" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bincode", "cfg-if", diff --git a/prover/examples/sha2-rust/host/Cargo.toml b/prover/examples/sha2-rust/host/Cargo.toml index eda86fa7..239e8a54 100644 --- a/prover/examples/sha2-rust/host/Cargo.toml +++ b/prover/examples/sha2-rust/host/Cargo.toml @@ -7,10 +7,8 @@ publish = false [dependencies] zkm-prover = { workspace = true } zkm-emulator = { workspace = true } -plonky2 = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_util = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2_maybe_rayon = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } -plonky2x = { git = "https://github.com/zkMIPS/succinctx.git", package = "plonky2x", branch = "zkm" } +zkm-utils = { path = "../../utils" } + log = { version = "0.4.14", default-features = false } serde = { version = "1.0.144", features = ["derive"] } diff --git a/prover/examples/sha2-rust/host/src/main.rs b/prover/examples/sha2-rust/host/src/main.rs index bae10e1f..be7041ed 100644 --- a/prover/examples/sha2-rust/host/src/main.rs +++ b/prover/examples/sha2-rust/host/src/main.rs @@ -1,218 +1,7 @@ use std::env; -use std::fs::File; -use std::io::BufReader; -use std::ops::Range; -use std::time::Duration; - -use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; -use plonky2::util::timing::TimingTree; -use plonky2x::backend::circuit::Groth16WrapperParameters; -use plonky2x::backend::wrapper::wrap::WrappedCircuit; -use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; -use plonky2x::prelude::DefaultParameters; use zkm_emulator::utils::{load_elf_with_patch, split_prog_into_segs}; -use zkm_prover::all_stark::AllStark; -use zkm_prover::config::StarkConfig; -use zkm_prover::cpu::kernel::assembler::segment_kernel; -use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; -use zkm_prover::proof; -use zkm_prover::proof::PublicValues; -use zkm_prover::prover::prove; -use zkm_prover::verifier::verify_proof; - -const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; - -const D: usize = 2; -type C = PoseidonGoldilocksConfig; -type F = >::F; - -fn prove_single_seg_common(seg_file: &str, basedir: &str, block: &str, file: &str) { - let seg_reader = BufReader::new(File::open(seg_file).unwrap()); - let kernel = segment_kernel(basedir, block, file, seg_reader); - - let allstark: AllStark = AllStark::default(); - let config = StarkConfig::standard_fast_config(); - let mut timing = TimingTree::new("prove", log::Level::Info); - let allproof: proof::AllProof = - prove(&allstark, &kernel, &config, &mut timing).unwrap(); - let mut count_bytes = 0; - for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() { - let proof_str = serde_json::to_string(&proof.proof).unwrap(); - log::info!("row:{} proof bytes:{}", row, proof_str.len()); - count_bytes += proof_str.len(); - } - timing.filter(Duration::from_millis(100)).print(); - log::info!("total proof bytes:{}KB", count_bytes / 1024); - verify_proof(&allstark, allproof, &config).unwrap(); - log::info!("Prove done"); -} - -fn prove_multi_seg_common( - seg_dir: &str, - basedir: &str, - block: &str, - file: &str, - seg_file_number: usize, - seg_start_id: usize, -) -> anyhow::Result<()> { - type InnerParameters = DefaultParameters; - type OuterParameters = Groth16WrapperParameters; - - if seg_file_number < 2 { - panic!("seg file number must >= 2\n"); - } - - let total_timing = TimingTree::new("prove total time", log::Level::Info); - let all_stark = AllStark::::default(); - let config = StarkConfig::standard_fast_config(); - // Preprocess all circuits. - let all_circuits = - AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); - - let seg_file = format!("{}/{}", seg_dir, seg_start_id); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (mut agg_proof, mut updated_agg_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(agg_proof.clone())?; - - let mut base_seg = seg_start_id + 1; - let mut seg_num = seg_file_number - 1; - let mut is_agg = false; - - if seg_file_number % 2 == 0 { - let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader); - timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &agg_proof, - false, - &root_proof, - agg_public_values.clone(), - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&agg_proof)?; - - is_agg = true; - base_seg = seg_start_id + 2; - seg_num -= 1; - } - - for i in 0..seg_num / 2 { - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (root_proof_first, first_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(root_proof_first.clone())?; - - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader); - let mut timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let new_agg_public_values = PublicValues { - roots_before: first_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &root_proof_first, - false, - &root_proof, - new_agg_public_values, - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&new_agg_proof)?; - - // Update public values for the nested aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: new_updated_agg_public_values.roots_after, - userdata: new_updated_agg_public_values.userdata, - }; - timing = TimingTree::new("prove nested aggression", log::Level::Info); - - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - is_agg, - &agg_proof, - true, - &new_agg_proof, - agg_public_values.clone(), - )?; - is_agg = true; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_aggregation(&agg_proof)?; - } - - let (block_proof, _block_public_values) = - all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?; - - log::info!( - "proof size: {:?}", - serde_json::to_string(&block_proof.proof).unwrap().len() - ); - let result = all_circuits.verify_block(&block_proof); - - let build_path = "../verifier/data".to_string(); - let path = format!("{}/test_circuit/", build_path); - let builder = WrapperBuilder::::new(); - let mut circuit = builder.build(); - circuit.set_data(all_circuits.block.circuit); - let mut bit_size = vec![32usize; 16]; - bit_size.extend(vec![8; 32]); - bit_size.extend(vec![64; 68]); - let wrapped_circuit = WrappedCircuit::::build( - circuit, - Some((vec![], bit_size)), - ); - log::info!("build finish"); - - let wrapped_proof = wrapped_circuit.prove(&block_proof).unwrap(); - wrapped_proof.save(path).unwrap(); - - total_timing.filter(Duration::from_millis(100)).print(); - result -} +use zkm_utils::utils::prove_segments; const ELF_PATH: &str = "../guest/elf/mips-zkm-zkvm-elf"; @@ -244,12 +33,7 @@ fn prove_sha2_rust() { log::info!("public value: {:X?}", value); log::info!("public value: {} in hex", hex::encode(value)); - if seg_num == 1 { - let seg_file = format!("{seg_path}/{}", 0); - prove_single_seg_common(&seg_file, "", "", "") - } else { - prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap() - } + let _ = prove_segments(&seg_path, "", "", "", seg_num, 0, vec![]).unwrap(); } fn main() { diff --git a/prover/examples/utils/Cargo.toml b/prover/examples/utils/Cargo.toml new file mode 100644 index 00000000..af5c4e3a --- /dev/null +++ b/prover/examples/utils/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "zkm-utils" +version = "0.2.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +zkm-prover = { path = "../../" } +plonky2 = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } +plonky2_util = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } +plonky2_maybe_rayon = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } +plonky2x = { git = "https://github.com/zkMIPS/succinctx.git", package = "plonky2x", branch = "zkm" } + +log = { version = "0.4.14", default-features = false } +serde = { version = "1.0.144", features = ["derive"] } +serde_json = "1.0" +byteorder = "1.5.0" +hex = "0.4" +env_logger = "0.11.5" +anyhow = "1.0.75" diff --git a/prover/examples/utils/src/lib.rs b/prover/examples/utils/src/lib.rs new file mode 100644 index 00000000..b5614dd8 --- /dev/null +++ b/prover/examples/utils/src/lib.rs @@ -0,0 +1 @@ +pub mod utils; diff --git a/prover/examples/utils/src/utils.rs b/prover/examples/utils/src/utils.rs new file mode 100644 index 00000000..76ea10f2 --- /dev/null +++ b/prover/examples/utils/src/utils.rs @@ -0,0 +1,159 @@ +use std::fs::File; +use std::io::BufReader; +use std::ops::Range; +use std::time::Duration; + +use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; +use plonky2::util::timing::TimingTree; +use plonky2x::backend::circuit::Groth16WrapperParameters; +use plonky2x::backend::wrapper::wrap::WrappedCircuit; +use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; +use plonky2x::prelude::DefaultParameters; + +use zkm_prover::all_stark::AllStark; +use zkm_prover::config::StarkConfig; +use zkm_prover::cpu::kernel::assembler::segment_kernel; +use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; +use zkm_prover::generation::state::{AssumptionReceipts, Receipt}; + +const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; + +const D: usize = 2; +type C = PoseidonGoldilocksConfig; +type F = >::F; + +pub fn prove_segments( + seg_dir: &str, + basedir: &str, + block: &str, + file: &str, + seg_file_number: usize, + seg_start_id: usize, + assumptions: AssumptionReceipts, +) -> anyhow::Result> { + type InnerParameters = DefaultParameters; + type OuterParameters = Groth16WrapperParameters; + + let total_timing = TimingTree::new("prove total time", log::Level::Info); + let all_stark = AllStark::::default(); + let config = StarkConfig::standard_fast_config(); + // Preprocess all circuits. + let all_circuits = + AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); + + let seg_file = format!("{}/{}", seg_dir, seg_start_id); + log::info!("Process segment {}", seg_file); + let seg_reader = BufReader::new(File::open(seg_file)?); + let input_first = segment_kernel(basedir, block, file, seg_reader); + let mut timing = TimingTree::new("prove root first", log::Level::Info); + let mut agg_receipt = all_circuits.prove_root_with_assumption( + &all_stark, + &input_first, + &config, + &mut timing, + assumptions, + )?; + + timing.filter(Duration::from_millis(100)).print(); + all_circuits.verify_root(agg_receipt.clone())?; + + let mut base_seg = seg_start_id + 1; + let mut seg_num = seg_file_number - 1; + let mut is_agg = false; + + if seg_file_number % 2 == 0 { + let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); + log::info!("Process segment {}", seg_file); + let seg_reader = BufReader::new(File::open(seg_file)?); + let input = segment_kernel(basedir, block, file, seg_reader); + timing = TimingTree::new("prove root second", log::Level::Info); + let receipt = all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; + timing.filter(Duration::from_millis(100)).print(); + + all_circuits.verify_root(receipt.clone())?; + + timing = TimingTree::new("prove aggression", log::Level::Info); + // We can duplicate the proofs here because the state hasn't mutated. + agg_receipt = all_circuits.prove_aggregation(false, &agg_receipt, false, &receipt)?; + timing.filter(Duration::from_millis(100)).print(); + all_circuits.verify_aggregation(&agg_receipt)?; + + is_agg = true; + base_seg = seg_start_id + 2; + seg_num -= 1; + } + + for i in 0..seg_num / 2 { + let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); + log::info!("Process segment {}", seg_file); + let seg_reader = BufReader::new(File::open(&seg_file)?); + let input_first = segment_kernel(basedir, block, file, seg_reader); + let mut timing = TimingTree::new("prove root first", log::Level::Info); + let root_receipt_first = + all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; + + timing.filter(Duration::from_millis(100)).print(); + all_circuits.verify_root(root_receipt_first.clone())?; + + let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); + log::info!("Process segment {}", seg_file); + let seg_reader = BufReader::new(File::open(&seg_file)?); + let input = segment_kernel(basedir, block, file, seg_reader); + let mut timing = TimingTree::new("prove root second", log::Level::Info); + let root_receipt = all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; + timing.filter(Duration::from_millis(100)).print(); + + all_circuits.verify_root(root_receipt.clone())?; + + timing = TimingTree::new("prove aggression", log::Level::Info); + // We can duplicate the proofs here because the state hasn't mutated. + let new_agg_receipt = + all_circuits.prove_aggregation(false, &root_receipt_first, false, &root_receipt)?; + timing.filter(Duration::from_millis(100)).print(); + all_circuits.verify_aggregation(&new_agg_receipt)?; + + timing = TimingTree::new("prove nested aggression", log::Level::Info); + + // We can duplicate the proofs here because the state hasn't mutated. + agg_receipt = + all_circuits.prove_aggregation(is_agg, &agg_receipt, true, &new_agg_receipt)?; + is_agg = true; + timing.filter(Duration::from_millis(100)).print(); + + all_circuits.verify_aggregation(&agg_receipt)?; + } + + log::info!( + "proof size: {:?}", + serde_json::to_string(&agg_receipt.proof().proof) + .unwrap() + .len() + ); + let final_receipt = if seg_file_number > 1 { + let block_receipt = all_circuits.prove_block(None, &agg_receipt)?; + all_circuits.verify_block(&block_receipt)?; + let build_path = "../verifier/data".to_string(); + let path = format!("{}/test_circuit/", build_path); + let builder = WrapperBuilder::::new(); + let mut circuit = builder.build(); + circuit.set_data(all_circuits.block.circuit); + let mut bit_size = vec![32usize; 16]; + bit_size.extend(vec![8; 32]); + bit_size.extend(vec![64; 68]); + let wrapped_circuit = WrappedCircuit::::build( + circuit, + Some((vec![], bit_size)), + ); + let wrapped_proof = wrapped_circuit.prove(&block_receipt.proof()).unwrap(); + wrapped_proof.save(path).unwrap(); + + block_receipt + } else { + agg_receipt + }; + + log::info!("build finish"); + + total_timing.filter(Duration::from_millis(100)).print(); + Ok(final_receipt) +} diff --git a/prover/src/fixed_recursive_verifier.rs b/prover/src/fixed_recursive_verifier.rs index 45cf7024..2269664d 100644 --- a/prover/src/fixed_recursive_verifier.rs +++ b/prover/src/fixed_recursive_verifier.rs @@ -33,18 +33,20 @@ use crate::cross_table_lookup::{ get_grand_product_challenge_set_target, verify_cross_table_lookups_circuit, CrossTableLookup, GrandProductChallengeSet, }; -use crate::generation::state::{AssumptionReceipts, AssumptionUsage}; +use crate::generation::state::{ + AssumptionReceipt, AssumptionReceipts, CompositeReceipt, InnerReceipt, Receipt, ReceiptClaim, +}; use crate::get_challenges::observe_public_values_target; use crate::proof::{MemRootsTarget, PublicValues, PublicValuesTarget, StarkProofWithMetadata}; -use crate::prover::{prove, prove_with_assumptions}; +use crate::prover::{prove_with_output_and_assumptions, prove_with_outputs}; use crate::recursive_verifier::{ add_common_recursion_gates, add_virtual_public_values, recursive_stark_circuit, set_public_value_targets, PlonkWrapperCircuit, PublicInputs, StarkWrapperCircuit, }; use crate::stark::Stark; +use crate::util::u32_array_to_u8_vec; use crate::verifier::verify_proof; //use crate::util::h256_limbs; -use std::{cell::RefCell, rc::Rc}; /// The recursion threshold. We end a chain of recursive proofs once we reach this size. const THRESHOLD_DEGREE_BITS: usize = 13; @@ -701,8 +703,8 @@ where kernel: &Kernel, config: &StarkConfig, timing: &mut TimingTree, - ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { - let all_proof = prove::(all_stark, kernel, config, timing)?; + ) -> anyhow::Result> { + let (all_proof, output) = prove_with_outputs::(all_stark, kernel, config, timing)?; verify_proof(all_stark, all_proof.clone(), config).unwrap(); let mut root_inputs = PartialWitness::new(); @@ -749,7 +751,14 @@ where let root_proof = self.root.circuit.prove(root_inputs)?; - Ok((root_proof, all_proof.public_values)) + Ok(Receipt::Segments(InnerReceipt { + proof: root_proof, + values: all_proof.public_values.clone(), + claim: ReceiptClaim { + elf_id: u32_array_to_u8_vec(&all_proof.public_values.roots_before.root), + commit: output.output.clone(), + }, + })) } pub fn prove_root_with_assumption( @@ -759,13 +768,17 @@ where config: &StarkConfig, timing: &mut TimingTree, assumptions: AssumptionReceipts, - ) -> anyhow::Result<( - ProofWithPublicInputs, - PublicValues, - Rc>>, - )> { - let (all_proof, receipts) = - prove_with_assumptions::(all_stark, kernel, config, timing, assumptions)?; + ) -> anyhow::Result> { + if assumptions.is_empty() { + return self.prove_root(all_stark, kernel, config, timing); + } + let (all_proof, output, assumption_used) = prove_with_output_and_assumptions::( + all_stark, + kernel, + config, + timing, + assumptions, + )?; verify_proof(all_stark, all_proof.clone(), config).unwrap(); let mut root_inputs = PartialWitness::new(); @@ -812,30 +825,46 @@ where let root_proof = self.root.circuit.prove(root_inputs)?; - Ok((root_proof, all_proof.public_values, receipts)) + let program_receipt = InnerReceipt { + proof: root_proof, + values: all_proof.public_values.clone(), + claim: ReceiptClaim { + elf_id: u32_array_to_u8_vec(&all_proof.public_values.roots_before.root), + commit: output.output.clone(), + }, + }; + Ok(Receipt::Composite(CompositeReceipt { + program_receipt, + assumption_used, + })) } - pub fn verify_root(&self, agg_proof: ProofWithPublicInputs) -> anyhow::Result<()> { - self.root.circuit.verify(agg_proof) + pub fn verify_root(&self, agg_receipt: Receipt) -> anyhow::Result<()> { + self.root.circuit.verify(agg_receipt.proof()) } pub fn prove_aggregation( &self, lhs_is_agg: bool, - lhs_proof: &ProofWithPublicInputs, + lhs_receipt: &Receipt, rhs_is_agg: bool, - rhs_proof: &ProofWithPublicInputs, - public_values: PublicValues, - ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { + rhs_receipt: &Receipt, + ) -> anyhow::Result> { let mut agg_inputs = PartialWitness::new(); + let public_values = PublicValues { + roots_before: lhs_receipt.values().roots_before, + roots_after: rhs_receipt.values().roots_after, + userdata: rhs_receipt.values().userdata, + }; + agg_inputs.set_bool_target(self.aggregation.lhs.is_agg, lhs_is_agg); - agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.agg_proof, lhs_proof); - agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.evm_proof, lhs_proof); + agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.agg_proof, &lhs_receipt.proof()); + agg_inputs.set_proof_with_pis_target(&self.aggregation.lhs.evm_proof, &lhs_receipt.proof()); agg_inputs.set_bool_target(self.aggregation.rhs.is_agg, rhs_is_agg); - agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.agg_proof, rhs_proof); - agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.evm_proof, rhs_proof); + agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.agg_proof, &rhs_receipt.proof()); + agg_inputs.set_proof_with_pis_target(&self.aggregation.rhs.evm_proof, &rhs_receipt.proof()); agg_inputs.set_verifier_data_target( &self.aggregation.cyclic_vk, @@ -852,16 +881,34 @@ where })?; let aggregation_proof = self.aggregation.circuit.prove(agg_inputs)?; - Ok((aggregation_proof, public_values)) + let inner = InnerReceipt { + proof: aggregation_proof, + values: public_values, + claim: ReceiptClaim { + elf_id: lhs_receipt.claim().clone().elf_id, + commit: rhs_receipt.claim().clone().commit, + }, + }; + + let assumptions = lhs_receipt.assumptions(); + for assumption in rhs_receipt.assumptions().borrow().iter() { + assumptions.borrow_mut().insert(0, assumption.clone()); + } + + if assumptions.borrow().is_empty() { + Ok(Receipt::Segments(inner)) + } else { + Ok(Receipt::Composite(CompositeReceipt { + program_receipt: inner, + assumption_used: assumptions, + })) + } } - pub fn verify_aggregation( - &self, - agg_proof: &ProofWithPublicInputs, - ) -> anyhow::Result<()> { - self.aggregation.circuit.verify(agg_proof.clone())?; + pub fn verify_aggregation(&self, receipt: &Receipt) -> anyhow::Result<()> { + self.aggregation.circuit.verify(receipt.proof())?; check_cyclic_proof_verifier_data( - agg_proof, + &receipt.proof(), &self.aggregation.circuit.verifier_only, &self.aggregation.circuit.common, ) @@ -869,37 +916,39 @@ where pub fn prove_block( &self, - opt_parent_block_proof: Option<&ProofWithPublicInputs>, - agg_root_proof: &ProofWithPublicInputs, - public_values: PublicValues, - ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { + opt_parent_block_receipt: Option<&Receipt>, + agg_root_receipt: &Receipt, + ) -> anyhow::Result> { let mut block_inputs = PartialWitness::new(); block_inputs.set_bool_target( self.block.has_parent_block, - opt_parent_block_proof.is_some(), + opt_parent_block_receipt.is_some(), ); - if let Some(parent_block_proof) = opt_parent_block_proof { - block_inputs - .set_proof_with_pis_target(&self.block.parent_block_proof, parent_block_proof); + if let Some(parent_block_receipt) = opt_parent_block_receipt { + block_inputs.set_proof_with_pis_target( + &self.block.parent_block_proof, + &parent_block_receipt.proof(), + ); } else { // Initialize `state_root_after`. let mut nonzero_pis = HashMap::new(); let state_trie_root_before_keys = 0..8; for (key, &value) in - state_trie_root_before_keys.zip_eq(&public_values.roots_before.root) + state_trie_root_before_keys.zip_eq(&agg_root_receipt.values().roots_before.root) { nonzero_pis.insert(key, F::from_canonical_u32(value)); } let state_trie_root_after_keys = 8..16; - for (key, &value) in state_trie_root_after_keys.zip_eq(&public_values.roots_before.root) + for (key, &value) in + state_trie_root_after_keys.zip_eq(&agg_root_receipt.values().roots_before.root) { nonzero_pis.insert(key, F::from_canonical_u32(value)); } - let userdata_keys = 16..16 + public_values.userdata.len(); - for (key, &value) in userdata_keys.zip_eq(&public_values.userdata) { + let userdata_keys = 16..16 + agg_root_receipt.values().userdata.len(); + for (key, &value) in userdata_keys.zip_eq(&agg_root_receipt.values().userdata) { nonzero_pis.insert(key, F::from_canonical_u8(value)); } @@ -913,24 +962,56 @@ where ); } - block_inputs.set_proof_with_pis_target(&self.block.agg_root_proof, agg_root_proof); + block_inputs + .set_proof_with_pis_target(&self.block.agg_root_proof, &agg_root_receipt.proof()); block_inputs .set_verifier_data_target(&self.block.cyclic_vk, &self.block.circuit.verifier_only); - set_public_value_targets(&mut block_inputs, &self.block.public_values, &public_values) - .map_err(|_| { - anyhow::Error::msg("Invalid conversion when setting public values targets.") - })?; + set_public_value_targets( + &mut block_inputs, + &self.block.public_values, + &agg_root_receipt.values(), + ) + .map_err(|_| { + anyhow::Error::msg("Invalid conversion when setting public values targets.") + })?; let block_proof = self.block.circuit.prove(block_inputs)?; - Ok((block_proof, public_values)) + let inner = InnerReceipt { + proof: block_proof, + values: agg_root_receipt.values(), + claim: agg_root_receipt.claim(), + }; + match agg_root_receipt { + Receipt::Segments(_receipt) => Ok(Receipt::Segments(inner)), + Receipt::Composite(receipt) => Ok(Receipt::Composite(CompositeReceipt { + program_receipt: inner, + assumption_used: receipt.assumption_used.clone(), + })), + } } - pub fn verify_block(&self, block_proof: &ProofWithPublicInputs) -> anyhow::Result<()> { - self.block.circuit.verify(block_proof.clone())?; + pub fn verify_block(&self, block_receipt: &Receipt) -> anyhow::Result<()> { + self.block.circuit.verify(block_receipt.proof())?; + match block_receipt { + Receipt::Segments(_receipt) => (), + Receipt::Composite(receipt) => { + for assumption in receipt.assumption_used.borrow_mut().iter_mut() { + let receipt = assumption.1.clone(); + match receipt { + AssumptionReceipt::::Proven(inner) => { + self.verify_root(Receipt::Segments(*inner))?; + } + AssumptionReceipt::Unresolved(assumpt) => { + log::error!("unresolved assumption: {:X?}", assumpt); + } + } + } + } + }; check_cyclic_proof_verifier_data( - block_proof, + &block_receipt.proof(), &self.block.circuit.verifier_only, &self.block.circuit.common, ) diff --git a/prover/src/generation/outputs.rs b/prover/src/generation/outputs.rs index 6b851e50..ab8a75f9 100644 --- a/prover/src/generation/outputs.rs +++ b/prover/src/generation/outputs.rs @@ -1,14 +1,13 @@ use plonky2::field::extension::Extendable; use plonky2::hash::hash_types::RichField; use plonky2::plonk::config::GenericConfig; -use std::collections::HashMap; use crate::generation::state::GenerationState; use crate::witness::errors::ProgramError; #[derive(Clone, Debug)] pub struct GenerationOutputs { - pub new_state: HashMap, + pub output: Vec, } pub(crate) fn get_outputs< @@ -16,10 +15,9 @@ pub(crate) fn get_outputs< C: GenericConfig, const D: usize, >( - _state: &mut GenerationState, + state: &mut GenerationState, ) -> Result { - // FIXME Ok(GenerationOutputs { - new_state: HashMap::new(), + output: state.public_values_stream.clone(), }) } diff --git a/prover/src/generation/state.rs b/prover/src/generation/state.rs index 93e89f32..d0fba29f 100644 --- a/prover/src/generation/state.rs +++ b/prover/src/generation/state.rs @@ -1,5 +1,6 @@ // use keccak_hash::keccak; use crate::cpu::kernel::assembler::Kernel; +use crate::proof::PublicValues; use crate::witness::errors::ProgramError; use crate::witness::memory::MemoryState; use crate::witness::state::RegistersState; @@ -13,14 +14,30 @@ use std::{cell::RefCell, rc::Rc}; pub const ZERO: [u8; 32] = [0u8; 32]; -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct Receipt, C: GenericConfig, const D: usize> { +pub(crate) struct GenerationStateCheckpoint { + pub(crate) registers: RegistersState, + pub(crate) traces: TraceCheckpoint, +} + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct Assumption { + pub claim: [u8; 32], +} + +#[derive(Clone, Debug)] +pub struct ReceiptClaim { + pub elf_id: Vec, // pre image id + pub commit: Vec, // commit info +} + +#[derive(Clone, Debug)] +pub struct InnerReceipt, C: GenericConfig, const D: usize> { pub proof: ProofWithPublicInputs, - pub root_before: Vec, - pub userdata: Vec, + pub values: PublicValues, + pub claim: ReceiptClaim, } -impl Receipt +impl InnerReceipt where F: RichField + Extendable, C: GenericConfig, @@ -28,29 +45,18 @@ where pub fn claim_digest(&self) -> [u8; 32] { let mut hasher = Sha256::new(); - hasher.update(self.root_before.clone()); - hasher.update(self.userdata.clone()); + hasher.update(self.claim.elf_id.clone()); + hasher.update(self.claim.commit.clone()); let digest: [u8; 32] = hasher.finalize().into(); digest } } -pub(crate) struct GenerationStateCheckpoint { - pub(crate) registers: RegistersState, - pub(crate) traces: TraceCheckpoint, -} - -#[derive(Clone, Debug, Eq, Hash, PartialEq)] -pub struct Assumption { - pub claim: [u8; 32], - pub control_root: [u8; 32], -} - #[derive(Clone, Debug)] pub enum AssumptionReceipt, C: GenericConfig, const D: usize> { // A [Receipt] for a proven assumption. - Proven(Box>), + Proven(Box>), // An [Assumption] that is not directly proven to be true. Unresolved(Assumption), @@ -74,13 +80,13 @@ where pub type AssumptionReceipts = Vec>; pub type AssumptionUsage = Vec<(Assumption, AssumptionReceipt)>; -impl From> for AssumptionReceipt +impl From> for AssumptionReceipt where F: RichField + Extendable, C: GenericConfig, { /// Create a proven assumption from a [Receipt]. - fn from(receipt: Receipt) -> Self { + fn from(receipt: InnerReceipt) -> Self { Self::Proven(Box::new(receipt)) } } @@ -96,6 +102,104 @@ where } } +#[derive(Clone, Debug)] +pub struct CompositeReceipt< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + pub program_receipt: InnerReceipt, + pub assumption_used: Rc>>, +} + +impl CompositeReceipt +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub fn claim_digest(&self) -> [u8; 32] { + let mut hasher = Sha256::new(); + + hasher.update(self.program_receipt.claim.elf_id.clone()); + hasher.update(self.program_receipt.claim.commit.clone()); + let digest: [u8; 32] = hasher.finalize().into(); + digest + } +} + +#[derive(Clone, Debug)] +pub enum Receipt, C: GenericConfig, const D: usize> { + Segments(InnerReceipt), + Composite(CompositeReceipt), +} + +impl Receipt +where + F: RichField + Extendable, + C: GenericConfig, +{ + pub fn claim_digest(&self) -> [u8; 32] { + match self { + Self::Segments(receipt) => receipt.claim_digest(), + Self::Composite(receipt) => receipt.claim_digest(), + } + } + + pub fn proof(&self) -> ProofWithPublicInputs { + match self { + Self::Segments(receipt) => receipt.proof.clone(), + Self::Composite(receipt) => receipt.program_receipt.proof.clone(), + } + } + + pub fn values(&self) -> PublicValues { + match self { + Self::Segments(receipt) => receipt.values.clone(), + Self::Composite(receipt) => receipt.program_receipt.values.clone(), + } + } + + pub fn claim(&self) -> ReceiptClaim { + match self { + Self::Segments(receipt) => receipt.claim.clone(), + Self::Composite(receipt) => receipt.program_receipt.claim.clone(), + } + } + + pub fn assumptions(&self) -> Rc>> { + match self { + Self::Segments(_receipt) => Rc::new(RefCell::new(Vec::new())), + Self::Composite(receipt) => receipt.assumption_used.clone(), + } + } +} + +impl From> for InnerReceipt +where + F: RichField + Extendable, + C: GenericConfig, +{ + /// Create a proven assumption from a [Receipt]. + fn from(receipt: Receipt) -> Self { + match receipt { + Receipt::::Segments(segments_receipt) => segments_receipt, + Receipt::::Composite(composite_receipt) => composite_receipt.program_receipt, + } + } +} + +impl From> for AssumptionReceipt +where + F: RichField + Extendable, + C: GenericConfig, +{ + /// Create a proven assumption from a [Receipt]. + fn from(receipt: Receipt) -> Self { + let inner: InnerReceipt = receipt.into(); + inner.into() + } +} + #[derive(Clone)] pub(crate) struct GenerationState< F: RichField + Extendable, @@ -147,7 +251,6 @@ where pub(crate) fn find_assumption( &self, claim_digest: &[u8; 32], - control_root: &[u8; 32], ) -> Option<(Assumption, AssumptionReceipt)> { for assumption_receipt in self.assumptions.borrow().iter() { let cached_claim_digest = assumption_receipt.claim_digest(); @@ -163,7 +266,6 @@ where return Some(( Assumption { claim: *claim_digest, - control_root: *control_root, }, assumption_receipt.clone(), )); diff --git a/prover/src/util.rs b/prover/src/util.rs index 251e704c..97b629c1 100644 --- a/prover/src/util.rs +++ b/prover/src/util.rs @@ -1,6 +1,5 @@ -use std::mem::{size_of, transmute_copy, ManuallyDrop}; - use itertools::Itertools; +use std::mem::{size_of, transmute_copy, ManuallyDrop}; use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; @@ -61,3 +60,11 @@ pub(crate) unsafe fn transmute_no_compile_time_size_checks(value: T) -> U // Copy the bit pattern. The original value is no longer safe to use. transmute_copy(&value) } + +pub fn u32_array_to_u8_vec(u32_array: &[u32; 8]) -> Vec { + let mut u8_vec = Vec::with_capacity(u32_array.len() * 4); + for &item in u32_array { + u8_vec.extend_from_slice(&item.to_le_bytes()); + } + u8_vec +} diff --git a/prover/src/witness/operation.rs b/prover/src/witness/operation.rs index c33e4f15..f2d43901 100644 --- a/prover/src/witness/operation.rs +++ b/prover/src/witness/operation.rs @@ -987,7 +987,7 @@ pub(crate) fn verify, C: GenericConfig, c addr: usize, size: usize, ) -> Result<()> { - assert!(size == 64); + assert!(size == 32); let mut claim_digest = [0u8; 32]; { let mut cpu_row = CpuColumnsView::default(); @@ -1001,22 +1001,9 @@ pub(crate) fn verify, C: GenericConfig, c state.traces.push_cpu(cpu_row); } - let mut control_root = [0u8; 32]; - { - let mut cpu_row = CpuColumnsView::default(); - cpu_row.clock = F::from_canonical_usize(state.traces.clock()); - for i in 0..8 { - let address = MemoryAddress::new(0, Segment::Code, addr + 32 + i * 4); - let (mem, op) = mem_read_gp_with_log_and_fill(i, address, state, &mut cpu_row); - control_root[i * 4..i * 4 + 4].copy_from_slice(mem.to_be_bytes().as_ref()); - state.traces.push_memory(op); - } - state.traces.push_cpu(cpu_row); - } + log::debug!("SYS_VERIFY: ({:?})", claim_digest); - log::debug!("SYS_VERIFY: ({:?}, {:?})", claim_digest, control_root); - - let assumption = state.find_assumption(&claim_digest, &control_root); + let assumption = state.find_assumption(&claim_digest); // Mark the assumption as accessed, pushing it to the head of the list, and return the success code. match assumption { @@ -1073,6 +1060,38 @@ pub(crate) fn load_input< Ok(()) } +pub(crate) fn commit, C: GenericConfig, const D: usize>( + state: &mut GenerationState, + addr: usize, + size: usize, +) -> Result<()> { + let map_addr = addr; + + let mut cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + let mut j = 0; + for i in (0..size).step_by(4) { + if j == 8 { + state.traces.push_cpu(cpu_row); + cpu_row = CpuColumnsView::default(); + cpu_row.clock = F::from_canonical_usize(state.traces.clock()); + j = 0; + } + + // Get each byte in the chunk + let addr = MemoryAddress::new(0, Segment::Code, map_addr + i); + let (data, mem_op) = mem_read_gp_with_log_and_fill(j, addr, state, &mut cpu_row); + state.traces.push_memory(mem_op); + let len = if i + 3 >= size { size - i } else { 4 }; + state + .public_values_stream + .extend_from_slice(&data.to_be_bytes()[..len]); + j += 1; + } + state.traces.push_cpu(cpu_row); + Ok(()) +} + pub(crate) fn generate_syscall< F: RichField + Extendable, C: GenericConfig, @@ -1091,6 +1110,7 @@ pub(crate) fn generate_syscall< let mut is_load_preimage = false; let mut is_load_input = false; let mut is_verify = false; + let mut is_commit = false; let result = match sys_num { SYSGETPID => { row.general.syscall_mut().sysnum[0] = F::ONE; @@ -1178,11 +1198,17 @@ pub(crate) fn generate_syscall< row.general.syscall_mut().sysnum[6] = F::ONE; match a0 { // fdStdout - FD_STDOUT | FD_STDERR | FD_PUBLIC_VALUES | FD_HINT => { + FD_STDOUT | FD_STDERR | FD_HINT => { row.general.syscall_mut().a0[1] = F::ONE; row.general.syscall_mut().cond[7] = F::ONE; v0 = a2; } // fdStdout + FD_PUBLIC_VALUES => { + row.general.syscall_mut().a0[1] = F::ONE; + row.general.syscall_mut().cond[7] = F::ONE; + is_commit = true; + v0 = a2; + } _ => { row.general.syscall_mut().a0[2] = F::ONE; row.general.syscall_mut().cond[6] = F::ONE; @@ -1262,6 +1288,9 @@ pub(crate) fn generate_syscall< if is_verify { let _ = verify(state, a1, a2); } + if is_commit { + let _ = commit(state, a1, a2); + } result } diff --git a/runtime/entrypoint/Cargo.toml b/runtime/entrypoint/Cargo.toml index 6a3d16ee..1fb12042 100644 --- a/runtime/entrypoint/Cargo.toml +++ b/runtime/entrypoint/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "zkm-runtime" -version = "0.1.0" -edition = "2021" +version = { workspace = true } +edition = { workspace = true } [dependencies] zkm-precompiles = { path = "../precompiles" } diff --git a/runtime/entrypoint/src/syscalls/io.rs b/runtime/entrypoint/src/syscalls/io.rs index 5ba44363..93a6b66f 100644 --- a/runtime/entrypoint/src/syscalls/io.rs +++ b/runtime/entrypoint/src/syscalls/io.rs @@ -64,10 +64,9 @@ pub extern "C" fn syscall_hint_read(ptr: *mut u8, len: usize) { #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_verify(claim_digest: &[u8; 32], control_root: &[u8; 32]) { - let mut to_host = [0u8; 64]; +pub extern "C" fn syscall_verify(claim_digest: &[u8; 32]) { + let mut to_host = [0u8; 32]; to_host[..32].copy_from_slice(claim_digest); - to_host[32..].copy_from_slice(control_root); cfg_if::cfg_if! { if #[cfg(target_os = "zkvm")] { @@ -76,7 +75,7 @@ pub extern "C" fn syscall_verify(claim_digest: &[u8; 32], control_root: &[u8; 32 "syscall", in("$2") crate::syscalls::VERIFY, in("$5") to_host.as_ptr() as u32, - in("$6") 64u32, + in("$6") 32u32, ) } } else { diff --git a/runtime/precompiles/Cargo.toml b/runtime/precompiles/Cargo.toml index 1b82dc1b..cd882b15 100644 --- a/runtime/precompiles/Cargo.toml +++ b/runtime/precompiles/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "zkm-precompiles" -version = "0.1.0" -edition = "2021" +version = { workspace = true } +edition = { workspace = true } [dependencies] bincode = "1.3.3" diff --git a/runtime/precompiles/src/io.rs b/runtime/precompiles/src/io.rs index 42ad325a..e213be53 100644 --- a/runtime/precompiles/src/io.rs +++ b/runtime/precompiles/src/io.rs @@ -80,16 +80,12 @@ pub fn verify(image_id: Vec, public_input: &T) { let mut buf = Vec::new(); bincode::serialize_into(&mut buf, public_input).expect("serialization failed"); - let mut hasher = Sha256::new(); - hasher.update(buf); - let input_digest: [u8; 32] = hasher.finalize().into(); - let mut hasher = Sha256::new(); hasher.update(image_id); - hasher.update(input_digest); + hasher.update(buf); let digest: [u8; 32] = hasher.finalize().into(); - unsafe { syscall_verify(&digest, &ZERO) } + unsafe { syscall_verify(&digest) } } pub fn hint_slice(buf: &[u8]) { diff --git a/runtime/precompiles/src/lib.rs b/runtime/precompiles/src/lib.rs index 5f61ef0d..17774d2d 100644 --- a/runtime/precompiles/src/lib.rs +++ b/runtime/precompiles/src/lib.rs @@ -16,5 +16,5 @@ extern "C" { pub fn syscall_hint_len() -> usize; pub fn syscall_hint_read(ptr: *mut u8, len: usize); pub fn sys_alloc_aligned(bytes: usize, align: usize) -> *mut u8; - pub fn syscall_verify(claim_digest: &[u8; 32], control_root: &[u8; 32]); + pub fn syscall_verify(claim_digest: &[u8; 32]); }