Skip to content

Commit

Permalink
Field impl for refactored metal EC backend (#22)
Browse files Browse the repository at this point in the history
* feat: conversion between bigint and arbitrary limb size

* lint

* test(bigint): adapt from https://github.com/geometryxyz/msl-secp256k1

* refactor: add overflow detection and correct suitable bigint val for each cases

* test(field): adapt the ff tests from https://github.com/geometryxyz/msl-secp256k1

* test(field): add check to scalarfield modulus match

* chore: correct the docs, use new refactored code for correct implementation

* chore: update path for contants.metal
  • Loading branch information
moven0831 authored Dec 21, 2024
1 parent bf1636b commit a0ae553
Show file tree
Hide file tree
Showing 12 changed files with 285 additions and 16 deletions.
39 changes: 39 additions & 0 deletions mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub trait ToLimbs {
pub trait FromLimbs {
fn from_u32_limbs(limbs: &[u32]) -> Self;
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self;
fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self;
fn from_u128(num: u128) -> Self;
fn from_u32(num: u32) -> Self;
}
Expand Down Expand Up @@ -77,6 +78,10 @@ impl ToLimbs for Fq {
fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32> {
self.0.to_limbs(num_limbs, log_limb_size)
}

fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec<u32> {
self.0.to_limbs(num_limbs, log_limb_size)
}
}

impl FromLimbs for BigInteger256 {
Expand Down Expand Up @@ -129,6 +134,35 @@ impl FromLimbs for BigInteger256 {

BigInteger256::new(result)
}

fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
let mut result = [0u64; 4];
let limb_size = log_limb_size as usize;
let mut accumulated_bits = 0;
let mut current_u64 = 0u64;
let mut result_idx = 0;

for &limb in limbs {
// Add the current limb at the appropriate position
current_u64 |= (limb as u64) << accumulated_bits;
accumulated_bits += limb_size;

// If we've accumulated 64 bits or more, store the result
while accumulated_bits >= 64 && result_idx < 4 {
result[result_idx] = current_u64;
current_u64 = limb as u64 >> (limb_size - (accumulated_bits - 64));
accumulated_bits -= 64;
result_idx += 1;
}
}

// Handle any remaining bits
if accumulated_bits > 0 && result_idx < 4 {
result[result_idx] = current_u64;
}

BigInteger256::new(result)
}
}

impl FromLimbs for Fq {
Expand Down Expand Up @@ -159,4 +193,9 @@ impl FromLimbs for Fq {
let bigint = BigInteger256::from_limbs(limbs, log_limb_size);
Fq::new(mont_reduction::raw_reduction(bigint))
}

fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self {
let bigint = BigInteger256::from_limbs(limbs, log_limb_size);
Fq::new(mont_reduction::raw_reduction(bigint))
}
}
4 changes: 2 additions & 2 deletions mopro-msm/src/msm/metal/shader/fields/fp_bn254.h.metal
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ namespace {
}

/* Constants for bn254 field operations
* N: base field modulus
* N: scalar field modulus
* R_SQUARED: R^2 mod N
* R_SUB_N: R - N
* MU: Montgomery Multiplication Constant = -N^{-1} mod (2^32)
*
* For bn254, the modulus is "21888242871839275222246405745257275088696311157297823662689037894645226208583" [1, 2]
* For bn254, the modulus is "21888242871839275222246405745257275088548364400416034343698204186575808495617" [1, 2]
* We use 8 limbs of 32 bits unsigned integers to represent the constanst
*
* References:
Expand Down
2 changes: 1 addition & 1 deletion mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// source: https://github.com/geometryxyz/msl-secp256k1

using namespace metal;
#include "constants.metal"
#include "../constants.metal"

struct BigInt {
array<uint, NUM_LIMBS> limbs;
Expand Down
8 changes: 0 additions & 8 deletions mopro-msm/src/msm/metal_msm/shader/bigint/constants.metal

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub fn test_bigint_add_unsafe() {
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader/bigint",
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
0,
Expand Down
4 changes: 2 additions & 2 deletions mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub fn test_bigint_add() {
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader/bigint",
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
0,
Expand Down Expand Up @@ -133,7 +133,7 @@ pub fn test_bigint_add_no_overflow() {
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader/bigint",
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
0,
Expand Down
4 changes: 2 additions & 2 deletions mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub fn test_bigint_sub() {
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader/bigint",
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
0,
Expand Down Expand Up @@ -129,7 +129,7 @@ fn test_bigint_sub_underflow() {
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader/bigint",
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
0,
Expand Down
3 changes: 3 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/bigint/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[cfg(test)]
pub mod bigint_add_unsafe;
#[cfg(test)]
pub mod bigint_add_wide;
#[cfg(test)]
pub mod bigint_sub;
114 changes: 114 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// adapted from: https://github.com/geometryxyz/msl-secp256k1

use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::host::gpu::{
create_buffer, create_empty_buffer, get_default_device, read_buffer,
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use ark_bn254::Fr as ScalarField;
use ark_ff::{BigInt, BigInteger, PrimeField};
use metal::*;

#[test]
#[serial_test::serial]
pub fn test_ff_add() {
let log_limb_size = 13;
let num_limbs = 20;

// Scalar field modulus for bn254
let p = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A029,
]);
assert!(p == ScalarField::MODULUS);

let a = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A028,
]);
let b = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E7200000000,
]);

let device = get_default_device();
let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size));
let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size));
let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size));
let result_buf = create_empty_buffer(&device, num_limbs);

// Perform (a + b) % p
let mut expected = a.clone();
expected.add_with_carry(&b);

// While result >= p, subtract p
while expected >= p {
expected.sub_with_borrow(&p);
}
let expected_limbs = expected.to_limbs(num_limbs, log_limb_size);

let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let compute_pass_descriptor = ComputePassDescriptor::new();
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
0,
0,
);
let library_path = compile_metal(
"../mopro-msm/src/msm/metal_msm/shader/field",
"ff_add.metal",
);
let library = device.new_library_with_file(library_path).unwrap();
let kernel = library.get_function("run", None).unwrap();

let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&kernel));

let pipeline_state = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();

encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&a_buf), 0);
encoder.set_buffer(1, Some(&b_buf), 0);
encoder.set_buffer(2, Some(&p_buf), 0);
encoder.set_buffer(3, Some(&result_buf), 0);

let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};

let thread_group_size = MTLSize {
width: 1,
height: 1,
depth: 1,
};

encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

command_buffer.commit();
command_buffer.wait_until_completed();

let result_limbs: Vec<u32> = read_buffer(&result_buf, num_limbs);
let result = BigInt::from_limbs(&result_limbs, log_limb_size);

assert!(result == expected);
assert!(result_limbs == expected_limbs);
}
114 changes: 114 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// adapted from: https://github.com/geometryxyz/msl-secp256k1

use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::host::gpu::{
create_buffer, create_empty_buffer, get_default_device, read_buffer,
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use ark_bn254::Fr as ScalarField;
use ark_ff::{BigInt, BigInteger, PrimeField};
use metal::*;

#[test]
#[serial_test::serial]
pub fn test_ff_sub() {
let log_limb_size = 13;
let num_limbs = 20;

// Scalar field modulus for bn254
let p = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A029,
]);
assert!(p == ScalarField::MODULUS);

let a = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A028,
]);
let b = BigInt::new([
0xAAAAAAAAF0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E7200000000,
]);

let device = get_default_device();
let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size));
let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size));
let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size));
let result_buf = create_empty_buffer(&device, num_limbs);

// Perform (a - b) % p
let mut expected = a.clone();
expected.sub_with_borrow(&b);

// If result is negative, add p until it's positive
while expected < BigInt::zero() {
expected.add_with_carry(&p);
}
let expected_limbs = expected.to_limbs(num_limbs, log_limb_size);

let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let compute_pass_descriptor = ComputePassDescriptor::new();
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
0,
0,
);
let library_path = compile_metal(
"../mopro-msm/src/msm/metal_msm/shader/field",
"ff_sub.metal",
);
let library = device.new_library_with_file(library_path).unwrap();
let kernel = library.get_function("run", None).unwrap();

let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&kernel));

let pipeline_state = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();

encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&a_buf), 0);
encoder.set_buffer(1, Some(&b_buf), 0);
encoder.set_buffer(2, Some(&p_buf), 0);
encoder.set_buffer(3, Some(&result_buf), 0);

let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};

let thread_group_size = MTLSize {
width: 1,
height: 1,
depth: 1,
};

encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

command_buffer.commit();
command_buffer.wait_until_completed();

let result_limbs: Vec<u32> = read_buffer(&result_buf, num_limbs);
let result = BigInt::from_limbs(&result_limbs, log_limb_size);

assert!(result == expected);
assert!(result_limbs == expected_limbs);
}
4 changes: 4 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/field/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#[cfg(test)]
pub mod ff_add;
#[cfg(test)]
pub mod ff_sub;
3 changes: 3 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
#[cfg(test)]
pub mod bigint;
#[cfg(test)]
pub mod field;

0 comments on commit a0ae553

Please sign in to comment.