Skip to content

Commit

Permalink
Merge pull request #30 from zkmopro/test/metal/mont
Browse files Browse the repository at this point in the history
Test/metal/mont
  • Loading branch information
moven0831 authored Jan 4, 2025
2 parents a6df775 + 5350f18 commit dadc88c
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 82 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mopro-msm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ objc = { version = "=0.2.7" }
proptest = { version = "1.4.0" }
rayon = "1.5.1"
itertools = "0.13.0"
rand = "0.8.5"

[build-dependencies]
color-eyre = "0.6"
Expand Down
33 changes: 11 additions & 22 deletions mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::utils::mont_params::{
calc_bitwidth, calc_mont_radix, calc_nsafe, calc_num_limbs, calc_rinv_and_n0,
};
use ark_bn254::Fr as ScalarField;
use ark_bn254::Fq as BaseField;
use ark_ff::{BigInt, PrimeField};
use metal::*;
use num_bigint::BigUint;
use num_bigint::{BigUint, RandBigInt};
use rand::thread_rng;
use stopwatch::Stopwatch;

#[test]
Expand Down Expand Up @@ -53,26 +54,14 @@ fn expensive_computation(
}

pub fn benchmark(log_limb_size: u32, shader_file: &str) -> Result<i64, String> {
let p = BigUint::parse_bytes(
b"30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001",
16,
)
.unwrap();
assert!(p == ScalarField::MODULUS.try_into().unwrap());
let p: BigUint = BaseField::MODULUS.try_into().unwrap();

let p_bitwidth = calc_bitwidth(&p);
let num_limbs = calc_num_limbs(log_limb_size, p_bitwidth);

let a = BigUint::parse_bytes(
b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();
let b = BigUint::parse_bytes(
b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();
let mut rng = thread_rng();
let a = rng.gen_biguint_below(&p);
let b = rng.gen_biguint_below(&p);

let nsafe = calc_nsafe(log_limb_size);
if nsafe == 0 {
Expand All @@ -88,20 +77,20 @@ pub fn benchmark(log_limb_size: u32, shader_file: &str) -> Result<i64, String> {

let cost = 2u32.pow(16u32) as usize;
let expected = expensive_computation(cost, &a, &b, &p, &r);
let expected_limbs = ScalarField::from_bigint(expected.clone().try_into().unwrap())
let expected_limbs = BaseField::from_bigint(expected.clone().try_into().unwrap())
.unwrap()
.into_bigint()
.to_limbs(num_limbs, log_limb_size);

let ar_limbs = ScalarField::from_bigint(a_r.clone().try_into().unwrap())
let ar_limbs = BaseField::from_bigint(a_r.clone().try_into().unwrap())
.unwrap()
.into_bigint()
.to_limbs(num_limbs, log_limb_size);
let br_limbs = ScalarField::from_bigint(b_r.clone().try_into().unwrap())
let br_limbs = BaseField::from_bigint(b_r.clone().try_into().unwrap())
.unwrap()
.into_bigint()
.to_limbs(num_limbs, log_limb_size);
let p_limbs = &ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size);
let p_limbs = &BaseField::MODULUS.to_limbs(num_limbs, log_limb_size);

let device = get_default_device();

Expand Down
30 changes: 12 additions & 18 deletions mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use crate::msm::metal_msm::host::gpu::{
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0};
use ark_bn254::Fr as ScalarField;
use ark_bn254::Fq as BaseField;
use ark_ff::{BigInt, PrimeField};
use metal::*;
use num_bigint::BigUint;
use num_bigint::{BigUint, RandBigInt};
use rand::thread_rng;

#[test]
#[serial_test::serial]
Expand All @@ -24,34 +25,27 @@ pub fn test_mont_mul_15() {
}

pub fn do_test(log_limb_size: u32) {
let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32;
let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32;
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize;

let r = calc_mont_radix(num_limbs, log_limb_size);
let p: BigUint = ScalarField::MODULUS.try_into().unwrap();
let p: BigUint = BaseField::MODULUS.try_into().unwrap();
let nsafe = calc_nsafe(log_limb_size);

let res = calc_rinv_and_n0(&p, &r, log_limb_size);
let n0 = res.1;

let a = BigUint::parse_bytes(
b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();
let b = BigUint::parse_bytes(
b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();
let mut rng = thread_rng();
let a = rng.gen_biguint_below(&p);
let b = rng.gen_biguint_below(&p);

let a_r = &a * &r % &p;
let b_r = &b * &r % &p;
let expected = (&a * &b * &r) % &p;

let a_r_in_ark = ScalarField::from_bigint(a_r.clone().try_into().unwrap()).unwrap();
let b_r_in_ark = ScalarField::from_bigint(b_r.clone().try_into().unwrap()).unwrap();
let expected_in_ark = ScalarField::from_bigint(expected.clone().try_into().unwrap()).unwrap();
let a_r_in_ark = BaseField::from_bigint(a_r.clone().try_into().unwrap()).unwrap();
let b_r_in_ark = BaseField::from_bigint(b_r.clone().try_into().unwrap()).unwrap();
let expected_in_ark = BaseField::from_bigint(expected.clone().try_into().unwrap()).unwrap();
let expected_limbs = expected_in_ark
.into_bigint()
.to_limbs(num_limbs, log_limb_size);
Expand All @@ -67,7 +61,7 @@ pub fn do_test(log_limb_size: u32) {
);
let p_buf = create_buffer(
&device,
&ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size),
&BaseField::MODULUS.to_limbs(num_limbs, log_limb_size),
);
let result_buf = create_empty_buffer(&device, num_limbs);

Expand Down
30 changes: 12 additions & 18 deletions mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_modified.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use crate::msm::metal_msm::host::gpu::{
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0};
use ark_bn254::Fr as ScalarField;
use ark_bn254::Fq as BaseField;
use ark_ff::{BigInt, PrimeField};
use metal::*;
use num_bigint::BigUint;
use num_bigint::{BigUint, RandBigInt};
use rand::thread_rng;

#[test]
#[serial_test::serial]
Expand All @@ -24,34 +25,27 @@ pub fn test_mont_mul_15() {
}

pub fn do_test(log_limb_size: u32) {
let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32;
let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32;
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize;

let r = calc_mont_radix(num_limbs, log_limb_size);
let p: BigUint = ScalarField::MODULUS.try_into().unwrap();
let p: BigUint = BaseField::MODULUS.try_into().unwrap();
let nsafe = calc_nsafe(log_limb_size);

let res = calc_rinv_and_n0(&p, &r, log_limb_size);
let n0 = res.1;

let a = BigUint::parse_bytes(
b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();
let b = BigUint::parse_bytes(
b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();
let mut rng = thread_rng();
let a = rng.gen_biguint_below(&p);
let b = rng.gen_biguint_below(&p);

let a_r = &a * &r % &p;
let b_r = &b * &r % &p;
let expected = (&a * &b * &r) % &p;

let a_r_in_ark = ScalarField::from_bigint(a_r.clone().try_into().unwrap()).unwrap();
let b_r_in_ark = ScalarField::from_bigint(b_r.clone().try_into().unwrap()).unwrap();
let expected_in_ark = ScalarField::from_bigint(expected.clone().try_into().unwrap()).unwrap();
let a_r_in_ark = BaseField::from_bigint(a_r.clone().try_into().unwrap()).unwrap();
let b_r_in_ark = BaseField::from_bigint(b_r.clone().try_into().unwrap()).unwrap();
let expected_in_ark = BaseField::from_bigint(expected.clone().try_into().unwrap()).unwrap();
let expected_limbs = expected_in_ark
.into_bigint()
.to_limbs(num_limbs, log_limb_size);
Expand All @@ -67,7 +61,7 @@ pub fn do_test(log_limb_size: u32) {
);
let p_buf = create_buffer(
&device,
&ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size),
&BaseField::MODULUS.to_limbs(num_limbs, log_limb_size),
);
let result_buf = create_empty_buffer(&device, num_limbs);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ use crate::msm::metal_msm::host::gpu::{
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_rinv_and_n0};
use ark_bn254::Fr as ScalarField;
use ark_bn254::Fq as BaseField;
use ark_ff::{BigInt, PrimeField};
use metal::*;
use num_bigint::BigUint;
use num_bigint::{BigUint, RandBigInt};
use rand::thread_rng;

#[test]
#[serial_test::serial]
Expand All @@ -28,33 +29,26 @@ pub fn test_mont_mul_13() {

pub fn do_test(log_limb_size: u32) {
// Calculate num_limbs based on modulus size and limb size
let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32;
let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32;
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize;

let r = calc_mont_radix(num_limbs, log_limb_size);
let p: BigUint = ScalarField::MODULUS.try_into().unwrap();
let p: BigUint = BaseField::MODULUS.try_into().unwrap();

let res = calc_rinv_and_n0(&p, &r, log_limb_size);
let n0 = res.1;

let a = BigUint::parse_bytes(
b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();
let b = BigUint::parse_bytes(
b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();
let mut rng = thread_rng();
let a = rng.gen_biguint_below(&p);
let b = rng.gen_biguint_below(&p);

let a_r = &a * &r % &p;
let b_r = &b * &r % &p;
let expected = (&a * &b * &r) % &p;

let a_r_in_ark = ScalarField::from_bigint(a_r.clone().try_into().unwrap()).unwrap();
let b_r_in_ark = ScalarField::from_bigint(b_r.clone().try_into().unwrap()).unwrap();
let expected_in_ark = ScalarField::from_bigint(expected.clone().try_into().unwrap()).unwrap();
let a_r_in_ark = BaseField::from_bigint(a_r.clone().try_into().unwrap()).unwrap();
let b_r_in_ark = BaseField::from_bigint(b_r.clone().try_into().unwrap()).unwrap();
let expected_in_ark = BaseField::from_bigint(expected.clone().try_into().unwrap()).unwrap();
let expected_limbs = expected_in_ark
.into_bigint()
.to_limbs(num_limbs, log_limb_size);
Expand All @@ -70,7 +64,7 @@ pub fn do_test(log_limb_size: u32) {
);
let p_buf = create_buffer(
&device,
&ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size),
&BaseField::MODULUS.to_limbs(num_limbs, log_limb_size),
);
let result_buf = create_empty_buffer(&device, num_limbs);

Expand Down Expand Up @@ -139,17 +133,17 @@ pub fn do_test(log_limb_size: u32) {
pub fn test_number_conversions() {
// Setup parameters
let log_limb_size = 12;
let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32;
let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32;
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize;

// Create test values using small numbers for clarity
let original_biguint = BigUint::parse_bytes(b"123456789", 10).unwrap();

// Convert BigUint to ScalarField
// Convert BigUint to BaseField
let scalar_field_value =
ScalarField::from_bigint(original_biguint.clone().try_into().unwrap()).unwrap();
BaseField::from_bigint(original_biguint.clone().try_into().unwrap()).unwrap();

// Convert ScalarField to limbs
// Convert BaseField to limbs
let limbs = scalar_field_value
.into_bigint()
.to_limbs(num_limbs, log_limb_size);
Expand All @@ -174,7 +168,7 @@ pub fn test_number_conversions() {
];

for value in test_values {
let scalar = ScalarField::from_bigint(value.clone().try_into().unwrap()).unwrap();
let scalar = BaseField::from_bigint(value.clone().try_into().unwrap()).unwrap();
let value_limbs = scalar.into_bigint().to_limbs(num_limbs, log_limb_size);
let converted: BigUint = BigInt::from_limbs(&value_limbs, log_limb_size)
.try_into()
Expand Down

0 comments on commit dadc88c

Please sign in to comment.