From f4378bc3cbd04a20b1215756cbb5e06948f58ccd Mon Sep 17 00:00:00 2001 From: moven0831 Date: Fri, 3 Jan 2025 20:44:43 +0800 Subject: [PATCH 1/2] fix(field): replace bigint_add_wide with bigint_add_unsafe to ignore incorrect carry --- .../src/msm/metal_msm/shader/field/ff.metal | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal index af8cf93..43a05d3 100644 --- a/mopro-msm/src/msm/metal_msm/shader/field/ff.metal +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal @@ -10,31 +10,21 @@ BigInt ff_add( BigInt b, BigInt p ) { - // Assign p to p_wide - BigIntWide p_wide; - for (uint i = 0; i < NUM_LIMBS; i ++) { - p_wide.limbs[i] = p.limbs[i]; - } - - // a + b - BigIntWide sum_wide = bigint_add_wide(a, b); + BigInt sum = bigint_add_unsafe(a, b); BigInt res; - - // if (a + b) >= p - if (bigint_wide_gte(sum_wide, p_wide)) { + if (bigint_gte(sum, p)) { // s = a + b - p - BigIntWide s = bigint_sub_wide(sum_wide, p_wide); - + BigInt s = bigint_sub(sum, p); for (uint i = 0; i < NUM_LIMBS; i ++) { res.limbs[i] = s.limbs[i]; } - } else { + } + else { for (uint i = 0; i < NUM_LIMBS; i ++) { - res.limbs[i] = sum_wide.limbs[i]; + res.limbs[i] = sum.limbs[i]; } } - return res; } From b14c60e88eb263629a56625d2edbda1c80488fbe Mon Sep 17 00:00:00 2001 From: moven0831 Date: Fri, 3 Jan 2025 22:37:32 +0800 Subject: [PATCH 2/2] refactor(tests): update field addition and subtraction tests to use random BigInt values - Changed the field tests to utilize random number generation for inputs, enhancing robustness against edge cases. - Updated the limb size and number of limbs from 20 to 16 for consistency. - Replaced the hardcoded values for the scalar field modulus with the BaseField's modulus. - Added assertions to ensure non-negativity and proper range for generated values. - Improved correctness checks by directly comparing results with expected values derived from Arkworks operations. --- .../src/msm/metal_msm/tests/field/ff_add.rs | 62 +++++++++------ .../src/msm/metal_msm/tests/field/ff_sub.rs | 78 ++++++++++++------- 2 files changed, 86 insertions(+), 54 deletions(-) diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs index 01d3976..053e54f 100644 --- a/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs @@ -5,37 +5,38 @@ use crate::msm::metal_msm::host::gpu::{ }; use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; -use ark_bn254::Fr as ScalarField; -use ark_ff::{BigInt, BigInteger, PrimeField}; +use ark_bn254::Fq as BaseField; +use ark_ff::{BigInt, BigInteger, PrimeField, UniformRand}; +use ark_std::rand; use metal::*; #[test] #[serial_test::serial] pub fn test_ff_add() { - let log_limb_size = 13; - let num_limbs = 20; + let log_limb_size = 16; + let num_limbs = 16; // Scalar field modulus for bn254 - let p = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E72E131A029, - ]); - assert!(p == ScalarField::MODULUS); - - let a = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E72E131A028, - ]); - let b = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E7200000000, - ]); + let p = BaseField::MODULUS; + + let mut rng = rand::thread_rng(); + let mut a = BigInt::rand(&mut rng); + let mut b = BigInt::rand(&mut rng); + + // Reduce a and b if they are greater than or equal to the prime field modulus + while a >= p { + a.sub_with_borrow(&p); + } + + while b >= p { + b.sub_with_borrow(&p); + } + + // Ensure a and b are non-negative and less than p + assert!(a >= BigInt::from(0u64), "a must be non-negative"); + assert!(b >= BigInt::from(0u64), "b must be non-negative"); + assert!(a < p, "a must be less than p"); + assert!(b < p, "b must be less than p"); let device = get_default_device(); let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); @@ -51,6 +52,19 @@ pub fn test_ff_add() { while expected >= p { expected.sub_with_borrow(&p); } + // Ensure expected is non-negative and less than p + assert!( + expected >= BigInt::from(0u64), + "expected must be non-negative" + ); + assert!(expected < p, "expected must be less than p"); + + // Ensure the operation is correct using Arkworks + let a_field = BaseField::from(a); + let b_field = BaseField::from(b); + let expected_field = a_field + b_field; + assert!(expected_field == expected.into()); + let expected_limbs = expected.to_limbs(num_limbs, log_limb_size); let command_queue = device.new_command_queue(); diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs index f9ac345..b08d3f0 100644 --- a/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs @@ -5,37 +5,38 @@ use crate::msm::metal_msm::host::gpu::{ }; use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; -use ark_bn254::Fr as ScalarField; -use ark_ff::{BigInt, BigInteger, PrimeField}; +use ark_bn254::Fq as BaseField; +use ark_ff::{BigInt, BigInteger, PrimeField, UniformRand}; +use ark_std::rand; use metal::*; #[test] #[serial_test::serial] pub fn test_ff_sub() { - let log_limb_size = 13; - let num_limbs = 20; + let log_limb_size = 16; + let num_limbs = 16; // Scalar field modulus for bn254 - let p = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E72E131A029, - ]); - assert!(p == ScalarField::MODULUS); - - let a = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E72E131A028, - ]); - let b = BigInt::new([ - 0xAAAAAAAAF0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E7200000000, - ]); + let p = BaseField::MODULUS; + + let mut rng = rand::thread_rng(); + let mut a = BigInt::rand(&mut rng); + let mut b = BigInt::rand(&mut rng); + + // Reduce a and b if they are greater than or equal to the prime field modulus + while a >= p { + a.sub_with_borrow(&p); + } + + while b >= p { + b.sub_with_borrow(&p); + } + + // Ensure a and b are non-negative and less than p + assert!(a >= BigInt::from(0u64), "a must be non-negative"); + assert!(b >= BigInt::from(0u64), "b must be non-negative"); + assert!(a < p, "a must be less than p"); + assert!(b < p, "b must be less than p"); let device = get_default_device(); let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); @@ -43,14 +44,31 @@ pub fn test_ff_sub() { let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size)); let result_buf = create_empty_buffer(&device, num_limbs); - // Perform (a - b) % p + // (a - b) % p let mut expected = a.clone(); - expected.sub_with_borrow(&b); - - // If result is negative, add p until it's positive - while expected < BigInt::zero() { - expected.add_with_carry(&p); + if a >= b { + expected.sub_with_borrow(&b); + } + // p - (b - a) + else { + let mut p_sub_b = p.clone(); + p_sub_b.sub_with_borrow(&b); + expected.add_with_carry(&p_sub_b); } + + // Ensure expected is non-negative and less than p + assert!( + expected >= BigInt::from(0u64), + "expected must be non-negative" + ); + assert!(expected < p, "expected must be less than p"); + + // Ensure the operation is correct using Arkworks + let a_field = BaseField::from(a); + let b_field = BaseField::from(b); + let expected_field = a_field - b_field; + assert!(expected_field == expected.into()); + let expected_limbs = expected.to_limbs(num_limbs, log_limb_size); let command_queue = device.new_command_queue();