Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/metal/ec #31

Merged
merged 6 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 98 additions & 24 deletions mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// source: https://github.com/geometryxyz/msl-secp256k1
// algorithms: https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html

using namespace metal;
#include <metal_stdlib>
Expand All @@ -11,7 +12,12 @@ struct Jacobian {
BigInt z;
};

Jacobian jacobian_add_2007_bl_unsafe(
struct Affine {
BigInt x;
BigInt y;
};

Jacobian jacobian_add_2007_bl(
Jacobian a,
Jacobian b,
BigInt p
Expand All @@ -23,36 +29,36 @@ Jacobian jacobian_add_2007_bl_unsafe(
BigInt y2 = b.y;
BigInt z2 = b.z;

BigInt z1z1 = mont_mul_optimised(z1, z1, p);
BigInt z2z2 = mont_mul_optimised(z2, z2, p);
BigInt u1 = mont_mul_optimised(x1, z2z2, p);
BigInt u2 = mont_mul_optimised(x2, z1z1, p);
BigInt y1z2 = mont_mul_optimised(y1, z2, p);
BigInt s1 = mont_mul_optimised(y1z2, z2z2, p);
BigInt z1z1 = mont_mul_cios(z1, z1, p);
BigInt z2z2 = mont_mul_cios(z2, z2, p);
BigInt u1 = mont_mul_cios(x1, z2z2, p);
BigInt u2 = mont_mul_cios(x2, z1z1, p);
BigInt y1z2 = mont_mul_cios(y1, z2, p);
BigInt s1 = mont_mul_cios(y1z2, z2z2, p);

BigInt y2z1 = mont_mul_optimised(y2, z1, p);
BigInt s2 = mont_mul_optimised(y2z1, z1z1, p);
BigInt y2z1 = mont_mul_cios(y2, z1, p);
BigInt s2 = mont_mul_cios(y2z1, z1z1, p);
BigInt h = ff_sub(u2, u1, p);
BigInt h2 = ff_add(h, h, p);
BigInt i = mont_mul_optimised(h2, h2, p);
BigInt j = mont_mul_optimised(h, i, p);
BigInt i = mont_mul_cios(h2, h2, p);
BigInt j = mont_mul_cios(h, i, p);

BigInt s2s1 = ff_sub(s2, s1, p);
BigInt r = ff_add(s2s1, s2s1, p);
BigInt v = mont_mul_optimised(u1, i, p);
BigInt v = mont_mul_cios(u1, i, p);
BigInt v2 = ff_add(v, v, p);
BigInt r2 = mont_mul_optimised(r, r, p);
BigInt r2 = mont_mul_cios(r, r, p);
BigInt jv2 = ff_add(j, v2, p);
BigInt x3 = ff_sub(r2, jv2, p);

BigInt vx3 = ff_sub(v, x3, p);
BigInt rvx3 = mont_mul_optimised(r, vx3, p);
BigInt rvx3 = mont_mul_cios(r, vx3, p);
BigInt s12 = ff_add(s1, s1, p);
BigInt s12j = mont_mul_optimised(s12, j, p);
BigInt s12j = mont_mul_cios(s12, j, p);
BigInt y3 = ff_sub(rvx3, s12j, p);

BigInt z1z2 = mont_mul_optimised(z1, z2, p);
BigInt z1z2h = mont_mul_optimised(z1z2, h, p);
BigInt z1z2 = mont_mul_cios(z1, z2, p);
BigInt z1z2h = mont_mul_cios(z1z2, h, p);
BigInt z3 = ff_add(z1z2h, z1z2h, p);

Jacobian result;
Expand All @@ -70,26 +76,26 @@ Jacobian jacobian_dbl_2009_l(
BigInt y = pt.y;
BigInt z = pt.z;

BigInt a = mont_mul_optimised(x, x, p);
BigInt b = mont_mul_optimised(y, y, p);
BigInt c = mont_mul_optimised(b, b, p);
BigInt a = mont_mul_cios(x, x, p);
BigInt b = mont_mul_cios(y, y, p);
BigInt c = mont_mul_cios(b, b, p);
BigInt x1b = ff_add(x, b, p);
BigInt x1b2 = mont_mul_optimised(x1b, x1b, p);
BigInt x1b2 = mont_mul_cios(x1b, x1b, p);
BigInt ac = ff_add(a, c, p);
BigInt x1b2ac = ff_sub(x1b2, ac, p);
BigInt d = ff_add(x1b2ac, x1b2ac, p);
BigInt a2 = ff_add(a, a, p);
BigInt e = ff_add(a2, a, p);
BigInt f = mont_mul_optimised(e, e, p);
BigInt f = mont_mul_cios(e, e, p);
BigInt d2 = ff_add(d, d, p);
BigInt x3 = ff_sub(f, d2, p);
BigInt c2 = ff_add(c, c, p);
BigInt c4 = ff_add(c2, c2, p);
BigInt c8 = ff_add(c4, c4, p);
BigInt dx3 = ff_sub(d, x3, p);
BigInt edx3 = mont_mul_optimised(e, dx3, p);
BigInt edx3 = mont_mul_cios(e, dx3, p);
BigInt y3 = ff_sub(edx3, c8, p);
BigInt y1z1 = mont_mul_optimised(y, z, p);
BigInt y1z1 = mont_mul_cios(y, z, p);
BigInt z3 = ff_add(y1z1, y1z1, p);

Jacobian result;
Expand All @@ -98,3 +104,71 @@ Jacobian jacobian_dbl_2009_l(
result.z = z3;
return result;
}

//http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-madd-2007-bl
Jacobian jacobian_madd_2007_bl(
Jacobian a,
Affine b,
BigInt p
) {
BigInt x1 = a.x;
BigInt y1 = a.y;
BigInt z1 = a.z;
BigInt x2 = b.x;
BigInt y2 = b.y;

// Z1Z1 = Z1^2
BigInt z1z1 = mont_mul_cios(z1, z1, p);

// U2 = X2*Z1Z1
BigInt u2 = mont_mul_cios(x2, z1z1, p);

// S2 = Y2*Z1*Z1Z1
BigInt temp_s2 = mont_mul_cios(y2, z1, p);
BigInt s2 = mont_mul_cios(temp_s2, z1z1, p);

// H = U2-X1
BigInt h = ff_sub(u2, x1, p);

// HH = H^2
BigInt hh = mont_mul_cios(h, h, p);

// I = 4*HH
BigInt i = ff_add(hh, hh, p); // *2
i = ff_add(i, i, p); // *4

// J = H*I
BigInt j = mont_mul_cios(h, i, p);

// r = 2*(S2-Y1)
BigInt s2_minus_y1 = ff_sub(s2, y1, p);
BigInt r = ff_add(s2_minus_y1, s2_minus_y1, p);

// V = X1*I
BigInt v = mont_mul_cios(x1, i, p);

// X3 = r^2-J-2*V
BigInt r2 = mont_mul_cios(r, r, p);
BigInt v2 = ff_add(v, v, p);
BigInt jv2 = ff_add(j, v2, p);
BigInt x3 = ff_sub(r2, jv2, p);

// Y3 = r*(V-X3)-2*Y1*J
BigInt v_minus_x3 = ff_sub(v, x3, p);
BigInt r_vmx3 = mont_mul_cios(r, v_minus_x3, p);
BigInt y1j = mont_mul_cios(y1, j, p);
BigInt y1j2 = ff_add(y1j, y1j, p);
BigInt y3 = ff_sub(r_vmx3, y1j2, p);

// Z3 = (Z1+H)^2-Z1Z1-HH
BigInt z1_plus_h = ff_add(z1, h, p);
BigInt z1_plus_h_squared = mont_mul_cios(z1_plus_h, z1_plus_h, p);
BigInt temp = ff_sub(z1_plus_h_squared, z1z1, p);
BigInt z3 = ff_sub(temp, hh, p);

Jacobian result;
result.x = x3;
result.y = y3;
result.z = z3;
return result;
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ kernel void run(
Jacobian a; a.x = x1; a.y = y1; a.z = z1;
Jacobian b; b.x = x2; b.y = y2; b.z = z2;

Jacobian res = jacobian_add_2007_bl_unsafe(a, b, p);
Jacobian res = jacobian_add_2007_bl(a, b, p);
result_xr->limbs = res.x.limbs;
result_yr->limbs = res.y.limbs;
result_zr->limbs = res.z.limbs;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// source: https://github.com/geometryxyz/msl-secp256k1

using namespace metal;
#include <metal_stdlib>
#include <metal_math>
#include "jacobian.metal"

kernel void run(
device BigInt* prime [[ buffer(0) ]],
device BigInt* a_xr [[ buffer(1) ]],
device BigInt* a_yr [[ buffer(2) ]],
device BigInt* a_zr [[ buffer(3) ]],
device BigInt* b_xr [[ buffer(4) ]],
device BigInt* b_yr [[ buffer(5) ]],
device BigInt* result_xr [[ buffer(6) ]],
device BigInt* result_yr [[ buffer(7) ]],
device BigInt* result_zr [[ buffer(8) ]],
uint gid [[ thread_position_in_grid ]]
) {
BigInt p; p.limbs = prime->limbs;
BigInt x1; x1.limbs = a_xr->limbs;
BigInt y1; y1.limbs = a_yr->limbs;
BigInt z1; z1.limbs = a_zr->limbs;
BigInt x2; x2.limbs = b_xr->limbs;
BigInt y2; y2.limbs = b_yr->limbs;

Jacobian a; a.x = x1; a.y = y1; a.z = z1;
Affine b; b.x = x2; b.y = y2;

Jacobian res = jacobian_madd_2007_bl(a, b, p);
result_xr->limbs = res.x.limbs;
result_yr->limbs = res.y.limbs;
result_zr->limbs = res.z.limbs;
}
176 changes: 176 additions & 0 deletions mopro-msm/src/msm/metal_msm/tests/curve/jacobian_add_2007_b1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// adapted from https://github.com/geometryxyz/msl-secp256k1

use ark_bn254::{Fq as BaseField, Fr as ScalarField, G1Affine as GAffine, G1Projective as G};
use ark_ec::AffineRepr;
use metal::*;

use crate::msm::metal_msm::host::gpu::{
create_buffer, create_empty_buffer, get_default_device, read_buffer,
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use crate::msm::metal_msm::utils::mont_params::{calc_mont_radix, calc_nsafe, calc_rinv_and_n0};
use ark_ff::{BigInt, PrimeField};
use ark_std::{rand::thread_rng, UniformRand};
use num_bigint::BigUint;

#[test]
#[serial_test::serial]
pub fn test_jacobian_add_2007_bl() {
let log_limb_size = 16;
let p: BigUint = BaseField::MODULUS.try_into().unwrap();

let modulus_bits = BaseField::MODULUS_BIT_SIZE as u32;
let num_limbs = ((modulus_bits + log_limb_size - 1) / log_limb_size) as usize;

let r = calc_mont_radix(num_limbs, log_limb_size);
let res = calc_rinv_and_n0(&p, &r, log_limb_size);
let rinv = res.0;
let n0 = res.1;
let nsafe = calc_nsafe(log_limb_size);

// Generate 2 random affine points
let (a, b) = {
let mut rng = thread_rng();
let base_point = GAffine::generator().into_group();

let s1 = ScalarField::rand(&mut rng);
let mut s2 = ScalarField::rand(&mut rng);

// Ensure s1 and s2 are different (if s1 == s2, we use pDBL instead of pADD)
while s1 == s2 {
s2 = ScalarField::rand(&mut rng);
}

(base_point * s1, base_point * s2)
};

// Compute the sum in projective form using Arkworks
let expected = a + b;

let ax: BigUint = a.x.into();
let ay: BigUint = a.y.into();
let az: BigUint = a.z.into();
let bx: BigUint = b.x.into();
let by: BigUint = b.y.into();
let bz: BigUint = b.z.into();

let axr = (&ax * &r) % &p;
let ayr = (&ay * &r) % &p;
let azr = (&az * &r) % &p;
let bxr = (&bx * &r) % &p;
let byr = (&by * &r) % &p;
let bzr = (&bz * &r) % &p;

let p_limbs = BaseField::MODULUS.to_limbs(num_limbs, log_limb_size);
let axr_limbs = ark_ff::BigInt::<4>::try_from(axr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let ayr_limbs = ark_ff::BigInt::<4>::try_from(ayr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let azr_limbs = ark_ff::BigInt::<4>::try_from(azr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let bxr_limbs = ark_ff::BigInt::<4>::try_from(bxr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let byr_limbs = ark_ff::BigInt::<4>::try_from(byr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);
let bzr_limbs = ark_ff::BigInt::<4>::try_from(bzr.clone())
.unwrap()
.to_limbs(num_limbs, log_limb_size);

let device = get_default_device();
let prime_buf = create_buffer(&device, &p_limbs);
let axr_buf = create_buffer(&device, &axr_limbs);
let ayr_buf = create_buffer(&device, &ayr_limbs);
let azr_buf = create_buffer(&device, &azr_limbs);
let bxr_buf = create_buffer(&device, &bxr_limbs);
let byr_buf = create_buffer(&device, &byr_limbs);
let bzr_buf = create_buffer(&device, &bzr_limbs);
let result_xr_buf = create_empty_buffer(&device, num_limbs);
let result_yr_buf = create_empty_buffer(&device, num_limbs);
let result_zr_buf = create_empty_buffer(&device, num_limbs);

let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let compute_pass_descriptor = ComputePassDescriptor::new();
let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor);

write_constants(
"../mopro-msm/src/msm/metal_msm/shader",
num_limbs,
log_limb_size,
n0,
nsafe,
);
let library_path = compile_metal(
"../mopro-msm/src/msm/metal_msm/shader/curve",
"jacobian_add_2007_bl.metal",
);
let library = device.new_library_with_file(library_path).unwrap();
let kernel = library.get_function("run", None).unwrap();

let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&kernel));

let pipeline_state = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();

encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&prime_buf), 0);
encoder.set_buffer(1, Some(&axr_buf), 0);
encoder.set_buffer(2, Some(&ayr_buf), 0);
encoder.set_buffer(3, Some(&azr_buf), 0);
encoder.set_buffer(4, Some(&bxr_buf), 0);
encoder.set_buffer(5, Some(&byr_buf), 0);
encoder.set_buffer(6, Some(&bzr_buf), 0);
encoder.set_buffer(7, Some(&result_xr_buf), 0);
encoder.set_buffer(8, Some(&result_yr_buf), 0);
encoder.set_buffer(9, Some(&result_zr_buf), 0);

let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};

let thread_group_size = MTLSize {
width: 1,
height: 1,
depth: 1,
};

encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();

command_buffer.commit();
command_buffer.wait_until_completed();

let result_xr_limbs: Vec<u32> = read_buffer(&result_xr_buf, num_limbs);
let result_yr_limbs: Vec<u32> = read_buffer(&result_yr_buf, num_limbs);
let result_zr_limbs: Vec<u32> = read_buffer(&result_zr_buf, num_limbs);

let result_xr: BigUint = BigInt::from_limbs(&result_xr_limbs, log_limb_size)
.try_into()
.unwrap();
let result_yr: BigUint = BigInt::from_limbs(&result_yr_limbs, log_limb_size)
.try_into()
.unwrap();
let result_zr: BigUint = BigInt::from_limbs(&result_zr_limbs, log_limb_size)
.try_into()
.unwrap();

let result_x = (result_xr * &rinv) % &p;
let result_y = (result_yr * &rinv) % &p;
let result_z = (result_zr * &rinv) % &p;

let result = G::new(result_x.into(), result_y.into(), result_z.into());
assert!(result == expected);
}
Loading