Skip to content

Commit

Permalink
test(mont_mul): add cios mont_mul test and benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
moven0831 committed Nov 8, 2024
1 parent 57bed4d commit d6c81b2
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using namespace metal;
#include <metal_stdlib>
#include <metal_math>
#include "mont.metal"

kernel void run(
device BigInt* lhs [[ buffer(0) ]],
device BigInt* rhs [[ buffer(1) ]],
device BigInt* prime [[ buffer(2) ]],
device array<uint, 1>* cost [[ buffer(3) ]],
device BigInt* result [[ buffer(4) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt a;
BigInt b;
BigInt p;
a.limbs = lhs->limbs;
b.limbs = rhs->limbs;
p.limbs = prime->limbs;
array<uint, 1> cost_arr = *cost;

BigInt c = mont_mul_cios(a, a, p);
for (uint i = 1; i < cost_arr[0]; i ++) {
c = mont_mul_cios(c, a, p);
}
BigInt res = mont_mul_cios(c, b, p);
result->limbs = res.limbs;
}
2 changes: 2 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/mont_backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#[cfg(test)]
pub mod mont_benchmarks;
#[cfg(test)]
pub mod mont_mul_cios;
#[cfg(test)]
pub mod mont_mul_modified;
#[cfg(test)]
pub mod mont_mul_optimised;
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ pub fn all_benchmarks() {
Err(e) => println!("benchmark for {}-bit limbs: {}", i, e),
}
}

println!("\n=== benchmarking mont_mul_cios ===");
for i in 11..17 {
match benchmark(i, "mont_mul_cios_benchmarks.metal") {
Ok(elapsed) => println!("benchmark for {}-bit limbs took {}ms", i, elapsed),
Err(e) => println!("benchmark for {}-bit limbs: {}", i, e),
}
}
}

fn expensive_computation(
Expand Down
132 changes: 132 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// adapted from: https://github.com/geometryxyz/msl-secp256k1

use crate::msm::metal_msm::host::gpu::{
create_buffer, create_empty_buffer, get_default_device, read_buffer,
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0};
use ark_bn254::Fr as ScalarField;
use ark_ff::{BigInt, PrimeField};
use metal::*;
use num_bigint::BigUint;

#[test]
#[serial_test::serial]
pub fn test_mont_mul_14() {
do_test(14);
}

#[test]
#[serial_test::serial]
pub fn test_mont_mul_15() {
do_test(15);
}

pub fn do_test(log_limb_size: u32) {
let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32;
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize;

let r = calc_mont_radix(num_limbs, log_limb_size);
let p: BigUint = ScalarField::MODULUS.try_into().unwrap();
let nsafe = calc_nsafe(log_limb_size);

let res = calc_rinv_and_n0(&p, &r, log_limb_size);
let n0 = res.1;

let a = BigUint::parse_bytes(
b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();
let b = BigUint::parse_bytes(
b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001",
16,
)
.unwrap();

let a_r = &a * &r % &p;
let b_r = &b * &r % &p;
let expected = (&a * &b * &r) % &p;

let a_r_in_ark = ScalarField::from_bigint(a_r.clone().try_into().unwrap()).unwrap();
let b_r_in_ark = ScalarField::from_bigint(b_r.clone().try_into().unwrap()).unwrap();
let expected_in_ark = ScalarField::from_bigint(expected.clone().try_into().unwrap()).unwrap();
let expected_limbs = expected_in_ark
.into_bigint()
.to_limbs(num_limbs, log_limb_size);

let device = get_default_device();
let a_buf = create_buffer(
&device,
&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size),
);
let b_buf = create_buffer(
&device,
&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size),
);
let p_buf = create_buffer(
&device,
&ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size),
);
let result_buf = create_empty_buffer(&device, num_limbs);

let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let compute_pass_descriptor = ComputePassDescriptor::new();
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
n0,
nsafe,
);
let library_path = compile_metal(
"../mopro-msm/src/msm/metal_msm/shader/mont_backend",
"mont_mul_cios.metal",
);
let library = device.new_library_with_file(library_path).unwrap();
let kernel = library.get_function("run", None).unwrap();

let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&kernel));

let pipeline_state = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();

encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&a_buf), 0);
encoder.set_buffer(1, Some(&b_buf), 0);
encoder.set_buffer(2, Some(&p_buf), 0);
encoder.set_buffer(3, Some(&result_buf), 0);

let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};

let thread_group_size = MTLSize {
width: 1,
height: 1,
depth: 1,
};

encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

command_buffer.commit();
command_buffer.wait_until_completed();

let result_limbs: Vec<u32> = read_buffer(&result_buf, num_limbs);
let result = BigInt::from_limbs(&result_limbs, log_limb_size);

assert!(result == expected.try_into().unwrap());
assert!(result_limbs == expected_limbs);
}

0 comments on commit d6c81b2

Please sign in to comment.