Skip to content

Commit

Permalink
Merge pull request #42 from zkmopro/refactor/limb_conversion
Browse files Browse the repository at this point in the history
refactor limb conversion: update GenericLimbConversion to BigInt (incl. BigInt<8> support)
  • Loading branch information
moven0831 authored Feb 13, 2025
2 parents aca53d1 + f2a7505 commit 21ac3c0
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 16 deletions.
3 changes: 2 additions & 1 deletion mopro-msm/src/msm/metal_msm/host/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ pub fn write_constants(
) {
let two_pow_word_size = 2u32.pow(log_limb_size);
let mask = two_pow_word_size - 1u32;
let slack = num_limbs as u32 * log_limb_size - BaseField::MODULUS_BIT_SIZE;
let num_limbs_wide = num_limbs + 1;
let basefield_modulus = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size);
let r = calc_mont_radix(num_limbs, log_limb_size);
Expand All @@ -132,6 +133,7 @@ pub fn write_constants(
data += format!("#define MASK {}\n", mask).as_str();
data += format!("#define N0 {}\n", n0).as_str();
data += format!("#define NSAFE {}\n", nsafe).as_str();
data += format!("#define SLACK {}\n", slack).as_str();

write_constant_array!(
data,
Expand Down Expand Up @@ -182,7 +184,6 @@ pub fn write_constants(
);

let p: BigUint = BaseField::MODULUS.try_into().unwrap();
let (rinv, n0) = calc_rinv_and_n0(&p, &r, log_limb_size);
let (bn254_zero_xr_limbs, bn254_zero_yr_limbs, bn254_zero_zr_limbs) = {
let bn254_zero_x: BigUint = bn254_zero.x.into();
let bn254_zero_y: BigUint = bn254_zero.y.into();
Expand Down
141 changes: 126 additions & 15 deletions mopro-msm/src/msm/metal_msm/utils/limbs_conversion.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use ark_ff::biginteger::{BigInteger, BigInteger256, BigInteger384};
use ark_ff::biginteger::{BigInt, BigInteger};
use std::convert::TryInto;

/// A trait that abstracts "to/from limbs" for *any* BigInteger type
pub trait GenericLimbConversion {
/// The number of 64-bit words in this BigInteger (e.g., 4 for BigInteger256)
/// The number of 64-bit words in this BigInteger (e.g., 4 for BigInt<4>)
const NUM_WORDS: usize;

/// Convert to big-endian `u32` limbs of length `2 * NUM_WORDS`.
Expand All @@ -26,12 +26,12 @@ pub trait GenericLimbConversion {
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self;
}

impl GenericLimbConversion for BigInteger256 {
impl GenericLimbConversion for BigInt<4> {
const NUM_WORDS: usize = 4; // 4 x 64-bit words

fn to_u32_limbs(&self) -> Vec<u32> {
let mut limbs = Vec::new();
// BigInteger256::to_bytes_be() gives us 32 bytes in BE
// BigInt<4>::to_bytes_be() gives us 32 bytes in BE
self.to_bytes_be().chunks(8).for_each(|chunk| {
let high = u32::from_be_bytes(chunk[0..4].try_into().unwrap());
let low = u32::from_be_bytes(chunk[4..8].try_into().unwrap());
Expand Down Expand Up @@ -81,17 +81,17 @@ impl GenericLimbConversion for BigInteger256 {
let low = u64::from(limb_pair[1]);
big_int[i] = (high << 32) | low;
}
BigInteger256::new(big_int)
BigInt::<4>::new(big_int)
}

fn from_u128(num: u128) -> Self {
let high = (num >> 64) as u64;
let low = num as u64;
BigInteger256::new([low, high, 0, 0])
BigInt::<4>::new([low, high, 0, 0])
}

fn from_u32(num: u32) -> Self {
BigInteger256::new([num as u64, 0, 0, 0])
BigInt::<4>::new([num as u64, 0, 0, 0])
}

fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
Expand All @@ -116,11 +116,11 @@ impl GenericLimbConversion for BigInteger256 {
if accumulated_bits > 0 && result_idx < 4 {
result[result_idx] = current_u64;
}
BigInteger256::new(result)
BigInt::<4>::new(result)
}
}

impl GenericLimbConversion for BigInteger384 {
impl GenericLimbConversion for BigInt<6> {
const NUM_WORDS: usize = 6; // 6 x 64-bit words

fn to_u32_limbs(&self) -> Vec<u32> {
Expand Down Expand Up @@ -172,17 +172,17 @@ impl GenericLimbConversion for BigInteger384 {
let low = u64::from(limb_pair[1]);
big_int[i] = (high << 32) | low;
}
BigInteger384::new(big_int)
BigInt::<6>::new(big_int)
}

fn from_u128(num: u128) -> Self {
let high = (num >> 64) as u64;
let low = num as u64;
BigInteger384::new([low, high, 0, 0, 0, 0])
BigInt::<6>::new([low, high, 0, 0, 0, 0])
}

fn from_u32(num: u32) -> Self {
BigInteger384::new([num as u64, 0, 0, 0, 0, 0])
BigInt::<6>::new([num as u64, 0, 0, 0, 0, 0])
}

fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
Expand All @@ -206,7 +206,96 @@ impl GenericLimbConversion for BigInteger384 {
if accumulated_bits > 0 && result_idx < 6 {
result[result_idx] = current_u64;
}
BigInteger384::new(result)
BigInt::<6>::new(result)
}
}

impl GenericLimbConversion for BigInt<8> {
const NUM_WORDS: usize = 8; // 8 x 64-bit words

fn to_u32_limbs(&self) -> Vec<u32> {
let mut limbs = Vec::new();
self.to_bytes_be().chunks(8).for_each(|chunk| {
let high = u32::from_be_bytes(chunk[0..4].try_into().unwrap());
let low = u32::from_be_bytes(chunk[4..8].try_into().unwrap());
limbs.push(high);
limbs.push(low);
});
limbs
}

fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32> {
let mut result = vec![0u32; num_limbs];
let limb_size = 1u32 << log_limb_size;
let mask = limb_size - 1;

let bytes = self.to_bytes_le();
let mut val = 0u32;
let mut bits = 0u32;
let mut limb_idx = 0;

for &byte in bytes.iter() {
if limb_idx >= num_limbs {
break;
}
val |= (byte as u32) << bits;
bits += 8;

while bits >= log_limb_size && limb_idx < num_limbs {
result[limb_idx] = val & mask;
val >>= log_limb_size;
bits -= log_limb_size;
limb_idx += 1;
}
}
if bits > 0 && limb_idx < num_limbs {
result[limb_idx] = val;
}
result
}

fn from_u32_limbs(limbs: &[u32]) -> Self {
let mut big_int = [0u64; 8];
for (i, limb_pair) in limbs.chunks(2).rev().enumerate() {
let high = u64::from(limb_pair[0]);
let low = u64::from(limb_pair[1]);
big_int[i] = (high << 32) | low;
}
BigInt::<8>::new(big_int)
}

fn from_u128(num: u128) -> Self {
let high = (num >> 64) as u64;
let low = num as u64;
BigInt::<8>::new([low, high, 0, 0, 0, 0, 0, 0])
}

fn from_u32(num: u32) -> Self {
BigInt::<8>::new([num as u64, 0, 0, 0, 0, 0, 0, 0])
}

fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
let mut result = [0u64; 8];
let limb_bits = log_limb_size as usize;
let mut accumulated_bits = 0;
let mut current_u64 = 0u64;
let mut result_idx = 0;

for &limb in limbs {
current_u64 |= (limb as u64) << accumulated_bits;
accumulated_bits += limb_bits;

while accumulated_bits >= 64 && result_idx < 8 {
result[result_idx] = current_u64;
current_u64 = (limb as u64) >> (limb_bits - (accumulated_bits - 64));
accumulated_bits -= 64;
result_idx += 1;
}
}
if accumulated_bits > 0 && result_idx < 8 {
result[result_idx] = current_u64;
}
BigInt::<8>::new(result)
}
}

Expand All @@ -217,14 +306,15 @@ mod tests {
use crate::msm::metal_msm::utils::mont_params::calc_mont_radix;
use ark_bn254::Fq as BaseField;
use ark_ff::{BigInt, PrimeField};
use num_bigint::{BigUint, RandBigInt};

#[test]
fn test_within_bigint256() {
let num_limbs = 16;
let log_limb_size = 16;

let p_limbs = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size);
let p_bigint256 = BigInteger256::from_limbs(&p_limbs, log_limb_size);
let p_bigint256 = BigInt::<4>::from_limbs(&p_limbs, log_limb_size);

assert_eq!(BaseField::MODULUS, p_bigint256);
}
Expand All @@ -238,9 +328,30 @@ mod tests {
let r = calc_mont_radix(num_limbs, log_limb_size); // r has 257 bits
let r_bigint384: BigInt<6> = r.try_into().unwrap();
let r_limbs = r_bigint384.to_limbs(num_limbs_wide, log_limb_size);
let r_reconstructed = BigInteger384::from_limbs(&r_limbs, log_limb_size);
let r_reconstructed = BigInt::<6>::from_limbs(&r_limbs, log_limb_size);

// Check if the original and reconstructed values are equal
assert_eq!(r_bigint384, r_reconstructed);
}

#[test]
fn test_within_bigint512() {
let num_limbs = 16;
let num_limbs_extra_wide = num_limbs * 2;
let log_limb_size = 16;

let mut rng = rand::thread_rng();
let p: BigUint = BaseField::MODULUS.try_into().unwrap();
let a = rng.gen_biguint_below(&p); // a has at most 254 bits
let r = calc_mont_radix(num_limbs, log_limb_size); // r has 257 bits

let mont_a = &a * &r; // mont_a has at most 511 bits

let mont_a_bigint512: BigInt<8> = mont_a.try_into().unwrap();
let mont_a_limbs = mont_a_bigint512.to_limbs(num_limbs_extra_wide, log_limb_size);
let mont_a_reconstructed = BigInt::<8>::from_limbs(&mont_a_limbs, log_limb_size);

// Check if the original and reconstructed values are equal
assert_eq!(mont_a_bigint512, mont_a_reconstructed);
}
}

0 comments on commit 21ac3c0

Please sign in to comment.