diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs index 01d3976..053e54f 100644 --- a/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs @@ -5,37 +5,38 @@ use crate::msm::metal_msm::host::gpu::{ }; use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; -use ark_bn254::Fr as ScalarField; -use ark_ff::{BigInt, BigInteger, PrimeField}; +use ark_bn254::Fq as BaseField; +use ark_ff::{BigInt, BigInteger, PrimeField, UniformRand}; +use ark_std::rand; use metal::*; #[test] #[serial_test::serial] pub fn test_ff_add() { - let log_limb_size = 13; - let num_limbs = 20; + let log_limb_size = 16; + let num_limbs = 16; // Scalar field modulus for bn254 - let p = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E72E131A029, - ]); - assert!(p == ScalarField::MODULUS); - - let a = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E72E131A028, - ]); - let b = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E7200000000, - ]); + let p = BaseField::MODULUS; + + let mut rng = rand::thread_rng(); + let mut a = BigInt::rand(&mut rng); + let mut b = BigInt::rand(&mut rng); + + // Reduce a and b if they are greater than or equal to the prime field modulus + while a >= p { + a.sub_with_borrow(&p); + } + + while b >= p { + b.sub_with_borrow(&p); + } + + // Ensure a and b are non-negative and less than p + assert!(a >= BigInt::from(0u64), "a must be non-negative"); + assert!(b >= BigInt::from(0u64), "b must be non-negative"); + assert!(a < p, "a must be less than p"); + assert!(b < p, "b must be less than p"); let device = get_default_device(); let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); @@ -51,6 +52,19 @@ pub fn test_ff_add() { while expected >= p { expected.sub_with_borrow(&p); } + // Ensure expected is non-negative and less than p + assert!( + expected >= BigInt::from(0u64), + "expected must be non-negative" + ); + assert!(expected < p, "expected must be less than p"); + + // Ensure the operation is correct using Arkworks + let a_field = BaseField::from(a); + let b_field = BaseField::from(b); + let expected_field = a_field + b_field; + assert!(expected_field == expected.into()); + let expected_limbs = expected.to_limbs(num_limbs, log_limb_size); let command_queue = device.new_command_queue(); diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs index f9ac345..b08d3f0 100644 --- a/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs @@ -5,37 +5,38 @@ use crate::msm::metal_msm::host::gpu::{ }; use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; -use ark_bn254::Fr as ScalarField; -use ark_ff::{BigInt, BigInteger, PrimeField}; +use ark_bn254::Fq as BaseField; +use ark_ff::{BigInt, BigInteger, PrimeField, UniformRand}; +use ark_std::rand; use metal::*; #[test] #[serial_test::serial] pub fn test_ff_sub() { - let log_limb_size = 13; - let num_limbs = 20; + let log_limb_size = 16; + let num_limbs = 16; // Scalar field modulus for bn254 - let p = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E72E131A029, - ]); - assert!(p == ScalarField::MODULUS); - - let a = BigInt::new([ - 0x43E1F593F0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E72E131A028, - ]); - let b = BigInt::new([ - 0xAAAAAAAAF0000001, - 0x2833E84879B97091, - 0xB85045B68181585D, - 0x30644E7200000000, - ]); + let p = BaseField::MODULUS; + + let mut rng = rand::thread_rng(); + let mut a = BigInt::rand(&mut rng); + let mut b = BigInt::rand(&mut rng); + + // Reduce a and b if they are greater than or equal to the prime field modulus + while a >= p { + a.sub_with_borrow(&p); + } + + while b >= p { + b.sub_with_borrow(&p); + } + + // Ensure a and b are non-negative and less than p + assert!(a >= BigInt::from(0u64), "a must be non-negative"); + assert!(b >= BigInt::from(0u64), "b must be non-negative"); + assert!(a < p, "a must be less than p"); + assert!(b < p, "b must be less than p"); let device = get_default_device(); let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); @@ -43,14 +44,31 @@ pub fn test_ff_sub() { let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size)); let result_buf = create_empty_buffer(&device, num_limbs); - // Perform (a - b) % p + // (a - b) % p let mut expected = a.clone(); - expected.sub_with_borrow(&b); - - // If result is negative, add p until it's positive - while expected < BigInt::zero() { - expected.add_with_carry(&p); + if a >= b { + expected.sub_with_borrow(&b); + } + // p - (b - a) + else { + let mut p_sub_b = p.clone(); + p_sub_b.sub_with_borrow(&b); + expected.add_with_carry(&p_sub_b); } + + // Ensure expected is non-negative and less than p + assert!( + expected >= BigInt::from(0u64), + "expected must be non-negative" + ); + assert!(expected < p, "expected must be less than p"); + + // Ensure the operation is correct using Arkworks + let a_field = BaseField::from(a); + let b_field = BaseField::from(b); + let expected_field = a_field - b_field; + assert!(expected_field == expected.into()); + let expected_limbs = expected.to_limbs(num_limbs, log_limb_size); let command_queue = device.new_command_queue();