Skip to content

Commit

Permalink
ARM f16 compilation fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
awxkee committed Feb 6, 2025
1 parent 10ad62f commit 72fb987
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 90 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ workspace = { members = ["app", "wasm", "fuzz", "app/accelerate"], exclude = ["p

[package]
name = "pic-scale"
version = "0.5.0"
version = "0.5.1"
edition = "2021"
description = "High performance image scaling"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion picscale/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

122 changes: 35 additions & 87 deletions src/neon/f16_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ pub(crate) unsafe fn xvldq_f16_x2(ptr: *const f16) -> x_float16x8x2_t {
)
}

#[inline]
#[inline(always)]
#[cfg(feature = "nightly_f16")]
pub(crate) unsafe fn xvldq_f16_x4(ptr: *const f16) -> x_float16x8x4_t {
let ptr_u16 = ptr as *const u16;
Expand All @@ -126,46 +126,46 @@ pub(crate) unsafe fn xvldq_f16_x4(ptr: *const f16) -> x_float16x8x4_t {
)
}

#[inline]
#[inline(always)]
pub(crate) unsafe fn xvget_low_f16(x: x_float16x8_t) -> x_float16x4_t {
std::mem::transmute::<uint16x4_t, x_float16x4_t>(vget_low_u16(std::mem::transmute::<
x_float16x8_t,
uint16x8_t,
>(x)))
}

#[inline]
#[inline(always)]
pub(crate) unsafe fn xvget_high_f16(x: x_float16x8_t) -> x_float16x4_t {
std::mem::transmute::<uint16x4_t, x_float16x4_t>(vget_high_u16(std::mem::transmute::<
x_float16x8_t,
uint16x8_t,
>(x)))
}

#[inline]
#[inline(always)]
pub(crate) unsafe fn xcombine_f16(low: x_float16x4_t, high: x_float16x4_t) -> x_float16x8_t {
std::mem::transmute::<uint16x8_t, x_float16x8_t>(vcombine_u16(
std::mem::transmute::<x_float16x4_t, uint16x4_t>(low),
std::mem::transmute::<x_float16x4_t, uint16x4_t>(high),
))
}

#[inline]
#[inline(always)]
pub(crate) unsafe fn xreinterpret_u16_f16(x: x_float16x4_t) -> uint16x4_t {
std::mem::transmute(x)
}

#[inline]
#[inline(always)]
pub(crate) unsafe fn xreinterpretq_u16_f16(x: x_float16x8_t) -> uint16x8_t {
std::mem::transmute(x)
}

#[inline]
#[inline(always)]
pub(crate) unsafe fn xreinterpret_f16_u16(x: uint16x4_t) -> x_float16x4_t {
std::mem::transmute(x)
}

#[inline]
#[inline(always)]
pub(crate) unsafe fn xreinterpretq_f16_u16(x: uint16x8_t) -> x_float16x8_t {
std::mem::transmute(x)
}
Expand Down Expand Up @@ -319,7 +319,7 @@ pub(super) unsafe fn xvrecpeq_f16(v: x_float16x8_t) -> x_float16x8_t {
// xreinterpretq_f16_u16(result)
// }

#[inline]
#[inline(always)]
pub(super) unsafe fn xvcombine_f16(v1: x_float16x4_t, v2: x_float16x4_t) -> x_float16x8_t {
xreinterpretq_f16_u16(vcombine_u16(
xreinterpret_u16_f16(v1),
Expand Down Expand Up @@ -380,11 +380,11 @@ pub(super) unsafe fn xvfmlalq_high_f16(
) -> float32x4_t {
let mut result: float32x4_t = a;
asm!(
"fmlal2 {0:v}.4s, {1:v}.4h, {2:v}.4h",
inout(vreg) result,
in(vreg) xreinterpretq_u16_f16(b),
in(vreg) xreinterpretq_u16_f16(c),
options(pure, nomem, nostack)
"fmlal2 {0:v}.4s, {1:v}.4h, {2:v}.4h",
inout(vreg) result,
in(vreg) xreinterpretq_u16_f16(b),
in(vreg) xreinterpretq_u16_f16(c),
options(pure, nomem, nostack)
);
result
}
Expand All @@ -405,11 +405,11 @@ pub(super) unsafe fn xvfmlalq_low_f16(
) -> float32x4_t {
let mut result: float32x4_t = a;
asm!(
"fmlal {0:v}.4s, {1:v}.4h, {2:v}.4h",
inout(vreg) result,
in(vreg) xreinterpretq_u16_f16(b),
in(vreg) xreinterpretq_u16_f16(c),
options(pure, nomem, nostack)
"fmlal {0:v}.4s, {1:v}.4h, {2:v}.4h",
inout(vreg) result,
in(vreg) xreinterpretq_u16_f16(b),
in(vreg) xreinterpretq_u16_f16(c),
options(pure, nomem, nostack)
);
result
}
Expand All @@ -429,41 +429,15 @@ pub(super) unsafe fn xvfmlalq_lane_low_f16<const LANE: i32>(
c: x_float16x4_t,
) -> float32x4_t {
let mut result: float32x4_t = a;
static_assert_uimm_bits!(LANE, 3);
let full_lane = xvcombine_f16(c, c);
if LANE == 0 {
asm!(
"fmlal {0:v}.4s, {1:v}.4h, {2:v}.h[0]",
inout(vreg) result,
in(vreg) xreinterpretq_u16_f16(b),
in(vreg) xreinterpretq_u16_f16(full_lane),
options(pure, nomem, nostack)
);
} else if LANE == 1 {
asm!(
"fmlal {0:v}.4s, {1:v}.4h, {2:v}.h[1]",
inout(vreg) result,
in(vreg) xreinterpretq_u16_f16(b),
in(vreg) xreinterpretq_u16_f16(full_lane),
options(pure, nomem, nostack)
);
} else if LANE == 2 {
asm!(
"fmlal {0:v}.4s, {1:v}.4h, {2:v}.h[2]",
inout(vreg) result,
in(vreg) xreinterpretq_u16_f16(b),
in(vreg) xreinterpretq_u16_f16(full_lane),
options(pure, nomem, nostack)
);
} else if LANE == 3 {
asm!(
"fmlal {0:v}.4s, {1:v}.4h, {2:v}.h[3]",
let lanes: uint16x8_t = vdupq_n_u16(vget_lane_u16::<LANE>(xreinterpret_u16_f16(c)));

asm!(
"fmlal {0:v}.4s, {1:v}.4h, {2:v}.4h",
inout(vreg) result,
in(vreg) xreinterpretq_u16_f16(b),
in(vreg) xreinterpretq_u16_f16(full_lane),
in(vreg) lanes,
options(pure, nomem, nostack)
);
}
);
result
}

Expand Down Expand Up @@ -791,39 +765,13 @@ pub(super) unsafe fn xvfmla_lane_f16<const LANE: i32>(
let mut result: uint16x4_t = xreinterpret_u16_f16(a);
let lanes: uint16x8_t = vdupq_n_u16(vget_lane_u16::<LANE>(xreinterpret_u16_f16(c)));

if LANE == 0 {
asm!(
"fmla {0:v}.4h, {1:v}.4h, {2:v}.4h",
inout(vreg) result,
in(vreg) xreinterpret_u16_f16(b),
in(vreg) lanes,
options(pure, nomem, nostack)
);
} else if LANE == 1 {
asm!(
"fmla {0:v}.4h, {1:v}.4h, {2:v}.4h",
inout(vreg) result,
in(vreg) xreinterpret_u16_f16(b),
in(vreg) lanes,
options(pure, nomem, nostack)
);
} else if LANE == 2 {
asm!(
"fmla {0:v}.4h, {1:v}.4h, {2:v}.4h",
inout(vreg) result,
in(vreg) xreinterpret_u16_f16(b),
in(vreg) lanes,
options(pure, nomem, nostack)
);
} else if LANE == 3 {
asm!(
asm!(
"fmla {0:v}.4h, {1:v}.4h, {2:v}.4h",
inout(vreg) result,
in(vreg) xreinterpret_u16_f16(b),
in(vreg) lanes,
options(pure, nomem, nostack)
);
}
);
xreinterpret_f16_u16(result)
}

Expand Down Expand Up @@ -990,27 +938,27 @@ pub(super) unsafe fn xvbslq_f16(
xreinterpretq_f16_u16(result)
}

#[inline]
#[inline(always)]
#[cfg(feature = "nightly_f16")]
pub(crate) unsafe fn xvst_f16(ptr: *mut f16, x: x_float16x4_t) {
vst1_u16(ptr as *mut u16, xreinterpret_u16_f16(x))
}

#[inline]
#[inline(always)]
#[cfg(feature = "nightly_f16")]
pub(crate) unsafe fn xvstq_f16(ptr: *mut f16, x: x_float16x8_t) {
vst1q_u16(ptr as *mut u16, xreinterpretq_u16_f16(x))
}

#[inline]
#[inline(always)]
#[cfg(feature = "nightly_f16")]
pub(crate) unsafe fn xvstq_f16_x2(ptr: *mut f16, x: x_float16x8x2_t) {
let ptr_u16 = ptr as *mut u16;
vst1q_u16(ptr_u16, xreinterpretq_u16_f16(x.0));
vst1q_u16(ptr_u16.add(8), xreinterpretq_u16_f16(x.1));
}

#[inline]
#[inline(always)]
#[cfg(feature = "nightly_f16")]
pub(crate) unsafe fn xvstq_f16_x4(ptr: *const f16, x: x_float16x8x4_t) {
let ptr_u16 = ptr as *mut u16;
Expand All @@ -1020,17 +968,17 @@ pub(crate) unsafe fn xvstq_f16_x4(ptr: *const f16, x: x_float16x8x4_t) {
vst1q_u16(ptr_u16.add(24), xreinterpretq_u16_f16(x.3));
}

#[inline]
#[inline(always)]
pub(crate) unsafe fn xvdup_lane_f16<const N: i32>(a: x_float16x4_t) -> x_float16x4_t {
xreinterpret_f16_u16(vdup_lane_u16::<N>(xreinterpret_u16_f16(a)))
}

#[inline]
#[inline(always)]
pub(crate) unsafe fn xvdup_laneq_f16<const N: i32>(a: x_float16x8_t) -> x_float16x4_t {
xreinterpret_f16_u16(vdup_laneq_u16::<N>(xreinterpretq_u16_f16(a)))
}

#[inline]
#[inline(always)]
#[cfg(feature = "nightly_f16")]
pub(crate) unsafe fn xvld1q_lane_f16<const LANE: i32>(
ptr: *const f16,
Expand All @@ -1042,7 +990,7 @@ pub(crate) unsafe fn xvld1q_lane_f16<const LANE: i32>(
))
}

#[inline]
#[inline(always)]
#[cfg(feature = "nightly_f16")]
pub(crate) unsafe fn xvsetq_lane_f16<const LANE: i32>(v: f16, r: x_float16x8_t) -> x_float16x8_t {
xreinterpretq_f16_u16(vsetq_lane_u16::<LANE>(
Expand Down

0 comments on commit 72fb987

Please sign in to comment.