Skip to content

Commit

Permalink
feat: read public input from input stream instead (#177)
Browse files Browse the repository at this point in the history
* feat: read public input from input stream instead

* fix warnings

* fix input for common program

* fix clippy

* use hash of first input as public input

* fix clippy
  • Loading branch information
weilzkm authored Oct 24, 2024
1 parent 1122bb7 commit 415b293
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 32 deletions.
1 change: 1 addition & 0 deletions prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ hashbrown = { version = "0.14.0", default-features = false, features = ["ahash",
lazy_static = "1.4.0"

elf = { version = "0.7", default-features = false }
sha2 = { version = "0.10.8", default-features = false }


[dev-dependencies]
Expand Down
19 changes: 17 additions & 2 deletions prover/examples/zkmips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,24 @@ fn split_segments() {
let seg_size = env::var("SEG_SIZE").unwrap_or(format!("{SEGMENT_STEPS}"));
let seg_size = seg_size.parse::<_>().unwrap_or(SEGMENT_STEPS);
let args = env::var("ARGS").unwrap_or("".to_string());
let args = args.split_whitespace().collect();
// assume the first arg is the hash output(which is a public input), and the others are the input.
let args: Vec<&str> = args.split_whitespace().collect();
let mut state = load_elf_with_patch(&elf_path, vec![]);

if !args.is_empty() {
let public_input: Vec<u8> = args[0].as_bytes().to_vec();
log::info!("public input value {:X?}", public_input);
state.add_input_stream(&public_input);
}

if args.len() > 1 {
for (i, arg) in args.iter().enumerate().skip(1) {
let private_input = arg.as_bytes().to_vec();
log::info!("private input value {}: {:X?}", i, private_input);
state.add_input_stream(&private_input);
}
}

let mut state = load_elf_with_patch(&elf_path, args);
let block_path = get_block_path(&basedir, &block_no, "");
if !block_no.is_empty() {
state.load_input(&block_path);
Expand Down
37 changes: 10 additions & 27 deletions prover/src/cpu/kernel/assembler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use super::elf::Program;
use zkm_emulator::utils::get_block_path;

use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::{collections::HashMap, io::Read};
use zkm_emulator::memory::INIT_SP;

#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct Kernel {
Expand Down Expand Up @@ -49,33 +49,16 @@ impl Kernel {
.find_map(|(k, v)| (*v == offset).then(|| k.clone()))
}

/// Read public input from memory at page INIT_SP
/// Read public input from input stream index 0
pub fn read_public_inputs(&self) -> Vec<u8> {
let arg_size = self.program.image.get(&INIT_SP).unwrap();
if *arg_size == 0 {
return vec![];
if let Some(first) = self.program.input_stream.first() {
// bincode::deserialize::<Vec<u8>>(first).expect("deserialization failed")
let mut hasher = Sha256::new();
hasher.update(first);
let result = hasher.finalize();
result.to_vec()
} else {
vec![]
}

let paddr = INIT_SP + 4;
let daddr = self.program.image.get(&paddr).unwrap();
log::trace!("Try read input at {}", daddr.to_be());
let mut args = vec![];
let mut value_addr = daddr.to_be();
let mut b = false;
while !b {
let value = self.program.image.get(&value_addr).unwrap();
let bytes = value.to_le_bytes();
for c in bytes.iter() {
if *c != 0 {
args.push(*c)
} else {
b = true;
break;
}
}
value_addr += 4;
}
log::trace!("Read public input: {:?}", args);
args
}
}
9 changes: 6 additions & 3 deletions prover/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ pub fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
// Execute the trace record

// Generate the public values and outputs
let mut userdata = kernel.read_public_inputs();
assert!(userdata.len() <= NUM_PUBLIC_INPUT_USERDATA);
userdata.resize(NUM_PUBLIC_INPUT_USERDATA, 0u8);
// let mut userdata = kernel.read_public_inputs();
// assert!(userdata.len() <= NUM_PUBLIC_INPUT_USERDATA);
// userdata.resize(NUM_PUBLIC_INPUT_USERDATA, 0u8);
let userdata = kernel.read_public_inputs();

assert!(userdata.len() == NUM_PUBLIC_INPUT_USERDATA);

let public_values = PublicValues {
roots_before: MemRoots {
Expand Down

0 comments on commit 415b293

Please sign in to comment.