From 68a295556ffb13d7589a14a3947a09efbd229cf7 Mon Sep 17 00:00:00 2001 From: HuggingFace-MacMini-Wozniak Date: Thu, 6 Jul 2023 22:00:05 +0200 Subject: [PATCH 1/6] Adding new intrinsics for ggblas. --- src/binary16.rs | 2 +- src/binary16/arch.rs | 6 +- src/binary16/arch/aarch64.rs | 111 +++++++++++++++++++++++++++++++++++ src/lib.rs | 5 +- 4 files changed, 121 insertions(+), 3 deletions(-) diff --git a/src/binary16.rs b/src/binary16.rs index 9f0ae36..8e88103 100644 --- a/src/binary16.rs +++ b/src/binary16.rs @@ -21,7 +21,7 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "zerocopy")] use zerocopy::{AsBytes, FromBytes}; -pub(crate) mod arch; +pub mod arch; /// A 16-bit floating point type implementing the IEEE 754-2008 standard [`binary16`] a.k.a "half" /// format. diff --git a/src/binary16/arch.rs b/src/binary16/arch.rs index d33103e..b29ff77 100644 --- a/src/binary16/arch.rs +++ b/src/binary16/arch.rs @@ -6,7 +6,11 @@ use core::mem; mod x86; #[cfg(target_arch = "aarch64")] -mod aarch64; +pub mod aarch64; + +#[cfg(target_arch = "arm")] +mod arm; + macro_rules! convert_fn { (if x86_feature("f16c") { $f16c:expr } diff --git a/src/binary16/arch/aarch64.rs b/src/binary16/arch/aarch64.rs index 9441e76..81908f0 100644 --- a/src/binary16/arch/aarch64.rs +++ b/src/binary16/arch/aarch64.rs @@ -7,6 +7,117 @@ use core::{ ptr, }; +use crate::f16; + +#[repr(simd)] +#[allow(non_camel_case_types)] +#[derive(Clone, Copy)] +pub struct float16x8_t(pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16); + +#[repr(simd)] +#[allow(non_camel_case_types)] +#[derive(Clone, Copy)] +pub struct float16x4_t(pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16); + +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vcvt_f32_f16(i: float16x4_t) -> float32x4_t { + let result: float32x4_t; + asm!( + "fcvtl {0:v}.4s, {1:v}.4h", + out(vreg) result, + in(vreg) i, + options(pure, nomem, nostack, preserves_flags)); + result +} + +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vget_high_f16_f32(i: float16x8_t) -> float32x4_t { + let result: float32x4_t; + asm!( + "fcvtl2 {0:v}.4s, {1:v}.8h", + out(vreg) result, + in(vreg) i, + options(pure, nomem, nostack, preserves_flags)); + result +} + +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vget_low_f16_f32(i: float16x8_t) -> float32x4_t { + let result: float32x4_t; + asm!( + "fcvtl {0:v}.4s, {1:v}.4h", + out(vreg) result, + in(vreg) i, + options(pure, nomem, nostack, preserves_flags)); + result +} + +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vaddq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t { + let result: float16x8_t; + asm!( + "fadd {0:v}.8h, {1:v}.8h, {2:v}.8h", + out(vreg) result, + in(vreg) a, + in(vreg) b, + options(pure, nomem, nostack, preserves_flags)); + result +} + +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vst1q_f16(mut ptr: *mut f16, mut val: float16x8_t){ + ptr::copy_nonoverlapping(&val, ptr.cast(), 8); + // asm!( + // "vst1q_f16 {0:s}, {1:h}", + // out(vreg) ptr, + // in(vreg) val, + // options(pure, nomem, nostack, preserves_flags)); +} + +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vld1q_f16(ptr: *const f16) -> float16x8_t{ + let mut result = MaybeUninit::::uninit(); + ptr::copy_nonoverlapping(ptr.cast(), &mut result, 8); + // asm!( + // "vld1q_f16 {0:s}, {1:h}", + // out(vreg) result, + // in(vreg) ptr, + // options(pure, nomem, nostack, preserves_flags)); + result.assume_init() +} + +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vfmaq_f16(a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t{ + // let result: float16x8_t; + asm!( + "fmla {0:v}.8h, {1:v}.8h, {2:v}.8h", + in(vreg) a, + in(vreg) b, + in(vreg) c, + options(nomem, nostack, preserves_flags)); + // result + a +} + +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vdupq_n_f16(a: u16) -> float16x8_t{ + let result: float16x8_t; + asm!( + "dup {0:v}.8h, {1:v}.h[0]", + out(vreg) result, + in(vreg) a, + options(pure, nomem, nostack, preserves_flags)); + result +} + #[target_feature(enable = "fp16")] #[inline] pub(super) unsafe fn f16_to_f32_fp16(i: u16) -> f32 { diff --git a/src/lib.rs b/src/lib.rs index a4dca74..c9e1b3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(repr_simd)] //! A crate that provides support for half-precision 16-bit floating point types. //! //! This crate provides the [`f16`] type, which is an implementation of the IEEE 754-2008 standard @@ -201,11 +202,13 @@ #![doc(test(attr(deny(warnings), allow(unused))))] #![cfg_attr(docsrs, feature(doc_auto_cfg))] + + #[cfg(feature = "alloc")] extern crate alloc; mod bfloat; -mod binary16; +pub mod binary16; mod leading_zeros; #[cfg(feature = "num-traits")] mod num_traits; From dd6ce8ed4984e59b13c09a1c3fe0553b205b95d7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 6 Jul 2023 22:23:03 +0200 Subject: [PATCH 2/6] Remove nightly requirements. --- src/binary16/arch/aarch64.rs | 11 +++-------- src/lib.rs | 1 - 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/binary16/arch/aarch64.rs b/src/binary16/arch/aarch64.rs index 81908f0..7230140 100644 --- a/src/binary16/arch/aarch64.rs +++ b/src/binary16/arch/aarch64.rs @@ -1,6 +1,6 @@ use core::{ arch::{ - aarch64::{float32x4_t, float64x2_t, uint16x4_t}, + aarch64::{float32x4_t, float64x2_t, uint16x4_t, uint16x8_t}, asm, }, mem::MaybeUninit, @@ -9,15 +9,10 @@ use core::{ use crate::f16; -#[repr(simd)] #[allow(non_camel_case_types)] -#[derive(Clone, Copy)] -pub struct float16x8_t(pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16); - -#[repr(simd)] +type float16x8_t = uint16x8_t; #[allow(non_camel_case_types)] -#[derive(Clone, Copy)] -pub struct float16x4_t(pub(crate) u16, pub(crate) u16, pub(crate) u16, pub(crate) u16); +type float16x4_t = uint16x4_t; #[target_feature(enable = "fp16")] #[inline] diff --git a/src/lib.rs b/src/lib.rs index c9e1b3d..57310d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,3 @@ -#![feature(repr_simd)] //! A crate that provides support for half-precision 16-bit floating point types. //! //! This crate provides the [`f16`] type, which is an implementation of the IEEE 754-2008 standard From 31ef63ba651b2e1eb8fe654c877fdc8db03b4958 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 30 Jul 2023 15:45:10 +0200 Subject: [PATCH 3/6] Going around clippy. --- src/binary16/arch.rs | 1 + src/binary16/arch/aarch64.rs | 22 ++++++++++++++++++++-- src/lib.rs | 1 + 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/binary16/arch.rs b/src/binary16/arch.rs index b29ff77..34fdc4d 100644 --- a/src/binary16/arch.rs +++ b/src/binary16/arch.rs @@ -5,6 +5,7 @@ use core::mem; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] mod x86; +#[allow(missing_docs)] #[cfg(target_arch = "aarch64")] pub mod aarch64; diff --git a/src/binary16/arch/aarch64.rs b/src/binary16/arch/aarch64.rs index 7230140..186b10c 100644 --- a/src/binary16/arch/aarch64.rs +++ b/src/binary16/arch/aarch64.rs @@ -1,3 +1,4 @@ +#![allow(clippy::missing_safety_doc)] use core::{ arch::{ aarch64::{float32x4_t, float64x2_t, uint16x4_t, uint16x8_t}, @@ -14,9 +15,12 @@ type float16x8_t = uint16x8_t; #[allow(non_camel_case_types)] type float16x4_t = uint16x4_t; +/// Convert to higher precision +/// Takes the 64 bits and convert them as [`float32x4_t`] +/// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FCVTL--FCVTL2--vector-) #[target_feature(enable = "fp16")] #[inline] -pub unsafe fn vcvt_f32_f16(i: float16x4_t) -> float32x4_t { +pub unsafe fn vcvt_f16_f32(i: float16x4_t) -> float32x4_t { let result: float32x4_t; asm!( "fcvtl {0:v}.4s, {1:v}.4h", @@ -26,6 +30,9 @@ pub unsafe fn vcvt_f32_f16(i: float16x4_t) -> float32x4_t { result } +/// Convert to higher precision +/// Takes the top 64 bits and convert them as [`float32x4_t`] +/// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FCVTL--FCVTL2--vector-) #[target_feature(enable = "fp16")] #[inline] pub unsafe fn vget_high_f16_f32(i: float16x8_t) -> float32x4_t { @@ -38,6 +45,9 @@ pub unsafe fn vget_high_f16_f32(i: float16x8_t) -> float32x4_t { result } +/// Convert to higher precision +/// Takes the lower 64 bits and convert them as [`float32x4_t`] +/// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FCVTL--FCVTL2--vector-) #[target_feature(enable = "fp16")] #[inline] pub unsafe fn vget_low_f16_f32(i: float16x8_t) -> float32x4_t { @@ -50,6 +60,8 @@ pub unsafe fn vget_low_f16_f32(i: float16x8_t) -> float32x4_t { result } +/// Floating point addition +/// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FADD--vector-) #[target_feature(enable = "fp16")] #[inline] pub unsafe fn vaddq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t { @@ -63,9 +75,10 @@ pub unsafe fn vaddq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t { result } +/// Casts [`float16x8t`] to raw pointer. #[target_feature(enable = "fp16")] #[inline] -pub unsafe fn vst1q_f16(mut ptr: *mut f16, mut val: float16x8_t){ +pub unsafe fn vst1q_f16(ptr: *mut f16, val: float16x8_t){ ptr::copy_nonoverlapping(&val, ptr.cast(), 8); // asm!( // "vst1q_f16 {0:s}, {1:h}", @@ -74,6 +87,8 @@ pub unsafe fn vst1q_f16(mut ptr: *mut f16, mut val: float16x8_t){ // options(pure, nomem, nostack, preserves_flags)); } +/// Casts pointer to [`float16x8t`]. +/// This functions assumes pointer is aligned #[target_feature(enable = "fp16")] #[inline] pub unsafe fn vld1q_f16(ptr: *const f16) -> float16x8_t{ @@ -87,6 +102,8 @@ pub unsafe fn vld1q_f16(ptr: *const f16) -> float16x8_t{ result.assume_init() } +/// Broadcast value into [`float16x8_t`] +/// Fused multiply add [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FMLA--vector-) #[target_feature(enable = "fp16")] #[inline] pub unsafe fn vfmaq_f16(a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t{ @@ -101,6 +118,7 @@ pub unsafe fn vfmaq_f16(a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float a } +/// Broadcast value into [`float16x8_t`] #[target_feature(enable = "fp16")] #[inline] pub unsafe fn vdupq_n_f16(a: u16) -> float16x8_t{ diff --git a/src/lib.rs b/src/lib.rs index 57310d3..1848828 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -207,6 +207,7 @@ extern crate alloc; mod bfloat; +#[allow(missing_docs)] pub mod binary16; mod leading_zeros; #[cfg(feature = "num-traits")] From a62777133178e40c82370a870eeaf84716783b1a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 1 Aug 2023 10:22:43 +0200 Subject: [PATCH 4/6] More intrinsics. Still failing when compiling `fmlaq` without black_box. --- src/binary16/arch/aarch64.rs | 93 ++++++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 3 deletions(-) diff --git a/src/binary16/arch/aarch64.rs b/src/binary16/arch/aarch64.rs index 186b10c..181d2ce 100644 --- a/src/binary16/arch/aarch64.rs +++ b/src/binary16/arch/aarch64.rs @@ -75,6 +75,81 @@ pub unsafe fn vaddq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t { result } +/// Floating point multiplication +/// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FADD--vector-) +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vmulq_f16(a: float16x8_t, b: float16x8_t) -> float16x8_t { + let result: float16x8_t; + asm!( + "fmul {0:v}.8h, {1:v}.8h, {2:v}.8h", + out(vreg) result, + in(vreg) a, + in(vreg) b, + options(pure, nomem, nostack, preserves_flags)); + result +} + +/// Floating point multiplication +/// [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FADD--vector-) +#[target_feature(enable = "fp16")] +#[inline] +pub unsafe fn vget_lane_f16(a: float16x8_t) -> u16 { + todo!("lane!"); + // let result: u16; + // match LANE { + // 0=> asm!( + // "dup {0:h}, {1:v}.8h[0]", + // out(vreg) result, + // in(vreg) a, + // options(pure, nomem, nostack, preserves_flags)), + // 1=> asm!( + // "dup {0:v}, {1:v}.8h[1]", + // out(vreg) result, + // in(vreg) a, + // options(nomem, nostack, preserves_flags)), + // 2=> asm!( + // "dup {0:v}, {1:v}.8h[2]", + // out(vreg) result, + // in(vreg) a, + // options(nomem, nostack, preserves_flags)), + // 3=> asm!( + // "dup {0:v}, {1:v}.8h[3]", + // out(vreg) result, + // in(vreg) a, + // options(nomem, nostack, preserves_flags)), + // 4=> asm!( + // "dup {0:v}, {1:v}.8h[4]", + // out(vreg) result, + // in(vreg) a, + // options(nomem, nostack, preserves_flags)), + // 5=> asm!( + // "dup {0:v}, {1:v}.8h[5]", + // out(vreg) result, + // in(vreg) a, + // options(nomem, nostack, preserves_flags)), + // 6=> asm!( + // "dup {0:v}, {1:v}.8h[6]", + // out(vreg) result, + // in(vreg) a, + // options(nomem, nostack, preserves_flags)), + // 7=> asm!( + // "dup {0:v}, {1:v}.8h[7]", + // out(vreg) result, + // in(vreg) a, + // options(nomem, nostack, preserves_flags)), + // _ => unimplemented!("get_lane_f16 - {LANE}") + // } + result +} + +#[inline] +pub unsafe fn vfmaq_laneq_f16(a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t { + let c = vget_lane_f16::(c); + let result = core::mem::transmute([c, c, c, c, c, c, c, c]); + vfmaq_f16(a, b, result) +} + /// Casts [`float16x8t`] to raw pointer. #[target_feature(enable = "fp16")] #[inline] @@ -107,14 +182,12 @@ pub unsafe fn vld1q_f16(ptr: *const f16) -> float16x8_t{ #[target_feature(enable = "fp16")] #[inline] pub unsafe fn vfmaq_f16(a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t{ - // let result: float16x8_t; asm!( "fmla {0:v}.8h, {1:v}.8h, {2:v}.8h", in(vreg) a, in(vreg) b, in(vreg) c, options(nomem, nostack, preserves_flags)); - // result a } @@ -124,13 +197,27 @@ pub unsafe fn vfmaq_f16(a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float pub unsafe fn vdupq_n_f16(a: u16) -> float16x8_t{ let result: float16x8_t; asm!( - "dup {0:v}.8h, {1:v}.h[0]", + "dup {0:v}.8h, {1:h}", out(vreg) result, in(vreg) a, options(pure, nomem, nostack, preserves_flags)); result } +#[cfg(test)] +mod tests{ + use super::*; + + #[test] + fn vdupq(){ + unsafe{ + let a = vdupq_n_f16(std::mem::transmute(f16::ONE)); + let b: f16 = std::mem::transmute(vget_lane_f16::<0>(a)); + assert_eq!(b, f16::ONE); + } + } +} + #[target_feature(enable = "fp16")] #[inline] pub(super) unsafe fn f16_to_f32_fp16(i: u16) -> f32 { From 4d309a04c4d4f18f895cc62e38b99d03ef76b5ee Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 1 Aug 2023 10:26:49 +0200 Subject: [PATCH 5/6] Fix. --- src/binary16/arch/aarch64.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/binary16/arch/aarch64.rs b/src/binary16/arch/aarch64.rs index 181d2ce..f954240 100644 --- a/src/binary16/arch/aarch64.rs +++ b/src/binary16/arch/aarch64.rs @@ -140,7 +140,7 @@ pub unsafe fn vget_lane_f16(a: float16x8_t) -> u16 { // options(nomem, nostack, preserves_flags)), // _ => unimplemented!("get_lane_f16 - {LANE}") // } - result + // result } #[inline] From d3e042abd53cc8c8d80abd65c38be8c12e267434 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 1 Aug 2023 11:31:31 +0200 Subject: [PATCH 6/6] `in` -> `inout` --- src/binary16/arch/aarch64.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/binary16/arch/aarch64.rs b/src/binary16/arch/aarch64.rs index f954240..46ab052 100644 --- a/src/binary16/arch/aarch64.rs +++ b/src/binary16/arch/aarch64.rs @@ -181,10 +181,10 @@ pub unsafe fn vld1q_f16(ptr: *const f16) -> float16x8_t{ /// Fused multiply add [doc](https://developer.arm.com/documentation/dui0801/g/A64-SIMD-Vector-Instructions/FMLA--vector-) #[target_feature(enable = "fp16")] #[inline] -pub unsafe fn vfmaq_f16(a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t{ +pub unsafe fn vfmaq_f16(mut a: float16x8_t, b: float16x8_t, c: float16x8_t) -> float16x8_t{ asm!( "fmla {0:v}.8h, {1:v}.8h, {2:v}.8h", - in(vreg) a, + inout(vreg) a, in(vreg) b, in(vreg) c, options(nomem, nostack, preserves_flags));