Skip to content

Commit

Permalink
refactor(tests): update field addition and subtraction tests to use r…
Browse files Browse the repository at this point in the history
…andom BigInt values

- Changed the field tests to utilize random number generation for inputs, enhancing robustness against edge cases.
- Updated the limb size and number of limbs from 20 to 16 for consistency.
- Replaced the hardcoded values for the scalar field modulus with the BaseField's modulus.
- Added assertions to ensure non-negativity and proper range for generated values.
- Improved correctness checks by directly comparing results with expected values derived from Arkworks operations.
  • Loading branch information
moven0831 committed Jan 3, 2025
1 parent f4378bc commit b14c60e
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 54 deletions.
62 changes: 38 additions & 24 deletions mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,38 @@ use crate::msm::metal_msm::host::gpu::{
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use ark_bn254::Fr as ScalarField;
use ark_ff::{BigInt, BigInteger, PrimeField};
use ark_bn254::Fq as BaseField;
use ark_ff::{BigInt, BigInteger, PrimeField, UniformRand};
use ark_std::rand;
use metal::*;

#[test]
#[serial_test::serial]
pub fn test_ff_add() {
let log_limb_size = 13;
let num_limbs = 20;
let log_limb_size = 16;
let num_limbs = 16;

// Scalar field modulus for bn254
let p = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A029,
]);
assert!(p == ScalarField::MODULUS);

let a = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A028,
]);
let b = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E7200000000,
]);
let p = BaseField::MODULUS;

let mut rng = rand::thread_rng();
let mut a = BigInt::rand(&mut rng);
let mut b = BigInt::rand(&mut rng);

// Reduce a and b if they are greater than or equal to the prime field modulus
while a >= p {
a.sub_with_borrow(&p);
}

while b >= p {
b.sub_with_borrow(&p);
}

// Ensure a and b are non-negative and less than p
assert!(a >= BigInt::from(0u64), "a must be non-negative");
assert!(b >= BigInt::from(0u64), "b must be non-negative");
assert!(a < p, "a must be less than p");
assert!(b < p, "b must be less than p");

let device = get_default_device();
let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size));
Expand All @@ -51,6 +52,19 @@ pub fn test_ff_add() {
while expected >= p {
expected.sub_with_borrow(&p);
}
// Ensure expected is non-negative and less than p
assert!(
expected >= BigInt::from(0u64),
"expected must be non-negative"
);
assert!(expected < p, "expected must be less than p");

// Ensure the operation is correct using Arkworks
let a_field = BaseField::from(a);
let b_field = BaseField::from(b);
let expected_field = a_field + b_field;
assert!(expected_field == expected.into());

let expected_limbs = expected.to_limbs(num_limbs, log_limb_size);

let command_queue = device.new_command_queue();
Expand Down
78 changes: 48 additions & 30 deletions mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,70 @@ use crate::msm::metal_msm::host::gpu::{
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use ark_bn254::Fr as ScalarField;
use ark_ff::{BigInt, BigInteger, PrimeField};
use ark_bn254::Fq as BaseField;
use ark_ff::{BigInt, BigInteger, PrimeField, UniformRand};
use ark_std::rand;
use metal::*;

#[test]
#[serial_test::serial]
pub fn test_ff_sub() {
let log_limb_size = 13;
let num_limbs = 20;
let log_limb_size = 16;
let num_limbs = 16;

// Scalar field modulus for bn254
let p = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A029,
]);
assert!(p == ScalarField::MODULUS);

let a = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A028,
]);
let b = BigInt::new([
0xAAAAAAAAF0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E7200000000,
]);
let p = BaseField::MODULUS;

let mut rng = rand::thread_rng();
let mut a = BigInt::rand(&mut rng);
let mut b = BigInt::rand(&mut rng);

// Reduce a and b if they are greater than or equal to the prime field modulus
while a >= p {
a.sub_with_borrow(&p);
}

while b >= p {
b.sub_with_borrow(&p);
}

// Ensure a and b are non-negative and less than p
assert!(a >= BigInt::from(0u64), "a must be non-negative");
assert!(b >= BigInt::from(0u64), "b must be non-negative");
assert!(a < p, "a must be less than p");
assert!(b < p, "b must be less than p");

let device = get_default_device();
let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size));
let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size));
let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size));
let result_buf = create_empty_buffer(&device, num_limbs);

// Perform (a - b) % p
// (a - b) % p
let mut expected = a.clone();
expected.sub_with_borrow(&b);

// If result is negative, add p until it's positive
while expected < BigInt::zero() {
expected.add_with_carry(&p);
if a >= b {
expected.sub_with_borrow(&b);
}
// p - (b - a)
else {
let mut p_sub_b = p.clone();
p_sub_b.sub_with_borrow(&b);
expected.add_with_carry(&p_sub_b);
}

// Ensure expected is non-negative and less than p
assert!(
expected >= BigInt::from(0u64),
"expected must be non-negative"
);
assert!(expected < p, "expected must be less than p");

// Ensure the operation is correct using Arkworks
let a_field = BaseField::from(a);
let b_field = BaseField::from(b);
let expected_field = a_field - b_field;
assert!(expected_field == expected.into());

let expected_limbs = expected.to_limbs(num_limbs, log_limb_size);

let command_queue = device.new_command_queue();
Expand Down

0 comments on commit b14c60e

Please sign in to comment.