diff --git a/src/cpu/jumps.rs b/src/cpu/jumps.rs index d5612fff..eeba30c9 100644 --- a/src/cpu/jumps.rs +++ b/src/cpu/jumps.rs @@ -262,8 +262,8 @@ pub fn eval_packed_branch( let is_ne = lv.opcode_bits[2] * (P::ONES - lv.opcode_bits[1]) * lv.opcode_bits[0]; let is_le = lv.opcode_bits[2] * lv.opcode_bits[1] * (P::ONES - lv.opcode_bits[0]); let is_gt = lv.opcode_bits[2] * lv.opcode_bits[1] * lv.opcode_bits[0]; - let is_ge = (P::ONES - lv.opcode_bits[2]) * lv.rt_bits[0]; - let is_lt = (P::ONES - lv.opcode_bits[2]) * (P::ONES - lv.rt_bits[0]); + let is_ge = special_filter * lv.rt_bits[0]; + let is_lt = special_filter * (P::ONES - lv.rt_bits[0]); let overflow = P::Scalar::from_canonical_u64(1 << 32); let overflow_inv = P::Scalar::from_canonical_u64(GOLDILOCKS_INVERSE_2EXP32); @@ -273,10 +273,10 @@ pub fn eval_packed_branch( // Check `branch target value`: // constraints: - // * jump_dest = sign_extended(offset << 2) + pc + // * jump_dest = sign_extended(offset << 2) + pc + 4 + // * filter * should_jump *(next_program_coutner - jump_dest) * (next_program_coutner + 1 << 32 - jump_dest) == 0 // * next_addr = pc + 8 - // * next_pc = jump_dest * should_jump + next_addr * (1 - should_jump) - // * filter * (next_program_coutner - next_pc) == 0 + // * filter * (1 - should_jump) * (next_program_coutner - next_pc) == 0 { let mut branch_offset = [P::ZEROS; 32]; @@ -286,21 +286,33 @@ pub fn eval_packed_branch( branch_offset[18..32].copy_from_slice(&[lv.rd_bits[4]; 14]); // lv.insn_bits[15] let offset_dst = limb_from_bits_le(branch_offset.into_iter()); - let branch_dst = lv.program_counter + offset_dst; + let branch_dst = lv.program_counter + P::Scalar::from_canonical_u8(4) + offset_dst; + yield_constr.constraint( + filter + * jumps_lv.should_jump + * (nv.program_counter - branch_dst) + * (nv.program_counter + overflow - branch_dst), + ); + let next_inst = lv.program_counter + P::Scalar::from_canonical_u64(8); - let branch_dst = - jumps_lv.should_jump * branch_dst + next_inst * (P::ONES - jumps_lv.should_jump); - yield_constr.constraint(filter * (nv.program_counter - branch_dst)); + yield_constr.constraint( + filter * (P::ONES - jumps_lv.should_jump) * (nv.program_counter - next_inst), + ); } // Check Aux Reg // constraint: // * sum = aux1 + aux2 - // * filter * (1 - sum * overflow_inv) == 0 + // * filter * aux1 * (1 - sum * overflow_inv) == 0 + // * filter * aux3 * (1 - aux3) == 0 { - let aux1 = lv.mem_channels[2].addr_virtual; - let aux2 = lv.mem_channels[3].addr_virtual; - yield_constr.constraint(filter * (P::ONES - (aux1 + aux2) * overflow_inv)); + let aux1: P = lv.mem_channels[2].value; + let aux2 = lv.mem_channels[3].value; + let aux3 = lv.mem_channels[4].value; + + yield_constr.constraint(filter * aux1 * (P::ONES - (aux1 + aux2) * overflow_inv)); + + yield_constr.constraint(filter * aux3 * (P::ONES - aux3)); } // Check rs Reg @@ -326,6 +338,7 @@ pub fn eval_packed_branch( let src2 = lv.mem_channels[1].value; let aux1 = lv.mem_channels[2].value; let aux2 = lv.mem_channels[3].value; + let aux3 = lv.mem_channels[4].value; // constraints: // * z = src2 + aux - src1 @@ -355,12 +368,17 @@ pub fn eval_packed_branch( let ne = lt + gt; yield_constr.constraint(filter * ne * (P::ONES - ne)); + // invert lt/gt if aux3 = 1 (src1 and src2 have different sign bits) + let lt = lt * (P::ONES - aux3) + (P::ONES - lt) * aux3; + let gt = gt * (P::ONES - aux3) + (P::ONES - gt) * aux3; + // constraints: // * is_eq = 1 - is_ne // * is_ge = 1 - is_lt // * is_le = 1 - is_gt // * is_jump = eq * is_eq + ne * is_ne + le * is_le + ge * is_ge + lt * is_lt + gt * is_gt // * filter * (should_jump - is_jump) == 0 + let constr_eq = (P::ONES - ne) * is_eq; let constr_ne = ne * is_ne; let constr_le = (P::ONES - gt) * is_le; @@ -409,10 +427,10 @@ pub fn eval_ext_circuit_branch, const D: usize>( // Check `branch target value`: // constraints: - // * jump_dest = sign_extended(offset << 2) + pc + // * jump_dest = sign_extended(offset << 2) + pc + 4 + // * filter * should_jump *(next_program_coutner - jump_dest) * (next_program_coutner + 1 << 32 - jump_dest) == 0 // * next_addr = pc + 8 - // * next_pc = jump_dest * should_jump + next_addr * (1 - should_jump) - // * filter * (next_program_coutner - next_pc) == 0 + // * filter * (1 - should_jump) * (next_program_coutner - next_pc) == 0 { let mut branch_offset = [zero_extension; 32]; @@ -422,14 +440,21 @@ pub fn eval_ext_circuit_branch, const D: usize>( branch_offset[18..32].copy_from_slice(&[lv.rd_bits[4]; 14]); // lv.insn_bits[15] let offset_dst = limb_from_bits_le_recursive(builder, branch_offset.into_iter()); - let branch_dst = builder.add_extension(lv.program_counter, offset_dst); - let next_insn = builder.add_const_extension(lv.program_counter, F::from_canonical_u64(8)); - let constr_a = builder.mul_extension(branch_dst, jumps_lv.should_jump); + let base_pc = builder.add_const_extension(lv.program_counter, F::from_canonical_u64(4)); + let branch_dst = builder.add_extension(base_pc, offset_dst); - let constr_b = builder.sub_extension(one_extension, jumps_lv.should_jump); - let constr_b = builder.mul_extension(constr_b, next_insn); - let constr = builder.add_extension(constr_a, constr_b); - let constr = builder.sub_extension(nv.program_counter, constr); + let overflow_target = builder.add_extension(nv.program_counter, overflow); + let constr_a = builder.sub_extension(overflow_target, branch_dst); + let constr_b = builder.sub_extension(nv.program_counter, branch_dst); + let constr = builder.mul_extension(jumps_lv.should_jump, constr_a); + let constr = builder.mul_extension(constr, constr_b); + let constr = builder.mul_extension(constr, filter); + yield_constr.constraint(builder, constr); + + let next_insn = builder.add_const_extension(lv.program_counter, F::from_canonical_u64(8)); + let constr_a = builder.sub_extension(one_extension, jumps_lv.should_jump); + let constr_b = builder.sub_extension(nv.program_counter, next_insn); + let constr = builder.mul_extension(constr_a, constr_b); let constr = builder.mul_extension(constr, filter); yield_constr.constraint(builder, constr); } @@ -437,16 +462,25 @@ pub fn eval_ext_circuit_branch, const D: usize>( // Check Aux Reg // constraint: // * sum = aux1 + aux2 - // * filter * (1 - sum * overflow_inv) == 0 + // * filter * aux1 * (1 - sum * overflow_inv) == 0 + // * filter * aux3 * (1 - aux3) == 0 { - let aux1 = lv.mem_channels[2].addr_virtual; - let aux2 = lv.mem_channels[3].addr_virtual; + let aux1 = lv.mem_channels[2].value; + let aux2 = lv.mem_channels[3].value; + let aux3 = lv.mem_channels[4].value; + let constr = builder.add_extension(aux1, aux2); let constr = builder.mul_extension(constr, overflow_inv); let constr = builder.sub_extension(one_extension, constr); + let constr = builder.mul_extension(aux1, constr); let constr = builder.mul_extension(constr, filter); yield_constr.constraint(builder, constr); + + let constr = builder.sub_extension(one_extension, aux3); + let constr = builder.mul_extension(filter, constr); + let constr = builder.mul_extension(aux3, constr); + yield_constr.constraint(builder, constr); } // Check rs Reg @@ -480,6 +514,7 @@ pub fn eval_ext_circuit_branch, const D: usize>( let src2 = lv.mem_channels[1].value; let aux1 = lv.mem_channels[2].value; let aux2 = lv.mem_channels[3].value; + let aux3 = lv.mem_channels[4].value; // constraints: // * z = src2 + aux - src1 @@ -526,6 +561,17 @@ pub fn eval_ext_circuit_branch, const D: usize>( let constr = builder.mul_extension(constr, filter); yield_constr.constraint(builder, constr); + // invert lt/gt if aux3 = 1 (src1 and src2 have different sign bits) + let inv_aux3 = builder.sub_extension(one_extension, aux3); + let inv_lt = builder.sub_extension(one_extension, lt); + let inv_gt = builder.sub_extension(one_extension, gt); + let lt_norm = builder.mul_extension(lt, inv_aux3); + let lt_inv = builder.mul_extension(inv_lt, aux3); + let lt = builder.add_extension(lt_norm, lt_inv); + let gt_norm = builder.mul_extension(gt, inv_aux3); + let gt_inv = builder.mul_extension(inv_gt, aux3); + let gt = builder.add_extension(gt_norm, gt_inv); + // constraints: // * is_eq = 1 - is_ne // * is_ge = 1 - is_lt @@ -638,10 +684,10 @@ pub fn eval_packed( nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - eval_packed_exit_kernel(lv, nv, yield_constr); + //eval_packed_exit_kernel(lv, nv, yield_constr); eval_packed_jump_jumpi(lv, nv, yield_constr); - //eval_packed_branch(lv, nv, yield_constr); - //eval_packed_condmov(lv, nv, yield_constr); + eval_packed_branch(lv, nv, yield_constr); + eval_packed_condmov(lv, nv, yield_constr); } pub fn eval_ext_circuit, const D: usize>( @@ -650,8 +696,8 @@ pub fn eval_ext_circuit, const D: usize>( nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - eval_ext_circuit_exit_kernel(builder, lv, nv, yield_constr); + //eval_ext_circuit_exit_kernel(builder, lv, nv, yield_constr); eval_ext_circuit_jump_jumpi(builder, lv, nv, yield_constr); - //eval_ext_circuit_branch(builder, lv, nv, yield_constr); - //eval_ext_circuit_condmov(builder, lv, nv, yield_constr); + eval_ext_circuit_branch(builder, lv, nv, yield_constr); + eval_ext_circuit_condmov(builder, lv, nv, yield_constr); } diff --git a/src/verifier.rs b/src/verifier.rs index 85ce8694..bc809b11 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -532,6 +532,7 @@ mod tests { } #[test] + #[ignore] fn test_mips_prove_and_verify() { env_logger::try_init().unwrap_or_default(); const D: usize = 2; diff --git a/src/witness/operation.rs b/src/witness/operation.rs index 3ae2a6cf..20678f1f 100644 --- a/src/witness/operation.rs +++ b/src/witness/operation.rs @@ -394,8 +394,14 @@ pub(crate) fn generate_branch( let (src1, src1_op) = reg_read_with_log(src1, 0, state, &mut row)?; let (src2, src2_op) = reg_read_with_log(src2, 1, state, &mut row)?; let should_jump = cond.result(src1 as i32, src2 as i32); - reg_write_with_log(0, 2, src1.wrapping_sub(src2), state, &mut row)?; - reg_write_with_log(0, 3, src2.wrapping_sub(src1), state, &mut row)?; + //println!("jump: {} c0: {}, c1: {}, aux1: {}, aux2: {}", should_jump, src1, src2, src1.wrapping_sub(src2), src2.wrapping_sub(src1)); + let aux1 = src1.wrapping_sub(src2); + let aux2 = src2.wrapping_sub(src1); + let aux3 = (src1 ^ src2) & 0x80000000 > 0; + + reg_write_with_log(0, 2, aux1, state, &mut row)?; + reg_write_with_log(0, 3, aux2, state, &mut row)?; + reg_write_with_log(0, 4, aux3 as usize, state, &mut row)?; let pc = state.registers.program_counter as u32; if should_jump { let target = sign_extend::<16>(target); @@ -410,7 +416,6 @@ pub(crate) fn generate_branch( state.traces.push_cpu(row); state.jump_to(next_pc as usize); } - state.traces.push_cpu(row); state.traces.push_memory(src1_op); state.traces.push_memory(src2_op); Ok(())