Skip to content

Commit

Permalink
Fix the number of segments (#170)
Browse files Browse the repository at this point in the history
* fix: fix the computation of segment number

* fix segfile for single segment
  • Loading branch information
weilzkm authored Sep 27, 2024
1 parent 80a0f88 commit 58f7960
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 71 deletions.
2 changes: 1 addition & 1 deletion emulator/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ pub struct InstrumentedState {
/// writer for stderr
stderr_writer: Box<dyn Write>,

pre_segment_id: u32,
pub pre_segment_id: u32,
pre_pc: u32,
pre_image_id: [u8; 32],
pre_hash_root: [u8; 32],
Expand Down
3 changes: 2 additions & 1 deletion emulator/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn split_prog_into_segs(
seg_path: &str,
block_path: &str,
seg_size: usize,
) -> (usize, Box<State>) {
) -> (usize, usize, Box<State>) {
let mut instrumented_state = InstrumentedState::new(state, block_path.to_string());
std::fs::create_dir_all(seg_path).unwrap();
let new_writer = |_: &str| -> Option<std::fs::File> { None };
Expand All @@ -50,6 +50,7 @@ pub fn split_prog_into_segs(
instrumented_state.dump_memory();
(
instrumented_state.state.total_step as usize,
instrumented_state.pre_segment_id as usize,
instrumented_state.state,
)
}
79 changes: 26 additions & 53 deletions prover/examples/zkmips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,9 @@ fn split_segments() {
let _ = split_prog_into_segs(state, &seg_path, &block_path, seg_size);
}

fn prove_single_seg_common(
seg_file: &str,
basedir: &str,
block: &str,
file: &str,
seg_size: usize,
) {
fn prove_single_seg_common(seg_file: &str, basedir: &str, block: &str, file: &str) {
let seg_reader = BufReader::new(File::open(seg_file).unwrap());
let kernel = segment_kernel(basedir, block, file, seg_reader, seg_size);
let kernel = segment_kernel(basedir, block, file, seg_reader);

const D: usize = 2;
type C = PoseidonGoldilocksConfig;
Expand Down Expand Up @@ -82,7 +76,6 @@ fn prove_multi_seg_common(
basedir: &str,
block: &str,
file: &str,
seg_size: usize,
seg_file_number: usize,
seg_start_id: usize,
) -> anyhow::Result<()> {
Expand All @@ -107,7 +100,7 @@ fn prove_multi_seg_common(
let seg_file = format!("{}/{}", seg_dir, seg_start_id);
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(seg_file)?);
let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size);
let input_first = segment_kernel(basedir, block, file, seg_reader);
let mut timing = TimingTree::new("prove root first", log::Level::Info);
let (mut agg_proof, mut updated_agg_public_values) =
all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?;
Expand All @@ -123,7 +116,7 @@ fn prove_multi_seg_common(
let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1);
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(seg_file)?);
let input = segment_kernel(basedir, block, file, seg_reader, seg_size);
let input = segment_kernel(basedir, block, file, seg_reader);
timing = TimingTree::new("prove root second", log::Level::Info);
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
Expand Down Expand Up @@ -158,7 +151,7 @@ fn prove_multi_seg_common(
let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1));
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(&seg_file)?);
let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size);
let input_first = segment_kernel(basedir, block, file, seg_reader);
let mut timing = TimingTree::new("prove root first", log::Level::Info);
let (root_proof_first, first_public_values) =
all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?;
Expand All @@ -169,7 +162,7 @@ fn prove_multi_seg_common(
let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1);
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(&seg_file)?);
let input = segment_kernel(basedir, block, file, seg_reader, seg_size);
let input = segment_kernel(basedir, block, file, seg_reader);
let mut timing = TimingTree::new("prove root second", log::Level::Info);
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
Expand Down Expand Up @@ -270,17 +263,18 @@ fn prove_sha2_rust() {
log::info!("private input value: {:X?}", private_input);
state.add_input_stream(&private_input);

let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);
let (_total_steps, seg_num, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);

let value = state.read_public_values::<[u8; 32]>();
log::info!("public value: {:X?}", value);
log::info!("public value: {} in hex", hex::encode(value));

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
if seg_num == 1 {
let seg_file = format!("{seg_path}/{}", 0);
prove_single_seg_common(&seg_file, "", "", "")
} else {
prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap()
}
prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
}

fn prove_sha2_go() {
Expand Down Expand Up @@ -312,17 +306,17 @@ fn prove_sha2_go() {
);
log::info!("public input: {:X?}", data);

let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);
let (_total_steps, seg_num, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);

let value = state.read_public_values::<Data>();
log::info!("public value: {:X?}", value);

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
if seg_num == 1 {
let seg_file = format!("{seg_path}/{}", 0);
prove_single_seg_common(&seg_file, "", "", "")
} else {
prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap()
}

prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
}

fn prove_revm() {
Expand All @@ -340,18 +334,13 @@ fn prove_revm() {
// load input
state.add_input_stream(&data);

let (total_steps, mut _state) = split_prog_into_segs(state, &seg_path, "", seg_size);

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}
let (_total_steps, seg_num, mut _state) = split_prog_into_segs(state, &seg_path, "", seg_size);

if seg_num == 1 {
let seg_file = format!("{seg_path}/{}", 0);
prove_single_seg_common(&seg_file, "", "", "", total_steps)
prove_single_seg_common(&seg_file, "", "", "")
} else {
prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap()
}
}

Expand Down Expand Up @@ -423,21 +412,16 @@ fn prove_add_example() {
);
log::info!("public input: {:X?}", data);

let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);
let (_total_steps, seg_num, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);

let value = state.read_public_values::<Data>();
log::info!("public value: {:X?}", value);

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}

if seg_num == 1 {
let seg_file = format!("{seg_path}/{}", 0);
prove_single_seg_common(&seg_file, "", "", "", total_steps)
prove_single_seg_common(&seg_file, "", "", "")
} else {
prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap()
}
}

Expand All @@ -461,23 +445,12 @@ fn prove_segments() {
let seg_num = seg_num.parse::<_>().unwrap_or(1usize);
let seg_start_id = env::var("SEG_START_ID").unwrap_or("0".to_string());
let seg_start_id = seg_start_id.parse::<_>().unwrap_or(0usize);
let seg_size = env::var("SEG_SIZE").unwrap_or(format!("{SEGMENT_STEPS}"));
let seg_size = seg_size.parse::<_>().unwrap_or(SEGMENT_STEPS);

if seg_num == 1 {
let seg_file = format!("{seg_dir}/{}", seg_start_id);
prove_single_seg_common(&seg_file, &basedir, &block, &file, seg_size)
prove_single_seg_common(&seg_file, &basedir, &block, &file)
} else {
prove_multi_seg_common(
&seg_dir,
&basedir,
&block,
&file,
seg_size,
seg_num,
seg_start_id,
)
.unwrap()
prove_multi_seg_common(&seg_dir, &basedir, &block, &file, seg_num, seg_start_id).unwrap()
}
}

Expand Down
16 changes: 1 addition & 15 deletions prover/src/cpu/kernel/assembler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,19 @@ pub struct Kernel {
// should be preprocessed after loading code
pub(crate) global_labels: HashMap<String, usize>,
pub blockpath: String,
pub steps: usize,
}

pub const MAX_MEM: u32 = 0x80000000;

pub fn segment_kernel<T: Read>(
basedir: &str,
block: &str,
file: &str,
seg_reader: T,
steps: usize,
) -> Kernel {
pub fn segment_kernel<T: Read>(basedir: &str, block: &str, file: &str, seg_reader: T) -> Kernel {
let p: Program = Program::load_segment(seg_reader).unwrap();
let blockpath = get_block_path(basedir, block, file);

let mut final_step = steps;
if p.step != 0 {
assert!(p.step <= steps);
final_step = p.step;
}

Kernel {
program: p,
ordered_labels: vec![],
global_labels: HashMap::new(),
blockpath,
steps: final_step,
}
}

Expand Down
2 changes: 1 addition & 1 deletion prover/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
// Decode the trace record
// 1. Decode instruction and fill in cpu columns
// 2. Decode memory and fill in memory columns
let mut state = GenerationState::<F>::new(kernel.steps, kernel).unwrap();
let mut state = GenerationState::<F>::new(kernel.program.step, kernel).unwrap();
generate_bootstrap_kernel::<F>(&mut state, kernel);

timed!(timing, "simulate CPU", simulate_cpu(&mut state, kernel)?);
Expand Down

0 comments on commit 58f7960

Please sign in to comment.