diff --git a/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs b/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs index ebf772b..10738a3 100644 --- a/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs +++ b/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs @@ -12,6 +12,7 @@ pub trait ToLimbs { pub trait FromLimbs { fn from_u32_limbs(limbs: &[u32]) -> Self; fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self; + fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self; fn from_u128(num: u128) -> Self; fn from_u32(num: u32) -> Self; } @@ -77,6 +78,10 @@ impl ToLimbs for Fq { fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec { self.0.to_limbs(num_limbs, log_limb_size) } + + fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec { + self.0.to_limbs(num_limbs, log_limb_size) + } } impl FromLimbs for BigInteger256 { @@ -129,6 +134,35 @@ impl FromLimbs for BigInteger256 { BigInteger256::new(result) } + + fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self { + let mut result = [0u64; 4]; + let limb_size = log_limb_size as usize; + let mut accumulated_bits = 0; + let mut current_u64 = 0u64; + let mut result_idx = 0; + + for &limb in limbs { + // Add the current limb at the appropriate position + current_u64 |= (limb as u64) << accumulated_bits; + accumulated_bits += limb_size; + + // If we've accumulated 64 bits or more, store the result + while accumulated_bits >= 64 && result_idx < 4 { + result[result_idx] = current_u64; + current_u64 = limb as u64 >> (limb_size - (accumulated_bits - 64)); + accumulated_bits -= 64; + result_idx += 1; + } + } + + // Handle any remaining bits + if accumulated_bits > 0 && result_idx < 4 { + result[result_idx] = current_u64; + } + + BigInteger256::new(result) + } } impl FromLimbs for Fq { @@ -159,4 +193,9 @@ impl FromLimbs for Fq { let bigint = BigInteger256::from_limbs(limbs, log_limb_size); Fq::new(mont_reduction::raw_reduction(bigint)) } + + fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self { + let bigint = BigInteger256::from_limbs(limbs, log_limb_size); + Fq::new(mont_reduction::raw_reduction(bigint)) + } } diff --git a/mopro-msm/src/msm/metal/shader/fields/fp_bn254.h.metal b/mopro-msm/src/msm/metal/shader/fields/fp_bn254.h.metal index 545e5bb..45219c2 100644 --- a/mopro-msm/src/msm/metal/shader/fields/fp_bn254.h.metal +++ b/mopro-msm/src/msm/metal/shader/fields/fp_bn254.h.metal @@ -9,12 +9,12 @@ namespace { } /* Constants for bn254 field operations - * N: base field modulus + * N: scalar field modulus * R_SQUARED: R^2 mod N * R_SUB_N: R - N * MU: Montgomery Multiplication Constant = -N^{-1} mod (2^32) * - * For bn254, the modulus is "21888242871839275222246405745257275088696311157297823662689037894645226208583" [1, 2] + * For bn254, the modulus is "21888242871839275222246405745257275088548364400416034343698204186575808495617" [1, 2] * We use 8 limbs of 32 bits unsigned integers to represent the constanst * * References: diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal index d0ff646..f6455f8 100644 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal @@ -1,7 +1,7 @@ // source: https://github.com/geometryxyz/msl-secp256k1 using namespace metal; -#include "constants.metal" +#include "../constants.metal" struct BigInt { array limbs; diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/constants.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/constants.metal deleted file mode 100644 index 6ece636..0000000 --- a/mopro-msm/src/msm/metal_msm/shader/bigint/constants.metal +++ /dev/null @@ -1,8 +0,0 @@ -// THIS FILE IS AUTOGENERATED BY shader.rs -#define NUM_LIMBS 20 -#define NUM_LIMBS_WIDE 21 -#define LOG_LIMB_SIZE 13 -#define TWO_POW_WORD_SIZE 8192 -#define MASK 8191 -#define N0 0 -#define NSAFE 0 diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs index 8058002..c35ffb6 100644 --- a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs @@ -41,7 +41,7 @@ pub fn test_bigint_add_unsafe() { let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); write_constants( - "../mopro-msm/src/msm/metal_msm/shader/bigint", + "../mopro-msm/src/msm/metal_msm/shader", num_limbs, log_limb_size, 0, diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs index 0eb7f9b..fbd1f72 100644 --- a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs @@ -45,7 +45,7 @@ pub fn test_bigint_add() { let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); write_constants( - "../mopro-msm/src/msm/metal_msm/shader/bigint", + "../mopro-msm/src/msm/metal_msm/shader", num_limbs, log_limb_size, 0, @@ -133,7 +133,7 @@ pub fn test_bigint_add_no_overflow() { let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); write_constants( - "../mopro-msm/src/msm/metal_msm/shader/bigint", + "../mopro-msm/src/msm/metal_msm/shader", num_limbs, log_limb_size, 0, diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs index 6341b9e..561967d 100644 --- a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs @@ -35,7 +35,7 @@ pub fn test_bigint_sub() { let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); write_constants( - "../mopro-msm/src/msm/metal_msm/shader/bigint", + "../mopro-msm/src/msm/metal_msm/shader", num_limbs, log_limb_size, 0, @@ -129,7 +129,7 @@ fn test_bigint_sub_underflow() { let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); write_constants( - "../mopro-msm/src/msm/metal_msm/shader/bigint", + "../mopro-msm/src/msm/metal_msm/shader", num_limbs, log_limb_size, 0, diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/mod.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/mod.rs index b3c015a..d1955db 100644 --- a/mopro-msm/src/msm/metal_msm/tests/bigint/mod.rs +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] pub mod bigint_add_unsafe; +#[cfg(test)] pub mod bigint_add_wide; +#[cfg(test)] pub mod bigint_sub; diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs new file mode 100644 index 0000000..0dda4e1 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs @@ -0,0 +1,114 @@ +// adapted from: https://github.com/geometryxyz/msl-secp256k1 + +use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs}; +use crate::msm::metal_msm::host::gpu::{ + create_buffer, create_empty_buffer, get_default_device, read_buffer, +}; +use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; +use ark_bn254::Fr as ScalarField; +use ark_ff::{BigInt, BigInteger, PrimeField}; +use metal::*; + +#[test] +#[serial_test::serial] +pub fn test_ff_add() { + let log_limb_size = 13; + let num_limbs = 20; + + // Scalar field modulus for bn254 + let p = BigInt::new([ + 0x43E1F593F0000001, + 0x2833E84879B97091, + 0xB85045B68181585D, + 0x30644E72E131A029, + ]); + assert!(p == ScalarField::MODULUS); + + let a = BigInt::new([ + 0x43E1F593F0000001, + 0x2833E84879B97091, + 0xB85045B68181585D, + 0x30644E72E131A028, + ]); + let b = BigInt::new([ + 0x43E1F593F0000001, + 0x2833E84879B97091, + 0xB85045B68181585D, + 0x30644E7200000000, + ]); + + let device = get_default_device(); + let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); + let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size)); + let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size)); + let result_buf = create_empty_buffer(&device, num_limbs); + + // Perform (a + b) % p + let mut expected = a.clone(); + expected.add_with_carry(&b); + + // While result >= p, subtract p + while expected >= p { + expected.sub_with_borrow(&p); + } + let expected_limbs = expected.to_limbs(num_limbs, log_limb_size); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + write_constants( + "../mopro-msm/src/msm/metal_msm/shader", + num_limbs, + log_limb_size, + 0, + 0, + ); + let library_path = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader/field", + "ff_add.metal", + ); + let library = device.new_library_with_file(library_path).unwrap(); + let kernel = library.get_function("run", None).unwrap(); + + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&kernel)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&a_buf), 0); + encoder.set_buffer(1, Some(&b_buf), 0); + encoder.set_buffer(2, Some(&p_buf), 0); + encoder.set_buffer(3, Some(&result_buf), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result_limbs: Vec = read_buffer(&result_buf, num_limbs); + let result = BigInt::from_limbs(&result_limbs, log_limb_size); + + assert!(result == expected); + assert!(result_limbs == expected_limbs); +} diff --git a/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs b/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs new file mode 100644 index 0000000..7b25f14 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs @@ -0,0 +1,114 @@ +// adapted from: https://github.com/geometryxyz/msl-secp256k1 + +use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs}; +use crate::msm::metal_msm::host::gpu::{ + create_buffer, create_empty_buffer, get_default_device, read_buffer, +}; +use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; +use ark_bn254::Fr as ScalarField; +use ark_ff::{BigInt, BigInteger, PrimeField}; +use metal::*; + +#[test] +#[serial_test::serial] +pub fn test_ff_sub() { + let log_limb_size = 13; + let num_limbs = 20; + + // Scalar field modulus for bn254 + let p = BigInt::new([ + 0x43E1F593F0000001, + 0x2833E84879B97091, + 0xB85045B68181585D, + 0x30644E72E131A029, + ]); + assert!(p == ScalarField::MODULUS); + + let a = BigInt::new([ + 0x43E1F593F0000001, + 0x2833E84879B97091, + 0xB85045B68181585D, + 0x30644E72E131A028, + ]); + let b = BigInt::new([ + 0xAAAAAAAAF0000001, + 0x2833E84879B97091, + 0xB85045B68181585D, + 0x30644E7200000000, + ]); + + let device = get_default_device(); + let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); + let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size)); + let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size)); + let result_buf = create_empty_buffer(&device, num_limbs); + + // Perform (a - b) % p + let mut expected = a.clone(); + expected.sub_with_borrow(&b); + + // If result is negative, add p until it's positive + while expected < BigInt::zero() { + expected.add_with_carry(&p); + } + let expected_limbs = expected.to_limbs(num_limbs, log_limb_size); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + write_constants( + "../mopro-msm/src/msm/metal_msm/shader", + num_limbs, + log_limb_size, + 0, + 0, + ); + let library_path = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader/field", + "ff_sub.metal", + ); + let library = device.new_library_with_file(library_path).unwrap(); + let kernel = library.get_function("run", None).unwrap(); + + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&kernel)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&a_buf), 0); + encoder.set_buffer(1, Some(&b_buf), 0); + encoder.set_buffer(2, Some(&p_buf), 0); + encoder.set_buffer(3, Some(&result_buf), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result_limbs: Vec = read_buffer(&result_buf, num_limbs); + let result = BigInt::from_limbs(&result_limbs, log_limb_size); + + assert!(result == expected); + assert!(result_limbs == expected_limbs); +} diff --git a/mopro-msm/src/msm/metal_msm/tests/field/mod.rs b/mopro-msm/src/msm/metal_msm/tests/field/mod.rs new file mode 100644 index 0000000..3be2ba6 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/tests/field/mod.rs @@ -0,0 +1,4 @@ +#[cfg(test)] +pub mod ff_add; +#[cfg(test)] +pub mod ff_sub; diff --git a/mopro-msm/src/msm/metal_msm/tests/mod.rs b/mopro-msm/src/msm/metal_msm/tests/mod.rs index 8f8d148..0eaa24f 100644 --- a/mopro-msm/src/msm/metal_msm/tests/mod.rs +++ b/mopro-msm/src/msm/metal_msm/tests/mod.rs @@ -1 +1,4 @@ +#[cfg(test)] pub mod bigint; +#[cfg(test)] +pub mod field;