From 59e762edd6a8115f37c8392209b6d4852d800c6e Mon Sep 17 00:00:00 2001 From: Stephen <81497928+eigmax@users.noreply.github.com> Date: Tue, 28 Nov 2023 09:00:13 +0800 Subject: [PATCH] Fix memio (#50) * fix: mmio --- Cargo.toml | 1 + src/all_stark.rs | 6 +- src/cpu/bootstrap_kernel.rs | 4 +- src/cpu/columns/general.rs | 23 +- src/cpu/columns/ops.rs | 3 +- src/cpu/control_flow.rs | 5 +- src/cpu/cpu_stark.rs | 44 +- src/cpu/decode.rs | 9 +- src/cpu/memio.rs | 1280 ++++++++++++++++++++++++++++------- src/cpu/syscall.rs | 1 - src/logic.rs | 22 +- src/stark_testing.rs | 4 +- src/witness/operation.rs | 142 +++- src/witness/traces.rs | 8 +- src/witness/transition.rs | 7 +- src/witness/util.rs | 8 +- 16 files changed, 1236 insertions(+), 331 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index af8f6a36..102017e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ plonky2 = { git = "https://github.com/0xPolygonZero/plonky2", branch = "main", f starky = { git = "https://github.com/0xPolygonZero/plonky2", branch = "main" , features = ["timing"] } plonky2_util = { git = "https://github.com/0xPolygonZero/plonky2", branch = "main" } plonky2_maybe_rayon = { git = "https://github.com/0xPolygonZero/plonky2", branch = "main" } + itertools = "0.11.0" log = { version = "0.4.14", default-features = false } anyhow = { version = "1.0.40", default-features = false } diff --git a/src/all_stark.rs b/src/all_stark.rs index 308dda46..d6910481 100644 --- a/src/all_stark.rs +++ b/src/all_stark.rs @@ -10,11 +10,7 @@ use crate::cpu::cpu_stark; use crate::cpu::cpu_stark::CpuStark; use crate::cpu::membus::NUM_GP_CHANNELS; use crate::cross_table_lookup::{CrossTableLookup, TableWithColumns}; -use crate::keccak::keccak_stark; -use crate::keccak::keccak_stark::KeccakStark; -use crate::keccak_sponge::columns::KECCAK_RATE_BYTES; -use crate::keccak_sponge::keccak_sponge_stark; -use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeStark; + use crate::logic; use crate::logic::LogicStark; use crate::memory::memory_stark; diff --git a/src/cpu/bootstrap_kernel.rs b/src/cpu/bootstrap_kernel.rs index 463f2974..bff3aa02 100644 --- a/src/cpu/bootstrap_kernel.rs +++ b/src/cpu/bootstrap_kernel.rs @@ -83,7 +83,7 @@ pub(crate) fn eval_bootstrap_kernel_packed> // If this is a bootloading row and the i'th memory channel is used, it must have the right // address, name context = 0, segment = Code, virt + 4 = next_virt let code_segment = F::from_canonical_usize(Segment::Code as usize); - for (i, channel) in local_values.mem_channels.iter().enumerate() { + for (_i, channel) in local_values.mem_channels.iter().enumerate() { let filter = local_is_bootstrap * channel.used; yield_constr.constraint(filter * channel.addr_context); yield_constr.constraint(filter * (channel.addr_segment - code_segment)); @@ -137,7 +137,7 @@ pub(crate) fn eval_bootstrap_kernel_ext_circuit, co // address, name context = 0, segment = Code, virt + 4 = next_virt let code_segment = builder.constant_extension(F::Extension::from_canonical_usize(Segment::Code as usize)); - for (i, channel) in local_values.mem_channels.iter().enumerate() { + for (_i, channel) in local_values.mem_channels.iter().enumerate() { let filter = builder.mul_extension(local_is_bootstrap, channel.used); let constraint = builder.mul_extension(filter, channel.addr_context); yield_constr.constraint(builder, constraint); diff --git a/src/cpu/columns/general.rs b/src/cpu/columns/general.rs index 9869286a..35b1672f 100644 --- a/src/cpu/columns/general.rs +++ b/src/cpu/columns/general.rs @@ -10,6 +10,7 @@ pub(crate) union CpuGeneralColumnsView { logic: CpuLogicView, jumps: CpuJumpsView, shift: CpuShiftView, + io: CpuIoView, } impl CpuGeneralColumnsView { @@ -52,6 +53,16 @@ impl CpuGeneralColumnsView { pub(crate) fn shift_mut(&mut self) -> &mut CpuShiftView { unsafe { &mut self.shift } } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn io(&self) -> &CpuIoView { + unsafe { &self.io } + } + + // SAFETY: Each view is a valid interpretation of the underlying array. + pub(crate) fn io_mut(&mut self) -> &mut CpuIoView { + unsafe { &mut self.io } + } } impl PartialEq for CpuGeneralColumnsView { @@ -88,8 +99,6 @@ pub(crate) struct CpuSyscallView { pub(crate) sysnum: [T; 11], pub(crate) a0: [T; 3], pub(crate) a1: T, - // pub(crate) a1: [T;2], - // pub(crate) sz: [T;2], } #[derive(Copy, Clone)] @@ -112,10 +121,12 @@ pub(crate) struct CpuShiftView { } #[derive(Copy, Clone)] -pub(crate) struct CpuGPRView { - // For a shift amount of displacement: [T], this is the inverse of - // sum(displacement[1..]) or zero if the sum is zero. - pub(crate) regs: [T; 32], +pub(crate) struct CpuIoView { + pub(crate) rs_le: [T; 32], + pub(crate) rt_le: [T; 32], + pub(crate) mem_le: [T; 32], + pub(crate) micro_op: [T; 8], + pub(crate) diff_inv: T, } // `u8` is guaranteed to have a `size_of` of 1. diff --git a/src/cpu/columns/ops.rs b/src/cpu/columns/ops.rs index 921ee1e2..a3b99fa1 100644 --- a/src/cpu/columns/ops.rs +++ b/src/cpu/columns/ops.rs @@ -24,7 +24,8 @@ pub struct OpsColumnsView { pub get_context: T, pub set_context: T, pub exit_kernel: T, - pub m_op_general: T, + pub m_op_load: T, + pub m_op_store: T, pub syscall: T, } diff --git a/src/cpu/control_flow.rs b/src/cpu/control_flow.rs index a7cae170..7fb45416 100644 --- a/src/cpu/control_flow.rs +++ b/src/cpu/control_flow.rs @@ -8,7 +8,7 @@ use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer use crate::cpu::columns::{CpuColumnsView, COL_MAP}; // use crate::cpu::kernel::aggregator::KERNEL; -const NATIVE_INSTRUCTIONS: [usize; 12] = [ +const NATIVE_INSTRUCTIONS: [usize; 13] = [ COL_MAP.op.binary_op, COL_MAP.op.eq_iszero, COL_MAP.op.logic_op, @@ -23,7 +23,8 @@ const NATIVE_INSTRUCTIONS: [usize; 12] = [ COL_MAP.op.get_context, COL_MAP.op.set_context, // not EXIT_KERNEL (performs a jump) - COL_MAP.op.m_op_general, + COL_MAP.op.m_op_load, + COL_MAP.op.m_op_store, // not SYSCALL (performs a jump) // not exceptions (also jump) ]; diff --git a/src/cpu/cpu_stark.rs b/src/cpu/cpu_stark.rs index 3fcc7a24..71930b9b 100644 --- a/src/cpu/cpu_stark.rs +++ b/src/cpu/cpu_stark.rs @@ -14,9 +14,7 @@ use super::columns::CpuColumnsView; use crate::all_stark::Table; use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::{COL_MAP, NUM_CPU_COLUMNS}; -use crate::cpu::{ - bootstrap_kernel, control_flow, count, decode, jumps, membus, memio, pc, shift, syscall, -}; +use crate::cpu::{bootstrap_kernel, count, decode, jumps, membus, memio, pc, syscall}; use crate::cross_table_lookup::{Column, TableWithColumns}; use crate::evaluation_frame::{StarkEvaluationFrame, StarkFrame}; use crate::memory::segments::Segment; @@ -191,24 +189,17 @@ impl, const D: usize> Stark for CpuStark = next_values.borrow(); - /* bootstrap_kernel::eval_bootstrap_kernel_packed(local_values, next_values, yield_constr); - contextops::eval_packed(local_values, next_values, yield_constr); - control_flow::eval_packed_generic(local_values, next_values, yield_constr); - */ - syscall::eval_packed(local_values, yield_constr); - - /* + //contextops::eval_packed(local_values, next_values, yield_constr); + //control_flow::eval_packed_generic(local_values, next_values, yield_constr); decode::eval_packed_generic(local_values, yield_constr); jumps::eval_packed(local_values, next_values, yield_constr); membus::eval_packed(local_values, yield_constr); memio::eval_packed(local_values, next_values, yield_constr); pc::eval_packed(local_values, next_values, yield_constr); - shift::eval_packed(local_values, yield_constr); - syscall::eval_packed(local_values, yield_constr); + //shift::eval_packed(local_values, yield_constr); count::eval_packed(local_values, yield_constr); - - */ + syscall::eval_packed(local_values, yield_constr); } fn eval_ext_circuit( @@ -224,27 +215,21 @@ impl, const D: usize> Stark for CpuStark> = next_values.borrow(); - /* bootstrap_kernel::eval_bootstrap_kernel_ext_circuit( builder, local_values, next_values, yield_constr, ); - contextops::eval_ext_circuit(builder, local_values, next_values, yield_constr); - control_flow::eval_ext_circuit(builder, local_values, next_values, yield_constr); - */ - /* + //contextops::eval_ext_circuit(builder, local_values, next_values, yield_constr); + //control_flow::eval_ext_circuit(builder, local_values, next_values, yield_constr); decode::eval_ext_circuit(builder, local_values, yield_constr); jumps::eval_ext_circuit(builder, local_values, next_values, yield_constr); membus::eval_ext_circuit(builder, local_values, yield_constr); memio::eval_ext_circuit(builder, local_values, next_values, yield_constr); pc::eval_ext_circuit(builder, local_values, next_values, yield_constr); - shift::eval_ext_circuit(builder, local_values, yield_constr); - syscall::eval_ext_circuit(builder, local_values, yield_constr); + //shift::eval_ext_circuit(builder, local_values, yield_constr); count::eval_ext_circuit(builder, local_values, yield_constr); - - */ syscall::eval_ext_circuit(builder, local_values, yield_constr); } @@ -256,19 +241,18 @@ impl, const D: usize> Stark for CpuStark>(); for i in 0..(vals.len() - 1) { - println!("vals: {:?}, cpu column: {:?}", vals[i], state.traces.cpu[i]); + println!( + "[] vals: {:?},\ncpu column: {:?}", + vals[i], state.traces.cpu[i] + ); test_stark_cpu_check_constraints::(stark, &vals[i], &vals[i + 1]); } } diff --git a/src/cpu/decode.rs b/src/cpu/decode.rs index c8db4bf8..ec1e0a18 100644 --- a/src/cpu/decode.rs +++ b/src/cpu/decode.rs @@ -1,6 +1,6 @@ use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; -use plonky2::field::types::Field; + use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -46,13 +46,14 @@ const OPCODES: [(u32, usize, bool, usize); 10] = [ /// List of combined opcodes requiring a special handling. /// Each index in the list corresponds to an arbitrary combination /// of opcodes defined in evm/src/cpu/columns/ops.rs. -const COMBINED_OPCODES: [usize; 6] = [ +const COMBINED_OPCODES: [usize; 7] = [ COL_MAP.op.logic_op, COL_MAP.op.binary_op, COL_MAP.op.binary_imm_op, COL_MAP.op.shift, COL_MAP.op.shift_imm, - COL_MAP.op.m_op_general, + COL_MAP.op.m_op_load, + COL_MAP.op.m_op_store, ]; /// Break up an opcode (which is 32 bits long) into its 32 bits. @@ -105,8 +106,6 @@ pub fn eval_ext_circuit, const D: usize>( lv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let one = builder.one_extension(); - // Note: The constraints below do not need to be restricted to CPU cycles. // Ensure that the kernel flag is valid (either 0 or 1). diff --git a/src/cpu/memio.rs b/src/cpu/memio.rs index 40783265..dc425e9a 100644 --- a/src/cpu/memio.rs +++ b/src/cpu/memio.rs @@ -1,105 +1,651 @@ -use itertools::izip; use plonky2::field::extension::Extendable; -use plonky2::field::packed::PackedField; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; + use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use crate::cpu::columns::CpuColumnsView; use crate::cpu::membus::NUM_GP_CHANNELS; +use crate::memory::segments::Segment; +use crate::util::{limb_from_bits_le, limb_from_bits_le_recursive}; + +#[inline] +fn load_offset(lv: &CpuColumnsView

) -> P { + let mut mem_offset = [P::ZEROS; 32]; + mem_offset[0..6].copy_from_slice(&lv.func_bits); // 6 bits + mem_offset[6..11].copy_from_slice(&lv.shamt_bits); // 5 bits + mem_offset[11..16].copy_from_slice(&lv.rd_bits); // 5 bits + //mem_offset[16..].copy_from_slice(&[lv.rd_bits[4]; 16]); + let mem_offset = sign_extend::<_, 16>(&mem_offset); + limb_from_bits_le(mem_offset.into_iter()) +} -fn get_addr(lv: &CpuColumnsView) -> (T, T, T) { - let addr_context = lv.mem_channels[0].value; - let addr_segment = lv.mem_channels[1].value; - let addr_virtual = lv.mem_channels[2].value; - (addr_context, addr_segment, addr_virtual) +#[inline] +fn load_offset_ext, const D: usize>( + builder: &mut CircuitBuilder, + lv: &CpuColumnsView>, +) -> ExtensionTarget { + let mut mem_offset = [builder.zero_extension(); 32]; + mem_offset[0..6].copy_from_slice(&lv.func_bits); // 6 bits + mem_offset[6..11].copy_from_slice(&lv.shamt_bits); // 5 bits + mem_offset[11..16].copy_from_slice(&lv.rd_bits); // 5 bits + //mem_offset[16..].copy_from_slice(&[lv.rd_bits[4]; 16]); + let mem_offset = sign_extend_ext::<_, D, 16>(builder, &mem_offset); + limb_from_bits_le_recursive(builder, mem_offset.into_iter()) +} + +#[inline] +fn sign_extend(limbs: &[P; 32]) -> [P; 32] { + let mut out = [P::ZEROS; 32]; + for i in 0..N { + out[i] = limbs[i]; + } + for i in N..32 { + out[i] = limbs[N - 1]; + } + out } +#[inline] +fn sign_extend_ext, const D: usize, const N: usize>( + builder: &mut CircuitBuilder, + limbs: &[ExtensionTarget; 32], +) -> [ExtensionTarget; 32] { + let mut out = [builder.zero_extension(); 32]; + for i in 0..N { + out[i] = limbs[i]; + } + for i in N..32 { + out[i] = limbs[N - 1]; + } + out +} + +/// Constant -4 +const GOLDILOCKS_INVERSE_NEG4: u64 = 18446744069414584317; + fn eval_packed_load( lv: &CpuColumnsView

, _nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - // The opcode for MLOAD_GENERAL is 0xfb. If the operation is MLOAD_GENERAL, lv.opcode_bits[0] = 1 - let filter = lv.op.m_op_general * lv.opcode_bits[0]; + // If the operation is MLOAD_GENERAL, lv.opcode_bits[5] = 1 + let filter = lv.op.m_op_load * lv.opcode_bits[5]; - let (addr_context, addr_segment, addr_virtual) = get_addr(lv); + // Check mem channel segment is register + let diff = lv.mem_channels[0].addr_segment + - P::Scalar::from_canonical_u64(Segment::RegisterFile as u64); + yield_constr.constraint(filter * diff); + let diff = lv.mem_channels[1].addr_segment + - P::Scalar::from_canonical_u64(Segment::RegisterFile as u64); + yield_constr.constraint(filter * diff); - let load_channel = lv.mem_channels[3]; - yield_constr.constraint(filter * (load_channel.used - P::ONES)); - yield_constr.constraint(filter * (load_channel.is_read - P::ONES)); - yield_constr.constraint(filter * (load_channel.addr_context - addr_context)); - yield_constr.constraint(filter * (load_channel.addr_segment - addr_segment)); - yield_constr.constraint(filter * (load_channel.addr_virtual - addr_virtual)); + // Check memory is used + // Check is_read is 0/1 + + let rs = lv.mem_channels[0].value; + let rt = lv.mem_channels[1].value; + let mem = lv.mem_channels[3].value; + let rs_limbs = lv.general.io().rs_le; + let rt_limbs = lv.general.io().rt_le; + let mem_limbs = lv.general.io().mem_le; + + // Calculate rs: + // let virt_raw = (rs as u32).wrapping_add(sign_extend::<16>(offset)); + let offset = load_offset(lv); + let virt_raw = rs + offset; + + // May raise overflow here since wrapping_add used in simulator + let rs_from_bits = limb_from_bits_le(rs_limbs.into_iter()); + let u32max = P::Scalar::from_canonical_u64(1u64 << 32); + yield_constr + .constraint(filter * (rs_from_bits - virt_raw) * (rs_from_bits + u32max - virt_raw)); + + let rt_from_bits = limb_from_bits_le(rt_limbs.into_iter()); + yield_constr.constraint(filter * (rt_from_bits - rt)); + + // Constrain mem address + // let virt = virt_raw & 0xFFFF_FFFC; + let mut tmp = rs_limbs.clone(); + tmp[0] = P::ZEROS; + tmp[1] = P::ZEROS; + let virt = limb_from_bits_le(tmp.into_iter()); + + let mem_virt = lv.mem_channels[2].addr_virtual; + yield_constr.constraint(filter * (virt - mem_virt)); + + // Verify op + let op_inv = lv.general.io().diff_inv; + let op = lv.mem_channels[4].value; + yield_constr.constraint(filter * (P::ONES - op * op_inv)); + + // Constrain mem value + // LH: micro_op[0] * sign_extend::<16>((mem >> (16 - (rs & 2) * 8)) & 0xffff) + { + // Range value(rs[1]): rs[1] == 1 + let mut mem_val_1 = [P::ZEROS; 32]; + mem_val_1[0..16].copy_from_slice(&mem_limbs[0..16]); + let mem_val_1 = sign_extend::<_, 16>(&mem_val_1); + + // Range value(rs[1]): rs[1] == 0 + let mut mem_val_0 = [P::ZEROS; 32]; + mem_val_0[0..16].copy_from_slice(&mem_limbs[16..32]); + let mem_val_0 = sign_extend::<_, 16>(&mut mem_val_0); + + let mem_val_1 = limb_from_bits_le(mem_val_1.into_iter()); + let mem_val_0 = limb_from_bits_le(mem_val_0.into_iter()); + + // Range check + let sum = rs_limbs[1] * (mem - mem_val_1) + (rs_limbs[1] - P::ONES) * (mem - mem_val_0); + yield_constr.constraint(filter * lv.general.io().micro_op[0] * sum); + } + + // LWL: + // let val = mem << ((rs & 3) * 8); + // let mask = 0xffFFffFFu32 << ((rs & 3) * 8); + // (rt & (!mask)) | val + // Use mem_val_{rs[0]}_{rs[1]} to indicate the mem value for different value on rs' first and + // second bit + { + let mut mem_val_0_0 = [P::ZEROS; 32]; + let mut mem_val_0_1 = [P::ZEROS; 32]; + let mut mem_val_1_0 = [P::ZEROS; 32]; + let mut mem_val_1_1 = [P::ZEROS; 32]; + + mem_val_0_0[0..32].copy_from_slice(&mem_limbs[0..32]); + + mem_val_1_0[0..8].copy_from_slice(&rt_limbs[0..8]); + mem_val_1_0[8..].copy_from_slice(&mem_limbs[0..24]); + + mem_val_0_1[0..16].copy_from_slice(&rt_limbs[0..16]); + mem_val_0_1[16..].copy_from_slice(&mem_limbs[0..16]); + + mem_val_1_1[0..24].copy_from_slice(&rt_limbs[0..24]); + mem_val_1_1[24..].copy_from_slice(&mem_limbs[0..8]); + + let mem_val_0_0 = limb_from_bits_le(mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le(mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le(mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le(mem_val_1_1.into_iter()); + + let sum = (mem - mem_val_0_0) * (rs_limbs[1] - P::ONES) * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_0) * (rs_limbs[1] - P::ONES) * rs_limbs[0] + + (mem - mem_val_0_1) * rs_limbs[1] * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_1) * rs_limbs[1] * rs_limbs[0]; + yield_constr.constraint(filter * lv.general.io().micro_op[1] * sum); + } + + // LW: + { + let mem_value = limb_from_bits_le(mem_limbs.into_iter()); + yield_constr.constraint(filter * lv.general.io().micro_op[2] * (mem - mem_value)); + } + + // LBU: (mem >> (24 - (rs & 3) * 8)) & 0xff + { + let mut mem_val_0_0 = [P::ZEROS; 32]; + let mut mem_val_0_1 = [P::ZEROS; 32]; + let mut mem_val_1_0 = [P::ZEROS; 32]; + let mut mem_val_1_1 = [P::ZEROS; 32]; + + mem_val_0_0[0..8].copy_from_slice(&mem_limbs[24..32]); + mem_val_1_0[0..8].copy_from_slice(&mem_limbs[16..24]); + mem_val_0_1[0..8].copy_from_slice(&mem_limbs[8..16]); + mem_val_1_1[0..8].copy_from_slice(&mem_limbs[0..8]); + + let mem_val_0_0 = limb_from_bits_le(mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le(mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le(mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le(mem_val_1_1.into_iter()); + + let sum = (mem - mem_val_0_0) * (rs_limbs[1] - P::ONES) * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_0) * (rs_limbs[1] - P::ONES) * rs_limbs[0] + + (mem - mem_val_0_1) * rs_limbs[1] * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_1) * rs_limbs[1] * rs_limbs[0]; + yield_constr.constraint(filter * lv.general.io().micro_op[3] * sum); + } + + // LHU: (mem >> (16 - (rs & 2) * 8)) & 0xffff + { + let mut mem_val_0 = [P::ZEROS; 32]; + let mut mem_val_1 = [P::ZEROS; 32]; + + mem_val_0[0..16].copy_from_slice(&mem_limbs[16..32]); + mem_val_1[0..16].copy_from_slice(&mem_limbs[0..16]); + + let mem_val_1 = limb_from_bits_le(mem_val_1.into_iter()); + let mem_val_0 = limb_from_bits_le(mem_val_0.into_iter()); + + let sum = rs_limbs[1] * (mem - mem_val_1) + (rs_limbs[1] - P::ONES) * (mem - mem_val_0); + yield_constr.constraint(filter * lv.general.io().micro_op[4] * sum); + } + + // LWR: + // let val = mem >> (24 - (rs & 3) * 8); + // let mask = 0xffFFffFFu32 >> (24 - (rs & 3) * 8); + // (rt & (!mask)) | val + { + let mut mem_val_0_0 = [P::ZEROS; 32]; + let mut mem_val_0_1 = [P::ZEROS; 32]; + let mut mem_val_1_0 = [P::ZEROS; 32]; + let mut mem_val_1_1 = [P::ZEROS; 32]; + + mem_val_0_0[8..].copy_from_slice(&rt_limbs[8..32]); + mem_val_0_0[0..8].copy_from_slice(&mem_limbs[24..32]); + + mem_val_1_0[16..].copy_from_slice(&rt_limbs[16..32]); + mem_val_1_0[0..16].copy_from_slice(&mem_limbs[16..32]); + + mem_val_0_1[24..].copy_from_slice(&rt_limbs[0..8]); + mem_val_0_1[0..24].copy_from_slice(&mem_limbs[8..32]); + + mem_val_1_1[0..32].copy_from_slice(&mem_limbs[..]); + + let mem_val_0_0 = limb_from_bits_le(mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le(mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le(mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le(mem_val_1_1.into_iter()); + + let sum = (mem - mem_val_0_0) * (rs_limbs[1] - P::ONES) * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_0) * (rs_limbs[1] - P::ONES) * rs_limbs[0] + + (mem - mem_val_0_1) * rs_limbs[1] * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_1) * rs_limbs[1] * rs_limbs[0]; + yield_constr.constraint(filter * lv.general.io().micro_op[5] * sum); + } + + // LL: + { + let mem_value = limb_from_bits_le(mem_limbs.into_iter()); + yield_constr.constraint(filter * lv.general.io().micro_op[6] * (mem - mem_value)); + } + + // LB: sign_extend::<8>((mem >> (24 - (rs & 3) * 8)) & 0xff) + { + let mut mem_val_0_0 = [P::ZEROS; 32]; + let mut mem_val_0_1 = [P::ZEROS; 32]; + let mut mem_val_1_0 = [P::ZEROS; 32]; + let mut mem_val_1_1 = [P::ZEROS; 32]; + + mem_val_0_0[0..8].copy_from_slice(&mem_limbs[24..]); + mem_val_1_0[0..8].copy_from_slice(&mem_limbs[16..24]); + mem_val_0_1[0..8].copy_from_slice(&mem_limbs[8..16]); + mem_val_1_1[0..8].copy_from_slice(&mem_limbs[0..8]); + + let mem_val_0_0 = sign_extend::<_, 8>(&mem_val_0_0); + let mem_val_1_0 = sign_extend::<_, 8>(&mem_val_1_0); + let mem_val_0_1 = sign_extend::<_, 8>(&mem_val_0_1); + let mem_val_1_1 = sign_extend::<_, 8>(&mem_val_1_1); + + let mem_val_0_0 = limb_from_bits_le(mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le(mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le(mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le(mem_val_1_1.into_iter()); + + let sum = (mem - mem_val_0_0) * (rs_limbs[1] - P::ONES) * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_0) * (rs_limbs[1] - P::ONES) * rs_limbs[0] + + (mem - mem_val_0_1) * rs_limbs[1] * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_1) * rs_limbs[1] * rs_limbs[0]; + yield_constr.constraint(filter * lv.general.io().micro_op[7] * sum); + } // Disable remaining memory channels, if any. - for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS] { + // Note: SC needs 5 channel + for &channel in &lv.mem_channels[6..NUM_GP_CHANNELS] { yield_constr.constraint(filter * channel.used); } - - // Stack constraints - /* - stack::eval_packed_one( - lv, - nv, - filter, - stack::MLOAD_GENERAL_OP.unwrap(), - yield_constr, - ); - */ } fn eval_ext_circuit_load, const D: usize>( - builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + builder: &mut CircuitBuilder, lv: &CpuColumnsView>, _nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, ) { - let mut filter = lv.op.m_op_general; - filter = builder.mul_extension(filter, lv.opcode_bits[0]); + let zeros = builder.zero_extension(); + let ones = builder.one_extension(); + let filter = builder.mul_extension(lv.op.m_op_load, lv.opcode_bits[5]); + + // Check mem channel segment is register + let diff = builder.add_const_extension( + lv.mem_channels[0].addr_segment, + -F::from_canonical_u64(Segment::RegisterFile as u64), + ); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + let diff = builder.add_const_extension( + lv.mem_channels[1].addr_segment, + -F::from_canonical_u64(Segment::RegisterFile as u64), + ); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + let rs = lv.mem_channels[0].value; + let rt = lv.mem_channels[1].value; + let mem = lv.mem_channels[3].value; + let rs_limbs = lv.general.io().rs_le; + let rt_limbs = lv.general.io().rt_le; + let mem_limbs = lv.general.io().mem_le; - let (addr_context, addr_segment, addr_virtual) = get_addr(lv); + // Calculate rs: + // let virt_raw = (rs as u32).wrapping_add(sign_extend::<16>(offset)); + let offset = load_offset_ext(builder, lv); + let virt_raw = builder.add_extension(rs, offset); - let load_channel = lv.mem_channels[3]; + let u32max = F::from_canonical_u64(1u64 << 32); + //yield_constr.constraint(filter * (rs_from_bits - virt_raw) * (rs_from_bits + u32max - virt_raw)); + let rs_from_bits = limb_from_bits_le_recursive(builder, rs_limbs.into_iter()); + let diff1 = builder.sub_extension(rs_from_bits, virt_raw); + + let diff2 = builder.add_const_extension(rs_from_bits, u32max); + let diff2 = builder.sub_extension(diff2, virt_raw); + + let constr = builder.mul_many_extension(&[filter, diff1, diff2]); + yield_constr.constraint(builder, constr); + + let rt_from_bits = limb_from_bits_le_recursive(builder, rt_limbs.into_iter()); + let diff = builder.sub_extension(rt_from_bits, rt); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + // Constrain mem address + // let virt = virt_raw & 0xFFFF_FFFC; + let mut tmp = rs_limbs.clone(); + tmp[0] = zeros; + tmp[1] = zeros; + let virt = limb_from_bits_le_recursive(builder, tmp.into_iter()); + + let mem_virt = lv.mem_channels[2].addr_virtual; + let diff = builder.sub_extension(virt, mem_virt); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + // Verify op + let op_inv = lv.general.io().diff_inv; + let op = lv.mem_channels[4].value; + let mul = builder.mul_extension(op, op_inv); + let diff = builder.sub_extension(ones, mul); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + // Constrain mem value + // LH: micro_op[0] * sign_extend::<16>((mem >> (16 - (rs & 2) * 8)) & 0xffff) { - let constr = builder.mul_sub_extension(filter, load_channel.used, filter); - yield_constr.constraint(builder, constr); + // Range value(rs[1]): rs[1] == 1 + let mut mem_val_1 = [zeros; 32]; + mem_val_1[0..16].copy_from_slice(&mem_limbs[0..16]); + let mem_val_1 = sign_extend_ext::<_, D, 16>(builder, &mem_val_1); + + // Range value(rs[1]): rs[1] == 0 + let mut mem_val_0 = [zeros; 32]; + mem_val_0[0..16].copy_from_slice(&mem_limbs[16..32]); + let mem_val_0 = sign_extend_ext::<_, D, 16>(builder, &mem_val_0); + + let mem_val_1 = limb_from_bits_le_recursive(builder, mem_val_1.into_iter()); + let mem_val_0 = limb_from_bits_le_recursive(builder, mem_val_0.into_iter()); + + // Range check + // let sum = rs_limbs[1] * (mem - mem_val_1) + (rs_limbs[1] - P::ONES) * (mem - mem_val_0); + // yield_constr.constraint(filter * lv.general.io().micro_op[0] * sum); + let diff1 = builder.sub_extension(mem, mem_val_1); + let diff2 = builder.sub_extension(mem, mem_val_0); + let coff2 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let mult2 = builder.mul_extension(diff2, coff2); + let sum = builder.arithmetic_extension(F::ONES, F::ONES, rs_limbs[1], diff1, mult2); + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[0], sum]); + yield_constr.constraint(builder, mult); } + + // LWL: + // let val = mem << ((rs & 3) * 8); + // let mask = 0xffFFffFFu32 << ((rs & 3) * 8); + // (rt & (!mask)) | val + // Use mem_val_{rs[0]}_{rs[1]} to indicate the mem value for different value on rs' first and + // second bit { - let constr = builder.mul_sub_extension(filter, load_channel.is_read, filter); - yield_constr.constraint(builder, constr); + let mut mem_val_0_0 = [zeros; 32]; + let mut mem_val_0_1 = [zeros; 32]; + let mut mem_val_1_0 = [zeros; 32]; + let mut mem_val_1_1 = [zeros; 32]; + + mem_val_0_0[0..32].copy_from_slice(&mem_limbs[0..32]); + + mem_val_1_0[0..8].copy_from_slice(&rt_limbs[0..8]); + mem_val_1_0[8..].copy_from_slice(&mem_limbs[0..24]); + + mem_val_0_1[0..16].copy_from_slice(&rt_limbs[0..16]); + mem_val_0_1[16..].copy_from_slice(&mem_limbs[0..16]); + + mem_val_1_1[0..24].copy_from_slice(&rt_limbs[0..24]); + mem_val_1_1[24..].copy_from_slice(&mem_limbs[0..8]); + + let mem_val_0_0 = limb_from_bits_le_recursive(builder, mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le_recursive(builder, mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le_recursive(builder, mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le_recursive(builder, mem_val_1_1.into_iter()); + + // let sum = + // (mem - mem_val_0_0) * (rs_limbs[1] - P::ONES) * (rs_limbs[0] - P::ONES) + + // (mem - mem_val_1_0) * (rs_limbs[1] - P::ONES) * rs_limbs[0] + + // (mem - mem_val_0_1) * rs_limbs[1] * (rs_limbs[0] - P::ONES) + + // (mem - mem_val_1_1) * rs_limbs[1] * rs_limbs[0]; + // yield_constr.constraint(filter * lv.general.io().micro_op[1] * sum); + let diff1 = builder.sub_extension(mem, mem_val_0_0); + let diff2 = builder.sub_extension(mem, mem_val_1_0); + let diff3 = builder.sub_extension(mem, mem_val_0_1); + let diff4 = builder.sub_extension(mem, mem_val_1_1); + + let coff10 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let coff11 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum1 = builder.mul_many_extension([diff1, coff10, coff11]); + + let coff21 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum2 = builder.mul_many_extension([diff2, coff21, rs_limbs[0]]); + + let coff30 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let sum3 = builder.mul_many_extension([diff3, rs_limbs[1], coff30]); + + let sum4 = builder.mul_many_extension([diff4, rs_limbs[1], rs_limbs[0]]); + + let sum = builder.add_many_extension([sum1, sum2, sum3, sum4]); + + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[1], sum]); + yield_constr.constraint(builder, mult); } - for (channel_field, target) in izip!( - [ - load_channel.addr_context, - load_channel.addr_segment, - load_channel.addr_virtual, - ], - [addr_context, addr_segment, addr_virtual] - ) { - let diff = builder.sub_extension(channel_field, target); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); + + // LW: + { + let mem_value = limb_from_bits_le_recursive(builder, mem_limbs.into_iter()); + // yield_constr.constraint(filter * lv.general.io().micro_op[2] * (mem - mem_value)); + let diff1 = builder.sub_extension(mem, mem_value); + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[2], diff1]); + yield_constr.constraint(builder, mult); + } + + // LBU: (mem >> (24 - (rs & 3) * 8)) & 0xff + { + let mut mem_val_0_0 = [zeros; 32]; + let mut mem_val_0_1 = [zeros; 32]; + let mut mem_val_1_0 = [zeros; 32]; + let mut mem_val_1_1 = [zeros; 32]; + + mem_val_0_0[0..8].copy_from_slice(&mem_limbs[24..32]); + mem_val_1_0[0..8].copy_from_slice(&mem_limbs[16..24]); + mem_val_0_1[0..8].copy_from_slice(&mem_limbs[8..16]); + mem_val_1_1[0..8].copy_from_slice(&mem_limbs[0..8]); + + let mem_val_0_0 = limb_from_bits_le_recursive(builder, mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le_recursive(builder, mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le_recursive(builder, mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le_recursive(builder, mem_val_1_1.into_iter()); + + /* + //yield_constr.constraint(filter * lv.general.io().micro_op[3] + // * (mem - mem_val_0_0) * (mem - mem_val_0_1) + // * (mem - mem_val_1_0) * (mem - mem_val_1_1)); + let diff1 = builder.sub_extension(mem, mem_val_0_0); + let diff2 = builder.sub_extension(mem, mem_val_0_1); + let diff3 = builder.sub_extension(mem, mem_val_1_0); + let diff4 = builder.sub_extension(mem, mem_val_1_1); + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[3], diff1, diff2, diff3, diff4]); + yield_constr.constraint(builder, mult); + */ + + let diff1 = builder.sub_extension(mem, mem_val_0_0); + let diff2 = builder.sub_extension(mem, mem_val_1_0); + let diff3 = builder.sub_extension(mem, mem_val_0_1); + let diff4 = builder.sub_extension(mem, mem_val_1_1); + + let coff10 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let coff11 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum1 = builder.mul_many_extension([diff1, coff10, coff11]); + + let coff21 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum2 = builder.mul_many_extension([diff2, coff21, rs_limbs[0]]); + + let coff30 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let sum3 = builder.mul_many_extension([diff3, rs_limbs[1], coff30]); + + let sum4 = builder.mul_many_extension([diff4, rs_limbs[1], rs_limbs[0]]); + + let sum = builder.add_many_extension([sum1, sum2, sum3, sum4]); + + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[3], sum]); + yield_constr.constraint(builder, mult); + } + + // LHU: (mem >> (16 - (rs & 2) * 8)) & 0xffff + { + let mut mem_val_0 = [zeros; 32]; + let mut mem_val_1 = [zeros; 32]; + + mem_val_0[0..16].copy_from_slice(&mem_limbs[16..32]); + mem_val_1[0..16].copy_from_slice(&mem_limbs[0..16]); + + let mem_val_1 = limb_from_bits_le_recursive(builder, mem_val_1.into_iter()); + let mem_val_0 = limb_from_bits_le_recursive(builder, mem_val_0.into_iter()); + + let diff1 = builder.sub_extension(mem, mem_val_1); + let diff2 = builder.sub_extension(mem, mem_val_0); + let coff2 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let mult2 = builder.mul_extension(diff2, coff2); + let sum = builder.arithmetic_extension(F::ONES, F::ONES, rs_limbs[1], diff1, mult2); + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[4], sum]); + yield_constr.constraint(builder, mult); + } + + // LWR: + // let val = mem >> (24 - (rs & 3) * 8); + // let mask = 0xffFFffFFu32 >> (24 - (rs & 3) * 8); + // (rt & (!mask)) | val + { + let mut mem_val_0_0 = [zeros; 32]; + let mut mem_val_0_1 = [zeros; 32]; + let mut mem_val_1_0 = [zeros; 32]; + let mut mem_val_1_1 = [zeros; 32]; + + mem_val_0_0[8..].copy_from_slice(&rt_limbs[8..32]); + mem_val_0_0[0..8].copy_from_slice(&mem_limbs[24..32]); + + mem_val_1_0[16..].copy_from_slice(&rt_limbs[16..32]); + mem_val_1_0[0..16].copy_from_slice(&mem_limbs[16..32]); + + mem_val_0_1[24..].copy_from_slice(&rt_limbs[0..8]); + mem_val_0_1[0..24].copy_from_slice(&mem_limbs[8..32]); + + mem_val_1_1[0..32].copy_from_slice(&mem_limbs[..]); + + let mem_val_0_0 = limb_from_bits_le_recursive(builder, mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le_recursive(builder, mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le_recursive(builder, mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le_recursive(builder, mem_val_1_1.into_iter()); + + let diff1 = builder.sub_extension(mem, mem_val_0_0); + let diff2 = builder.sub_extension(mem, mem_val_1_0); + let diff3 = builder.sub_extension(mem, mem_val_0_1); + let diff4 = builder.sub_extension(mem, mem_val_1_1); + + let coff10 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let coff11 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum1 = builder.mul_many_extension([diff1, coff10, coff11]); + + let coff21 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum2 = builder.mul_many_extension([diff2, coff21, rs_limbs[0]]); + + let coff30 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let sum3 = builder.mul_many_extension([diff3, rs_limbs[1], coff30]); + + let sum4 = builder.mul_many_extension([diff4, rs_limbs[1], rs_limbs[0]]); + + let sum = builder.add_many_extension([sum1, sum2, sum3, sum4]); + + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[5], sum]); + yield_constr.constraint(builder, mult); + } + + // LL: + { + let mem_value = limb_from_bits_le_recursive(builder, mem_limbs.into_iter()); + // yield_constr.constraint(filter * lv.general.io().micro_op[6] * (mem - mem_value)); + let diff1 = builder.sub_extension(mem, mem_value); + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[6], diff1]); + yield_constr.constraint(builder, mult); + } + + // LB: sign_extend::<8>((mem >> (24 - (rs & 3) * 8)) & 0xff) + { + let mut mem_val_0_0 = [zeros; 32]; + let mut mem_val_0_1 = [zeros; 32]; + let mut mem_val_1_0 = [zeros; 32]; + let mut mem_val_1_1 = [zeros; 32]; + + mem_val_0_0[0..8].copy_from_slice(&mem_limbs[24..]); + mem_val_1_0[0..8].copy_from_slice(&mem_limbs[16..24]); + mem_val_0_1[0..8].copy_from_slice(&mem_limbs[8..16]); + mem_val_1_1[0..8].copy_from_slice(&mem_limbs[0..8]); + + let mem_val_0_0 = sign_extend_ext::<_, D, 8>(builder, &mem_val_0_0); + let mem_val_1_0 = sign_extend_ext::<_, D, 8>(builder, &mem_val_1_0); + let mem_val_0_1 = sign_extend_ext::<_, D, 8>(builder, &mem_val_0_1); + let mem_val_1_1 = sign_extend_ext::<_, D, 8>(builder, &mem_val_1_1); + + let mem_val_0_0 = limb_from_bits_le_recursive(builder, mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le_recursive(builder, mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le_recursive(builder, mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le_recursive(builder, mem_val_1_1.into_iter()); + + let diff1 = builder.sub_extension(mem, mem_val_0_0); + let diff2 = builder.sub_extension(mem, mem_val_1_0); + let diff3 = builder.sub_extension(mem, mem_val_0_1); + let diff4 = builder.sub_extension(mem, mem_val_1_1); + + let coff10 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let coff11 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum1 = builder.mul_many_extension([diff1, coff10, coff11]); + + let coff21 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum2 = builder.mul_many_extension([diff2, coff21, rs_limbs[0]]); + + let coff30 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let sum3 = builder.mul_many_extension([diff3, rs_limbs[1], coff30]); + + let sum4 = builder.mul_many_extension([diff4, rs_limbs[1], rs_limbs[0]]); + + let sum = builder.add_many_extension([sum1, sum2, sum3, sum4]); + + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[7], sum]); + yield_constr.constraint(builder, mult); } // Disable remaining memory channels, if any. - for &channel in &lv.mem_channels[4..NUM_GP_CHANNELS] { + for &channel in &lv.mem_channels[6..NUM_GP_CHANNELS] { let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); } - - // Stack constraints - /* - stack::eval_ext_circuit_one( - builder, - lv, - nv, - filter, - stack::MLOAD_GENERAL_OP.unwrap(), - yield_constr, - ); - */ } fn eval_packed_store( @@ -107,214 +653,476 @@ fn eval_packed_store( _nv: &CpuColumnsView

, yield_constr: &mut ConstraintConsumer

, ) { - let filter = lv.op.m_op_general * (lv.opcode_bits[0] - P::ONES); + let filter = lv.op.m_op_store * lv.opcode_bits[5]; - let (addr_context, addr_segment, addr_virtual) = get_addr(lv); + // Check mem channel segment is register + let diff = lv.mem_channels[0].addr_segment + - P::Scalar::from_canonical_u64(Segment::RegisterFile as u64); + yield_constr.constraint(filter * diff); + let diff = lv.mem_channels[1].addr_segment + - P::Scalar::from_canonical_u64(Segment::RegisterFile as u64); + yield_constr.constraint(filter * diff); - let value_channel = lv.mem_channels[3]; - let store_channel = lv.mem_channels[4]; - yield_constr.constraint(filter * (store_channel.used - P::ONES)); - yield_constr.constraint(filter * store_channel.is_read); - yield_constr.constraint(filter * (store_channel.addr_context - addr_context)); - yield_constr.constraint(filter * (store_channel.addr_segment - addr_segment)); - yield_constr.constraint(filter * (store_channel.addr_virtual - addr_virtual)); - yield_constr.constraint(filter * (value_channel.value - store_channel.value)); + // Check memory is used + // Check is_read is 0/1 - // Disable remaining memory channels, if any. - for &channel in &lv.mem_channels[5..] { - yield_constr.constraint(filter * channel.used); - } + let rs = lv.mem_channels[0].value; + let rt = lv.mem_channels[1].value; + let mem = lv.mem_channels[3].value; + let rs_limbs = lv.general.io().rs_le; + let rt_limbs = lv.general.io().rt_le; + let mem_limbs = lv.general.io().mem_le; - // Stack constraints. - // Pops. - /* - for i in 1..4 { - let channel = lv.mem_channels[i]; - - yield_constr.constraint(filter * (channel.used - P::ONES)); - yield_constr.constraint(filter * (channel.is_read - P::ONES)); - - yield_constr.constraint(filter * (channel.addr_context - lv.context)); - yield_constr.constraint( - filter * (channel.addr_segment - P::Scalar::from_canonical_u64(Segment::Stack as u64)), - ); - // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. - let addr_virtual = lv.stack_len - P::Scalar::from_canonical_usize(i + 1); - yield_constr.constraint(filter * (channel.addr_virtual - addr_virtual)); - } - // Constrain `stack_inv_aux`. - let len_diff = lv.stack_len - P::Scalar::from_canonical_usize(4); - yield_constr.constraint( - lv.op.m_op_general - * (len_diff * lv.general.stack().stack_inv - lv.general.stack().stack_inv_aux), - ); - // If stack_len != 4 and MSTORE, read new top of the stack in nv.mem_channels[0]. - let top_read_channel = nv.mem_channels[0]; - let is_top_read = lv.general.stack().stack_inv_aux * (P::ONES - lv.opcode_bits[0]); - // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`. + // Calculate rs: + // let virt_raw = (rs as u32).wrapping_add(sign_extend::<16>(offset)); + let offset = load_offset(lv); + let virt_raw = rs + offset; + + let rs_from_bits = limb_from_bits_le(rs_limbs.into_iter()); + let u32max = P::Scalar::from_canonical_u64(1u64 << 32); yield_constr - .constraint(lv.op.m_op_general * (lv.general.stack().stack_inv_aux_2 - is_top_read)); - let new_filter = lv.op.m_op_general * lv.general.stack().stack_inv_aux_2; - yield_constr.constraint_transition(new_filter * (top_read_channel.used - P::ONES)); - yield_constr.constraint_transition(new_filter * (top_read_channel.is_read - P::ONES)); - yield_constr.constraint_transition(new_filter * (top_read_channel.addr_context - nv.context)); - yield_constr.constraint_transition( - new_filter - * (top_read_channel.addr_segment - - P::Scalar::from_canonical_u64(Segment::Stack as u64)), - ); - let addr_virtual = nv.stack_len - P::ONES; - yield_constr.constraint_transition(new_filter * (top_read_channel.addr_virtual - addr_virtual)); - // If stack_len == 4 or MLOAD, disable the channel. - yield_constr.constraint( - lv.op.m_op_general * (lv.general.stack().stack_inv_aux - P::ONES) * top_read_channel.used, - ); - yield_constr.constraint(lv.op.m_op_general * lv.opcode_bits[0] * top_read_channel.used); - */ -} + .constraint(filter * (rs_from_bits - virt_raw) * (rs_from_bits + u32max - virt_raw)); -fn eval_ext_circuit_store, const D: usize>( - builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, - lv: &CpuColumnsView>, - _nv: &CpuColumnsView>, - yield_constr: &mut RecursiveConstraintConsumer, -) { - let filter = - builder.mul_sub_extension(lv.op.m_op_general, lv.opcode_bits[0], lv.op.m_op_general); + let rt_from_bits = limb_from_bits_le(rt_limbs.into_iter()); + yield_constr.constraint(filter * (rt_from_bits - rt)); - let (addr_context, addr_segment, addr_virtual) = get_addr(lv); + // Constrain mem address + // let virt = virt_raw & 0xFFFF_FFFC; + let mut tmp = rs_limbs.clone(); + tmp[0] = P::ZEROS; + tmp[1] = P::ZEROS; + let virt = limb_from_bits_le(tmp.into_iter()); - let value_channel = lv.mem_channels[3]; - let store_channel = lv.mem_channels[4]; + let mem_virt = lv.mem_channels[2].addr_virtual; + yield_constr.constraint(filter * (virt - mem_virt)); + + // Verify op + let op_inv = lv.general.io().diff_inv; + let op = lv.mem_channels[5].value; + yield_constr.constraint(filter * (P::ONES - op * op_inv)); + + // Constrain mem value + // SB: + // let val = (rt & 0xff) << (24 - (rs & 3) * 8); + // let mask = 0xffFFffFFu32 ^ (0xff << (24 - (rs & 3) * 8)); + // (mem & mask) | val { - let constr = builder.mul_sub_extension(filter, store_channel.used, filter); - yield_constr.constraint(builder, constr); + let mut mem_val_0_0 = [P::ZEROS; 32]; + let mut mem_val_1_0 = [P::ZEROS; 32]; + let mut mem_val_0_1 = [P::ZEROS; 32]; + let mut mem_val_1_1 = [P::ZEROS; 32]; + + // rs[0] = 0, rs[1] = 0 + mem_val_0_0[24..].copy_from_slice(&rt_limbs[0..8]); + mem_val_0_0[0..24].copy_from_slice(&mem_limbs[0..24]); + // rs[0] = 1, rs[1] = 0 + mem_val_1_0[24..].copy_from_slice(&mem_limbs[24..]); + mem_val_1_0[16..24].copy_from_slice(&rt_limbs[0..8]); + mem_val_1_0[0..16].copy_from_slice(&mem_limbs[0..16]); + // rs[0] = 0, rs[1] = 1 + mem_val_0_1[16..].copy_from_slice(&mem_limbs[16..]); + mem_val_0_1[8..16].copy_from_slice(&rt_limbs[0..8]); + mem_val_0_1[0..8].copy_from_slice(&mem_limbs[0..8]); + // rs[0] = 1, rs[1] = 1 + mem_val_1_1[0..8].copy_from_slice(&rt_limbs[0..8]); + mem_val_1_1[8..].copy_from_slice(&mem_limbs[8..]); + + let mem_val_0_0 = limb_from_bits_le(mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le(mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le(mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le(mem_val_1_1.into_iter()); + + let sum = (mem - mem_val_0_0) * (rs_limbs[1] - P::ONES) * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_0) * (rs_limbs[1] - P::ONES) * rs_limbs[0] + + (mem - mem_val_0_1) * rs_limbs[1] * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_1) * rs_limbs[1] * rs_limbs[0]; + yield_constr.constraint(filter * lv.general.io().micro_op[0] * sum); } + + // SH + // let val = (rt & 0xffff) << (16 - (rs & 2) * 8); + // let mask = 0xffFFffFFu32 ^ (0xffff << (16 - (rs & 2) * 8)); + // (mem & mask) | val { - let constr = builder.mul_extension(filter, store_channel.is_read); - yield_constr.constraint(builder, constr); - } - for (channel_field, target) in izip!( - [ - store_channel.addr_context, - store_channel.addr_segment, - store_channel.addr_virtual, - ], - [addr_context, addr_segment, addr_virtual] - ) { - let diff = builder.sub_extension(channel_field, target); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - let diff = builder.sub_extension(value_channel.value, store_channel.value); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); + let mut mem_val_0 = [P::ZEROS; 32]; + let mut mem_val_1 = [P::ZEROS; 32]; - // Disable remaining memory channels, if any. - for &channel in &lv.mem_channels[5..] { - let constr = builder.mul_extension(filter, channel.used); - yield_constr.constraint(builder, constr); + mem_val_0[16..].copy_from_slice(&rt_limbs[0..16]); + mem_val_0[0..16].copy_from_slice(&mem_limbs[0..16]); + + mem_val_1[0..16].copy_from_slice(&rt_limbs[0..16]); + mem_val_1[16..].copy_from_slice(&mem_limbs[16..]); + + let mem_val_1 = limb_from_bits_le(mem_val_1.into_iter()); + let mem_val_0 = limb_from_bits_le(mem_val_0.into_iter()); + + let sum = rs_limbs[1] * (mem - mem_val_1) + (rs_limbs[1] - P::ONES) * (mem - mem_val_0); + yield_constr.constraint(filter * lv.general.io().micro_op[1] * sum); } - // Stack constraints - // Pops. - /* - for i in 1..4 { - let channel = lv.mem_channels[i]; - - { - let constr = builder.mul_sub_extension(filter, channel.used, filter); - yield_constr.constraint(builder, constr); - } - { - let constr = builder.mul_sub_extension(filter, channel.is_read, filter); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.sub_extension(channel.addr_context, lv.context); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - { - let diff = builder.add_const_extension( - channel.addr_segment, - -F::from_canonical_u64(Segment::Stack as u64), - ); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); - } - // Remember that the first read (`i == 1`) is for the second stack element at `stack[stack_len - 1]`. - let addr_virtual = - builder.add_const_extension(lv.stack_len, -F::from_canonical_usize(i + 1)); - let diff = builder.sub_extension(channel.addr_virtual, addr_virtual); - let constr = builder.mul_extension(filter, diff); - yield_constr.constraint(builder, constr); + // SWL + // let val = rt >> ((rs & 3) * 8); + // let mask = 0xffFFffFFu32 >> ((rs & 3) * 8); + // (mem & (!mask)) | val + { + let mut mem_val_0_0 = [P::ZEROS; 32]; + let mut mem_val_1_0 = [P::ZEROS; 32]; + let mut mem_val_0_1 = [P::ZEROS; 32]; + let mut mem_val_1_1 = [P::ZEROS; 32]; + + // rs[0] = 0, rs[1] = 0 + mem_val_0_0[..].copy_from_slice(&rt_limbs[..]); + // rs[0] = 1, rs[1] = 0 + mem_val_1_0[0..24].copy_from_slice(&rt_limbs[8..]); + mem_val_1_0[24..].copy_from_slice(&mem_limbs[24..]); + // rs[0] = 0, rs[1] = 1 + mem_val_0_1[0..16].copy_from_slice(&rt_limbs[16..]); + mem_val_0_1[16..].copy_from_slice(&mem_limbs[16..]); + // rs[0] = 1, rs[1] = 1 + mem_val_1_1[0..8].copy_from_slice(&rt_limbs[24..]); + mem_val_1_1[8..].copy_from_slice(&mem_limbs[8..]); + + let mem_val_0_0 = limb_from_bits_le(mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le(mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le(mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le(mem_val_1_1.into_iter()); + + let sum = (mem - mem_val_0_0) * (rs_limbs[1] - P::ONES) * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_0) * (rs_limbs[1] - P::ONES) * rs_limbs[0] + + (mem - mem_val_0_1) * rs_limbs[1] * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_1) * rs_limbs[1] * rs_limbs[0]; + yield_constr.constraint(filter * lv.general.io().micro_op[2] * sum); } - // Constrain `stack_inv_aux`. + + // SW { - let len_diff = builder.add_const_extension(lv.stack_len, -F::from_canonical_usize(4)); - let diff = builder.mul_sub_extension( - len_diff, - lv.general.stack().stack_inv, - lv.general.stack().stack_inv_aux, - ); - let constr = builder.mul_extension(lv.op.m_op_general, diff); - yield_constr.constraint(builder, constr); + let rt_value = limb_from_bits_le(rt_limbs.into_iter()); + yield_constr.constraint(filter * lv.general.io().micro_op[3] * (mem - rt_value)); } - // If stack_len != 4 and MSTORE, read new top of the stack in nv.mem_channels[0]. - let top_read_channel = nv.mem_channels[0]; - let is_top_read = builder.mul_extension(lv.general.stack().stack_inv_aux, lv.opcode_bits[0]); - let is_top_read = builder.sub_extension(lv.general.stack().stack_inv_aux, is_top_read); - // Constrain `stack_inv_aux_2`. It contains `stack_inv_aux * opcode_bits[0]`. + + // SWR + // let val = rt << (24 - (rs & 3) * 8); + // let mask = 0xffFFffFFu32 << (24 - (rs & 3) * 8); + // (mem & (!mask)) | val { - let diff = builder.sub_extension(lv.general.stack().stack_inv_aux_2, is_top_read); - let constr = builder.mul_extension(lv.op.m_op_general, diff); - yield_constr.constraint(builder, constr); + let mut mem_val_0_0 = [P::ZEROS; 32]; + let mut mem_val_1_0 = [P::ZEROS; 32]; + let mut mem_val_0_1 = [P::ZEROS; 32]; + let mut mem_val_1_1 = [P::ZEROS; 32]; + + // rs[0] = 0, rs[1] = 0 + mem_val_0_0[24..].copy_from_slice(&rt_limbs[0..8]); + mem_val_0_0[0..24].copy_from_slice(&mem_limbs[0..24]); + // rs[0] = 1, rs[1] = 0 + mem_val_1_0[16..].copy_from_slice(&rt_limbs[0..16]); + mem_val_1_0[0..16].copy_from_slice(&mem_limbs[0..16]); + // rs[0] = 0, rs[1] = 1 + mem_val_0_1[8..].copy_from_slice(&rt_limbs[0..24]); + mem_val_0_1[0..8].copy_from_slice(&mem_limbs[0..8]); + // rs[0] = 1, rs[1] = 1 + mem_val_1_1[..].copy_from_slice(&rt_limbs[..]); + + let mem_val_0_0 = limb_from_bits_le(mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le(mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le(mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le(mem_val_1_1.into_iter()); + + let sum = (mem - mem_val_0_0) * (rs_limbs[1] - P::ONES) * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_0) * (rs_limbs[1] - P::ONES) * rs_limbs[0] + + (mem - mem_val_0_1) * rs_limbs[1] * (rs_limbs[0] - P::ONES) + + (mem - mem_val_1_1) * rs_limbs[1] * rs_limbs[0]; + yield_constr.constraint(filter * lv.general.io().micro_op[4] * sum); } - let new_filter = builder.mul_extension(lv.op.m_op_general, lv.general.stack().stack_inv_aux_2); + + // SC: + // TODO: write back rt register { - let constr = builder.mul_sub_extension(new_filter, top_read_channel.used, new_filter); - yield_constr.constraint_transition(builder, constr); + let rt_value = limb_from_bits_le(rt_limbs.into_iter()); + yield_constr.constraint(filter * lv.general.io().micro_op[5] * (mem - rt_value)); + } + + // Disable remaining memory channels, if any. + for &channel in &lv.mem_channels[6..] { + yield_constr.constraint(filter * channel.used); } +} + +fn eval_ext_circuit_store, const D: usize>( + builder: &mut CircuitBuilder, + lv: &CpuColumnsView>, + _nv: &CpuColumnsView>, + yield_constr: &mut RecursiveConstraintConsumer, +) { + let zeros = builder.zero_extension(); + let ones = builder.one_extension(); + let filter = builder.mul_extension(lv.op.m_op_store, lv.opcode_bits[5]); + + // Check mem channel segment is register + let diff = builder.add_const_extension( + lv.mem_channels[0].addr_segment, + -F::from_canonical_u64(Segment::RegisterFile as u64), + ); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + let diff = builder.add_const_extension( + lv.mem_channels[1].addr_segment, + -F::from_canonical_u64(Segment::RegisterFile as u64), + ); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + let rs = lv.mem_channels[0].value; + let rt = lv.mem_channels[1].value; + let mem = lv.mem_channels[3].value; + let rs_limbs = lv.general.io().rs_le; + let rt_limbs = lv.general.io().rt_le; + let mem_limbs = lv.general.io().mem_le; + + // Calculate rs: + // let virt_raw = (rs as u32).wrapping_add(sign_extend::<16>(offset)); + let offset = load_offset_ext(builder, lv); + let virt_raw = builder.add_extension(rs, offset); + + let u32max = F::from_canonical_u64(1u64 << 32); + //yield_constr.constraint(filter * (rs_from_bits - virt_raw) * (rs_from_bits + u32max - virt_raw)); + let rs_from_bits = limb_from_bits_le_recursive(builder, rs_limbs.into_iter()); + let diff1 = builder.sub_extension(rs_from_bits, virt_raw); + + let diff2 = builder.add_const_extension(rs_from_bits, u32max); + let diff2 = builder.sub_extension(diff2, virt_raw); + + let constr = builder.mul_many_extension(&[filter, diff1, diff2]); + yield_constr.constraint(builder, constr); + + let rt_from_bits = limb_from_bits_le_recursive(builder, rt_limbs.into_iter()); + let diff = builder.sub_extension(rt_from_bits, rt); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + // Constrain mem address + // let virt = virt_raw & 0xFFFF_FFFC; + let mut tmp = rs_limbs.clone(); + tmp[0] = zeros; + tmp[1] = zeros; + let virt = limb_from_bits_le_recursive(builder, tmp.into_iter()); + + let mem_virt = lv.mem_channels[2].addr_virtual; + let diff = builder.sub_extension(virt, mem_virt); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + // Verify op + let op_inv = lv.general.io().diff_inv; + let op = lv.mem_channels[5].value; + let mul = builder.mul_extension(op, op_inv); + let diff = builder.sub_extension(ones, mul); + let constr = builder.mul_extension(filter, diff); + yield_constr.constraint(builder, constr); + + // Constrain mem value + // SB: + // let val = (rt & 0xff) << (24 - (rs & 3) * 8); + // let mask = 0xffFFffFFu32 ^ (0xff << (24 - (rs & 3) * 8)); + // (mem & mask) | val { - let constr = builder.mul_sub_extension(new_filter, top_read_channel.is_read, new_filter); - yield_constr.constraint_transition(builder, constr); + let mut mem_val_0_0 = [zeros; 32]; + let mut mem_val_1_0 = [zeros; 32]; + let mut mem_val_0_1 = [zeros; 32]; + let mut mem_val_1_1 = [zeros; 32]; + + // rs[0] = 0, rs[1] = 0 + mem_val_0_0[24..].copy_from_slice(&rt_limbs[0..8]); + mem_val_0_0[0..24].copy_from_slice(&mem_limbs[0..24]); + // rs[0] = 1, rs[1] = 0 + mem_val_1_0[24..].copy_from_slice(&mem_limbs[24..]); + mem_val_1_0[16..24].copy_from_slice(&rt_limbs[0..8]); + mem_val_1_0[0..16].copy_from_slice(&mem_limbs[0..16]); + // rs[0] = 0, rs[1] = 1 + mem_val_0_1[16..].copy_from_slice(&mem_limbs[16..]); + mem_val_0_1[8..16].copy_from_slice(&rt_limbs[0..8]); + mem_val_0_1[0..8].copy_from_slice(&mem_limbs[0..8]); + // rs[0] = 1, rs[1] = 1 + mem_val_1_1[0..8].copy_from_slice(&rt_limbs[0..8]); + mem_val_1_1[8..].copy_from_slice(&mem_limbs[8..]); + + let mem_val_0_0 = limb_from_bits_le_recursive(builder, mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le_recursive(builder, mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le_recursive(builder, mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le_recursive(builder, mem_val_1_1.into_iter()); + + let diff1 = builder.sub_extension(mem, mem_val_0_0); + let diff2 = builder.sub_extension(mem, mem_val_1_0); + let diff3 = builder.sub_extension(mem, mem_val_0_1); + let diff4 = builder.sub_extension(mem, mem_val_1_1); + + let coff10 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let coff11 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum1 = builder.mul_many_extension([diff1, coff10, coff11]); + + let coff21 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum2 = builder.mul_many_extension([diff2, coff21, rs_limbs[0]]); + + let coff30 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let sum3 = builder.mul_many_extension([diff3, rs_limbs[1], coff30]); + + let sum4 = builder.mul_many_extension([diff4, rs_limbs[1], rs_limbs[0]]); + + let sum = builder.add_many_extension([sum1, sum2, sum3, sum4]); + + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[0], sum]); + yield_constr.constraint(builder, mult); } + + // SH + // let val = (rt & 0xffff) << (16 - (rs & 2) * 8); + // let mask = 0xffFFffFFu32 ^ (0xffff << (16 - (rs & 2) * 8)); + // (mem & mask) | val { - let diff = builder.sub_extension(top_read_channel.addr_context, nv.context); - let constr = builder.mul_extension(new_filter, diff); - yield_constr.constraint_transition(builder, constr); + let mut mem_val_0 = [zeros; 32]; + let mut mem_val_1 = [zeros; 32]; + + mem_val_0[16..].copy_from_slice(&rt_limbs[0..16]); + mem_val_0[0..16].copy_from_slice(&mem_limbs[0..16]); + + mem_val_1[0..16].copy_from_slice(&rt_limbs[0..16]); + mem_val_1[16..].copy_from_slice(&mem_limbs[16..]); + + let mem_val_1 = limb_from_bits_le_recursive(builder, mem_val_1.into_iter()); + let mem_val_0 = limb_from_bits_le_recursive(builder, mem_val_0.into_iter()); + + let diff1 = builder.sub_extension(mem, mem_val_1); + let diff2 = builder.sub_extension(mem, mem_val_0); + let coff2 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let mult2 = builder.mul_extension(diff2, coff2); + let sum = builder.arithmetic_extension(F::ONES, F::ONES, rs_limbs[1], diff1, mult2); + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[1], sum]); + yield_constr.constraint(builder, mult); } + + // SWL + // let val = rt >> ((rs & 3) * 8); + // let mask = 0xffFFffFFu32 >> ((rs & 3) * 8); + // (mem & (!mask)) | val { - let diff = builder.add_const_extension( - top_read_channel.addr_segment, - -F::from_canonical_u64(Segment::Stack as u64), - ); - let constr = builder.mul_extension(new_filter, diff); - yield_constr.constraint_transition(builder, constr); + let mut mem_val_0_0 = [zeros; 32]; + let mut mem_val_1_0 = [zeros; 32]; + let mut mem_val_0_1 = [zeros; 32]; + let mut mem_val_1_1 = [zeros; 32]; + + // rs[0] = 0, rs[1] = 0 + mem_val_0_0[..].copy_from_slice(&rt_limbs[..]); + // rs[0] = 1, rs[1] = 0 + mem_val_1_0[0..24].copy_from_slice(&rt_limbs[8..]); + mem_val_1_0[24..].copy_from_slice(&mem_limbs[24..]); + // rs[0] = 0, rs[1] = 1 + mem_val_0_1[0..16].copy_from_slice(&rt_limbs[16..]); + mem_val_0_1[16..].copy_from_slice(&mem_limbs[16..]); + // rs[0] = 1, rs[1] = 1 + mem_val_1_1[0..8].copy_from_slice(&rt_limbs[24..]); + mem_val_1_1[8..].copy_from_slice(&mem_limbs[8..]); + + let mem_val_0_0 = limb_from_bits_le_recursive(builder, mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le_recursive(builder, mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le_recursive(builder, mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le_recursive(builder, mem_val_1_1.into_iter()); + + let diff1 = builder.sub_extension(mem, mem_val_0_0); + let diff2 = builder.sub_extension(mem, mem_val_1_0); + let diff3 = builder.sub_extension(mem, mem_val_0_1); + let diff4 = builder.sub_extension(mem, mem_val_1_1); + + let coff10 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let coff11 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum1 = builder.mul_many_extension([diff1, coff10, coff11]); + + let coff21 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum2 = builder.mul_many_extension([diff2, coff21, rs_limbs[0]]); + + let coff30 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let sum3 = builder.mul_many_extension([diff3, rs_limbs[1], coff30]); + + let sum4 = builder.mul_many_extension([diff4, rs_limbs[1], rs_limbs[0]]); + + let sum = builder.add_many_extension([sum1, sum2, sum3, sum4]); + + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[2], sum]); + yield_constr.constraint(builder, mult); } + + // SW { - let addr_virtual = builder.add_const_extension(nv.stack_len, -F::ONE); - let diff = builder.sub_extension(top_read_channel.addr_virtual, addr_virtual); - let constr = builder.mul_extension(new_filter, diff); - yield_constr.constraint_transition(builder, constr); + let rt_value = limb_from_bits_le_recursive(builder, rt_limbs.into_iter()); + //yield_constr.constraint(filter * lv.general.io().micro_op[3] * (mem - rt_value)); + let diff1 = builder.sub_extension(mem, rt_value); + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[3], diff1]); + yield_constr.constraint(builder, mult); } - // If stack_len == 4 or MLOAD, disable the channel. + + // SWR + // let val = rt << (24 - (rs & 3) * 8); + // let mask = 0xffFFffFFu32 << (24 - (rs & 3) * 8); + // (mem & (!mask)) | val { - let diff = builder.mul_sub_extension( - lv.op.m_op_general, - lv.general.stack().stack_inv_aux, - lv.op.m_op_general, - ); - let constr = builder.mul_extension(diff, top_read_channel.used); - yield_constr.constraint(builder, constr); + let mut mem_val_0_0 = [zeros; 32]; + let mut mem_val_1_0 = [zeros; 32]; + let mut mem_val_0_1 = [zeros; 32]; + let mut mem_val_1_1 = [zeros; 32]; + + // rs[0] = 0, rs[1] = 0 + mem_val_0_0[24..].copy_from_slice(&rt_limbs[0..8]); + mem_val_0_0[0..24].copy_from_slice(&mem_limbs[0..24]); + // rs[0] = 1, rs[1] = 0 + mem_val_1_0[16..].copy_from_slice(&rt_limbs[0..16]); + mem_val_1_0[0..16].copy_from_slice(&mem_limbs[0..16]); + // rs[0] = 0, rs[1] = 1 + mem_val_0_1[8..].copy_from_slice(&rt_limbs[0..24]); + mem_val_0_1[0..8].copy_from_slice(&mem_limbs[0..8]); + // rs[0] = 1, rs[1] = 1 + mem_val_1_1[..].copy_from_slice(&rt_limbs[..]); + + let mem_val_0_0 = limb_from_bits_le_recursive(builder, mem_val_0_0.into_iter()); + let mem_val_1_0 = limb_from_bits_le_recursive(builder, mem_val_1_0.into_iter()); + let mem_val_0_1 = limb_from_bits_le_recursive(builder, mem_val_0_1.into_iter()); + let mem_val_1_1 = limb_from_bits_le_recursive(builder, mem_val_1_1.into_iter()); + + let diff1 = builder.sub_extension(mem, mem_val_0_0); + let diff2 = builder.sub_extension(mem, mem_val_1_0); + let diff3 = builder.sub_extension(mem, mem_val_0_1); + let diff4 = builder.sub_extension(mem, mem_val_1_1); + + let coff10 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let coff11 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum1 = builder.mul_many_extension([diff1, coff10, coff11]); + + let coff21 = builder.add_const_extension(rs_limbs[1], -F::ONES); + let sum2 = builder.mul_many_extension([diff2, coff21, rs_limbs[0]]); + + let coff30 = builder.add_const_extension(rs_limbs[0], -F::ONES); + let sum3 = builder.mul_many_extension([diff3, rs_limbs[1], coff30]); + + let sum4 = builder.mul_many_extension([diff4, rs_limbs[1], rs_limbs[0]]); + + let sum = builder.add_many_extension([sum1, sum2, sum3, sum4]); + + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[4], sum]); + yield_constr.constraint(builder, mult); } + + // SC { - let mul = builder.mul_extension(lv.op.m_op_general, lv.opcode_bits[0]); - let constr = builder.mul_extension(mul, top_read_channel.used); + let rt_value = limb_from_bits_le_recursive(builder, rt_limbs.into_iter()); + //yield_constr.constraint(filter * lv.general.io().micro_op[5] * (mem - rt_value)); + let diff1 = builder.sub_extension(mem, rt_value); + let mult = builder.mul_many_extension([filter, lv.general.io().micro_op[5], diff1]); + yield_constr.constraint(builder, mult); + } + + // Disable remaining memory channels, if any. + for &channel in &lv.mem_channels[6..] { + let constr = builder.mul_extension(filter, channel.used); yield_constr.constraint(builder, constr); } - */ } pub fn eval_packed( @@ -327,7 +1135,7 @@ pub fn eval_packed( } pub fn eval_ext_circuit, const D: usize>( - builder: &mut plonky2::plonk::circuit_builder::CircuitBuilder, + builder: &mut CircuitBuilder, lv: &CpuColumnsView>, nv: &CpuColumnsView>, yield_constr: &mut RecursiveConstraintConsumer, diff --git a/src/cpu/syscall.rs b/src/cpu/syscall.rs index dc792533..a6cc410b 100644 --- a/src/cpu/syscall.rs +++ b/src/cpu/syscall.rs @@ -205,7 +205,6 @@ pub fn eval_ext_circuit, const D: usize>( let a2 = lv.mem_channels[3].value; let v0 = builder.zero_extension(); let v1 = builder.zero_extension(); - let one = builder.one_extension(); let syscall = lv.general.syscall(); let result_v0 = lv.mem_channels[4].value; let result_v1 = lv.mem_channels[5].value; diff --git a/src/logic.rs b/src/logic.rs index 865a77f0..4efb4e43 100644 --- a/src/logic.rs +++ b/src/logic.rs @@ -344,30 +344,16 @@ impl, const D: usize> Stark for LogicStark Result<()> { diff --git a/src/stark_testing.rs b/src/stark_testing.rs index 9e47c57b..c79c154c 100644 --- a/src/stark_testing.rs +++ b/src/stark_testing.rs @@ -218,8 +218,8 @@ pub fn test_stark_cpu_check_constraints< let alphas = F::rand_vec(1); let z_last = F::Extension::rand(); - let lagrange_first = F::Extension::rand(); - let lagrange_last = F::Extension::rand(); + let lagrange_first = F::Extension::ZERO; + let lagrange_last = F::Extension::ZERO; let mut consumer = ConstraintConsumer::::new( alphas .iter() diff --git a/src/witness/operation.rs b/src/witness/operation.rs index 20678f1f..9a46e6f9 100644 --- a/src/witness/operation.rs +++ b/src/witness/operation.rs @@ -993,26 +993,82 @@ pub(crate) fn generate_mload_general( let address = MemoryAddress::new(0, Segment::Code, virt as usize); let (mem, log_in3) = mem_read_gp_with_log_and_fill(2, address, state, &mut row); - let rs = virt_raw as u32; + row.general + .io_mut() + .mem_le + .iter_mut() + .enumerate() + .for_each(|(i, v)| { + *v = F::from_canonical_u32((mem >> i) & 1); + }); + + let rs = virt_raw; let rt = rt as u32; + row.general + .io_mut() + .rs_le + .iter_mut() + .enumerate() + .for_each(|(i, v)| { + *v = F::from_canonical_u32((rs >> i) & 1); + }); + row.general + .io_mut() + .rt_le + .iter_mut() + .enumerate() + .for_each(|(i, v)| { + *v = F::from_canonical_u32((rt >> i) & 1); + }); + + let diff = op as u32; + let val = match op { - MemOp::LH => sign_extend::<16>((mem >> (16 - (rs & 2) * 8)) & 0xffff), + MemOp::LH => { + //diff = op as u32 - MemOp::LH as u32; + row.general.io_mut().micro_op[0] = F::ONE; + sign_extend::<16>((mem >> (16 - (rs & 2) * 8)) & 0xffff) + } MemOp::LWL => { + //diff = op as u32 - MemOp::LWL as u32; + row.general.io_mut().micro_op[1] = F::ONE; let val = mem << ((rs & 3) * 8); let mask = 0xffFFffFFu32 << ((rs & 3) * 8); (rt & (!mask)) | val } - MemOp::LW => mem, - MemOp::LBU => (mem >> (24 - (rs & 3) * 8)) & 0xff, - MemOp::LHU => (mem >> (16 - (rs & 2) * 8)) & 0xffff, + MemOp::LW => { + row.general.io_mut().micro_op[2] = F::ONE; + //diff = op as u32 - MemOp::LW as u32; + mem + } + MemOp::LBU => { + //diff = op as u32 - MemOp::LBU as u32; + row.general.io_mut().micro_op[3] = F::ONE; + (mem >> (24 - (rs & 3) * 8)) & 0xff + } + MemOp::LHU => { + //diff = op as u32 - MemOp::LHU as u32; + row.general.io_mut().micro_op[4] = F::ONE; + (mem >> (16 - (rs & 2) * 8)) & 0xffff + } MemOp::LWR => { + //diff = op as u32 - MemOp::LWR as u32; + row.general.io_mut().micro_op[5] = F::ONE; let val = mem >> (24 - (rs & 3) * 8); let mask = 0xffFFffFFu32 >> (24 - (rs & 3) * 8); (rt & (!mask)) | val } - MemOp::LL => mem, - MemOp::LB => sign_extend::<8>((mem >> (24 - (rs & 3) * 8)) & 0xff), + MemOp::LL => { + //diff = op as u32 - MemOp::LL as u32; + row.general.io_mut().micro_op[6] = F::ONE; + mem + } + MemOp::LB => { + //diff = op as u32 - MemOp::LB as u32; + row.general.io_mut().micro_op[7] = F::ONE; + sign_extend::<8>((mem >> (24 - (rs & 3) * 8)) & 0xff) + } _ => todo!(), }; @@ -1022,6 +1078,18 @@ pub(crate) fn generate_mload_general( state.traces.push_memory(log_in2); state.traces.push_memory(log_in3); state.traces.push_memory(log_out0); + + // aux1: op + let log_aux1 = reg_write_with_log(0, 4, op as usize, state, &mut row)?; + state.traces.push_memory(log_aux1); + + let diff = F::from_canonical_u32(diff); + if let Some(inv) = diff.try_inverse() { + row.general.io_mut().diff_inv = inv; + } else { + row.general.io_mut().diff_inv = F::ZERO; + } + state.traces.push_cpu(row); Ok(()) } @@ -1042,32 +1110,75 @@ pub(crate) fn generate_mstore_general( let address = MemoryAddress::new(0, Segment::Code, virt as usize); let (mem, log_in3) = mem_read_gp_with_log_and_fill(2, address, state, &mut row); - let rs = virt_raw as u32; + row.general + .io_mut() + .mem_le + .iter_mut() + .enumerate() + .for_each(|(i, v)| { + *v = F::from_canonical_u32((mem >> i) & 1); + }); + + let rs = virt_raw; let rt = rt as u32; + row.general + .io_mut() + .rs_le + .iter_mut() + .enumerate() + .for_each(|(i, v)| { + *v = F::from_canonical_u32((rs >> i) & 1); + }); + row.general + .io_mut() + .rt_le + .iter_mut() + .enumerate() + .for_each(|(i, v)| { + *v = F::from_canonical_u32((rt >> i) & 1); + }); + let diff = op as u32; + let val = match op { MemOp::SB => { + //diff = op as u32 - MemOp::SB as u32; + row.general.io_mut().micro_op[0] = F::ONE; let val = (rt & 0xff) << (24 - (rs & 3) * 8); let mask = 0xffFFffFFu32 ^ (0xff << (24 - (rs & 3) * 8)); (mem & mask) | val } MemOp::SH => { + //diff = op as u32 - MemOp::SH as u32; + row.general.io_mut().micro_op[1] = F::ONE; let val = (rt & 0xffff) << (16 - (rs & 2) * 8); let mask = 0xffFFffFFu32 ^ (0xffff << (16 - (rs & 2) * 8)); (mem & mask) | val } MemOp::SWL => { + //diff = op as u32 - MemOp::SWL as u32; + row.general.io_mut().micro_op[2] = F::ONE; let val = rt >> ((rs & 3) * 8); let mask = 0xffFFffFFu32 >> ((rs & 3) * 8); (mem & (!mask)) | val } - MemOp::SW => rt, + MemOp::SW => { + //diff = op as u32 - MemOp::SW as u32; + row.general.io_mut().micro_op[3] = F::ONE; + rt + } MemOp::SWR => { + //diff = op as u32 - MemOp::SWR as u32; + row.general.io_mut().micro_op[4] = F::ONE; let val = rt << (24 - (rs & 3) * 8); let mask = 0xffFFffFFu32 << (24 - (rs & 3) * 8); (mem & (!mask)) | val } - MemOp::SC => rt, + MemOp::SC => { + //diff = op as u32 - MemOp::SC as u32; + row.general.io_mut().micro_op[5] = F::ONE; + rt + } _ => todo!(), }; @@ -1084,6 +1195,17 @@ pub(crate) fn generate_mstore_general( state.traces.push_memory(log_out1); } + // aux1: op + let log_aux1 = reg_write_with_log(0, 5, op as usize, state, &mut row)?; + state.traces.push_memory(log_aux1); + + let diff = F::from_canonical_u32(diff); + if let Some(inv) = diff.try_inverse() { + row.general.io_mut().diff_inv = inv; + } else { + row.general.io_mut().diff_inv = F::ZERO; + } + state.traces.push_cpu(row); Ok(()) } diff --git a/src/witness/traces.rs b/src/witness/traces.rs index 4cb6fbb8..32199081 100644 --- a/src/witness/traces.rs +++ b/src/witness/traces.rs @@ -1,6 +1,3 @@ -use std::mem::size_of; - -use itertools::Itertools; use plonky2::field::extension::Extendable; use plonky2::field::polynomial::PolynomialValues; use plonky2::hash::hash_types::RichField; @@ -12,11 +9,10 @@ use crate::arithmetic::{BinaryOperator, Operation}; //use crate::byte_packing::byte_packing_stark::BytePackingOp; use crate::config::StarkConfig; use crate::cpu::columns::CpuColumnsView; -use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES; -use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; + use crate::util::trace_rows_to_poly_values; use crate::witness::memory::MemoryOp; -use crate::{arithmetic, keccak, keccak_sponge, logic}; +use crate::{arithmetic, logic}; #[derive(Clone, Copy, Debug)] pub struct TraceCheckpoint { diff --git a/src/witness/transition.rs b/src/witness/transition.rs index 83a9c2ae..01664661 100644 --- a/src/witness/transition.rs +++ b/src/witness/transition.rs @@ -200,20 +200,20 @@ fn decode(registers: RegistersState, insn: u32) -> Result Ok(Operation::Branch(BranchCond::LE, rs, 0u8, offset)), // BLEZ (0x07, _, _) => Ok(Operation::Branch(BranchCond::GT, rs, 0u8, offset)), // BGTZ + (0b100000, _, _) => Ok(Operation::MloadGeneral(MemOp::LB, rs, rt, offset)), (0b100001, _, _) => Ok(Operation::MloadGeneral(MemOp::LH, rs, rt, offset)), (0b100010, _, _) => Ok(Operation::MloadGeneral(MemOp::LWL, rs, rt, offset)), (0b100011, _, _) => Ok(Operation::MloadGeneral(MemOp::LW, rs, rt, offset)), (0b100100, _, _) => Ok(Operation::MloadGeneral(MemOp::LBU, rs, rt, offset)), (0b100101, _, _) => Ok(Operation::MloadGeneral(MemOp::LHU, rs, rt, offset)), (0b100110, _, _) => Ok(Operation::MloadGeneral(MemOp::LWR, rs, rt, offset)), + (0b110000, _, _) => Ok(Operation::MloadGeneral(MemOp::LL, rs, rt, offset)), (0b101000, _, _) => Ok(Operation::MstoreGeneral(MemOp::SB, rs, rt, offset)), (0b101001, _, _) => Ok(Operation::MstoreGeneral(MemOp::SH, rs, rt, offset)), (0b101010, _, _) => Ok(Operation::MstoreGeneral(MemOp::SWL, rs, rt, offset)), (0b101011, _, _) => Ok(Operation::MstoreGeneral(MemOp::SW, rs, rt, offset)), (0b101110, _, _) => Ok(Operation::MstoreGeneral(MemOp::SWR, rs, rt, offset)), - (0b110000, _, _) => Ok(Operation::MloadGeneral(MemOp::LL, rs, rt, offset)), (0b111000, _, _) => Ok(Operation::MstoreGeneral(MemOp::SC, rs, rt, offset)), - (0b100000, _, _) => Ok(Operation::MloadGeneral(MemOp::LB, rs, rt, offset)), (0b001000, _, _) => Ok(Operation::BinaryArithmeticImm( arithmetic::BinaryOperator::ADDI, @@ -305,7 +305,8 @@ fn fill_op_flag(op: Operation, row: &mut CpuColumnsView) { Operation::GetContext => &mut flags.get_context, Operation::SetContext => &mut flags.set_context, Operation::ExitKernel => &mut flags.exit_kernel, - Operation::MloadGeneral(..) | Operation::MstoreGeneral(..) => &mut flags.m_op_general, + Operation::MloadGeneral(..) => &mut flags.m_op_load, + Operation::MstoreGeneral(..) => &mut flags.m_op_store, } = F::ONE; } diff --git a/src/witness/util.rs b/src/witness/util.rs index ec43d2c0..b6039860 100644 --- a/src/witness/util.rs +++ b/src/witness/util.rs @@ -1,16 +1,12 @@ use plonky2::field::types::Field; use crate::cpu::columns::CpuColumnsView; -use crate::cpu::kernel::keccak_util::keccakf_u8s; -use crate::cpu::membus::NUM_CHANNELS; + use crate::generation::state::GenerationState; -use crate::keccak_sponge::columns::{KECCAK_RATE_BYTES, KECCAK_WIDTH_BYTES}; -use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp; -use crate::logic; + use crate::memory::segments::Segment; use crate::witness::errors::ProgramError; use crate::witness::memory::{MemoryAddress, MemoryChannel, MemoryOp, MemoryOpKind}; -use byteorder::{ByteOrder, LittleEndian}; fn to_byte_checked(n: u32) -> u8 { let res: u8 = n.to_le_bytes()[0];