Skip to content

Commit

Permalink
fix: stash code
Browse files Browse the repository at this point in the history
  • Loading branch information
eigmax committed Nov 25, 2023
1 parent afce9c1 commit dd730d1
Showing 1 changed file with 151 additions and 10 deletions.
161 changes: 151 additions & 10 deletions src/cpu/memio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ fn load_offset<P: PackedField>(lv: &CpuColumnsView<P>) -> P {
limb_from_bits_le(mem_offset.into_iter())
}

#[inline]
fn load_offset_ext<F: RichField + Extendable<D>, const D: usize>(builder: &mut CircuitBuilder<F, D>, lv: &CpuColumnsView<ExtensionTarget<D>>) -> ExtensionTarget<D> {
let mut mem_offset = [builder.zero_extension(); 32];
mem_offset[0..16].copy_from_slice(&vec![lv.func_bits[0]; 16]); // fill in MSB
mem_offset[16..22].copy_from_slice(&lv.func_bits); // 6 bits
mem_offset[22..27].copy_from_slice(&lv.shamt_bits); // 5 bits
mem_offset[27..32].copy_from_slice(&lv.rd_bits); // 5 bits
limb_from_bits_le_recursive(builder, mem_offset.into_iter())
}

#[inline]
fn sign_extend<P: PackedField, const N: usize>(limbs: &mut [P; 32]) {
let mut tmp = limbs.clone();
Expand All @@ -37,8 +47,18 @@ fn sign_extend<P: PackedField, const N: usize>(limbs: &mut [P; 32]) {
}
}

#[inline]
fn sign_extend_ext<F: RichField + Extendable<D>, const D: usize, const N: usize>(builder: &mut CircuitBuilder<F, D>, limbs: &mut [ExtensionTarget<D>; 32]) {
let mut tmp = limbs.clone();
for i in 0..N {
limbs[i] = tmp[N];
}
for i in N..32 {
limbs[i] = tmp[i];
}
}

/// Binary shift limbs by offset
///
#[inline]
fn shift<P: PackedField>(limbs: &mut [P; 32], offset: &P, left: bool) {
let mut tmp = limbs.clone();
Expand All @@ -63,6 +83,32 @@ fn shift<P: PackedField>(limbs: &mut [P; 32], offset: &P, left: bool) {
}
}

/// Binary shift limbs by offset
#[inline]
fn shift_ext<F: RichField + Extendable<D>, const D: usize>(builder: &mut CircuitBuilder<F, D>, limbs: &mut [ExtensionTarget<D>; 32], offset: &F, left: bool) {
let zeros = builder.zero_extension();
let mut tmp = limbs.clone();
let offset = F::as_slice(offset)[0].to_string();
let offset: usize = offset.parse().unwrap();
if left {
for i in 0..32 {
if i + offset < 32 {
limbs[i] = tmp[i+offset];
} else {
limbs[i] = zeros;
}
}
} else {
for i in 0..32 {
if i - offset >= 0 {
limbs[i] = tmp[i - offset];
} else {
limbs[i] = zeros;
}
}
}
}

#[inline]
fn not<P: PackedField>(limbs: &mut [P; 32]) {
for i in 0..32 {
Expand All @@ -77,6 +123,32 @@ fn and<P: PackedField>(x: &mut [P; 32], y: &[P; 32]) {
}
}

/// For binary x, y, x + y - 2xy
#[inline]
fn xor_ext<F: RichField + Extendable<D>, const D: usize>(builder: &mut CircuitBuilder<F, D>, x: &mut [ExtensionTarget<D>; 32], y: &[ExtensionTarget<D>; 32]) {
for i in 0..32 {
let add_tmp = builder.add_extension(x[i], y[i]);
let mul_tmp = builder.mul_extension(x[i], y[i]);
let mul_tmp = builder.add_extension(mul_tmp, mul_tmp);
x[i] = builder.sub_extension(add_tmp, mul_tmp);
}
}

#[inline]
fn not_ext<F: RichField + Extendable<D>, const D: usize>(builder: &mut CircuitBuilder<F, D>, limbs: &mut [ExtensionTarget<D>; 32]) {
let ones = builder.one_extension();
for i in 0..32 {
limbs[i] = builder.sub_extension(ones, limbs[i]);
}
}

#[inline]
fn and_ext<F: RichField + Extendable<D>, const D: usize>(builder: &mut CircuitBuilder<F, D>, x: &mut [ExtensionTarget<D>; 32], y: &[ExtensionTarget<D>; 32]) {
for i in 0..32 {
x[i] = builder.mul_extension(x[i], y[i]);
}
}

/// For binary x, y, x + y - 2xy
#[inline]
fn xor<P: PackedField>(x: &mut [P; 32], y: &[P; 32]) {
Expand Down Expand Up @@ -112,21 +184,21 @@ fn eval_packed_load<P: PackedField>(
// If the operation is MLOAD_GENERAL, lv.opcode_bits[5] = 1
let filter = lv.op.m_op_general * lv.opcode_bits[5];

// check mem channel segment is register
// Check mem channel segment is register
let diff = lv.mem_channels[0].addr_segment
- P::Scalar::from_canonical_u64(Segment::RegisterFile as u64);
yield_constr.constraint(filter * diff);
let diff = lv.mem_channels[1].addr_segment
- P::Scalar::from_canonical_u64(Segment::RegisterFile as u64);
yield_constr.constraint(filter * diff);

// check memory is used
// check is_read is 0/1
// Check memory is used
// Check is_read is 0/1

let rs = lv.mem_channels[0].value;
let rt = lv.mem_channels[1].value;

// calculate rs:
// Calculate rs:
// let virt_raw = (rs as u32).wrapping_add(sign_extend::<16>(offset));
let offset = load_offset(lv);
let virt_raw = rs + offset;
Expand Down Expand Up @@ -169,15 +241,16 @@ fn eval_packed_load<P: PackedField>(
// where 16 - (rs & 2) * 8 = 8 * (2 - (rs & 2)) = 8 * 2 * (1 - rs[30])
let mut rs_tmp = lv.general.io().rs_le.clone();
let mut mem_tmp = lv.general.io().mem_le.clone();
rs_tmp[30] = (P::ONES - rs_tmp[30]) * P::Scalar::from_canonical_u32(16);

rs_tmp[30] = P::ONES - rs_tmp[30];
shift(&mut rs_tmp, & {P::Scalar::from_canonical_u32(16).into()}, true);
shift(&mut mem_tmp, &limb_from_bits_le(rs_tmp.into_iter()), false);
and(&mut mem_tmp, &bits_ffff);
sign_extend::<_, 16>(&mut mem_tmp);
let mem_value = limb_from_bits_le(mem_tmp.into_iter());
yield_constr.constraint(filter * lv.general.io().micro_op[0] * (mem - mem_value));



// Disable remaining memory channels, if any.
// Note: SC needs 5 channel
for &channel in &lv.mem_channels[7..NUM_GP_CHANNELS] {
Expand All @@ -191,6 +264,8 @@ fn eval_ext_circuit_load<F: RichField + Extendable<D>, const D: usize>(
_nv: &CpuColumnsView<ExtensionTarget<D>>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let zeros = builder.zero_extension();
let ones = builder.one_extension();
let mut filter = lv.op.m_op_general;
filter = builder.mul_extension(filter, lv.opcode_bits[5]);

Expand All @@ -208,13 +283,79 @@ fn eval_ext_circuit_load<F: RichField + Extendable<D>, const D: usize>(
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);

let rs = lv.mem_channels[0].value;
let rt = lv.mem_channels[1].value;
let offset = load_offset_ext(builder, lv);
let virt_raw = builder.add_extension(rs, offset);
let rs_from_bits = limb_from_bits_le_recursive(builder, lv.general.io().rs_le.into_iter());
let diff = builder.sub_extension(rs_from_bits, virt_raw);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);

let rt_from_bits = limb_from_bits_le_recursive(builder, lv.general.io().rt_le.into_iter());
let diff = builder.sub_extension(rt_from_bits, rt);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);

// constraint mem address
// let virt = virt_raw & 0xFFFF_FFFC;
let mut tmp = lv.general.io().rs_le.clone();
tmp[30] = zeros;
tmp[31] = zeros;
let virt = limb_from_bits_le_recursive(builder, tmp.into_iter());
let mem_virt = lv.mem_channels[2].addr_virtual;
let diff = builder.sub_extension(virt, mem_virt);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);

// Verify op
let op_inv = lv.general.io().diff_inv;
let op = lv.mem_channels[6].value;
let mul = builder.mul_extension(op, op_inv);
let diff = builder.sub_extension(ones, mul);
let constr = builder.mul_extension(filter, diff);
yield_constr.constraint(builder, constr);

// Constrain mem value
let mem = lv.mem_channels[2].value;

let mut bits_2 = [zeros; 32];
bits_2[30] = ones;
let mut bits_3 = [zeros; 32];
bits_3[30] = ones;
bits_3[31] = ones;

let mut bits_ffff = [zeros; 32];
for i in 16..32 {
bits_ffff[i] = ones;
}

// LH: micro_op[0] * sign_extend::<16>((mem >> (16 - (rs & 2) * 8)) & 0xffff)
// where 16 - (rs & 2) * 8 = 8 * (2 - (rs & 2)) = 8 * 2 * (1 - rs[30])
let mut rs_tmp = lv.general.io().rs_le.clone();
let mut mem_tmp = lv.general.io().mem_le.clone();

rs_tmp[30] = builder.sub_extension(ones, rs_tmp[30]);
let _16 = F::from_canonical_u32(16);
shift_ext(builder, &mut rs_tmp, &_16, true);
let tmp = limb_from_bits_le_recursive(builder, rs_tmp.into_iter());
shift_ext(builder, &mut mem_tmp, &tmp, false);

and_ext(builder, &mut mem_tmp, &bits_ffff);
sign_extend_ext::<_, D, 16>(builder, &mut mem_tmp);
let mem_value = limb_from_bits_le_recursive(builder, mem_tmp.into_iter());

let diff = builder.sub_extension(mem, mem_value);
let mul_1 = builder.mul_extension(filter, lv.general.io().micro_op[0]);
let constr = builder.mul_extension(mul_1, diff);
yield_constr.constraint(builder, constr);


// Disable remaining memory channels, if any.
/*
for &channel in &lv.mem_channels[5..NUM_GP_CHANNELS] {
for &channel in &lv.mem_channels[5..NUM_GP_CHANNELS] {
let constr = builder.mul_extension(filter, channel.used);
yield_constr.constraint(builder, constr);
}
*/
}

fn eval_packed_store<P: PackedField>(
Expand Down

0 comments on commit dd730d1

Please sign in to comment.