Skip to content

Commit

Permalink
Fuzzing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
awxkee committed Dec 28, 2024
1 parent 6020097 commit 642f00d
Showing 1 changed file with 31 additions and 31 deletions.
62 changes: 31 additions & 31 deletions src/avx2/vertical_u8_lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub(crate) fn convolve_vertical_avx_row_lp(
}

#[inline(always)]
unsafe fn m256dot<const SCALE: i32>(
unsafe fn m256dot(
store0: __m256i,
store1: __m256i,
row: __m256i,
Expand Down Expand Up @@ -108,15 +108,15 @@ unsafe fn convolve_vertical_avx2_row_impl(
let item_row1 =
_mm256_loadu_si256(src_ptr0.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row0, v_weight0);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row1, v_weight0);
(store0, store1) = m256dot(store0, store1, item_row0, v_weight0);
(store2, store3) = m256dot(store2, store3, item_row1, v_weight0);

let item_row10 = _mm256_loadu_si256(src_ptr1.as_ptr() as *const __m256i);
let item_row11 =
_mm256_loadu_si256(src_ptr1.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row10, v_weight1);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row11, v_weight1);
(store0, store1) = m256dot(store0, store1, item_row10, v_weight1);
(store2, store3) = m256dot(store2, store3, item_row11, v_weight1);
} else if bounds_size == 3 {
let py = bounds.start;
let weights = weight.get_unchecked(0..3);
Expand All @@ -134,22 +134,22 @@ unsafe fn convolve_vertical_avx2_row_impl(
let item_row1 =
_mm256_loadu_si256(src_ptr0.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row0, v_weight0);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row1, v_weight0);
(store0, store1) = m256dot(store0, store1, item_row0, v_weight0);
(store2, store3) = m256dot(store2, store3, item_row1, v_weight0);

let item_row10 = _mm256_loadu_si256(src_ptr1.as_ptr() as *const __m256i);
let item_row11 =
_mm256_loadu_si256(src_ptr1.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row10, v_weight1);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row11, v_weight1);
(store0, store1) = m256dot(store0, store1, item_row10, v_weight1);
(store2, store3) = m256dot(store2, store3, item_row11, v_weight1);

let item_row20 = _mm256_loadu_si256(src_ptr2.as_ptr() as *const __m256i);
let item_row21 =
_mm256_loadu_si256(src_ptr2.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row20, v_weight2);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row21, v_weight2);
(store0, store1) = m256dot(store0, store1, item_row20, v_weight2);
(store2, store3) = m256dot(store2, store3, item_row21, v_weight2);
} else if bounds_size == 4 {
let py = bounds.start;
let weights = weight.get_unchecked(0..4);
Expand All @@ -170,29 +170,29 @@ unsafe fn convolve_vertical_avx2_row_impl(
let item_row1 =
_mm256_loadu_si256(src_ptr0.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row0, v_weight0);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row1, v_weight0);
(store0, store1) = m256dot(store0, store1, item_row0, v_weight0);
(store2, store3) = m256dot(store2, store3, item_row1, v_weight0);

let item_row10 = _mm256_loadu_si256(src_ptr1.as_ptr() as *const __m256i);
let item_row11 =
_mm256_loadu_si256(src_ptr1.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row10, v_weight1);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row11, v_weight1);
(store0, store1) = m256dot(store0, store1, item_row10, v_weight1);
(store2, store3) = m256dot(store2, store3, item_row11, v_weight1);

let item_row20 = _mm256_loadu_si256(src_ptr2.as_ptr() as *const __m256i);
let item_row21 =
_mm256_loadu_si256(src_ptr2.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row20, v_weight2);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row21, v_weight2);
(store0, store1) = m256dot(store0, store1, item_row20, v_weight2);
(store2, store3) = m256dot(store2, store3, item_row21, v_weight2);

let item_row30 = _mm256_loadu_si256(src_ptr3.as_ptr() as *const __m256i);
let item_row31 =
_mm256_loadu_si256(src_ptr3.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row30, v_weight3);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row31, v_weight3);
(store0, store1) = m256dot(store0, store1, item_row30, v_weight3);
(store2, store3) = m256dot(store2, store3, item_row31, v_weight3);
} else {
for j in 0..bounds_size {
let py = bounds.start + j;
Expand All @@ -204,8 +204,8 @@ unsafe fn convolve_vertical_avx2_row_impl(
let item_row1 =
_mm256_loadu_si256(src_ptr.get_unchecked(32..).as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row0, v_weight);
(store2, store3) = m256dot::<SCALE>(store2, store3, item_row1, v_weight);
(store0, store1) = m256dot(store0, store1, item_row0, v_weight);
(store2, store3) = m256dot(store2, store3, item_row1, v_weight);
}
}

Expand Down Expand Up @@ -246,10 +246,10 @@ unsafe fn convolve_vertical_avx2_row_impl(
let src_ptr1 = src.get_unchecked(v_offset1..);

let item_row0 = _mm256_loadu_si256(src_ptr0.as_ptr() as *const __m256i);
(store0, store1) = m256dot::<SCALE>(store0, store1, item_row0, v_weight0);
(store0, store1) = m256dot(store0, store1, item_row0, v_weight0);

let item_row1 = _mm256_loadu_si256(src_ptr1.as_ptr() as *const __m256i);
(store0, store1) = m256dot::<SCALE>(store0, store1, item_row1, v_weight1);
(store0, store1) = m256dot(store0, store1, item_row1, v_weight1);
} else if bounds_size == 3 {
let py = bounds.start;
let weights = weight.get_unchecked(0..3);
Expand All @@ -264,13 +264,13 @@ unsafe fn convolve_vertical_avx2_row_impl(
let src_ptr2 = src.get_unchecked(v_offset2..);

let item_row0 = _mm256_loadu_si256(src_ptr0.as_ptr() as *const __m256i);
(store0, store1) = m256dot::<SCALE>(store0, store1, item_row0, v_weight0);
(store0, store1) = m256dot(store0, store1, item_row0, v_weight0);

let item_row1 = _mm256_loadu_si256(src_ptr1.as_ptr() as *const __m256i);
(store0, store1) = m256dot::<SCALE>(store0, store1, item_row1, v_weight1);
(store0, store1) = m256dot(store0, store1, item_row1, v_weight1);

let item_row2 = _mm256_loadu_si256(src_ptr2.as_ptr() as *const __m256i);
(store0, store1) = m256dot::<SCALE>(store0, store1, item_row2, v_weight2);
(store0, store1) = m256dot(store0, store1, item_row2, v_weight2);
} else if bounds_size == 4 {
let py = bounds.start;
let weights = weight.get_unchecked(0..4);
Expand All @@ -288,16 +288,16 @@ unsafe fn convolve_vertical_avx2_row_impl(
let src_ptr3 = src.get_unchecked(v_offset3..);

let item_row0 = _mm256_loadu_si256(src_ptr0.as_ptr() as *const __m256i);
(store0, store1) = m256dot::<SCALE>(store0, store1, item_row0, v_weight0);
(store0, store1) = m256dot(store0, store1, item_row0, v_weight0);

let item_row1 = _mm256_loadu_si256(src_ptr1.as_ptr() as *const __m256i);
(store0, store1) = m256dot::<SCALE>(store0, store1, item_row1, v_weight1);
(store0, store1) = m256dot(store0, store1, item_row1, v_weight1);

let item_row2 = _mm256_loadu_si256(src_ptr2.as_ptr() as *const __m256i);
(store0, store1) = m256dot::<SCALE>(store0, store1, item_row2, v_weight2);
(store0, store1) = m256dot(store0, store1, item_row2, v_weight2);

let item_row3 = _mm256_loadu_si256(src_ptr3.as_ptr() as *const __m256i);
(store0, store1) = m256dot::<SCALE>(store0, store1, item_row3, v_weight3);
(store0, store1) = m256dot(store0, store1, item_row3, v_weight3);
} else {
for j in 0..bounds_size {
let py = bounds.start + j;
Expand All @@ -307,7 +307,7 @@ unsafe fn convolve_vertical_avx2_row_impl(
let src_ptr = src.get_unchecked(v_offset..);
let item_row0 = _mm256_loadu_si256(src_ptr.as_ptr() as *const __m256i);

(store0, store1) = m256dot::<SCALE>(store0, store1, item_row0, v_weight);
(store0, store1) = m256dot(store0, store1, item_row0, v_weight);
}
}

Expand Down

0 comments on commit 642f00d

Please sign in to comment.