Skip to content

Commit

Permalink
Fix branch constraint (#47)
Browse files Browse the repository at this point in the history
* Fix bugs in constraints for branch instructions

* restore steps number

* fix fmt
  • Loading branch information
weilzkm authored Nov 23, 2023
1 parent a24e874 commit 0452441
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 35 deletions.
110 changes: 78 additions & 32 deletions src/cpu/jumps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ pub fn eval_packed_branch<P: PackedField>(
let is_ne = lv.opcode_bits[2] * (P::ONES - lv.opcode_bits[1]) * lv.opcode_bits[0];
let is_le = lv.opcode_bits[2] * lv.opcode_bits[1] * (P::ONES - lv.opcode_bits[0]);
let is_gt = lv.opcode_bits[2] * lv.opcode_bits[1] * lv.opcode_bits[0];
let is_ge = (P::ONES - lv.opcode_bits[2]) * lv.rt_bits[0];
let is_lt = (P::ONES - lv.opcode_bits[2]) * (P::ONES - lv.rt_bits[0]);
let is_ge = special_filter * lv.rt_bits[0];
let is_lt = special_filter * (P::ONES - lv.rt_bits[0]);
let overflow = P::Scalar::from_canonical_u64(1 << 32);
let overflow_inv = P::Scalar::from_canonical_u64(GOLDILOCKS_INVERSE_2EXP32);

Expand All @@ -273,10 +273,10 @@ pub fn eval_packed_branch<P: PackedField>(

// Check `branch target value`:
// constraints:
// * jump_dest = sign_extended(offset << 2) + pc
// * jump_dest = sign_extended(offset << 2) + pc + 4
// * filter * should_jump *(next_program_coutner - jump_dest) * (next_program_coutner + 1 << 32 - jump_dest) == 0
// * next_addr = pc + 8
// * next_pc = jump_dest * should_jump + next_addr * (1 - should_jump)
// * filter * (next_program_coutner - next_pc) == 0
// * filter * (1 - should_jump) * (next_program_coutner - next_pc) == 0
{
let mut branch_offset = [P::ZEROS; 32];

Expand All @@ -286,21 +286,33 @@ pub fn eval_packed_branch<P: PackedField>(
branch_offset[18..32].copy_from_slice(&[lv.rd_bits[4]; 14]); // lv.insn_bits[15]

let offset_dst = limb_from_bits_le(branch_offset.into_iter());
let branch_dst = lv.program_counter + offset_dst;
let branch_dst = lv.program_counter + P::Scalar::from_canonical_u8(4) + offset_dst;
yield_constr.constraint(
filter
* jumps_lv.should_jump
* (nv.program_counter - branch_dst)
* (nv.program_counter + overflow - branch_dst),
);

let next_inst = lv.program_counter + P::Scalar::from_canonical_u64(8);
let branch_dst =
jumps_lv.should_jump * branch_dst + next_inst * (P::ONES - jumps_lv.should_jump);
yield_constr.constraint(filter * (nv.program_counter - branch_dst));
yield_constr.constraint(
filter * (P::ONES - jumps_lv.should_jump) * (nv.program_counter - next_inst),
);
}

// Check Aux Reg
// constraint:
// * sum = aux1 + aux2
// * filter * (1 - sum * overflow_inv) == 0
// * filter * aux1 * (1 - sum * overflow_inv) == 0
// * filter * aux3 * (1 - aux3) == 0
{
let aux1 = lv.mem_channels[2].addr_virtual;
let aux2 = lv.mem_channels[3].addr_virtual;
yield_constr.constraint(filter * (P::ONES - (aux1 + aux2) * overflow_inv));
let aux1: P = lv.mem_channels[2].value;
let aux2 = lv.mem_channels[3].value;
let aux3 = lv.mem_channels[4].value;

yield_constr.constraint(filter * aux1 * (P::ONES - (aux1 + aux2) * overflow_inv));

yield_constr.constraint(filter * aux3 * (P::ONES - aux3));
}

// Check rs Reg
Expand All @@ -326,6 +338,7 @@ pub fn eval_packed_branch<P: PackedField>(
let src2 = lv.mem_channels[1].value;
let aux1 = lv.mem_channels[2].value;
let aux2 = lv.mem_channels[3].value;
let aux3 = lv.mem_channels[4].value;

// constraints:
// * z = src2 + aux - src1
Expand Down Expand Up @@ -355,12 +368,17 @@ pub fn eval_packed_branch<P: PackedField>(
let ne = lt + gt;
yield_constr.constraint(filter * ne * (P::ONES - ne));

// invert lt/gt if aux3 = 1 (src1 and src2 have different sign bits)
let lt = lt * (P::ONES - aux3) + (P::ONES - lt) * aux3;
let gt = gt * (P::ONES - aux3) + (P::ONES - gt) * aux3;

// constraints:
// * is_eq = 1 - is_ne
// * is_ge = 1 - is_lt
// * is_le = 1 - is_gt
// * is_jump = eq * is_eq + ne * is_ne + le * is_le + ge * is_ge + lt * is_lt + gt * is_gt
// * filter * (should_jump - is_jump) == 0

let constr_eq = (P::ONES - ne) * is_eq;
let constr_ne = ne * is_ne;
let constr_le = (P::ONES - gt) * is_le;
Expand Down Expand Up @@ -409,10 +427,10 @@ pub fn eval_ext_circuit_branch<F: RichField + Extendable<D>, const D: usize>(

// Check `branch target value`:
// constraints:
// * jump_dest = sign_extended(offset << 2) + pc
// * jump_dest = sign_extended(offset << 2) + pc + 4
// * filter * should_jump *(next_program_coutner - jump_dest) * (next_program_coutner + 1 << 32 - jump_dest) == 0
// * next_addr = pc + 8
// * next_pc = jump_dest * should_jump + next_addr * (1 - should_jump)
// * filter * (next_program_coutner - next_pc) == 0
// * filter * (1 - should_jump) * (next_program_coutner - next_pc) == 0
{
let mut branch_offset = [zero_extension; 32];

Expand All @@ -422,31 +440,47 @@ pub fn eval_ext_circuit_branch<F: RichField + Extendable<D>, const D: usize>(
branch_offset[18..32].copy_from_slice(&[lv.rd_bits[4]; 14]); // lv.insn_bits[15]
let offset_dst = limb_from_bits_le_recursive(builder, branch_offset.into_iter());

let branch_dst = builder.add_extension(lv.program_counter, offset_dst);
let next_insn = builder.add_const_extension(lv.program_counter, F::from_canonical_u64(8));
let constr_a = builder.mul_extension(branch_dst, jumps_lv.should_jump);
let base_pc = builder.add_const_extension(lv.program_counter, F::from_canonical_u64(4));
let branch_dst = builder.add_extension(base_pc, offset_dst);

let constr_b = builder.sub_extension(one_extension, jumps_lv.should_jump);
let constr_b = builder.mul_extension(constr_b, next_insn);
let constr = builder.add_extension(constr_a, constr_b);
let constr = builder.sub_extension(nv.program_counter, constr);
let overflow_target = builder.add_extension(nv.program_counter, overflow);
let constr_a = builder.sub_extension(overflow_target, branch_dst);
let constr_b = builder.sub_extension(nv.program_counter, branch_dst);
let constr = builder.mul_extension(jumps_lv.should_jump, constr_a);
let constr = builder.mul_extension(constr, constr_b);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

let next_insn = builder.add_const_extension(lv.program_counter, F::from_canonical_u64(8));
let constr_a = builder.sub_extension(one_extension, jumps_lv.should_jump);
let constr_b = builder.sub_extension(nv.program_counter, next_insn);
let constr = builder.mul_extension(constr_a, constr_b);
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);
}

// Check Aux Reg
// constraint:
// * sum = aux1 + aux2
// * filter * (1 - sum * overflow_inv) == 0
// * filter * aux1 * (1 - sum * overflow_inv) == 0
// * filter * aux3 * (1 - aux3) == 0
{
let aux1 = lv.mem_channels[2].addr_virtual;
let aux2 = lv.mem_channels[3].addr_virtual;
let aux1 = lv.mem_channels[2].value;
let aux2 = lv.mem_channels[3].value;
let aux3 = lv.mem_channels[4].value;

let constr = builder.add_extension(aux1, aux2);
let constr = builder.mul_extension(constr, overflow_inv);
let constr = builder.sub_extension(one_extension, constr);
let constr = builder.mul_extension(aux1, constr);
let constr = builder.mul_extension(constr, filter);

yield_constr.constraint(builder, constr);

let constr = builder.sub_extension(one_extension, aux3);
let constr = builder.mul_extension(filter, constr);
let constr = builder.mul_extension(aux3, constr);
yield_constr.constraint(builder, constr);
}

// Check rs Reg
Expand Down Expand Up @@ -480,6 +514,7 @@ pub fn eval_ext_circuit_branch<F: RichField + Extendable<D>, const D: usize>(
let src2 = lv.mem_channels[1].value;
let aux1 = lv.mem_channels[2].value;
let aux2 = lv.mem_channels[3].value;
let aux3 = lv.mem_channels[4].value;

// constraints:
// * z = src2 + aux - src1
Expand Down Expand Up @@ -526,6 +561,17 @@ pub fn eval_ext_circuit_branch<F: RichField + Extendable<D>, const D: usize>(
let constr = builder.mul_extension(constr, filter);
yield_constr.constraint(builder, constr);

// invert lt/gt if aux3 = 1 (src1 and src2 have different sign bits)
let inv_aux3 = builder.sub_extension(one_extension, aux3);
let inv_lt = builder.sub_extension(one_extension, lt);
let inv_gt = builder.sub_extension(one_extension, gt);
let lt_norm = builder.mul_extension(lt, inv_aux3);
let lt_inv = builder.mul_extension(inv_lt, aux3);
let lt = builder.add_extension(lt_norm, lt_inv);
let gt_norm = builder.mul_extension(gt, inv_aux3);
let gt_inv = builder.mul_extension(inv_gt, aux3);
let gt = builder.add_extension(gt_norm, gt_inv);

// constraints:
// * is_eq = 1 - is_ne
// * is_ge = 1 - is_lt
Expand Down Expand Up @@ -638,10 +684,10 @@ pub fn eval_packed<P: PackedField>(
nv: &CpuColumnsView<P>,
yield_constr: &mut ConstraintConsumer<P>,
) {
eval_packed_exit_kernel(lv, nv, yield_constr);
//eval_packed_exit_kernel(lv, nv, yield_constr);
eval_packed_jump_jumpi(lv, nv, yield_constr);
//eval_packed_branch(lv, nv, yield_constr);
//eval_packed_condmov(lv, nv, yield_constr);
eval_packed_branch(lv, nv, yield_constr);
eval_packed_condmov(lv, nv, yield_constr);
}

pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
Expand All @@ -650,8 +696,8 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
eval_ext_circuit_exit_kernel(builder, lv, nv, yield_constr);
//eval_ext_circuit_exit_kernel(builder, lv, nv, yield_constr);
eval_ext_circuit_jump_jumpi(builder, lv, nv, yield_constr);
//eval_ext_circuit_branch(builder, lv, nv, yield_constr);
//eval_ext_circuit_condmov(builder, lv, nv, yield_constr);
eval_ext_circuit_branch(builder, lv, nv, yield_constr);
eval_ext_circuit_condmov(builder, lv, nv, yield_constr);
}
1 change: 1 addition & 0 deletions src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ mod tests {
}

#[test]
#[ignore]
fn test_mips_prove_and_verify() {
env_logger::try_init().unwrap_or_default();
const D: usize = 2;
Expand Down
11 changes: 8 additions & 3 deletions src/witness/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,14 @@ pub(crate) fn generate_branch<F: Field>(
let (src1, src1_op) = reg_read_with_log(src1, 0, state, &mut row)?;
let (src2, src2_op) = reg_read_with_log(src2, 1, state, &mut row)?;
let should_jump = cond.result(src1 as i32, src2 as i32);
reg_write_with_log(0, 2, src1.wrapping_sub(src2), state, &mut row)?;
reg_write_with_log(0, 3, src2.wrapping_sub(src1), state, &mut row)?;
//println!("jump: {} c0: {}, c1: {}, aux1: {}, aux2: {}", should_jump, src1, src2, src1.wrapping_sub(src2), src2.wrapping_sub(src1));
let aux1 = src1.wrapping_sub(src2);
let aux2 = src2.wrapping_sub(src1);
let aux3 = (src1 ^ src2) & 0x80000000 > 0;

reg_write_with_log(0, 2, aux1, state, &mut row)?;
reg_write_with_log(0, 3, aux2, state, &mut row)?;
reg_write_with_log(0, 4, aux3 as usize, state, &mut row)?;
let pc = state.registers.program_counter as u32;
if should_jump {
let target = sign_extend::<16>(target);
Expand All @@ -410,7 +416,6 @@ pub(crate) fn generate_branch<F: Field>(
state.traces.push_cpu(row);
state.jump_to(next_pc as usize);
}
state.traces.push_cpu(row);
state.traces.push_memory(src1_op);
state.traces.push_memory(src2_op);
Ok(())
Expand Down

0 comments on commit 0452441

Please sign in to comment.