Skip to content

Commit

Permalink
test: add support for generate aggregation prove for total minigeth (#75
Browse files Browse the repository at this point in the history
)

* test: add support for generate aggregation prove for total minigeth

* fix cargo fmt

* fix problem when segment number = 3
  • Loading branch information
weilzkm authored Jan 30, 2024
1 parent 5f26d6a commit b5ab48c
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 2 deletions.
7 changes: 7 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@ BASEDIR=test-vectors RUST_LOG=trace BLOCK_NO=13284491 SEG_FILE="/tmp/output/0" S
BASEDIR=test-vectors RUST_LOG=trace BLOCK_NO=13284491 SEG_FILE="/tmp/output/0" SEG_FILE2="/tmp/output/1" SEG_SIZE=262144 \
cargo run --release --example zkmips aggregate_proof
```

* Aggregate proof all

```
BASEDIR=test-vectors RUST_LOG=trace BLOCK_NO=13284491 SEG_FILE_DIR="/tmp/output" SEG_FILE_NUM=55 SEG_SIZE=262144 \
cargo run --release --example zkmips aggregate_proof_all
```
146 changes: 144 additions & 2 deletions examples/zkmips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ fn main() {
let args: Vec<String> = env::args().collect();
let helper = || {
println!(
"Help: {} split | prove | aggregate_proof | prove_groth16",
"Help: {} split | prove | aggregate_proof | aggregate_proof_all | prove_groth16",
args[0]
);
std::process::exit(-1);
Expand All @@ -108,6 +108,7 @@ fn main() {
"split" => split_elf_into_segs(),
"prove" => prove_single_seg(),
"aggregate_proof" => aggregate_proof().unwrap(),
"aggregate_proof_all" => aggregate_proof_all().unwrap(),
"prove_groth16" => prove_groth16(),
_ => helper(),
};
Expand All @@ -131,7 +132,7 @@ fn aggregate_proof() -> anyhow::Result<()> {
// Preprocess all circuits.
let all_circuits = AllRecursiveCircuits::<F, C, D>::new(
&all_stark,
&[10..20, 15..22, 14..19, 9..17, 12..20, 15..22],
&[10..20, 15..22, 14..19, 9..17, 12..20, 15..23],
&config,
);

Expand Down Expand Up @@ -175,3 +176,144 @@ fn aggregate_proof() -> anyhow::Result<()> {
);
all_circuits.verify_block(&block_proof)
}

fn aggregate_proof_all() -> anyhow::Result<()> {
type F = GoldilocksField;
const D: usize = 2;
type C = PoseidonGoldilocksConfig;

let basedir = env::var("BASEDIR").unwrap_or("/tmp/cannon".to_string());
let block = env::var("BLOCK_NO").expect("Block number is missing");
let file = env::var("BLOCK_FILE").unwrap_or(String::from(""));
let seg_dir = env::var("SEG_FILE_DIR").expect("segment file dir is missing");
let seg_file_number = env::var("SEG_FILE_NUM").expect("The segment file number is missing");
let seg_file_number = seg_file_number.parse::<_>().unwrap_or(2usize);
let seg_size = env::var("SEG_SIZE").unwrap_or(format!("{SEGMENT_STEPS}"));
let seg_size = seg_size.parse::<_>().unwrap_or(SEGMENT_STEPS);

if seg_file_number < 2 {
panic!("seg file number must >= 2\n");
}

let total_timing = TimingTree::new("prove total time", log::Level::Info);
let all_stark = AllStark::<F, D>::default();
let config = StarkConfig::standard_fast_config();
// Preprocess all circuits.
let all_circuits = AllRecursiveCircuits::<F, C, D>::new(
&all_stark,
&[10..20, 15..22, 14..19, 9..17, 12..20, 15..23],
&config,
);

let seg_file = format!("{}/{}", seg_dir, 0);
let input_first = segment_kernel(&basedir, &block, &file, &seg_file, seg_size);
let mut timing = TimingTree::new("prove root first", log::Level::Info);
let (mut agg_proof, mut updated_agg_public_values) =
all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?;

timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_root(agg_proof.clone())?;

let mut base_seg = 1;
let mut is_agg = false;

if seg_file_number % 2 == 0 {
let seg_file = format!("{}/{}", seg_dir, 1);
let input = segment_kernel(&basedir, &block, &file, &seg_file, seg_size);
timing = TimingTree::new("prove root second", log::Level::Info);
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
timing.filter(Duration::from_millis(100)).print();

all_circuits.verify_root(root_proof.clone())?;

// Update public values for the aggregation.
let agg_public_values = PublicValues {
roots_before: updated_agg_public_values.roots_before,
roots_after: public_values.roots_after,
};
timing = TimingTree::new("prove aggression", log::Level::Info);
// We can duplicate the proofs here because the state hasn't mutated.
(agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation(
false,
&agg_proof,
false,
&root_proof,
agg_public_values.clone(),
)?;
timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_aggregation(&agg_proof)?;

is_agg = true;
base_seg = 2;
}

for i in 0..(seg_file_number - base_seg) / 2 {
let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1));
let input_first = segment_kernel(&basedir, &block, &file, &seg_file, seg_size);
let mut timing = TimingTree::new("prove root first", log::Level::Info);
let (root_proof_first, first_public_values) =
all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?;

timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_root(root_proof_first.clone())?;

let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1);
let input = segment_kernel(&basedir, &block, &file, &seg_file, seg_size);
let mut timing = TimingTree::new("prove root second", log::Level::Info);
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
timing.filter(Duration::from_millis(100)).print();

all_circuits.verify_root(root_proof.clone())?;

// Update public values for the aggregation.
let new_agg_public_values = PublicValues {
roots_before: first_public_values.roots_before,
roots_after: public_values.roots_after,
};
timing = TimingTree::new("prove aggression", log::Level::Info);
// We can duplicate the proofs here because the state hasn't mutated.
let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation(
false,
&root_proof_first,
false,
&root_proof,
new_agg_public_values,
)?;
timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_aggregation(&new_agg_proof)?;

// Update public values for the nested aggregation.
let agg_public_values = PublicValues {
roots_before: updated_agg_public_values.roots_before,
roots_after: new_updated_agg_public_values.roots_after,
};
timing = TimingTree::new("prove nested aggression", log::Level::Info);

// We can duplicate the proofs here because the state hasn't mutated.
(agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation(
is_agg,
&agg_proof,
true,
&new_agg_proof,
agg_public_values.clone(),
)?;
is_agg = true;
timing.filter(Duration::from_millis(100)).print();

all_circuits.verify_aggregation(&agg_proof)?;
}

let (block_proof, _block_public_values) =
all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?;

log::info!(
"proof size: {:?}",
serde_json::to_string(&block_proof.proof).unwrap().len()
);
let result = all_circuits.verify_block(&block_proof);

total_timing.filter(Duration::from_millis(100)).print();
result
}

0 comments on commit b5ab48c

Please sign in to comment.