diff --git a/mopro-msm/src/msm/metal_msm/host/shader.rs b/mopro-msm/src/msm/metal_msm/host/shader.rs index 2ce0e5c..dc9d665 100644 --- a/mopro-msm/src/msm/metal_msm/host/shader.rs +++ b/mopro-msm/src/msm/metal_msm/host/shader.rs @@ -111,6 +111,7 @@ pub fn write_constants( ) { let two_pow_word_size = 2u32.pow(log_limb_size); let mask = two_pow_word_size - 1u32; + let slack = num_limbs as u32 * log_limb_size - BaseField::MODULUS_BIT_SIZE; let num_limbs_wide = num_limbs + 1; let basefield_modulus = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size); let r = calc_mont_radix(num_limbs, log_limb_size); @@ -132,6 +133,7 @@ pub fn write_constants( data += format!("#define MASK {}\n", mask).as_str(); data += format!("#define N0 {}\n", n0).as_str(); data += format!("#define NSAFE {}\n", nsafe).as_str(); + data += format!("#define SLACK {}\n", slack).as_str(); write_constant_array!( data, @@ -182,7 +184,6 @@ pub fn write_constants( ); let p: BigUint = BaseField::MODULUS.try_into().unwrap(); - let (rinv, n0) = calc_rinv_and_n0(&p, &r, log_limb_size); let (bn254_zero_xr_limbs, bn254_zero_yr_limbs, bn254_zero_zr_limbs) = { let bn254_zero_x: BigUint = bn254_zero.x.into(); let bn254_zero_y: BigUint = bn254_zero.y.into(); diff --git a/mopro-msm/src/msm/metal_msm/utils/limbs_conversion.rs b/mopro-msm/src/msm/metal_msm/utils/limbs_conversion.rs index 3d087c1..e71d4a5 100644 --- a/mopro-msm/src/msm/metal_msm/utils/limbs_conversion.rs +++ b/mopro-msm/src/msm/metal_msm/utils/limbs_conversion.rs @@ -1,9 +1,9 @@ -use ark_ff::biginteger::{BigInteger, BigInteger256, BigInteger384}; +use ark_ff::biginteger::{BigInt, BigInteger}; use std::convert::TryInto; /// A trait that abstracts "to/from limbs" for *any* BigInteger type pub trait GenericLimbConversion { - /// The number of 64-bit words in this BigInteger (e.g., 4 for BigInteger256) + /// The number of 64-bit words in this BigInteger (e.g., 4 for BigInt<4>) const NUM_WORDS: usize; /// Convert to big-endian `u32` limbs of length `2 * NUM_WORDS`. @@ -26,12 +26,12 @@ pub trait GenericLimbConversion { fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self; } -impl GenericLimbConversion for BigInteger256 { +impl GenericLimbConversion for BigInt<4> { const NUM_WORDS: usize = 4; // 4 x 64-bit words fn to_u32_limbs(&self) -> Vec { let mut limbs = Vec::new(); - // BigInteger256::to_bytes_be() gives us 32 bytes in BE + // BigInt<4>::to_bytes_be() gives us 32 bytes in BE self.to_bytes_be().chunks(8).for_each(|chunk| { let high = u32::from_be_bytes(chunk[0..4].try_into().unwrap()); let low = u32::from_be_bytes(chunk[4..8].try_into().unwrap()); @@ -81,17 +81,17 @@ impl GenericLimbConversion for BigInteger256 { let low = u64::from(limb_pair[1]); big_int[i] = (high << 32) | low; } - BigInteger256::new(big_int) + BigInt::<4>::new(big_int) } fn from_u128(num: u128) -> Self { let high = (num >> 64) as u64; let low = num as u64; - BigInteger256::new([low, high, 0, 0]) + BigInt::<4>::new([low, high, 0, 0]) } fn from_u32(num: u32) -> Self { - BigInteger256::new([num as u64, 0, 0, 0]) + BigInt::<4>::new([num as u64, 0, 0, 0]) } fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self { @@ -116,11 +116,11 @@ impl GenericLimbConversion for BigInteger256 { if accumulated_bits > 0 && result_idx < 4 { result[result_idx] = current_u64; } - BigInteger256::new(result) + BigInt::<4>::new(result) } } -impl GenericLimbConversion for BigInteger384 { +impl GenericLimbConversion for BigInt<6> { const NUM_WORDS: usize = 6; // 6 x 64-bit words fn to_u32_limbs(&self) -> Vec { @@ -172,17 +172,17 @@ impl GenericLimbConversion for BigInteger384 { let low = u64::from(limb_pair[1]); big_int[i] = (high << 32) | low; } - BigInteger384::new(big_int) + BigInt::<6>::new(big_int) } fn from_u128(num: u128) -> Self { let high = (num >> 64) as u64; let low = num as u64; - BigInteger384::new([low, high, 0, 0, 0, 0]) + BigInt::<6>::new([low, high, 0, 0, 0, 0]) } fn from_u32(num: u32) -> Self { - BigInteger384::new([num as u64, 0, 0, 0, 0, 0]) + BigInt::<6>::new([num as u64, 0, 0, 0, 0, 0]) } fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self { @@ -206,7 +206,96 @@ impl GenericLimbConversion for BigInteger384 { if accumulated_bits > 0 && result_idx < 6 { result[result_idx] = current_u64; } - BigInteger384::new(result) + BigInt::<6>::new(result) + } +} + +impl GenericLimbConversion for BigInt<8> { + const NUM_WORDS: usize = 8; // 8 x 64-bit words + + fn to_u32_limbs(&self) -> Vec { + let mut limbs = Vec::new(); + self.to_bytes_be().chunks(8).for_each(|chunk| { + let high = u32::from_be_bytes(chunk[0..4].try_into().unwrap()); + let low = u32::from_be_bytes(chunk[4..8].try_into().unwrap()); + limbs.push(high); + limbs.push(low); + }); + limbs + } + + fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec { + let mut result = vec![0u32; num_limbs]; + let limb_size = 1u32 << log_limb_size; + let mask = limb_size - 1; + + let bytes = self.to_bytes_le(); + let mut val = 0u32; + let mut bits = 0u32; + let mut limb_idx = 0; + + for &byte in bytes.iter() { + if limb_idx >= num_limbs { + break; + } + val |= (byte as u32) << bits; + bits += 8; + + while bits >= log_limb_size && limb_idx < num_limbs { + result[limb_idx] = val & mask; + val >>= log_limb_size; + bits -= log_limb_size; + limb_idx += 1; + } + } + if bits > 0 && limb_idx < num_limbs { + result[limb_idx] = val; + } + result + } + + fn from_u32_limbs(limbs: &[u32]) -> Self { + let mut big_int = [0u64; 8]; + for (i, limb_pair) in limbs.chunks(2).rev().enumerate() { + let high = u64::from(limb_pair[0]); + let low = u64::from(limb_pair[1]); + big_int[i] = (high << 32) | low; + } + BigInt::<8>::new(big_int) + } + + fn from_u128(num: u128) -> Self { + let high = (num >> 64) as u64; + let low = num as u64; + BigInt::<8>::new([low, high, 0, 0, 0, 0, 0, 0]) + } + + fn from_u32(num: u32) -> Self { + BigInt::<8>::new([num as u64, 0, 0, 0, 0, 0, 0, 0]) + } + + fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self { + let mut result = [0u64; 8]; + let limb_bits = log_limb_size as usize; + let mut accumulated_bits = 0; + let mut current_u64 = 0u64; + let mut result_idx = 0; + + for &limb in limbs { + current_u64 |= (limb as u64) << accumulated_bits; + accumulated_bits += limb_bits; + + while accumulated_bits >= 64 && result_idx < 8 { + result[result_idx] = current_u64; + current_u64 = (limb as u64) >> (limb_bits - (accumulated_bits - 64)); + accumulated_bits -= 64; + result_idx += 1; + } + } + if accumulated_bits > 0 && result_idx < 8 { + result[result_idx] = current_u64; + } + BigInt::<8>::new(result) } } @@ -217,6 +306,7 @@ mod tests { use crate::msm::metal_msm::utils::mont_params::calc_mont_radix; use ark_bn254::Fq as BaseField; use ark_ff::{BigInt, PrimeField}; + use num_bigint::{BigUint, RandBigInt}; #[test] fn test_within_bigint256() { @@ -224,7 +314,7 @@ mod tests { let log_limb_size = 16; let p_limbs = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size); - let p_bigint256 = BigInteger256::from_limbs(&p_limbs, log_limb_size); + let p_bigint256 = BigInt::<4>::from_limbs(&p_limbs, log_limb_size); assert_eq!(BaseField::MODULUS, p_bigint256); } @@ -238,9 +328,30 @@ mod tests { let r = calc_mont_radix(num_limbs, log_limb_size); // r has 257 bits let r_bigint384: BigInt<6> = r.try_into().unwrap(); let r_limbs = r_bigint384.to_limbs(num_limbs_wide, log_limb_size); - let r_reconstructed = BigInteger384::from_limbs(&r_limbs, log_limb_size); + let r_reconstructed = BigInt::<6>::from_limbs(&r_limbs, log_limb_size); // Check if the original and reconstructed values are equal assert_eq!(r_bigint384, r_reconstructed); } + + #[test] + fn test_within_bigint512() { + let num_limbs = 16; + let num_limbs_extra_wide = num_limbs * 2; + let log_limb_size = 16; + + let mut rng = rand::thread_rng(); + let p: BigUint = BaseField::MODULUS.try_into().unwrap(); + let a = rng.gen_biguint_below(&p); // a has at most 254 bits + let r = calc_mont_radix(num_limbs, log_limb_size); // r has 257 bits + + let mont_a = &a * &r; // mont_a has at most 511 bits + + let mont_a_bigint512: BigInt<8> = mont_a.try_into().unwrap(); + let mont_a_limbs = mont_a_bigint512.to_limbs(num_limbs_extra_wide, log_limb_size); + let mont_a_reconstructed = BigInt::<8>::from_limbs(&mont_a_limbs, log_limb_size); + + // Check if the original and reconstructed values are equal + assert_eq!(mont_a_bigint512, mont_a_reconstructed); + } }