From bf1636b41334f96f547b3583e9f752c7cc984f87 Mon Sep 17 00:00:00 2001 From: "Hao-Chen (Moven) Tsai" <60170228+moven0831@users.noreply.github.com> Date: Sat, 21 Dec 2024 21:58:12 +0800 Subject: [PATCH] Bigint impl for refactored metal EC backend (#21) * feat: import mont_mul backend from https://github.com/geometryxyz/msl-secp256k1 * chore: add refactored metal_msm, will remove the previous one once this is completed * chore: migrate prev utils mod to refactored metal msm * feat: conversion between bigint and arbitrary limb size * lint * chore: ignore all metal ir and lib * test(bigint): add host test * test(bigint): adapt from https://github.com/geometryxyz/msl-secp256k1 * chore: ignore all constants file since it's been automatically generated * refactor: add overflow detection and correct suitable bigint val for each cases --- Cargo.lock | 175 +++++++++++ mopro-msm/.gitignore | 9 +- mopro-msm/Cargo.toml | 3 + .../msm/metal/abstraction/limbs_conversion.rs | 71 +++++ mopro-msm/src/msm/metal_msm/host/errors.rs | 19 ++ mopro-msm/src/msm/metal_msm/host/gpu.rs | 114 +++++++ mopro-msm/src/msm/metal_msm/host/mod.rs | 4 + mopro-msm/src/msm/metal_msm/host/shader.rs | 124 ++++++++ mopro-msm/src/msm/metal_msm/host/state.rs | 118 +++++++ mopro-msm/src/msm/metal_msm/mod.rs | 3 + .../msm/metal_msm/shader/bigint/bigint.metal | 131 ++++++++ .../shader/bigint/bigint_add_unsafe.metal | 20 ++ .../shader/bigint/bigint_add_unsafe.metal.ir | Bin 0 -> 14928 bytes .../shader/bigint/bigint_add_wide.metal | 20 ++ .../metal_msm/shader/bigint/bigint_sub.metal | 20 ++ .../metal_msm/shader/bigint/constants.metal | 8 + .../msm/metal_msm/shader/bigint/u128.h.metal | 173 +++++++++++ .../msm/metal_msm/shader/bigint/u256.h.metal | 263 ++++++++++++++++ .../src/msm/metal_msm/shader/constants.metal | 8 + .../metal_msm/shader/curve/ec_point.h.metal | 110 +++++++ .../msm/metal_msm/shader/curve/jacobian.metal | 100 ++++++ .../curve/jacobian_add_2007_bl_unsafe.metal | 36 +++ .../shader/curve/jacobian_dbl_2009_l.metal | 29 ++ .../src/msm/metal_msm/shader/field/ff.metal | 63 ++++ .../msm/metal_msm/shader/field/ff_add.metal | 24 ++ .../msm/metal_msm/shader/field/ff_sub.metal | 24 ++ .../metal_msm/shader/field/fp_bn254.h.metal | 291 ++++++++++++++++++ .../metal_msm/shader/mont_backend/mont.metal | 94 ++++++ .../mont_backend/mont_mul_modified.metal | 25 ++ .../mont_backend/mont_mul_optimised.metal | 25 ++ .../tests/bigint/bigint_add_unsafe.rs | 93 ++++++ .../metal_msm/tests/bigint/bigint_add_wide.rs | 186 +++++++++++ .../msm/metal_msm/tests/bigint/bigint_sub.rs | 182 +++++++++++ .../src/msm/metal_msm/tests/bigint/mod.rs | 3 + mopro-msm/src/msm/metal_msm/tests/mod.rs | 1 + .../msm/metal_msm/utils/limbs_conversion.rs | 91 ++++++ mopro-msm/src/msm/metal_msm/utils/mod.rs | 2 + .../src/msm/metal_msm/utils/mont_reduction.rs | 40 +++ mopro-msm/src/msm/mod.rs | 1 + 39 files changed, 2702 insertions(+), 1 deletion(-) create mode 100644 mopro-msm/src/msm/metal_msm/host/errors.rs create mode 100644 mopro-msm/src/msm/metal_msm/host/gpu.rs create mode 100644 mopro-msm/src/msm/metal_msm/host/mod.rs create mode 100644 mopro-msm/src/msm/metal_msm/host/shader.rs create mode 100644 mopro-msm/src/msm/metal_msm/host/state.rs create mode 100644 mopro-msm/src/msm/metal_msm/mod.rs create mode 100644 mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal.ir create mode 100644 mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/bigint/constants.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/bigint/u128.h.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/bigint/u256.h.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/constants.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/curve/ec_point.h.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl_unsafe.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/field/ff.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/field/fp_bn254.h.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal create mode 100644 mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised.metal create mode 100644 mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs create mode 100644 mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs create mode 100644 mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs create mode 100644 mopro-msm/src/msm/metal_msm/tests/bigint/mod.rs create mode 100644 mopro-msm/src/msm/metal_msm/tests/mod.rs create mode 100644 mopro-msm/src/msm/metal_msm/utils/limbs_conversion.rs create mode 100644 mopro-msm/src/msm/metal_msm/utils/mod.rs create mode 100644 mopro-msm/src/msm/metal_msm/utils/mont_reduction.rs diff --git a/Cargo.lock b/Cargo.lock index 9d9ad9c..703b01e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1216,6 +1216,83 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1528,6 +1605,16 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.22" @@ -1658,6 +1745,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", + "serial_test", "thiserror", "toml", "wasmer 2.3.0 (git+https://github.com/oskarth/wasmer.git?rev=09c7070)", @@ -1867,6 +1955,29 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "paste" version = "1.0.15" @@ -1890,6 +2001,12 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkcs8" version = "0.10.2" @@ -2131,6 +2248,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "regalloc" version = "0.0.34" @@ -2388,6 +2514,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "scc" +version = "2.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8d25269dd3a12467afe2e510f69fb0b46b698e5afb296b59f2145259deaf8e8" +dependencies = [ + "sdd", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2400,6 +2535,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" +[[package]] +name = "sdd" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49c1eeaf4b6a87c7479688c6d52b9f1153cedd3c489300564f932b065c6eab95" + [[package]] name = "seahash" version = "4.1.0" @@ -2503,6 +2644,31 @@ dependencies = [ "serde", ] +[[package]] +name = "serial_test" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b4b487fe2acf240a021cf57c6b2b4903b1e78ca0ecd862a71b71d2a51fed77d" +dependencies = [ + "futures", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "sha2" version = "0.10.8" @@ -2549,6 +2715,15 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.13.2" diff --git a/mopro-msm/.gitignore b/mopro-msm/.gitignore index 1012b39..f619967 100644 --- a/mopro-msm/.gitignore +++ b/mopro-msm/.gitignore @@ -17,4 +17,11 @@ Cargo.lock src/middleware/gpu_explorations/utils/vectors # GPU exploration - proptest generated files -proptest-regressions \ No newline at end of file +proptest-regressions + +# Metal shader intermediate files and libraries +src/msm/metal_msm/shader/**/*.ir +src/msm/metal_msm/shader/**/*.lib + +# Metal shader constants file +src/msm/metal_msm/shader/**/constants.metal \ No newline at end of file diff --git a/mopro-msm/Cargo.toml b/mopro-msm/Cargo.toml index 11f3639..5615835 100644 --- a/mopro-msm/Cargo.toml +++ b/mopro-msm/Cargo.toml @@ -58,6 +58,9 @@ serde_derive = "1.0" wasmer = { git = "https://github.com/oskarth/wasmer.git", rev = "09c7070" } witness = { git = "https://github.com/philsippl/circom-witness-rs.git" } +[dev-dependencies] +serial_test = "3.0.0" + # [dependencies.rayon] # version = "1" # optional=false \ No newline at end of file diff --git a/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs b/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs index ff9f7d2..ebf772b 100644 --- a/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs +++ b/mopro-msm/src/msm/metal/abstraction/limbs_conversion.rs @@ -6,10 +6,12 @@ use crate::msm::metal::abstraction::mont_reduction; // implement to_u32_limbs and from_u32_limbs for BigInt<4> pub trait ToLimbs { fn to_u32_limbs(&self) -> Vec; + fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec; } pub trait FromLimbs { fn from_u32_limbs(limbs: &[u32]) -> Self; + fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self; fn from_u128(num: u128) -> Self; fn from_u32(num: u32) -> Self; } @@ -26,6 +28,37 @@ impl ToLimbs for BigInteger256 { }); limbs } + + fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec { + let mut result = vec![0u32; num_limbs]; + let limb_size = 1u32 << log_limb_size; + let mask = limb_size - 1; + + // Convert to little-endian representation + let bytes = self.to_bytes_le(); + let mut val = 0u32; + let mut bits = 0u32; + let mut limb_idx = 0; + + for &byte in bytes.iter() { + val |= (byte as u32) << bits; + bits += 8; + + while bits >= log_limb_size && limb_idx < num_limbs { + result[limb_idx] = val & mask; + val >>= log_limb_size; + bits -= log_limb_size; + limb_idx += 1; + } + } + + // Handle any remaining bits + if bits > 0 && limb_idx < num_limbs { + result[limb_idx] = val; + } + + result + } } // convert from little endian to big endian @@ -40,6 +73,10 @@ impl ToLimbs for Fq { }); limbs } + + fn to_limbs(&self, num_limbs: usize, log_limb_size: u32) -> Vec { + self.0.to_limbs(num_limbs, log_limb_size) + } } impl FromLimbs for BigInteger256 { @@ -63,6 +100,35 @@ impl FromLimbs for BigInteger256 { fn from_u32(num: u32) -> Self { BigInteger256::new([num as u64, 0, 0, 0]) } + + fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self { + let mut result = [0u64; 4]; + let limb_size = log_limb_size as usize; + let mut accumulated_bits = 0; + let mut current_u64 = 0u64; + let mut result_idx = 0; + + for &limb in limbs { + // Add the current limb at the appropriate position + current_u64 |= (limb as u64) << accumulated_bits; + accumulated_bits += limb_size; + + // If we've accumulated 64 bits or more, store the result + while accumulated_bits >= 64 && result_idx < 4 { + result[result_idx] = current_u64; + current_u64 = limb as u64 >> (limb_size - (accumulated_bits - 64)); + accumulated_bits -= 64; + result_idx += 1; + } + } + + // Handle any remaining bits + if accumulated_bits > 0 && result_idx < 4 { + result[result_idx] = current_u64; + } + + BigInteger256::new(result) + } } impl FromLimbs for Fq { @@ -88,4 +154,9 @@ impl FromLimbs for Fq { num as u64, 0, 0, 0, ]))) } + + fn from_limbs(limbs: &[u32], log_limb_size: u32) -> Self { + let bigint = BigInteger256::from_limbs(limbs, log_limb_size); + Fq::new(mont_reduction::raw_reduction(bigint)) + } } diff --git a/mopro-msm/src/msm/metal_msm/host/errors.rs b/mopro-msm/src/msm/metal_msm/host/errors.rs new file mode 100644 index 0000000..9a1e935 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/host/errors.rs @@ -0,0 +1,19 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum MetalError { + #[error("Couldn't find a system default device for Metal")] + DeviceNotFound(), + #[error("Couldn't create a new Metal library: {0}")] + LibraryError(String), + #[error("Couldn't create a new Metal function object: {0}")] + FunctionError(String), + #[error("Couldn't create a new Metal compute pipeline: {0}")] + PipelineError(String), + #[error("Could not calculate {1} root of unity")] + RootOfUnityError(String, u64), + // #[error("Input length is {0}, which is not a power of two")] + // InputError(usize), + #[error("Invalid input: {0}")] + InputError(String), +} diff --git a/mopro-msm/src/msm/metal_msm/host/gpu.rs b/mopro-msm/src/msm/metal_msm/host/gpu.rs new file mode 100644 index 0000000..90392f3 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/host/gpu.rs @@ -0,0 +1,114 @@ +use metal::*; + +pub fn get_default_device() -> metal::Device { + Device::system_default().expect("No device found") +} + +pub fn create_buffer(device: &Device, data: &Vec) -> metal::Buffer { + device.new_buffer_with_data( + unsafe { std::mem::transmute(data.as_ptr()) }, + (data.len() * std::mem::size_of::()) as u64, + MTLResourceOptions::CPUCacheModeDefaultCache, + ) +} + +pub fn read_buffer(result_buf: &metal::Buffer, num_u32s: usize) -> Vec { + let ptr = result_buf.contents() as *const u32; + let result_limbs: Vec; + + // Check if ptr is not null + if !ptr.is_null() { + result_limbs = unsafe { std::slice::from_raw_parts(ptr, num_u32s) }.to_vec(); + } else { + panic!("Pointer is null"); + } + result_limbs +} + +pub fn create_empty_buffer(device: &Device, size: usize) -> metal::Buffer { + let data = vec![0u32; size]; + create_buffer(device, &data) +} + +// From metal-rs +pub fn create_counter_sample_buffer(device: &Device, num_samples: usize) -> CounterSampleBuffer { + let counter_sample_buffer_desc = metal::CounterSampleBufferDescriptor::new(); + counter_sample_buffer_desc.set_storage_mode(metal::MTLStorageMode::Shared); + counter_sample_buffer_desc.set_sample_count(num_samples as u64); + let counter_sets = device.counter_sets(); + + let timestamp_counter = counter_sets.iter().find(|cs| cs.name() == "timestamp"); + + counter_sample_buffer_desc + .set_counter_set(timestamp_counter.expect("No timestamp counter found")); + + device + .new_counter_sample_buffer_with_descriptor(&counter_sample_buffer_desc) + .unwrap() +} + +pub fn handle_compute_pass_sample_buffer_attachment( + compute_pass_descriptor: &ComputePassDescriptorRef, + counter_sample_buffer: &CounterSampleBufferRef, +) { + let sample_buffer_attachment_descriptor = compute_pass_descriptor + .sample_buffer_attachments() + .object_at(0) + .unwrap(); + + sample_buffer_attachment_descriptor.set_sample_buffer(counter_sample_buffer); + sample_buffer_attachment_descriptor.set_start_of_encoder_sample_index(0); + sample_buffer_attachment_descriptor.set_end_of_encoder_sample_index(1); +} + +pub fn resolve_samples_into_buffer( + command_buffer: &CommandBufferRef, + counter_sample_buffer: &CounterSampleBufferRef, + destination_buffer: &BufferRef, + num_samples: usize, +) { + let blit_encoder = command_buffer.new_blit_command_encoder(); + blit_encoder.resolve_counters( + counter_sample_buffer, + metal::NSRange::new(0_u64, num_samples as u64), + destination_buffer, + 0_u64, + ); + blit_encoder.end_encoding(); +} + +// From metal-rs +pub fn handle_timestamps( + resolved_sample_buffer: &BufferRef, + cpu_start: u64, + cpu_end: u64, + gpu_start: u64, + gpu_end: u64, + num_samples: usize, +) { + let samples = unsafe { + std::slice::from_raw_parts(resolved_sample_buffer.contents() as *const u64, num_samples) + }; + let pass_start = samples[0]; + let pass_end = samples[1]; + println!("samples: {:?}", samples); + + let cpu_time_span = cpu_end - cpu_start; + let gpu_time_span = gpu_end - gpu_start; + + let micros = microseconds_between_begin(pass_start, pass_end, gpu_time_span, cpu_time_span); + println!("Compute pass duration: {} µs", micros); +} + +// From metal-rs +/// +pub fn microseconds_between_begin( + begin: u64, + end: u64, + gpu_time_span: u64, + cpu_time_span: u64, +) -> f64 { + let time_span = (end as f64) - (begin as f64); + let nanoseconds = time_span / (gpu_time_span as f64) * (cpu_time_span as f64); + nanoseconds / 1000.0 +} diff --git a/mopro-msm/src/msm/metal_msm/host/mod.rs b/mopro-msm/src/msm/metal_msm/host/mod.rs new file mode 100644 index 0000000..223e361 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/host/mod.rs @@ -0,0 +1,4 @@ +pub mod errors; +// pub mod state; +pub mod gpu; +pub mod shader; diff --git a/mopro-msm/src/msm/metal_msm/host/shader.rs b/mopro-msm/src/msm/metal_msm/host/shader.rs new file mode 100644 index 0000000..330f8c4 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/host/shader.rs @@ -0,0 +1,124 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +/* + * It is necessary to hardcode certain constants into MSL source code but dynamically generate the + * code so that the Rust binary that runs a shader can insert said constants. + * + * Shader lifecycle: + * + * MSL source -> Compiled .metallib file -> Loaded by program -> Sent to GPU + * + * xcrun -sdk macosx metal -c -o + * xcrun -sdk macosx metallib -o + */ + +use std::fs; +use std::path::PathBuf; +use std::process::Command; +use std::string::String; + +pub fn compile_metal(path_from_cargo_manifest_dir: &str, input_filename: &str) -> String { + let input_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(path_from_cargo_manifest_dir) + .join(input_filename); + let c = input_path.clone().into_os_string().into_string().unwrap(); + + let ir = input_path.clone().into_os_string().into_string().unwrap(); + let ir = format!("{}.ir", ir); + + let exe = if cfg!(target_os = "ios") { + Command::new("xcrun") + .args([ + "-sdk", + "iphoneos", + "metal", + "-c", + c.as_str(), + "-o", + ir.as_str(), + ]) + .output() + .expect("failed to compile") + } else if cfg!(target_os = "macos") { + Command::new("xcrun") + .args([ + "-sdk", + "macosx", + "metal", + "-c", + c.as_str(), + "-o", + ir.as_str(), + ]) + .output() + .expect("failed to compile") + } else { + panic!("Unsupported architecture"); + }; + + if exe.stderr.len() != 0 { + panic!("{}", String::from_utf8(exe.stderr).unwrap()); + } + + let lib = input_path.clone().into_os_string().into_string().unwrap(); + let lib = format!("{}.lib", lib); + + let exe = if cfg!(target_os = "ios") { + Command::new("xcrun") + .args(["-sdk", "iphoneos", "metal", ir.as_str(), "-o", lib.as_str()]) + .output() + .expect("failed to compile") + } else if cfg!(target_os = "macos") { + Command::new("xcrun") + .args(["-sdk", "macosx", "metal", ir.as_str(), "-o", lib.as_str()]) + .output() + .expect("failed to compile") + } else { + panic!("Unsupported architecture"); + }; + + if exe.stderr.len() != 0 { + panic!("{}", String::from_utf8(exe.stderr).unwrap()); + } + + lib +} + +pub fn write_constants( + filepath: &str, + num_limbs: usize, + log_limb_size: u32, + n0: u32, + nsafe: usize, +) { + let two_pow_word_size = 2u32.pow(log_limb_size); + let mask = two_pow_word_size - 1u32; + + let mut data = "// THIS FILE IS AUTOGENERATED BY shader.rs\n".to_owned(); + data += format!("#define NUM_LIMBS {}\n", num_limbs).as_str(); + data += format!("#define NUM_LIMBS_WIDE {}\n", num_limbs + 1).as_str(); + data += format!("#define LOG_LIMB_SIZE {}\n", log_limb_size).as_str(); + data += format!("#define TWO_POW_WORD_SIZE {}\n", two_pow_word_size).as_str(); + data += format!("#define MASK {}\n", mask).as_str(); + data += format!("#define N0 {}\n", n0).as_str(); + data += format!("#define NSAFE {}\n", nsafe).as_str(); + + let output_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(filepath) + .join("constants.metal"); + fs::write(output_path, data).expect("Unable to write constants file"); +} + +#[cfg(test)] +pub mod tests { + use super::compile_metal; + + #[test] + pub fn test_compile() { + let lib_filepath = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader", + "bigint/bigint_add_unsafe.metal", + ); + println!("{}", lib_filepath); + } +} diff --git a/mopro-msm/src/msm/metal_msm/host/state.rs b/mopro-msm/src/msm/metal_msm/host/state.rs new file mode 100644 index 0000000..8a7cdde --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/host/state.rs @@ -0,0 +1,118 @@ +use metal::{ComputeCommandEncoderRef, MTLResourceOptions}; + +use super::errors::MetalError; + +const LIB_DATA: &[u8] = include_bytes!("../shader/msm.metallib"); + +/// Structure for abstracting basic calls to a Metal device and saving the state. Used for +/// implementing GPU parallel computations in Apple machines. +pub struct MetalState { + pub device: metal::Device, + pub library: metal::Library, + pub queue: metal::CommandQueue, +} + +impl MetalState { + /// Creates a new Metal state with an optional `device` (GPU). If `None` is passed then it will use + /// the system's default. + pub fn new(device: Option) -> Result { + let device: metal::Device = + device.unwrap_or(metal::Device::system_default().ok_or(MetalError::DeviceNotFound())?); + + let library = device + .new_library_with_data(LIB_DATA) // TODO: allow different files + .map_err(MetalError::LibraryError)?; + let queue = device.new_command_queue(); + + Ok(Self { + device, + library, + queue, + }) + } + + /// Creates a pipeline based on a compute function `kernel` which needs to exist in the state's + /// library. A pipeline is used for issuing commands to the GPU through command buffers, + /// executing the `kernel` function. + pub fn setup_pipeline( + &self, + kernel_name: &str, + ) -> Result { + let kernel = self + .library + .get_function(kernel_name, None) + .map_err(MetalError::FunctionError)?; + + let pipeline = self + .device + .new_compute_pipeline_state_with_function(&kernel) + .map_err(MetalError::PipelineError)?; + + Ok(pipeline) + } + + /// Allocates `length` bytes of shared memory between CPU and the device (GPU). + pub fn alloc_buffer(&self, length: usize) -> metal::Buffer { + let size = mem::size_of::(); + + self.device.new_buffer( + (length * size) as u64, + MTLResourceOptions::StorageModeShared, // TODO: use managed mode + ) + } + + /// Allocates `data` in a buffer of shared memory between CPU and the device (GPU). + pub fn alloc_buffer_data(&self, data: &[T]) -> metal::Buffer { + let size = mem::size_of::(); + + self.device.new_buffer_with_data( + data.as_ptr() as *const ffi::c_void, + (data.len() * size) as u64, + MTLResourceOptions::StorageModeShared, // TODO: use managed mode + ) + } + + pub fn set_bytes(index: usize, data: &[T], encoder: &ComputeCommandEncoderRef) { + let size = mem::size_of::(); + + encoder.set_bytes( + index as u64, + (data.len() * size) as u64, + data.as_ptr() as *const ffi::c_void, + ); + } + + /// Creates a command buffer and a compute encoder in a pipeline, optionally issuing `buffers` + /// to it. + pub fn setup_command( + &self, + pipeline: &metal::ComputePipelineState, + buffers: Option<&[(u64, &metal::Buffer)]>, + ) -> (&metal::CommandBufferRef, &metal::ComputeCommandEncoderRef) { + let command_buffer = self.queue.new_command_buffer(); + let command_encoder = command_buffer.new_compute_command_encoder(); + command_encoder.set_compute_pipeline_state(pipeline); + + if let Some(buffers) = buffers { + for (i, buffer) in buffers.iter() { + command_encoder.set_buffer(*i, Some(buffer), 0); + } + } + + (command_buffer, command_encoder) + } + + /// Returns a vector of a copy of the data that `buffer` holds, interpreting it into a specific + /// type `T`. + /// + /// BEWARE: this function uses an unsafe function for retrieveing the data, if the buffer's + /// contents don't match the specified `T`, expect undefined behaviour. Always make sure the + /// buffer you are retreiving from holds data of type `T`. + pub fn retrieve_contents(buffer: &metal::Buffer) -> Vec { + let ptr = buffer.contents() as *const T; + let len = buffer.length() as usize / mem::size_of::(); + let slice = unsafe { std::slice::from_raw_parts(ptr, len) }; + + slice.to_vec() + } +} diff --git a/mopro-msm/src/msm/metal_msm/mod.rs b/mopro-msm/src/msm/metal_msm/mod.rs new file mode 100644 index 0000000..0b2e95f --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/mod.rs @@ -0,0 +1,3 @@ +pub mod host; +pub mod tests; +pub mod utils; diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal new file mode 100644 index 0000000..d0ff646 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint.metal @@ -0,0 +1,131 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include "constants.metal" + +struct BigInt { + array limbs; +}; + +struct BigIntWide { + array limbs; +}; + +BigInt bigint_zero() { + BigInt s; + for (uint i = 0; i < NUM_LIMBS; i ++) { + s.limbs[i] = 0; + } + return s; +} + +BigInt bigint_add_unsafe( + BigInt lhs, + BigInt rhs +) { + BigInt result; + uint mask = (1 << LOG_LIMB_SIZE) - 1; + uint carry = 0; + + for (uint i = 0; i < NUM_LIMBS; i ++) { + uint c = lhs.limbs[i] + rhs.limbs[i] + carry; + result.limbs[i] = c & mask; + carry = c >> LOG_LIMB_SIZE; + } + return result; +} + +BigIntWide bigint_add_wide( + BigInt lhs, + BigInt rhs +) { + BigIntWide result; + uint mask = (1 << LOG_LIMB_SIZE) - 1; + uint carry = 0; + + for (uint i = 0; i < NUM_LIMBS; i ++) { + uint c = lhs.limbs[i] + rhs.limbs[i] + carry; + result.limbs[i] = c & mask; + carry = c >> LOG_LIMB_SIZE; + } + result.limbs[NUM_LIMBS] = carry; + + return result; +} + +BigInt bigint_sub( + BigInt lhs, + BigInt rhs +) { + uint borrow = 0; + + BigInt res; + + for (uint i = 0; i < NUM_LIMBS; i ++) { + res.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - borrow; + + if (lhs.limbs[i] < (rhs.limbs[i] + borrow)) { + res.limbs[i] = res.limbs[i] + TWO_POW_WORD_SIZE; + borrow = 1; + } else { + borrow = 0; + } + } + + return res; +} + + +BigIntWide bigint_sub_wide( + BigIntWide lhs, + BigIntWide rhs +) { + uint borrow = 0; + + BigIntWide res; + + for (uint i = 0; i < NUM_LIMBS; i ++) { + res.limbs[i] = lhs.limbs[i] - rhs.limbs[i] - borrow; + + if (lhs.limbs[i] < (rhs.limbs[i] + borrow)) { + res.limbs[i] = res.limbs[i] + TWO_POW_WORD_SIZE; + borrow = 1; + } else { + borrow = 0; + } + } + + return res; +} + +bool bigint_gte( + BigInt lhs, + BigInt rhs +) { + for (uint idx = 0; idx < NUM_LIMBS; idx ++) { + uint i = NUM_LIMBS - 1 - idx; + if (lhs.limbs[i] < rhs.limbs[i]) { + return false; + } else if (lhs.limbs[i] > rhs.limbs[i]) { + return true; + } + } + + return true; +} + +bool bigint_wide_gte( + BigIntWide lhs, + BigIntWide rhs +) { + for (uint idx = 0; idx < NUM_LIMBS_WIDE; idx ++) { + uint i = NUM_LIMBS_WIDE - 1 - idx; + if (lhs.limbs[i] < rhs.limbs[i]) { + return false; + } else if (lhs.limbs[i] > rhs.limbs[i]) { + return true; + } + } + + return true; +} diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal new file mode 100644 index 0000000..7791963 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal @@ -0,0 +1,20 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "bigint.metal" + +kernel void run( + device BigInt* lhs [[ buffer(0) ]], + device BigInt* rhs [[ buffer(1) ]], + device BigInt* result [[ buffer(2) ]], + uint gid [[ thread_position_in_grid ]] +) { + BigInt a; + BigInt b; + a.limbs = lhs->limbs; + b.limbs = rhs->limbs; + BigInt res = bigint_add_unsafe(a, b); + result->limbs = res.limbs; +} diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal.ir b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_unsafe.metal.ir new file mode 100644 index 0000000000000000000000000000000000000000..9863771b56fa9edd98c9800ca00ccf076d61844b GIT binary patch literal 14928 zcmeHueOOahy6?_+2#^E}A4(Du1Pw(nAR?&Pe4s^2ZJ<<8Q4$ag7BGZ>h{eeVaf8xM zY@wx{b|&#-hIU#XY;8-OP8$+x8Kq7{ro(YK-axITj`g6!IE?M-+;{I}qk+!cd!KXu zyC)AjEBoE+eb>7__HX^xyB61+QNaj&cyOh!MUa2MN1=3HTMYsV$jU>c@kTCf2zxFL zL6AR%B1kYEjLDQBr9GBRt$2}c$v)T z$60CR?YX)^@zR`dw391;fxLT&d4?xaZfY0(j+9p37X66i5nJlBeWk}l=ZC;_*%R?#Cp=EF5g`$imx&wAvc&#S)%ckUmmbC4)mh+cE%bl zs<)#$H6z<&(b*UWnCP>)a^rL98@r@`k122eyo)?QrG!;QQ%xanBUNwSFbS_agy&7d zal7DxN$4^OW@H?fTzFl~zdI1>715ZRMtB9T>-*Zp8H+kYa#-Q3Bxc#s#a;1^s^>*g z_fjVHm5I&aFKj~SVTYZ(qNkrmwoyorARZ}A`WV8Z+S1w)DpC-QAb*7V#N1svi?uOH zN^^3MOymc|hOnk6aeOaHczr@J?Fb!q2*)-2t4_Yx5qee3pCSp678}|kVHl-YU793l zh7sML%h+#6^)`lxKV}jj#|0w;f~!Wsg$X|Rf7L0xpy5wV2(CLqZ;J&(CZUI9q^Nc1 zel4ZG7d5CEEuh}RFk%vmLC0w5K{KR`X04@($!L)=j97zQAc@8&z?8r>Ah-cuk%U+6 z0vBcqZ)ZKi+~165^q_`J#(oJyucfp~7zdis16oRh5zWw|Iy=LFOv!~e9el(kxaJ6j z-!78SD;A7T@aLIoKLGs|jW%j2%{COPS2UxgnE|Hidl?NMH8BJQsK?#gEbU(P)j~YZ+(?LN4SY-zt!V!~T z3e+`1U-hAbhBd+;>sGv9_A{YJ{%J3$g(DOEaR-0M5jx}G!#MfpJp#SMWHcL5gAQ$& z-=Q$nT^jxuj!>^$cn#`B3=;5mcz_Spyb1-$V=s`L6k{*uj}99DN1DhODhNY@Wh~k1 zHs#tOWlXy=yT_8%Wm$jLwC*EiOdDn0NoD3RrO%N>Rg-c_y@z$VfxU?jJe?XfFwp}@ zhT%N20!c&+(tmBul*u&mGIEBnuH_ih`rYf!nleX~Yx^kKCoSuSDadAc$Efg*Ei$Xk zT;`%9CP`<7!U*?$SowXFwOR|-|7xfcjC8afMynIev{Tj?;m;EKK>@oK;k2k320Nok z#=w4f=-+S|Ek^9G*2ZqOGq68g*kota!+XSKG{LQup|>*{U;w}mxDB*$Bk<{sz^C`X z!b7I+iLlb=Ox+9Ay(hyQnDg!4JX1E@lnBdH58oa-7}taI4}I6FIrTQoTfGX)jD15k z5u*|6^3XSAYXmJjO^3Tzk1+Hy%q({6q2I=f0Bhj(5^+AeVDjnLU;gSZ#oyJg z{LMFi(cPzYfAWW~Qd@u6`^l%(O&`);{^Z;0zx}=O@E>@9Q#m1-VpsBK}7JAi`^qeqqsXp2lHJ( zo}vkuXqfOA8#|C)og2Z{czimT7(rxCTV8w${02Q4Eyci?BX(C}k0R0LlC&%AH7f?hVDk)=_%fuc zm`48K1iy+$%5GO?wz)eMMu22M_xg|<2{~YWgp^Y!<$O8Ko>s9}aR&w0#QZU_ zz@-sZNjuhkWLf7?W}T(1|H!f)0AP5U*Q90~)Z()-d>c`n4V#mjB#fq68kpU0M9bhE zj7If6=z$*7irdrd+g|oxo$OmuO8F3goCCvn_=d^wf|x&O5?s*GzpyE@hLl-7_;N*Qd;e2KnB#{(_GDzk>&>-z_? zhTKrc^~1Sm>tX&;GwPuMOVGuCI?ukP=6vg=Pm4HTNjYDMIQLZSIxhQ5DI0G8;-!D% zWnWjZXQfLgZ)q@sFg6i&OANE9U{EZ89-I=#R*7B*|URmrg+{aKG959>&jp(9vvwe zqu1Kvvcfj<^TT!?8wjzS3B7dwn4Hi1PP+c5w5M_&7j3Q7H&pMd+h3v8=*t?aTPmyc zn%c@_-M;ENL#19*txu`dG?(qJY^tuPEGuuUs)GE4v|?XfgF#bgm@hYI^_804d-VGn zbqgL!3YFc%?;6zHS$-N|UCoN=aS4S|q_AR8=O=i4vc{ z+Gpj{4Ul(b?Y^eUIs>%ON|k=!-pUFCCMpLnlXd!iP1SXKR_@U?CTl7xpk?)V55aSI zlnl+hvO!<55=-$S%5aHRv%6Bivb=f^bca7FgT5mlNqa!RXlwkyKWCx zmowZ2j2(Y1EnP#Avreg5zmmGS42O}mih3<%V2u)?ucB(CZX>-(#8|~GS26ZNO0AG8 z)!<u9sO%MvYgKUtlt0kHNDRJ;EyXAGH>g#ZZnc49#dsK zUtlDEc9|sbm^!J;Bug_gev|R?`)l4=`+i!~s`vZe{jlJLbxzBAk0sk{*)VN+G&fNcyhDa^mrW4L^&i%%weDVl5gQyLGj6Waz^9 z)TW>k@ymVbSy8`9`|!iFYd$D2Z_RxK>{yeBzkMs^)?2qOp3bjJ`S$bLE@#9Gn8jBP z{`$q3r0>!&tNP$)l425?X;?qM0dsExq`@q`1m-kGONp)raMTU6N0!HJzQI55Tlrv% z7HmBLb46tOmumLcDwu*e*RhF+^R4oL7BvKa+*$L5C;-D~d5v!rz=E!kae z@-r~m@?q^XF5+L;VBb?XoAYd*NuuuP#VWrljAH9_tuT4exJ*ih+Oo=yM86x60~{%qr3k&|4VAErsXJ1s2zSPj zc~!=~?mUiVP$GyTmPKRVw6mrD4AL2g7AVTJ_k=$~PAqGPeNW9<>(8K?arE4Ea=xF- zuql>xl%o*yTt>2bcxwm=;kQcOc~nFzR8#ya#uXs{(n@G0M(cV)k{`=T*I8kBt|25(#2#7QPTSFE%_JmS zVp(hK*3uzs79n{u_DE(wZI_q!6d_T?vetU6WwW$7%LTDV)=kmsxr}TgPhl)88$tCF z#s)%ig-(e;Y7o>$ifWboVCRp@1!fH;t^_4Hg2 z>FfMKB>seSQ7ZB~N$0Xsr&fW?I7HWs>?v>P;m~$jwDEGAcuJ`OvNkrbx<*EIA)U*3 z(Am4$(+H;z9Ou&2=@ZfSRgy+Kqs0~)!!=2f6G1_jQS*nC4PHxXbcTqmK{hRV={U#w zBIA&kG2TZCV^XBAdZJq;_H2Zvnj#ANdP*t&t4Y#sUgIw)CpVULlap)Wjj6|s%r;W= z6Fo0%;BA%km=oIu{AN^%t&+Y18LfzGA-vgsU!^@|XT2$3vaI{6DO1*4?zOME{O!sWw@PLa!BvM~FX0x} z_=oK5|CnaKNw`%6FPut^Hn*cs6KE`>I}Z_AU2@!8y0+@Y-#4qKBr?2BO6CqMOPI*8OSD zKV-o~9T(@eO4O!UMO^%x`B?2Mx1gzJ+jx6KKJKci@Q!cE1>85bcjj-sxF>EkQ! zlZn!uX9r|7B+Z!+QiQ}&N)_}{fd8ORqMy6WlpuiouOiQlNS-s1)#ExRja=%Cyo#1D z&r_~0qAbSoqfQ$Df|_{d!{u%c#h{?Z<6vSn}Xdl=&Ru=4T$ z7py$_&#;mO^Wy(tCGkT44_5yFh?T_)v9fysR(3DI%I*bN*}VWOyBA<(_X4c!UVxR| z3$U_#0akV|z{>6gSlPV*E4vq9W%mNC>|TJC-3zd?djVEZeYNS3vMXgk^K2@<33V4+TOFk_SCKSe3D%@HOYX5KeWr?3( zj;q->oB{mODHn_e@XH)5b{N1f4Y2VSz%S2Ovd1U^{L;c`^#t%sQTD9m(ExsVM=og# z;FlLTn>4&H2rh?FoJdKRB9LE73o@-w1@g<2Hr5+~{4!q?vDFm7FCB`cv-9|6(L}_q zdHnJ%Zt$6T{L&dwrHQ;x_~hf4PDS$QJbp>p)^FW0k6+4yKeRpIm$L-F^hP$DqUs1D zhx1EY@Welf&8)r+x-gUmuD^OhvxB1*#4Sw1oF!}YR)SvfY9FQhj{c-avj?MP9B8OQFbd>7{BcJm>Wv)%h})&5eAf{YG;^y#^LF> zEDIgYJRH+$3IT4}Q^#bA86sGmD54RAmO?)qWBvlMY&$GtlFI?xaLjy|XBJ|ZNqOdB zp1Ey4&-_=IdEuX7CeAb4$k>|Pfv?GN%ru-5&Bx3<%{B~aXx0|iRWQvF>5|%o?d{NI}_()=D9-cJj^tl^3KQ1yovdk*-Z|akC}O5`h3jX znK-lnGn1@~=VN9g?8VK;%+tKRftY!c^oMS9XpJVSX@KWntx{DGz#!X`tTESN&?!;~ z8m#nmli5uqzEe)*Y*S3-i5b5Yu_vX(y4ICyIAyn#i5Pkjv91-1&lQfzGhv7QEU~Wb zqwMQ9PR_>DW6WQ~PjV86w`s67$% z)JQj(_XV8Hcn}51`zBA!`c}lb>zQ-wJs4Eh{SdU#&AC-^uCPb3e$0~n{hV8?cP6Sd z!f8@$IpNmSYo~c{syO%UoOcMfo_(^IwH3ombQ|H83%Qv$!NM@p$o8+vUC}u^6V*6o z))2-dlsA&BD=^F)oO6GVwDFYP`V&#m3(h&uLVHj1-sgrd(}dsg6*?Tw+?lAC$8VX8 z&oN^G!)XOXDT`UVBthM(pe7>PCNDlxChK@f8?>L0D2iG2B|$yLpk_j%EIy*wb-dCO z)JjNlidlyCpr7@1v=EYw#YY((|@#&kf@4TFZBlh z))U-AND7LNyzD&s@l^0Dgru;T^~zYV2ML`k{i)(3uTCGm93A=rk!KW(VO~UDm#Rqp zA-14I-Q{k@^4z7X3ztHUorz7Z^sN=E=`3TtiS!uta(CQ{qAyqP`ZBssN>A%IrKrz| zMQwZ|j1-mz@%&q8baWk4O9`o3OcpcSNj|Klk7E4Pj`P#q7}hFlAI4gY8ZV-5#ADoe z{FAr&70*+Yr4a1i9K@3^?oy<7DL3;rZtL1q-}P7r56ko8v2Nd*{#706L=ee?nwu$W z?UpsfAU41?FKQf}`;e=)S?8HFC}q{H^53 z2HqCotSNRk5uAbmobevQlgVTJ!O@?e;8;f)$3?6^JrsTG#e?5O_)wE~Usd8zvwrC% zVyxXTbH~t%WY&HDuwnm88+h9jd(Dp+iQuR$wN=vpl$y4+ch1Z+_f^ercv+_(kXMMi zA(p3&v=R}KX#(PI?ee4{^NKFQ5*LfRr8U#G+i8yy&4b0=H1sp-;P)T!S1j)4QAg@X zQ6ohC@wnTPQ3xU_gamT_1j5Kv*i^`HT6ay+N(qZtgU3v{R(`LDC30FJ9>;DL^IZd+ zhuU1s_*lyRa+*Q3sA9%pr|2Z8UQIxTCh@*%%#`_>JFA@lGc4jfZRC%L1)B)`X?}hK zuSL!Nb07O80cKdxd4mjX(xTf5w|;q&V~y=~ulvZINVtVHK5tAcoFvUL;b$OP%(@NR z^ELuwvADB{*4&G>OwYLm?cZZwKkm-%n{z9xRiZKp?}}r$6K>7^d;{-IHRroN&S}D} zHqt0+j+l1D-k;zT4UyV@U)68k@V)!dx%jAx4qFZ-Uev}6n#1Bt>YGY&@DI}PrD7iz zL&N)`p7rp)x`cTlppWQ=im5yt1LhX6@i2kR$AHs54A2rIw!&}<#{kU(3`p`}0I@;B74J-(j{)Zj zaSWJ#fC1e&2GFF$SUr6z$-0&Grk$lBM$T@SK5p|$Oc6JUH*)ON)4W}e^qZ4x#Ov+J z+L`E7N2p1W7l=%=KR?YY_psg(1y>Q4U<`JfLw1=WwGUp#)OH{Z@)cv|#fd3tX!oQQD{?(vnythM-VTLDpQ`0*sGpxvGInq_V?;_=JfWX_OS zFyfdaZReo_88`aae^wK3uB!pqDaD{gn=*+vSFAW!IBQz_k$c_voLdPy6J;jBq*z!) zxYgTrns?j_hkrP~Cfq80yqHyJL|b|pTM4()ewk!V9CBxmQ!)ua*>RobA?0ZM>_{|M+x_J^rwnh+mVhJBtbg}&qjXLOF z_!n3)1z7O}h7}@?4=XPCu%cVR$Cs!$EjT}spf%6O3flv$mEoC+|)GaVt*4ISe#!e@=b&N?7rV`D-Rp8tjsKEJR zP*JmSl?Ge|P9tC{0$k-Lu2SMtS%<4!1r_R@J~^Db8{<0N^a;Tv;CQP6s8cvg4aYTd)w(dOYh1hyz)UVp$Q&I#{?k@eHX> z%Z4dSCngqAun=IT&v3}Bh^H{Yi3H#%PAu*RhY?^4E>z~=5F`+1OH54hIilvnmTaVPH+gC%`Eg1oR^A;tG9~t?kwlsZW7Zs)?+c@F`>) zt=oI?Lkg{ud$P#7iOAg$ikNXsQMM0ROFg*4BnMk^YB*R4M^rd1!p;n|N;b)(4oyVm z;|dbmE+(p*^(pY4)$q@gc(_6l279OqSJ*1&Uz*_W!WFbMY@b&|!|FL1F{}}M;qWPx zU<`+Won>g1RLKR`CImWMp_kUwjyC7|6e8iQ`we*nUP38`VCd7hLX%v0djkHi0 zJPI54y>vTgW6566}_*1$y)DfB zrXKm;Z>so?-_#Jx<~OyqZO~`xoFd+DDq?B9&y-Y>JL&VGLe=HZcEK*~@VWY$yguCT zm_;FEcgYINFP@=;{}R*5!Iw^ivJPgd$eTNoKa1i0O=qXR5WeI2Ov&_wEGRzVeSU-L zy4w_$>YLDt62euELL7ySm4c+2a39=!kLJ{K&^%GnnXmdg$g%8h+1B!l#};kez<Z2dexF$*kX?reMZz@ci*iwCsIQGia$S*67q8($IiJCDIo_DvJ={BC#29;H?X!U+n$>0JLOx} zQVA!8;eL(({=sqI4Bz?br^?{z>W7~0uijnx@H46ZSwmy_!(u`O#ATRDs?=w%7=Gh6 z3B5gr0CR*Mrc(yb0`-6_&K-pY(wLE+ticXYC)a8!_BAx8rl&|#r2idnEkL ImVfa1-)aa>0{{R3 literal 0 HcmV?d00001 diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal new file mode 100644 index 0000000..30ffd35 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_add_wide.metal @@ -0,0 +1,20 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "bigint.metal" + +kernel void run( + device BigInt* lhs [[ buffer(0) ]], + device BigInt* rhs [[ buffer(1) ]], + device BigIntWide* result [[ buffer(2) ]], + uint gid [[ thread_position_in_grid ]] +) { + BigInt a; + BigInt b; + a.limbs = lhs->limbs; + b.limbs = rhs->limbs; + BigIntWide res = bigint_add_wide(a, b); + result->limbs = res.limbs; +} diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal new file mode 100644 index 0000000..552f012 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/bigint_sub.metal @@ -0,0 +1,20 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "bigint.metal" + +kernel void run( + device BigInt* lhs [[ buffer(0) ]], + device BigInt* rhs [[ buffer(1) ]], + device BigInt* result [[ buffer(2) ]], + uint gid [[ thread_position_in_grid ]] +) { + BigInt a; + BigInt b; + a.limbs = lhs->limbs; + b.limbs = rhs->limbs; + BigInt res = bigint_sub(a, b); + result->limbs = res.limbs; +} diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/constants.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/constants.metal new file mode 100644 index 0000000..6ece636 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/constants.metal @@ -0,0 +1,8 @@ +// THIS FILE IS AUTOGENERATED BY shader.rs +#define NUM_LIMBS 20 +#define NUM_LIMBS_WIDE 21 +#define LOG_LIMB_SIZE 13 +#define TWO_POW_WORD_SIZE 8192 +#define MASK 8191 +#define N0 0 +#define NSAFE 0 diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/u128.h.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/u128.h.metal new file mode 100644 index 0000000..07a3c94 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/u128.h.metal @@ -0,0 +1,173 @@ +// source: https://github.com/andrewmilson/ministark/blob/875fb385bab9fcbb347d4c69898b56cbeeb71ca1/gpu/src/metal/u256.h.metal + +#ifndef u128_h +#define u128_h + +class u128 +{ +public: + u128() = default; + constexpr u128(int l) : low(l), high(0) {} + constexpr u128(unsigned long l) : low(l), high(0) {} + constexpr u128(bool b) : low(b), high(0) {} + constexpr u128(unsigned long h, unsigned long l) : low(l), high(h) {} + + constexpr u128 operator+(const u128 rhs) const + { + return u128(high + rhs.high + ((low + rhs.low) < low), low + rhs.low); + } + + constexpr u128 operator+=(const u128 rhs) + { + *this = *this + rhs; + return *this; + } + + constexpr inline u128 operator-(const u128 rhs) const + { + return u128(high - rhs.high - ((low - rhs.low) > low), low - rhs.low); + } + + constexpr u128 operator-=(const u128 rhs) + { + *this = *this - rhs; + return *this; + } + + constexpr bool operator==(const u128 rhs) const + { + return high == rhs.high && low == rhs.low; + } + + constexpr bool operator!=(const u128 rhs) const + { + return !(*this == rhs); + } + + constexpr bool operator<(const u128 rhs) const + { + return ((high == rhs.high) && (low < rhs.low)) || (high < rhs.high); + } + + constexpr u128 operator&(const u128 rhs) const + { + return u128(high & rhs.high, low & rhs.low); + } + + constexpr u128 operator|(const u128 rhs) const + { + return u128(high | rhs.high, low | rhs.low); + } + + constexpr bool operator>(const u128 rhs) const + { + return ((high == rhs.high) && (low > rhs.low)) || (high > rhs.high); + } + + constexpr bool operator>=(const u128 rhs) const + { + return !(*this < rhs); + } + + constexpr bool operator<=(const u128 rhs) const + { + return !(*this > rhs); + } + + constexpr inline u128 operator>>(unsigned shift) const + { + // TODO: reduce branch conditions + if (shift >= 128) + { + return u128(0); + } + else if (shift == 64) + { + return u128(0, high); + } + else if (shift == 0) + { + return *this; + } + else if (shift < 64) + { + return u128(high >> shift, (high << (64 - shift)) | (low >> shift)); + } + else if ((128 > shift) && (shift > 64)) + { + return u128(0, (high >> (shift - 64))); + } + else + { + return u128(0); + } + } + + constexpr inline u128 operator<<(unsigned shift) const + { + // TODO: reduce branch conditions + if (shift >= 128) + { + return u128(0); + } + else if (shift == 64) + { + return u128(low, 0); + } + else if (shift == 0) + { + return *this; + } + else if (shift < 64) + { + return u128((high << shift) | (low >> (64 - shift)), low << shift); + } + else if ((128 > shift) && (shift > 64)) + { + return u128((low >> (shift - 64)), 0); + } + else + { + return u128(0); + } + } + + constexpr u128 operator>>=(unsigned rhs) + { + *this = *this >> rhs; + return *this; + } + + u128 operator*(const bool rhs) const + { + return u128(high * rhs, low * rhs); + } + + u128 operator*(const u128 rhs) const + { + unsigned long t_low_high = metal::mulhi(low, rhs.high); + unsigned long t_high = metal::mulhi(low, rhs.low); + unsigned long t_high_low = metal::mulhi(high, rhs.low); + unsigned long t_low = low * rhs.low; + return u128(t_low_high + t_high_low + t_high, t_low); + } + + u128 operator*=(const u128 rhs) + { + *this = *this * rhs; + return *this; + } + + // TODO: Could get better performance with smaller limb size + // Not sure what word size is for M1 GPU +#ifdef __LITTLE_ENDIAN__ + unsigned long low; + unsigned long high; +#endif +#ifdef __BIG_ENDIAN__ + unsigned long high; + unsigned long low; +#endif +}; + +#endif /* u128_h */ diff --git a/mopro-msm/src/msm/metal_msm/shader/bigint/u256.h.metal b/mopro-msm/src/msm/metal_msm/shader/bigint/u256.h.metal new file mode 100644 index 0000000..74a3474 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/bigint/u256.h.metal @@ -0,0 +1,263 @@ +// source: https://github.com/andrewmilson/ministark/blob/875fb385bab9fcbb347d4c69898b56cbeeb71ca1/gpu/src/metal/u256.h.metal + +#ifndef u256_h +#define u256_h + +#include "u128.h.metal" + +class u256 +{ +public: + u256() = default; + constexpr u256(int l) : low(l), high(0) {} + constexpr u256(unsigned long l) : low(u128(l)), high(0) {} + constexpr u256(u128 l) : low(l), high(0) {} + constexpr u256(bool b) : low(b), high(0) {} + constexpr u256(u128 h, u128 l) : low(l), high(h) {} + constexpr u256(unsigned long hh, unsigned long hl, unsigned long lh, unsigned long ll) : + low(u128(lh, ll)), high(u128(hh, hl)) {} + + constexpr u256 operator+(const u256 rhs) const + { + return u256(high + rhs.high + ((low + rhs.low) < low), low + rhs.low); + } + + constexpr u256 operator+=(const u256 rhs) + { + *this = *this + rhs; + return *this; + } + + constexpr inline u256 operator-(const u256 rhs) const + { + return u256(high - rhs.high - ((low - rhs.low) > low), low - rhs.low); + } + + constexpr u256 operator-=(const u256 rhs) + { + *this = *this - rhs; + return *this; + } + + constexpr bool operator==(const u256 rhs) const + { + return high == rhs.high && low == rhs.low; + } + + constexpr bool operator!=(const u256 rhs) const + { + return !(*this == rhs); + } + + constexpr bool operator<(const u256 rhs) const + { + return ((high == rhs.high) && (low < rhs.low)) || (high < rhs.high); + } + + constexpr u256 operator&(const u256 rhs) const + { + return u256(high & rhs.high, low & rhs.low); + } + + constexpr bool operator>(const u256 rhs) const + { + return ((high == rhs.high) && (low > rhs.low)) || (high > rhs.high); + } + + constexpr bool operator>=(const u256 rhs) const + { + return !(*this < rhs); + } + + constexpr bool operator<=(const u256 rhs) const + { + return !(*this > rhs); + } + + constexpr inline u256 operator>>(unsigned shift) const + { + // TODO: reduce branch conditions + if (shift >= 256) + { + return u256(0); + } + else if (shift == 128) + { + return u256(0, high); + } + else if (shift == 0) + { + return *this; + } + else if (shift < 128) + { + return u256(high >> shift, (high << (128 - shift)) | (low >> shift)); + } + else if ((256 > shift) && (shift > 128)) + { + return u256(0, (high >> (shift - 128))); + } + else + { + return u256(0); + } + } + + constexpr u256 operator>>=(unsigned rhs) + { + *this = *this >> rhs; + return *this; + } + + u256 operator*(const bool rhs) const + { + return u256(high * rhs, low * rhs); + } + + u256 operator*(const u256 rhs) const + { + // split values into 4 64-bit parts + u128 top[2] = {u128(low.high), u128(low.low)}; + u128 bottom[3] = {u128(rhs.high.low), u128(rhs.low.high), u128(rhs.low.low)}; + + unsigned long tmp3_3 = high.high * rhs.low.low; + unsigned long tmp0_0 = low.low * rhs.high.high; + unsigned long tmp2_2 = high.low * rhs.low.high; + + u128 tmp2_3 = u128(high.low) * bottom[2]; + u128 tmp0_3 = top[1] * bottom[2]; + u128 tmp1_3 = top[0] * bottom[2]; + + u128 tmp0_2 = top[1] * bottom[1]; + u128 third64 = u128(tmp0_2.low) + u128(tmp0_3.high); + u128 tmp1_2 = top[0] * bottom[1]; + + u128 tmp0_1 = top[1] * bottom[0]; + u128 second64 = u128(tmp0_1.low) + u128(tmp0_2.high); + unsigned long first64 = tmp0_0 + tmp0_1.high; + + u128 tmp1_1 = top[0] * bottom[0]; + first64 += tmp1_1.low + tmp1_2.high; + + // second row + third64 += u128(tmp1_3.low); + second64 += u128(tmp1_2.low) + u128(tmp1_3.high); + + // third row + second64 += u128(tmp2_3.low); + first64 += tmp2_2 + tmp2_3.high; + + // fourth row + first64 += tmp3_3; + second64 += u128(third64.high); + first64 += second64.high; + + return u256(u128(first64, second64.low), u128(third64.low, tmp0_3.low)); + + + // // unsigned long t_low_high_low = high * rhs.low; + // // unsigned long t_low_low_high = low * rhs.high; + + // // unsigned long t_low = low * rhs.low; + + // // u128 t_low = low * rhs.low; + + // // unsigned long t_low_high = metal::mulhi(low.low, rhs.low.high); + // // unsigned long t_high_low = metal::mulhi(low.high, rhs.low.low); + // // unsigned long t_high = metal::mulhi(low.low, rhs.low.low); + // // unsigned long t_low = low.low * rhs.low.low; + + // // u128 low_low = u128(t_low_high + t_high_low + t_high, t_low); + + // // t_low_high = metal::mulhi(low.low, rhs.low.high); + // // t_high_low = metal::mulhi(low.high, rhs.low.low); + // // t_high = metal::mulhi(low.low, rhs.low.low); + // // t_low = low.low * rhs.low.low; + + // // return ; + + // // split values into 4 64-bit parts + // u128 top[3] = {u128(high.low), u128(low.high), u128(low.low)}; + // u128 bottom[3] = {u128(rhs.high.low), u128(rhs.low.high), u128(rhs.low.low)}; + // // u128 top[4] = {high >> 32, high & 0xffffffff, low >> 32, low & 0xffffffff}; + // // u128 bottom[4] = {rhs.high >> 32, rhs.high & 0xffffffff, rhs.low >> 32, rhs.low & 0xffffffff}; + // // u128 products[4][4]; + + // // // multiply each component of the values + // // Alternative: + // // for(int y = 3; y > -1; y--){ + // // for(int x = 3; x > -1; x--){ + // // products[3 - x][y] = top[x] * bottom[y]; + // // } + // // } + // u128 tmp0_3 = top[2] * bottom[2]; + // u128 tmp1_3 = top[1] * bottom[2]; + // u128 tmp2_3 = top[0] * bottom[2]; + // // u128 tmp3_3 = top[0] * bottom[2]; + // unsigned long tmp3_3 = high.high * rhs.low.low; + // // unsigned long tmp0 = low.low * rhs.high.high; + + // u128 tmp0_2 = top[2] * bottom[1]; + // u128 tmp1_2 = top[1] * bottom[1]; + // // u128 tmp2_2 = top[0] * bottom[1]; + // unsigned long tmp2_2 = high.low * rhs.low.high; + + + // u128 tmp0_1 = top[2] * bottom[0]; + // u128 tmp1_1 = top[1] * bottom[0]; + // // u128 tmp3_1 = top[0] * bottom[0]; + + // unsigned long tmp0_0 = low.low * rhs.high.high; + + // // first row + // u128 fourth64 = tmp0_3.low; + // u128 third64 = u128(tmp0_2.low) + u128(tmp0_3.high); + // u128 second64 = u128(tmp0_1.low) + u128(tmp0_2.high); + // u128 first64 = u128(tmp0_0) + u128(tmp0_1.high); + + // // second row + // third64 += u128(tmp1_3.low); + // second64 += u128(tmp1_2.low) + u128(tmp1_3.high); + // first64 += u128(tmp1_1.low) + u128(tmp1_2.high); + + // // third row + // second64 += u128(tmp2_3.low); + // first64 += u128(tmp2_2) + u128(tmp2_3.high); + + // // fourth row + // first64 += u128(tmp3_3); + // second64 += u128(third64.high); + // first64 += u128(second64.high); + + // // remove carry from current digit + // // fourth64 &= 0xffffffff; // TODO: figure out if this is a nop + // // third64 &= 0xffffffff; + // // second64 = u128(second64.low); + // // first64 &= 0xffffffff; + + // // combine components + // // return u256((first64 << 64) | second64, (third64 << 64) | fourth64); + // return u256(u128(first64.low, second64.low), u128(third64.low, fourth64.low)); + + // // return u128((first64.high second64, (third64 << 64) | fourth64); + } + + u256 operator*=(const u256 rhs) + { + *this = *this * rhs; + return *this; + } + + // TODO: Could get better performance with smaller limb size + // Not sure what word size is for M1 GPU +#ifdef __LITTLE_ENDIAN__ + u128 low; + u128 high; +#endif +#ifdef __BIG_ENDIAN__ + u128 high; + u128 low; +#endif +}; + +#endif /* u256_h */ diff --git a/mopro-msm/src/msm/metal_msm/shader/constants.metal b/mopro-msm/src/msm/metal_msm/shader/constants.metal new file mode 100644 index 0000000..650b981 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/constants.metal @@ -0,0 +1,8 @@ +// THIS FILE IS AUTOGENERATED BY shader.rs +#define NUM_LIMBS 16 +#define NUM_LIMBS_WIDE 17 +#define LOG_LIMB_SIZE 16 +#define TWO_POW_WORD_SIZE 65536 +#define MASK 65535 +#define N0 25481 +#define NSAFE 1 diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/ec_point.h.metal b/mopro-msm/src/msm/metal_msm/shader/curve/ec_point.h.metal new file mode 100644 index 0000000..5bdebf2 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/curve/ec_point.h.metal @@ -0,0 +1,110 @@ +#pragma once + +template +class ECPoint { +public: + Fp x; + Fp y; + Fp t; + Fp z; + + /* + TODO: + + // r is the montgomery radix + fn get_r() -> BigInt { + var r: BigInt; + {{{ r_limbs }}} + return r; + } + + fn get_paf() -> Point { + var result: Point; + let r = get_r(); + result.y = r; + result.z = r; + return result; + } + */ + constexpr ECPoint() : ECPoint(ECPoint::point_at_infinity()) {} + constexpr ECPoint(Fp _x, Fp _y, Fp _t, Fp _z) : x(_x), y(_y), t(_t), z(_z) {} + + constexpr ECPoint operator+(const ECPoint other) const { + + } + + void operator+=(const ECPoint other) { + *this = *this + other; + } + + static ECPoint point_at_infinity() { + return ECPoint(Fp(1), Fp(1), Fp(0)); // Updated to new neutral element (1, 1, 0) + } + + ECPoint operate_with_self(uint64_t exponent) const { + ECPoint result = point_at_infinity(); + ECPoint base = ECPoint(x, y, t, z); + + while (exponent > 0) { + if ((exponent & 1) == 1) { + result = result + base; + } + exponent = exponent >> 1; + base = base + base; + } + + return result; + } + + constexpr ECPoint operator*(uint64_t exponent) const { + return operate_with_self(exponent); + } + + constexpr void operator*=(uint64_t exponent) { + *this = operate_with_self(exponent); + } + + constexpr ECPoint neg() const { + return ECPoint(x, y.neg(), t, z); + } + + constexpr bool is_neutral_element(const ECPoint a_point) const { + return a_point.z == Fp(0); // Updated to check for (1, 1, 0) + } + + constexpr ECPoint double_in_place() const { + if (is_neutral_element(*this)) { + return *this; + } + + // Doubling formulas + Fp a_fp = Fp(A_CURVE).to_montgomery(); + Fp two = Fp(2).to_montgomery(); + Fp three = Fp(3).to_montgomery(); + + Fp eight = Fp(8).to_montgomery(); + + Fp xx = x * x; // x^2 + Fp yy = y * y; // y^2 + Fp yyyy = yy * yy; // y^4 + Fp zz = z * z; // z^2 + + // S = 2 * ((X1 + YY)^2 - XX - YYYY) + Fp s = two * (((x + yy) * (x + yy)) - xx - yyyy); + + // M = 3 * XX + a * ZZ ^ 2 + Fp m = (three * xx) + (a_fp * (zz * zz)); + + // X3 = T = M^2 - 2*S + Fp x3 = (m * m) - (two * s); + + // Z3 = (Y + Z) ^ 2 - YY - ZZ + // or Z3 = 2 * Y * Z + Fp z3 = two * y * z; + + // Y3 = M*(S-X3)-8*YYYY + Fp y3 = m * (s - x3) - eight * yyyy; + + return ECPoint(x3, y3, z3); + } +}; diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal new file mode 100644 index 0000000..71e9b55 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian.metal @@ -0,0 +1,100 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "../mont_backend/mont.metal" + +struct Jacobian { + BigInt x; + BigInt y; + BigInt z; +}; + +Jacobian jacobian_add_2007_bl_unsafe( + Jacobian a, + Jacobian b, + BigInt p +) { + BigInt x1 = a.x; + BigInt y1 = a.y; + BigInt z1 = a.z; + BigInt x2 = b.x; + BigInt y2 = b.y; + BigInt z2 = b.z; + + BigInt z1z1 = mont_mul_optimised(z1, z1, p); + BigInt z2z2 = mont_mul_optimised(z2, z2, p); + BigInt u1 = mont_mul_optimised(x1, z2z2, p); + BigInt u2 = mont_mul_optimised(x2, z1z1, p); + BigInt y1z2 = mont_mul_optimised(y1, z2, p); + BigInt s1 = mont_mul_optimised(y1z2, z2z2, p); + + BigInt y2z1 = mont_mul_optimised(y2, z1, p); + BigInt s2 = mont_mul_optimised(y2z1, z1z1, p); + BigInt h = ff_sub(u2, u1, p); + BigInt h2 = ff_add(h, h, p); + BigInt i = mont_mul_optimised(h2, h2, p); + BigInt j = mont_mul_optimised(h, i, p); + + BigInt s2s1 = ff_sub(s2, s1, p); + BigInt r = ff_add(s2s1, s2s1, p); + BigInt v = mont_mul_optimised(u1, i, p); + BigInt v2 = ff_add(v, v, p); + BigInt r2 = mont_mul_optimised(r, r, p); + BigInt jv2 = ff_add(j, v2, p); + BigInt x3 = ff_sub(r2, jv2, p); + + BigInt vx3 = ff_sub(v, x3, p); + BigInt rvx3 = mont_mul_optimised(r, vx3, p); + BigInt s12 = ff_add(s1, s1, p); + BigInt s12j = mont_mul_optimised(s12, j, p); + BigInt y3 = ff_sub(rvx3, s12j, p); + + BigInt z1z2 = mont_mul_optimised(z1, z2, p); + BigInt z1z2h = mont_mul_optimised(z1z2, h, p); + BigInt z3 = ff_add(z1z2h, z1z2h, p); + + Jacobian result; + result.x = x3; + result.y = y3; + result.z = z3; + return result; +} + +Jacobian jacobian_dbl_2009_l( + Jacobian pt, + BigInt p +) { + BigInt x = pt.x; + BigInt y = pt.y; + BigInt z = pt.z; + + BigInt a = mont_mul_optimised(x, x, p); + BigInt b = mont_mul_optimised(y, y, p); + BigInt c = mont_mul_optimised(b, b, p); + BigInt x1b = ff_add(x, b, p); + BigInt x1b2 = mont_mul_optimised(x1b, x1b, p); + BigInt ac = ff_add(a, c, p); + BigInt x1b2ac = ff_sub(x1b2, ac, p); + BigInt d = ff_add(x1b2ac, x1b2ac, p); + BigInt a2 = ff_add(a, a, p); + BigInt e = ff_add(a2, a, p); + BigInt f = mont_mul_optimised(e, e, p); + BigInt d2 = ff_add(d, d, p); + BigInt x3 = ff_sub(f, d2, p); + BigInt c2 = ff_add(c, c, p); + BigInt c4 = ff_add(c2, c2, p); + BigInt c8 = ff_add(c4, c4, p); + BigInt dx3 = ff_sub(d, x3, p); + BigInt edx3 = mont_mul_optimised(e, dx3, p); + BigInt y3 = ff_sub(edx3, c8, p); + BigInt y1z1 = mont_mul_optimised(y, z, p); + BigInt z3 = ff_add(y1z1, y1z1, p); + + Jacobian result; + result.x = x3; + result.y = y3; + result.z = z3; + return result; +} diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl_unsafe.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl_unsafe.metal new file mode 100644 index 0000000..7f0a00e --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_add_2007_bl_unsafe.metal @@ -0,0 +1,36 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "jacobian.metal" + +kernel void run( + device BigInt* prime [[ buffer(0) ]], + device BigInt* a_xr [[ buffer(1) ]], + device BigInt* a_yr [[ buffer(2) ]], + device BigInt* a_zr [[ buffer(3) ]], + device BigInt* b_xr [[ buffer(4) ]], + device BigInt* b_yr [[ buffer(5) ]], + device BigInt* b_zr [[ buffer(6) ]], + device BigInt* result_xr [[ buffer(7) ]], + device BigInt* result_yr [[ buffer(8) ]], + device BigInt* result_zr [[ buffer(9) ]], + uint gid [[ thread_position_in_grid ]] +) { + BigInt p; p.limbs = prime->limbs; + BigInt x1; x1.limbs = a_xr->limbs; + BigInt y1; y1.limbs = a_yr->limbs; + BigInt z1; z1.limbs = a_zr->limbs; + BigInt x2; x2.limbs = b_xr->limbs; + BigInt y2; y2.limbs = b_yr->limbs; + BigInt z2; z2.limbs = b_zr->limbs; + + Jacobian a; a.x = x1; a.y = y1; a.z = z1; + Jacobian b; b.x = x2; b.y = y2; b.z = z2; + + Jacobian res = jacobian_add_2007_bl_unsafe(a, b, p); + result_xr->limbs = res.x.limbs; + result_yr->limbs = res.y.limbs; + result_zr->limbs = res.z.limbs; +} diff --git a/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal new file mode 100644 index 0000000..c6dcad1 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/curve/jacobian_dbl_2009_l.metal @@ -0,0 +1,29 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "jacobian.metal" + +kernel void run( + device BigInt* prime [[ buffer(0) ]], + device BigInt* a_xr [[ buffer(1) ]], + device BigInt* a_yr [[ buffer(2) ]], + device BigInt* a_zr [[ buffer(3) ]], + device BigInt* result_xr [[ buffer(4) ]], + device BigInt* result_yr [[ buffer(5) ]], + device BigInt* result_zr [[ buffer(6) ]], + uint gid [[ thread_position_in_grid ]] +) { + BigInt p; p.limbs = prime->limbs; + BigInt x1; x1.limbs = a_xr->limbs; + BigInt y1; y1.limbs = a_yr->limbs; + BigInt z1; z1.limbs = a_zr->limbs; + + Jacobian a; a.x = x1; a.y = y1; a.z = z1; + + Jacobian res = jacobian_dbl_2009_l(a, p); + result_xr->limbs = res.x.limbs; + result_yr->limbs = res.y.limbs; + result_zr->limbs = res.z.limbs; +} diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal new file mode 100644 index 0000000..af8cf93 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff.metal @@ -0,0 +1,63 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "../bigint/bigint.metal" + +BigInt ff_add( + BigInt a, + BigInt b, + BigInt p +) { + // Assign p to p_wide + BigIntWide p_wide; + for (uint i = 0; i < NUM_LIMBS; i ++) { + p_wide.limbs[i] = p.limbs[i]; + } + + // a + b + BigIntWide sum_wide = bigint_add_wide(a, b); + + BigInt res; + + // if (a + b) >= p + if (bigint_wide_gte(sum_wide, p_wide)) { + // s = a + b - p + BigIntWide s = bigint_sub_wide(sum_wide, p_wide); + + for (uint i = 0; i < NUM_LIMBS; i ++) { + res.limbs[i] = s.limbs[i]; + } + } else { + for (uint i = 0; i < NUM_LIMBS; i ++) { + res.limbs[i] = sum_wide.limbs[i]; + } + } + + return res; +} + +BigInt ff_sub( + BigInt a, + BigInt b, + BigInt p +) { + // if a >= b + if (bigint_gte(a, b)) { + // a - b + BigInt res = bigint_sub(a, b); + for (uint i = 0; i < NUM_LIMBS; i ++) { + res.limbs[i] = res.limbs[i]; + } + return res; + } else { + // p - (b - a) + BigInt r = bigint_sub(b, a); + BigInt res = bigint_sub(p, r); + for (uint i = 0; i < NUM_LIMBS; i ++) { + res.limbs[i] = res.limbs[i]; + } + return res; + } +} diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal new file mode 100644 index 0000000..b3539ba --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff_add.metal @@ -0,0 +1,24 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "ff.metal" + +kernel void run( + device BigInt* lhs [[ buffer(0) ]], + device BigInt* rhs [[ buffer(1) ]], + device BigInt* prime [[ buffer(2) ]], + device BigInt* result [[ buffer(3) ]], + uint gid [[ thread_position_in_grid ]] +) { + BigInt a; + BigInt b; + BigInt p; + a.limbs = lhs->limbs; + b.limbs = rhs->limbs; + p.limbs = prime->limbs; + + BigInt res = ff_add(a, b, p); + result->limbs = res.limbs; +} diff --git a/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal b/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal new file mode 100644 index 0000000..5f32e09 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/field/ff_sub.metal @@ -0,0 +1,24 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "ff.metal" + +kernel void run( + device BigInt* lhs [[ buffer(0) ]], + device BigInt* rhs [[ buffer(1) ]], + device BigInt* prime [[ buffer(2) ]], + device BigInt* result [[ buffer(3) ]], + uint gid [[ thread_position_in_grid ]] +) { + BigInt a; + BigInt b; + BigInt p; + a.limbs = lhs->limbs; + b.limbs = rhs->limbs; + p.limbs = prime->limbs; + + BigInt res = ff_sub(a, b, p); + result->limbs = res.limbs; +} diff --git a/mopro-msm/src/msm/metal_msm/shader/field/fp_bn254.h.metal b/mopro-msm/src/msm/metal_msm/shader/field/fp_bn254.h.metal new file mode 100644 index 0000000..0834b47 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/field/fp_bn254.h.metal @@ -0,0 +1,291 @@ +#pragma once + +// #include "../arithmetics/unsigned_int.h.metal" + + +namespace { + // 8 limbs of 32 bits uint + typedef UnsignedInteger<8> u256; +} + +/* Constants for bn254 field operations + * N: base field modulus + * R_SQUARED: R^2 mod N + * R_SUB_N: R - N + * MU: Montgomery Multiplication Constant = -N^{-1} mod (2^32) + * + * For bn254, the modulus is "21888242871839275222246405745257275088696311157297823662689037894645226208583" [1, 2] + * We use 8 limbs of 32 bits unsigned integers to represent the constanst + * + * References: + * [1] https://github.com/arkworks-rs/algebra/blob/065cd24fc5ae17e024c892cee126ad3bd885f01c/curves/bn254/src/lib.rs + * [2] https://github.com/scipr-lab/libff/blob/develop/libff/algebra/curves/alt_bn128/alt_bn128.sage + */ + +constexpr static const constant u256 N = { + 0x30644E72, 0xE131A029, + 0xB85045B6, 0x8181585D, + 0x97816A91, 0x6871CA8D, + 0x3C208C16, 0xD87CFD47 +}; + +constexpr static const constant u256 R_SQUARED = { + 0x06D89F71, 0xCAB8351F, + 0x47AB1EFF, 0x0A417FF6, + 0xB5E71911, 0xD44501FB, + 0xF32CFC5B, 0x538AFA89 +}; + +constexpr static const constant u256 R_SUB_N = { + 0xCF9BB18D, 0x1ECE5FD6, + 0x47AFBA49, 0x7E7EA7A2, + 0x687E956E, 0x978E3572, + 0xC3DF73E9, 0x278302B9 +}; + +constexpr static const constant uint64_t MU = 3834012553; + +class FpBN254 { +public: + u256 inner; + constexpr FpBN254() = default; + constexpr FpBN254(uint64_t v) : inner{u256::from_int(v)} {} + constexpr FpBN254(u256 v) : inner{v} {} + + constexpr explicit operator u256() const { + return inner; + } + + constexpr FpBN254 operator+(const FpBN254 rhs) const { + return FpBN254(add(inner, rhs.inner)); + } + + constexpr FpBN254 operator-(const FpBN254 rhs) const { + return FpBN254(sub(inner, rhs.inner)); + } + + constexpr FpBN254 operator*(const FpBN254 rhs) const { + return FpBN254(mul(inner, rhs.inner)); + } + + constexpr bool operator==(const FpBN254 rhs) const { + return inner == rhs.inner; + } + + constexpr bool operator!=(const FpBN254 rhs) const { + return !(inner == rhs.inner); + } + + constexpr explicit operator uint32_t() const { + return inner.m_limbs[7]; + } + + FpBN254 operator>>(const uint32_t rhs) const { + return FpBN254(inner >> rhs); + } + + FpBN254 operator<<(const uint32_t rhs) const { + return FpBN254(inner << rhs); + } + + constexpr static FpBN254 one() { + const FpBN254 ONE = FpBN254::mul(u256::from_int((uint32_t) 1), R_SQUARED); + return ONE; + } + + constexpr FpBN254 to_montgomery() { + return mul(inner, R_SQUARED); + } + + FpBN254 pow(uint32_t exp) const { + FpBN254 const ONE = one(); + FpBN254 res = ONE; + FpBN254 power = *this; + + while (exp > 0) { + if (exp & 1) { + res = res * power; + } + exp >>= 1; + power = power * power; + } + + return res; + } + + FpBN254 inverse() { + // Generate by the command: addchain search '21888242871839275222246405745257275088696311157297823662689037894645226208583 - 2' + // https://github.com/mmcloughlin/addchain + + // addchain: expr: "21888242871839275222246405745257275088696311157297823662689037894645226208583 - 2" + // addchain: hex: 30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd45 + // addchain: dec: 21888242871839275222246405745257275088696311157297823662689037894645226208581 + // addchain: best: opt(dictionary(sliding_window(8),heuristic(use_first(halving,delta_largest)))) + // addchain: cost: 303 + // _10 = 2*1 + // _11 = 1 + _10 + // _101 = _10 + _11 + // _110 = 1 + _101 + // _1000 = _10 + _110 + // _1101 = _101 + _1000 + // _10010 = _101 + _1101 + // _10011 = 1 + _10010 + // _10100 = 1 + _10011 + // _10111 = _11 + _10100 + // _11100 = _101 + _10111 + // _100000 = _1101 + _10011 + // _100011 = _11 + _100000 + // _101011 = _1000 + _100011 + // _101111 = _10011 + _11100 + // _1000001 = _10010 + _101111 + // _1010011 = _10010 + _1000001 + // _1011011 = _1000 + _1010011 + // _1100001 = _110 + _1011011 + // _1110101 = _10100 + _1100001 + // _10010001 = _11100 + _1110101 + // _10010101 = _100000 + _1110101 + // _10110101 = _100000 + _10010101 + // _10111011 = _110 + _10110101 + // _11000001 = _110 + _10111011 + // _11000011 = _10 + _11000001 + // _11010011 = _10010 + _11000001 + // _11100001 = _100000 + _11000001 + // _11100011 = _10 + _11100001 + // _11100111 = _110 + _11100001 + // i57 = ((_11000001 << 8 + _10010001) << 10 + _11100111) << 7 + // i76 = ((_10111 + i57) << 9 + _10011) << 7 + _1101 + // i109 = ((i76 << 14 + _1010011) << 9 + _11100001) << 8 + // i127 = ((_1000001 + i109) << 10 + _1011011) << 5 + _1101 + // i161 = ((i127 << 8 + _11) << 12 + _101011) << 12 + // i186 = ((_10111011 + i161) << 8 + _101111) << 14 + _10110101 + // i214 = ((i186 << 9 + _10010001) << 5 + _1101) << 12 + // i236 = ((_11100011 + i214) << 8 + _10010101) << 11 + _11010011 + // i268 = ((i236 << 7 + _1100001) << 11 + _100011) << 12 + // i288 = ((_1011011 + i268) << 9 + _11000011) << 8 + _11100111 + // return (i288 << 7 + _1110101) << 6 + _101 + + u256 _10 = mul(inner, inner); + u256 _11 = mul(_10, inner); + u256 _101 = mul(_10, _11); + u256 _110 = mul(inner, _101); + u256 _1000 = mul(_10, _110); + u256 _1101 = mul(_101, _1000); + u256 _10010 = mul(_101, _1101); + u256 _10011 = mul(inner, _10010); + u256 _10100 = mul(inner, _10011); + u256 _10111 = mul(_11, _10100); + u256 _11100 = mul(_101, _10111); + u256 _100000 = mul(_1101, _10011); + u256 _100011 = mul(_11, _100000); + u256 _101011 = mul(_1000, _100011); + u256 _101111 = mul(_10011, _11100); + u256 _1000001 = mul(_10010, _101111); + u256 _1010011 = mul(_10010, _1000001); + u256 _1011011 = mul(_1000, _1010011); + u256 _1100001 = mul(_110, _1011011); + u256 _1110101 = mul(_10100, _1100001); + u256 _10010001 = mul(_11100, _1110101); + u256 _10010101 = mul(_100000, _1110101); + u256 _10110101 = mul(_100000, _10010101); + u256 _10111011 = mul(_110, _10110101); + u256 _11000001 = mul(_110, _10111011); + u256 _11000011 = mul(_10, _11000001); + u256 _11010011 = mul(_10010, _11000001); + u256 _11100001 = mul(_100000, _11000001); + u256 _11100011 = mul(_10, _11100001); + u256 _11100111 = mul(_110, _11100001); + u256 i57 = sqn<7>(mul(sqn<10>(mul(sqn<8>(_11000001),_10010001)),_11100111)); + u256 i76 = mul(sqn<7>(mul(sqn<9>(mul(_10111,i57)),_10011)), _10011); + u256 i109 = sqn<8>(mul(sqn<9>(mul(sqn<14>(i76),_1010011)),_11100001)); + u256 i127 = mul(sqn<5>(mul(sqn<10>(mul(_1000001,i109)),_1011011)),_1101); + u256 i161 = sqn<12>(mul(sqn<12>(mul(sqn<8>(i127),_11)),_101011)); + u256 i186 = mul(sqn<14>(mul(sqn<8>(mul(_10111011,i161)),_101111)),_10110101); + u256 i214 = sqn<12>(mul(sqn<5>(mul(sqn<9>(i186),_10010001)),_1101)); + u256 i236 = mul(sqn<11>(mul(sqn<8>(mul(_11100011,i214)),_10010101)),_11010011); + u256 i268 = sqn<12>(mul(sqn<11>(mul(sqn<7>(i236),_1100001)),_100011)); + u256 i288 = mul(sqn<8>(mul(sqn<9>(mul(_1011011,i268)),_11000011)),_11100111); + return FpBN254(mul(sqn<6>(mul(sqn<7>(i288),_1110101)),_101)); + } + + FpBN254 neg() { + return FpBN254(sub(u256::from_int((uint32_t)0), inner)); + } + +private: + template + u256 sqn(u256 base) const { + u256 result = base; +#pragma unroll + for (uint32_t i = 0; i < N_ACC; i++) { + result = mul(result, result); + } + return result; + } + + inline u256 add(const u256 lhs, const u256 rhs) const { + u256 addition = lhs + rhs; + u256 res = addition; + + return res - u256::from_int((uint64_t)(addition >= N)) * N + u256::from_int((uint64_t)(addition < lhs)) * R_SUB_N; + } + + inline u256 sub(const u256 lhs, const u256 rhs) const { + return add(lhs, ((u256)N) - rhs); + } + + // Compute multiplication by performing single round of Montgomery reduction + constexpr static u256 mul(const u256 a, const u256 b) { + constexpr uint64_t NUM_LIMBS = 8; + metal::array t = {}; + metal::array t_extra = {}; + + u256 q = N; + + uint64_t i = NUM_LIMBS; + + while (i > 0) { + i -= 1; + uint64_t c = 0; + + uint64_t cs = 0; + uint64_t j = NUM_LIMBS; + while (j > 0) { + j -= 1; + cs = (uint64_t)t[j] + (uint64_t)a.m_limbs[j] * (uint64_t)b.m_limbs[i] + c; + c = cs >> 32; + t[j] = (uint32_t)((cs << 32) >> 32); + } + + cs = (uint64_t)t_extra[1] + c; + t_extra[0] = (uint32_t)(cs >> 32); + t_extra[1] = (uint32_t)((cs << 32) >> 32); + + uint64_t m = (((uint64_t)t[NUM_LIMBS - 1] * MU) << 32) >> 32; + + c = ((uint64_t)t[NUM_LIMBS - 1] + m * (uint64_t)q.m_limbs[NUM_LIMBS - 1]) >> 32; + + j = NUM_LIMBS - 1; + while (j > 0) { + j -= 1; + cs = (uint64_t)t[j] + m * (uint64_t)q.m_limbs[j] + c; + c = cs >> 32; + t[j + 1] = (uint32_t)((cs << 32) >> 32); + } + + cs = (uint64_t)t_extra[1] + c; + c = cs >> 32; + t[0] = (uint32_t)((cs << 32) >> 32); + + t_extra[1] = t_extra[0] + (uint32_t)c; + } + + u256 result {t}; + + uint64_t overflow = t_extra[0] > 0; + if (overflow || q <= result) { + result = result - q; + } + + return result; + } +}; diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal new file mode 100644 index 0000000..d246802 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont.metal @@ -0,0 +1,94 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "../field/ff.metal" + +BigInt conditional_reduce( + BigInt x, + BigInt y +) { + if (bigint_gte(x, y)) { + return bigint_sub(x, y); + } + + return x; +} + +/// An optimised variant of the Montgomery product algorithm from +/// https://github.com/mitschabaude/montgomery#13-x-30-bit-multiplication. +/// Known to work with 12 and 13-bit limbs. +BigInt mont_mul_optimised( + BigInt x, + BigInt y, + BigInt p +) { + BigInt s = bigint_zero(); + + for (uint i = 0; i < NUM_LIMBS; i ++) { + uint t = s.limbs[0] + x.limbs[i] * y.limbs[0]; + uint tprime = t & MASK; + uint qi = (N0 * tprime) & MASK; + uint c = (t + qi * p.limbs[0]) >> LOG_LIMB_SIZE; + s.limbs[0] = s.limbs[1] + x.limbs[i] * y.limbs[1] + qi * p.limbs[1] + c; + + for (uint j = 2; j < NUM_LIMBS; j ++) { + s.limbs[j - 1] = s.limbs[j] + x.limbs[i] * y.limbs[j] + qi * p.limbs[j]; + } + s.limbs[NUM_LIMBS - 2] = x.limbs[i] * y.limbs[NUM_LIMBS - 1] + qi * p.limbs[NUM_LIMBS - 1]; + } + + uint c = 0; + for (uint i = 0; i < NUM_LIMBS; i ++) { + uint v = s.limbs[i] + c; + c = v >> LOG_LIMB_SIZE; + s.limbs[i] = v & MASK; + } + + return conditional_reduce(s, p); +} + +/// An modified variant of the Montgomery product algorithm from +/// https://github.com/mitschabaude/montgomery#13-x-30-bit-multiplication. +/// Known to work with 14 and 15-bit limbs. +BigInt mont_mul_modified( + BigInt x, + BigInt y, + BigInt p +) { + BigInt s = bigint_zero(); + + for (uint i = 0; i < NUM_LIMBS; i ++) { + uint t = s.limbs[0] + x.limbs[i] * y.limbs[0]; + uint tprime = t & MASK; + uint qi = (N0 * tprime) & MASK; + uint c = (t + qi * p.limbs[0]) >> LOG_LIMB_SIZE; + + for (uint j = 1; j < NUM_LIMBS - 1; j ++) { + uint t = s.limbs[j] + x.limbs[i] * y.limbs[j] + qi * p.limbs[j]; + if ((j - 1) % NSAFE == 0) { + t = t + c; + } + + c = t >> LOG_LIMB_SIZE; + + if (j % NSAFE == 0) { + c = t >> LOG_LIMB_SIZE; + s.limbs[j - 1] = t & MASK; + } else { + s.limbs[j - 1] = t; + } + } + s.limbs[NUM_LIMBS - 2] = x.limbs[i] * y.limbs[NUM_LIMBS - 1] + qi * p.limbs[NUM_LIMBS - 1]; + } + + uint c = 0; + for (uint i = 0; i < NUM_LIMBS; i ++) { + uint v = s.limbs[i] + c; + c = v >> LOG_LIMB_SIZE; + s.limbs[i] = v & MASK; + } + + return conditional_reduce(s, p); +} diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal new file mode 100644 index 0000000..77020d1 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_modified.metal @@ -0,0 +1,25 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "mont.metal" + +kernel void run( + device BigInt* lhs [[ buffer(0) ]], + device BigInt* rhs [[ buffer(1) ]], + device BigInt* prime [[ buffer(2) ]], + device BigInt* result [[ buffer(3) ]], + uint gid [[ thread_position_in_grid ]] +) { + BigInt a; + BigInt b; + BigInt p; + a.limbs = lhs->limbs; + b.limbs = rhs->limbs; + p.limbs = prime->limbs; + + BigInt res = mont_mul_modified(a, b, p); + result->limbs = res.limbs; + +} diff --git a/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised.metal b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised.metal new file mode 100644 index 0000000..ffb0844 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/shader/mont_backend/mont_mul_optimised.metal @@ -0,0 +1,25 @@ +// source: https://github.com/geometryxyz/msl-secp256k1 + +using namespace metal; +#include +#include +#include "mont.metal" + +kernel void run( + device BigInt* lhs [[ buffer(0) ]], + device BigInt* rhs [[ buffer(1) ]], + device BigInt* prime [[ buffer(2) ]], + device BigInt* result [[ buffer(3) ]], + uint gid [[ thread_position_in_grid ]] +) { + BigInt a; + BigInt b; + BigInt p; + a.limbs = lhs->limbs; + b.limbs = rhs->limbs; + p.limbs = prime->limbs; + + BigInt res = mont_mul_optimised(a, b, p); + result->limbs = res.limbs; + +} diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs new file mode 100644 index 0000000..8058002 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_unsafe.rs @@ -0,0 +1,93 @@ +// adapted from: https://github.com/geometryxyz/msl-secp256k1 + +use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs}; +use crate::msm::metal_msm::host::gpu::{ + create_buffer, create_empty_buffer, get_default_device, read_buffer, +}; +use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; +use ark_ff::{BigInt, BigInteger}; +use metal::*; + +#[test] +#[serial_test::serial] +pub fn test_bigint_add_unsafe() { + let log_limb_size = 13; + let num_limbs = 20; + + // Create two test numbers (equivalent to the previous hex values) + let a = BigInt::new([ + 0x0000000100000001, + 0x0000000000000000, + 0x1800a1101800a110, + 0x0000000d0000000d, + ]); + let b = a.clone(); // Same value as a for this test + + let mut expected = a.clone(); + let overflow = expected.add_with_carry(&b); + + // We are testing add_unsafe, so the sum should not overflow + assert!(!overflow); + + let device = get_default_device(); + let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); + let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size)); + let result_buf = create_empty_buffer(&device, num_limbs); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + write_constants( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + num_limbs, + log_limb_size, + 0, + 0, + ); + let library_path = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + "bigint_add_unsafe.metal", + ); + let library = device.new_library_with_file(library_path).unwrap(); + let kernel = library.get_function("run", None).unwrap(); + + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&kernel)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&a_buf), 0); + encoder.set_buffer(1, Some(&b_buf), 0); + encoder.set_buffer(2, Some(&result_buf), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result_limbs: Vec = read_buffer(&result_buf, num_limbs); + let result = BigInt::from_limbs(&result_limbs, log_limb_size); + + assert!(result.eq(&expected)); +} diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs new file mode 100644 index 0000000..0eb7f9b --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_add_wide.rs @@ -0,0 +1,186 @@ +// adapted from: https://github.com/geometryxyz/msl-secp256k1 + +use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs}; +use crate::msm::metal_msm::host::gpu::{ + create_buffer, create_empty_buffer, get_default_device, read_buffer, +}; +use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; +use ark_ff::{BigInt, BigInteger}; +use metal::*; + +#[test] +#[serial_test::serial] +pub fn test_bigint_add() { + let log_limb_size = 13; + let num_limbs = 20; + + // Create two large numbers that will overflow when added + let a = BigInt::new([ + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffffff, + ]); + let b = BigInt::new([ + 0x1000000000000000, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ]); + + let mut expected = a.clone(); + + let overflow = expected.add_with_carry(&b); + assert!(overflow); + + let device = get_default_device(); + let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); + let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size)); + let result_buf = create_empty_buffer(&device, num_limbs + 1); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + write_constants( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + num_limbs, + log_limb_size, + 0, + 0, + ); + + let library_path = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + "bigint_add_wide.metal", + ); + let library = device.new_library_with_file(library_path).unwrap(); + let kernel = library.get_function("run", None).unwrap(); + + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&kernel)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&a_buf), 0); + encoder.set_buffer(1, Some(&b_buf), 0); + encoder.set_buffer(2, Some(&result_buf), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result_limbs: Vec = read_buffer(&result_buf, num_limbs + 1); + let result = BigInt::from_limbs(&result_limbs, log_limb_size); + + assert!(result.eq(&expected)); +} + +#[test] +#[serial_test::serial] +pub fn test_bigint_add_no_overflow() { + let log_limb_size = 13; + let num_limbs = 20; + + // Create two numbers that won't overflow when added + let a = BigInt::new([ + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000001, + ]); + let b = BigInt::new([ + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000002, + ]); + + let mut expected = a.clone(); + let overflow = expected.add_with_carry(&b); + assert!(!overflow); + + let device = get_default_device(); + let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); + let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size)); + let result_buf = create_empty_buffer(&device, num_limbs + 1); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + write_constants( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + num_limbs, + log_limb_size, + 0, + 0, + ); + + let library_path = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + "bigint_add_wide.metal", + ); + let library = device.new_library_with_file(library_path).unwrap(); + let kernel = library.get_function("run", None).unwrap(); + + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&kernel)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&a_buf), 0); + encoder.set_buffer(1, Some(&b_buf), 0); + encoder.set_buffer(2, Some(&result_buf), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result_limbs: Vec = read_buffer(&result_buf, num_limbs + 1); + let result = BigInt::from_limbs(&result_limbs, log_limb_size); + + assert!(result.eq(&expected)); +} diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs new file mode 100644 index 0000000..6341b9e --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/bigint_sub.rs @@ -0,0 +1,182 @@ +// adapted from: https://github.com/geometryxyz/msl-secp256k1 + +use core::borrow; + +use crate::msm::metal::abstraction::limbs_conversion::{FromLimbs, ToLimbs}; +use crate::msm::metal_msm::host::gpu::{ + create_buffer, create_empty_buffer, get_default_device, read_buffer, +}; +use crate::msm::metal_msm::host::shader::{compile_metal, write_constants}; +use ark_ff::{BigInt, BigInteger}; +use metal::*; + +#[test] +#[serial_test::serial] +pub fn test_bigint_sub() { + let log_limb_size = 13; + let num_limbs = 20; + + let mut a = BigInt::new([0xf09f8fb3, 0xefb88fe2, 0x808df09f, 0x8c880010]); + let b = BigInt::new([0xf09f8fb3, 0xefb88fe2, 0x808df09f, 0x8c880001]); + + let device = get_default_device(); + let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size)); + let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size)); + let result_buf = create_empty_buffer(&device, num_limbs); + + // perform a - b + let _borrow = a.sub_with_borrow(&b); + let expected_limbs = a.to_limbs(num_limbs, log_limb_size); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + write_constants( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + num_limbs, + log_limb_size, + 0, + 0, + ); + let library_path = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + "bigint_sub.metal", + ); + let library = device.new_library_with_file(library_path).unwrap(); + let kernel = library.get_function("run", None).unwrap(); + + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&kernel)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&a_buf), 0); + encoder.set_buffer(1, Some(&b_buf), 0); + encoder.set_buffer(2, Some(&result_buf), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result_limbs: Vec = read_buffer(&result_buf, num_limbs); + let result = BigInt::from_limbs(&result_limbs, log_limb_size); + + assert!(result_limbs.eq(&expected_limbs)); + assert!(result.eq(&a)); +} + +#[test] +#[serial_test::serial] +fn test_bigint_sub_underflow() { + let device = Device::system_default().expect("no device found"); + let num_limbs = 20; + let log_limb_size = 13; + + // Create smaller number a and larger number b + let mut a = BigInt::from_u32(100); + let b = BigInt::from_u32(200); + + let a_limbs = a.to_limbs(num_limbs, log_limb_size); + let b_limbs = b.to_limbs(num_limbs, log_limb_size); + + let a_buf = device.new_buffer_with_data( + unsafe { std::mem::transmute(a_limbs.as_ptr()) }, + (a_limbs.len() * std::mem::size_of::()) as u64, + MTLResourceOptions::StorageModeShared, + ); + + let b_buf = device.new_buffer_with_data( + unsafe { std::mem::transmute(b_limbs.as_ptr()) }, + (b_limbs.len() * std::mem::size_of::()) as u64, + MTLResourceOptions::StorageModeShared, + ); + + let result_buf = device.new_buffer( + (num_limbs * std::mem::size_of::()) as u64, + MTLResourceOptions::StorageModeShared, + ); + + // Expected result is 2^256 - 100 (since we're doing a - b where b > a) + let _expected = a.sub_with_borrow(&b); + let expected_limbs = a.to_limbs(num_limbs, log_limb_size); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let compute_pass_descriptor = ComputePassDescriptor::new(); + let encoder = command_buffer.compute_command_encoder_with_descriptor(compute_pass_descriptor); + + write_constants( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + num_limbs, + log_limb_size, + 0, + 0, + ); + let library_path = compile_metal( + "../mopro-msm/src/msm/metal_msm/shader/bigint", + "bigint_sub.metal", + ); + let library = device.new_library_with_file(library_path).unwrap(); + let kernel = library.get_function("run", None).unwrap(); + + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&kernel)); + + let pipeline_state = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + encoder.set_compute_pipeline_state(&pipeline_state); + encoder.set_buffer(0, Some(&a_buf), 0); + encoder.set_buffer(1, Some(&b_buf), 0); + encoder.set_buffer(2, Some(&result_buf), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result_limbs: Vec = read_buffer(&result_buf, num_limbs); + let result = BigInt::from_limbs(&result_limbs, log_limb_size); + + // assert!(result_limbs.eq(&expected_limbs)); // TODO: leading limb is incorrect + assert!(result.eq(&a)); +} diff --git a/mopro-msm/src/msm/metal_msm/tests/bigint/mod.rs b/mopro-msm/src/msm/metal_msm/tests/bigint/mod.rs new file mode 100644 index 0000000..b3c015a --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/tests/bigint/mod.rs @@ -0,0 +1,3 @@ +pub mod bigint_add_unsafe; +pub mod bigint_add_wide; +pub mod bigint_sub; diff --git a/mopro-msm/src/msm/metal_msm/tests/mod.rs b/mopro-msm/src/msm/metal_msm/tests/mod.rs new file mode 100644 index 0000000..8f8d148 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/tests/mod.rs @@ -0,0 +1 @@ +pub mod bigint; diff --git a/mopro-msm/src/msm/metal_msm/utils/limbs_conversion.rs b/mopro-msm/src/msm/metal_msm/utils/limbs_conversion.rs new file mode 100644 index 0000000..ff9f7d2 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/utils/limbs_conversion.rs @@ -0,0 +1,91 @@ +use ark_bn254::Fq; +use ark_ff::biginteger::{BigInteger, BigInteger256}; + +use crate::msm::metal::abstraction::mont_reduction; + +// implement to_u32_limbs and from_u32_limbs for BigInt<4> +pub trait ToLimbs { + fn to_u32_limbs(&self) -> Vec; +} + +pub trait FromLimbs { + fn from_u32_limbs(limbs: &[u32]) -> Self; + fn from_u128(num: u128) -> Self; + fn from_u32(num: u32) -> Self; +} + +// convert from little endian to big endian +impl ToLimbs for BigInteger256 { + fn to_u32_limbs(&self) -> Vec { + let mut limbs = Vec::new(); + self.to_bytes_be().chunks(8).for_each(|chunk| { + let high = u32::from_be_bytes(chunk[0..4].try_into().unwrap()); + let low = u32::from_be_bytes(chunk[4..8].try_into().unwrap()); + limbs.push(high); + limbs.push(low); + }); + limbs + } +} + +// convert from little endian to big endian +impl ToLimbs for Fq { + fn to_u32_limbs(&self) -> Vec { + let mut limbs = Vec::new(); + self.0.to_bytes_be().chunks(8).for_each(|chunk| { + let high = u32::from_be_bytes(chunk[0..4].try_into().unwrap()); + let low = u32::from_be_bytes(chunk[4..8].try_into().unwrap()); + limbs.push(high); + limbs.push(low); + }); + limbs + } +} + +impl FromLimbs for BigInteger256 { + // convert from big endian to little endian for metal + fn from_u32_limbs(limbs: &[u32]) -> Self { + let mut big_int = [0u64; 4]; + for (i, limb) in limbs.chunks(2).rev().enumerate() { + let high = u64::from(limb[0]); + let low = u64::from(limb[1]); + big_int[i] = (high << 32) | low; + } + BigInteger256::new(big_int) + } + // provide little endian u128 since arkworks use this value as well + fn from_u128(num: u128) -> Self { + let high = (num >> 64) as u64; + let low = num as u64; + BigInteger256::new([low, high, 0, 0]) + } + // provide little endian u32 since arkworks use this value as well + fn from_u32(num: u32) -> Self { + BigInteger256::new([num as u64, 0, 0, 0]) + } +} + +impl FromLimbs for Fq { + // convert from big endian to little endian for metal + fn from_u32_limbs(limbs: &[u32]) -> Self { + let mut big_int = [0u64; 4]; + for (i, limb) in limbs.chunks(2).rev().enumerate() { + let high = u64::from(limb[0]); + let low = u64::from(limb[1]); + big_int[i] = (high << 32) | low; + } + Fq::new(mont_reduction::raw_reduction(BigInteger256::new(big_int))) + } + fn from_u128(num: u128) -> Self { + let high = (num >> 64) as u64; + let low = num as u64; + Fq::new(mont_reduction::raw_reduction(BigInteger256::new([ + low, high, 0, 0, + ]))) + } + fn from_u32(num: u32) -> Self { + Fq::new(mont_reduction::raw_reduction(BigInteger256::new([ + num as u64, 0, 0, 0, + ]))) + } +} diff --git a/mopro-msm/src/msm/metal_msm/utils/mod.rs b/mopro-msm/src/msm/metal_msm/utils/mod.rs new file mode 100644 index 0000000..defe5d5 --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/utils/mod.rs @@ -0,0 +1,2 @@ +pub mod limbs_conversion; +pub mod mont_reduction; diff --git a/mopro-msm/src/msm/metal_msm/utils/mont_reduction.rs b/mopro-msm/src/msm/metal_msm/utils/mont_reduction.rs new file mode 100644 index 0000000..db0069f --- /dev/null +++ b/mopro-msm/src/msm/metal_msm/utils/mont_reduction.rs @@ -0,0 +1,40 @@ +use ark_bn254::FqConfig; +use ark_ff::{ + biginteger::{arithmetic as fa, BigInt}, + fields::models::{MontBackend, MontConfig}, + Fp, +}; + +// Reference: https://github.com/arkworks-rs/algebra/blob/master/ff/src/fields/models/fp/montgomery_backend.rs#L373-L389 +const N: usize = 4; +pub fn into_bigint(a: Fp, N>) -> BigInt { + let a = a.0; + raw_reduction(a) +} + +pub fn raw_reduction(a: BigInt) -> BigInt { + let mut r = a.0; // parse into [u64; N] + + // Montgomery Reduction + for i in 0..N { + let k = r[i].wrapping_mul(>::INV); + let mut carry = 0; + + fa::mac_with_carry( + r[i], + k, + >::MODULUS.0[0], + &mut carry, + ); + for j in 1..N { + r[(j + i) % N] = fa::mac_with_carry( + r[(j + i) % N], + k, + >::MODULUS.0[j], + &mut carry, + ); + } + r[i % N] = carry; + } + BigInt::new(r) +} diff --git a/mopro-msm/src/msm/mod.rs b/mopro-msm/src/msm/mod.rs index 3aab8ed..2044a8a 100644 --- a/mopro-msm/src/msm/mod.rs +++ b/mopro-msm/src/msm/mod.rs @@ -1,6 +1,7 @@ pub mod arkworks_pippenger; pub mod bucket_wise_msm; pub mod metal; +pub mod metal_msm; pub mod precompute_msm; pub mod utils;