-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test(mont_mul): add cios mont_mul test and benchmark
- Loading branch information
Showing
4 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_cios_benchmarks.metal
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
using namespace metal; | ||
#include <metal_stdlib> | ||
#include <metal_math> | ||
#include "mont.metal" | ||
|
||
kernel void run( | ||
device BigInt* lhs [[ buffer(0) ]], | ||
device BigInt* rhs [[ buffer(1) ]], | ||
device BigInt* prime [[ buffer(2) ]], | ||
device array<uint, 1>* cost [[ buffer(3) ]], | ||
device BigInt* result [[ buffer(4) ]], | ||
uint gid [[ thread_position_in_grid ]] | ||
) { | ||
BigInt a; | ||
BigInt b; | ||
BigInt p; | ||
a.limbs = lhs->limbs; | ||
b.limbs = rhs->limbs; | ||
p.limbs = prime->limbs; | ||
array<uint, 1> cost_arr = *cost; | ||
|
||
BigInt c = mont_mul_cios(a, a, p); | ||
for (uint i = 1; i < cost_arr[0]; i ++) { | ||
c = mont_mul_cios(c, a, p); | ||
} | ||
BigInt res = mont_mul_cios(c, b, p); | ||
result->limbs = res.limbs; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
#[cfg(test)] | ||
pub mod mont_benchmarks; | ||
#[cfg(test)] | ||
pub mod mont_mul_cios; | ||
#[cfg(test)] | ||
pub mod mont_mul_modified; | ||
#[cfg(test)] | ||
pub mod mont_mul_optimised; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
132 changes: 132 additions & 0 deletions
132
mopro-msm/src/msm/metal_msm/tests/mont_backend/mont_mul_cios.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
// adapted from: https://github.com/geometryxyz/msl-secp256k1 | ||
|
||
use crate::msm::metal_msm::host::gpu::{ | ||
create_buffer, create_empty_buffer, get_default_device, read_buffer, | ||
}; | ||
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; | ||
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs}; | ||
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0}; | ||
use ark_bn254::Fr as ScalarField; | ||
use ark_ff::{BigInt, PrimeField}; | ||
use metal::*; | ||
use num_bigint::BigUint; | ||
|
||
#[test] | ||
#[serial_test::serial] | ||
pub fn test_mont_mul_14() { | ||
do_test(14); | ||
} | ||
|
||
#[test] | ||
#[serial_test::serial] | ||
pub fn test_mont_mul_15() { | ||
do_test(15); | ||
} | ||
|
||
pub fn do_test(log_limb_size: u32) { | ||
let modulus_bits = ScalarField::MODULUS_BIT_SIZE as u32; | ||
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize; | ||
|
||
let r = calc_mont_radix(num_limbs, log_limb_size); | ||
let p: BigUint = ScalarField::MODULUS.try_into().unwrap(); | ||
let nsafe = calc_nsafe(log_limb_size); | ||
|
||
let res = calc_rinv_and_n0(&p, &r, log_limb_size); | ||
let n0 = res.1; | ||
|
||
let a = BigUint::parse_bytes( | ||
b"10ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", | ||
16, | ||
) | ||
.unwrap(); | ||
let b = BigUint::parse_bytes( | ||
b"11ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", | ||
16, | ||
) | ||
.unwrap(); | ||
|
||
let a_r = &a * &r % &p; | ||
let b_r = &b * &r % &p; | ||
let expected = (&a * &b * &r) % &p; | ||
|
||
let a_r_in_ark = ScalarField::from_bigint(a_r.clone().try_into().unwrap()).unwrap(); | ||
let b_r_in_ark = ScalarField::from_bigint(b_r.clone().try_into().unwrap()).unwrap(); | ||
let expected_in_ark = ScalarField::from_bigint(expected.clone().try_into().unwrap()).unwrap(); | ||
let expected_limbs = expected_in_ark | ||
.into_bigint() | ||
.to_limbs(num_limbs, log_limb_size); | ||
|
||
let device = get_default_device(); | ||
let a_buf = create_buffer( | ||
&device, | ||
&a_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size), | ||
); | ||
let b_buf = create_buffer( | ||
&device, | ||
&b_r_in_ark.into_bigint().to_limbs(num_limbs, log_limb_size), | ||
); | ||
let p_buf = create_buffer( | ||
&device, | ||
&ScalarField::MODULUS.to_limbs(num_limbs, log_limb_size), | ||
); | ||
let result_buf = create_empty_buffer(&device, num_limbs); | ||
|
||
let command_queue = device.new_command_queue(); | ||
let command_buffer = command_queue.new_command_buffer(); | ||
|
||
let compute_pass_descriptor = ComputePassDescriptor::new(); | ||
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); | ||
|
||
write_constants( | ||
"../mopro-msm/src/msm/metal_msm/shader", | ||
num_limbs, | ||
log_limb_size, | ||
n0, | ||
nsafe, | ||
); | ||
let library_path = compile_metal( | ||
"../mopro-msm/src/msm/metal_msm/shader/mont_backend", | ||
"mont_mul_cios.metal", | ||
); | ||
let library = device.new_library_with_file(library_path).unwrap(); | ||
let kernel = library.get_function("run", None).unwrap(); | ||
|
||
let pipeline_state_descriptor = ComputePipelineDescriptor::new(); | ||
pipeline_state_descriptor.set_compute_function(Some(&kernel)); | ||
|
||
let pipeline_state = device | ||
.new_compute_pipeline_state_with_function( | ||
pipeline_state_descriptor.compute_function().unwrap(), | ||
) | ||
.unwrap(); | ||
|
||
encoder.set_compute_pipeline_state(&pipeline_state); | ||
encoder.set_buffer(0, Some(&a_buf), 0); | ||
encoder.set_buffer(1, Some(&b_buf), 0); | ||
encoder.set_buffer(2, Some(&p_buf), 0); | ||
encoder.set_buffer(3, Some(&result_buf), 0); | ||
|
||
let thread_group_count = MTLSize { | ||
width: 1, | ||
height: 1, | ||
depth: 1, | ||
}; | ||
|
||
let thread_group_size = MTLSize { | ||
width: 1, | ||
height: 1, | ||
depth: 1, | ||
}; | ||
|
||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); | ||
encoder.end_encoding(); | ||
|
||
command_buffer.commit(); | ||
command_buffer.wait_until_completed(); | ||
|
||
let result_limbs: Vec<u32> = read_buffer(&result_buf, num_limbs); | ||
let result = BigInt::from_limbs(&result_limbs, log_limb_size); | ||
|
||
assert!(result == expected.try_into().unwrap()); | ||
assert!(result_limbs == expected_limbs); | ||
} |