Skip to content

Commit

Permalink
Reduce the pq build time
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <qianya.cheng@zilliz.com>
  • Loading branch information
cqy123456 committed Feb 7, 2025
1 parent 4cc0cb2 commit 5a1f166
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 0 deletions.
169 changes: 169 additions & 0 deletions src/simd/distances_avx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -754,5 +754,174 @@ bf16_vec_norm_L2sqr_avx(const knowhere::bf16* x, size_t d) {
auto res = _mm256_reduce_add_ps(msum_0);
return res;
}

namespace {
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
inline void
fvec_L2sqr_ny_avx_impl(float* dis, const float* x, const float* y, size_t d, size_t ny) {
size_t i = 0;
for (; i + 3 < ny; i += 4) {
const float* __restrict y1 = y + d * i;
const float* __restrict y2 = y + d * (i + 1);
const float* __restrict y3 = y + d * (i + 2);
const float* __restrict y4 = y + d * (i + 3);
fvec_L2sqr_batch_4_avx(x, y1, y2, y3, y4, d, dis[i], dis[i + 1], dis[i + 2], dis[i + 3]);
}
while (i < ny) {
const float* __restrict y_i = y + d * i;
dis[i] = fvec_L2sqr_avx(x, y_i, d);
y += d;
i++;
}
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

inline void
fvec_L2sqr_ny_avx_d2_impl(float* dis, const float* x, const float* y, size_t d, size_t ny) {
size_t y_i = ny;
auto mx = _mm256_setr_ps(x[0], x[1], x[0], x[1], x[0], x[1], x[0], x[1]);
const float* y_i_addr = y;
while (y_i >= 16) {
auto my1 = _mm256_loadu_ps(y_i_addr);
auto my2 = _mm256_loadu_ps(y_i_addr + 8); // 4-th
auto my3 = _mm256_loadu_ps(y_i_addr + 16); // 8-th
auto my4 = _mm256_loadu_ps(y_i_addr + 24); // 12-th
my1 = _mm256_sub_ps(my1, mx);
my1 = _mm256_mul_ps(my1, my1);
my2 = _mm256_sub_ps(my2, mx);
my2 = _mm256_mul_ps(my2, my2);
my3 = _mm256_sub_ps(my3, mx);
my3 = _mm256_mul_ps(my3, my3);
my4 = _mm256_sub_ps(my4, mx);
my4 = _mm256_mul_ps(my4, my4);
my1 = _mm256_hadd_ps(my1, my2);
my3 = _mm256_hadd_ps(my3, my4);
my1 = _mm256_permutevar8x32_ps(my1, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7));
my3 = _mm256_permutevar8x32_ps(my3, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7));
_mm256_storeu_ps(dis + (ny - y_i), my1);
_mm256_storeu_ps(dis + (ny - y_i) + 8, my3);
y_i_addr += 16 * d;
y_i -= 16;
}
while (y_i >= 4) {
auto my1 = _mm256_loadu_ps(y_i_addr);
my1 = _mm256_sub_ps(my1, mx);
my1 = _mm256_mul_ps(my1, my1);
my1 = _mm256_hadd_ps(my1, my1);
my1 = _mm256_permutevar8x32_ps(my1, _mm256_setr_epi32(0, 1, 4, 5, 2, 3, 6, 7));
__m128 high = _mm256_extractf128_ps(my1, 0);
_mm_storeu_ps(dis + (ny - y_i), high);
y_i_addr += 4 * d;
y_i -= 4;
}
while (y_i > 0) {
float dis1;
dis1 = (x[0] - y_i_addr[0]) * (x[0] - y_i_addr[0]);
dis1 += (x[1] - y_i_addr[1]) * (x[1] - y_i_addr[1]);
dis[ny - y_i] = dis1;
y_i_addr += d;
y_i -= 1;
}
}

inline void
fvec_L2sqr_ny_avx_d4_impl(float* dis, const float* x, const float* y, size_t d, size_t ny) {
size_t y_i = ny;
auto mx_t = _mm_loadu_ps(x);
auto mx = _mm256_set_m128(mx_t, mx_t);
const float* y_i_addr = y;
while (y_i >= 8) {
auto my1 = _mm256_loadu_ps(y_i_addr);
auto my2 = _mm256_loadu_ps(y_i_addr + 8);
auto my3 = _mm256_loadu_ps(y_i_addr + 16);
auto my4 = _mm256_loadu_ps(y_i_addr + 24);
my1 = _mm256_sub_ps(my1, mx);
my1 = _mm256_mul_ps(my1, my1);
my2 = _mm256_sub_ps(my2, mx);
my2 = _mm256_mul_ps(my2, my2);
my3 = _mm256_sub_ps(my3, mx);
my3 = _mm256_mul_ps(my3, my3);
my4 = _mm256_sub_ps(my4, mx);
my4 = _mm256_mul_ps(my4, my4);
my1 = _mm256_hadd_ps(my1, my2);
my3 = _mm256_hadd_ps(my3, my4);
my1 = _mm256_hadd_ps(my1, my3);
my1 = _mm256_permutevar8x32_ps(my1, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
_mm256_storeu_ps(dis + (ny - y_i), my1);
y_i_addr += 8 * d;
y_i -= 8;
}
if (y_i >= 4) {
auto my1 = _mm256_loadu_ps(y_i_addr);
auto my2 = _mm256_loadu_ps(y_i_addr + 8);
my1 = _mm256_sub_ps(my1, mx);
my1 = _mm256_mul_ps(my1, my1);
my2 = _mm256_sub_ps(my2, mx);
my2 = _mm256_mul_ps(my2, my2);
my1 = _mm256_hadd_ps(my1, my2);
my1 = _mm256_hadd_ps(my1, my1);
my1 = _mm256_permutevar8x32_ps(my1, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
__m128 res = _mm256_extractf128_ps(my1, 0);
_mm_storeu_ps(dis + (ny - y_i), res);
y_i_addr = y_i_addr + 4 * d;
y_i -= 4;
}
if (y_i >= 2) {
auto my1 = _mm256_loadu_ps(y_i_addr);
my1 = _mm256_sub_ps(my1, mx);
my1 = _mm256_mul_ps(my1, my1);
__m128 high = _mm256_extractf128_ps(my1, 1);
__m128 low = _mm256_extractf128_ps(my1, 0);
__m128 sum_low = _mm_hadd_ps(low, low);
sum_low = _mm_hadd_ps(sum_low, sum_low);

__m128 sum_high = _mm_hadd_ps(high, high);
sum_high = _mm_hadd_ps(sum_high, sum_high);

dis[ny - y_i] = _mm_cvtss_f32(sum_low);
dis[ny - y_i + 1] = _mm_cvtss_f32(sum_high);
y_i_addr = y_i_addr + 2 * d;
y_i -= 2;
}
if (y_i > 0) {
float dis1, dis2;
dis1 = (x[0] - y_i_addr[0]) * (x[0] - y_i_addr[0]);
dis2 = (x[1] - y_i_addr[1]) * (x[1] - y_i_addr[1]);
dis1 += (x[2] - y_i_addr[2]) * (x[2] - y_i_addr[2]);
dis2 += (x[3] - y_i_addr[3]) * (x[3] - y_i_addr[3]);
dis[ny - y_i] = dis1 + dis2;
}
}
} // namespace

void
fvec_L2sqr_ny_avx(float* dis, const float* x, const float* y, size_t d, size_t ny) {
// todo: add more small dim support
if (d == 2) {
return fvec_L2sqr_ny_avx_d2_impl(dis, x, y, d, ny);
} else if (d == 4) {
return fvec_L2sqr_ny_avx_d4_impl(dis, x, y, d, ny);
} else {
return fvec_L2sqr_ny_avx_impl(dis, x, y, d, ny);
}
}

size_t
fvec_L2sqr_ny_nearest_avx(float* __restrict distances_tmp_buffer, const float* __restrict x, const float* __restrict y,
size_t d, size_t ny) {
fvec_L2sqr_ny_avx(distances_tmp_buffer, x, y, d, ny);

size_t nearest_idx = 0;
float min_dis = HUGE_VALF;

for (size_t i = 0; i < ny; i++) {
if (distances_tmp_buffer[i] < min_dis) {
min_dis = distances_tmp_buffer[i];
nearest_idx = i;
}
}
return nearest_idx;
}

} // namespace faiss
#endif
6 changes: 6 additions & 0 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ fp16_vec_norm_L2sqr_avx(const knowhere::fp16* x, size_t d);
float
bf16_vec_norm_L2sqr_avx(const knowhere::bf16* x, size_t d);

void
fvec_L2sqr_ny_avx(float* dis, const float* x, const float* y, size_t d, size_t ny);

size_t
fvec_L2sqr_ny_nearest_avx(float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny);

} // namespace faiss

#endif /* DISTANCES_AVX_H */
2 changes: 2 additions & 0 deletions src/simd/hook.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ fvec_hook(std::string& simd_type) {
bf16_vec_inner_product_batch_4 = bf16_vec_inner_product_batch_4_avx512;
fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_avx512;
bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_avx512;
fvec_L2sqr_ny_nearest = fvec_L2sqr_ny_nearest_avx; // avx2 compute small dim faster than avx512

simd_type = "AVX512";
support_pq_fast_scan = true;
Expand Down Expand Up @@ -248,6 +249,7 @@ fvec_hook(std::string& simd_type) {
fp16_vec_L2sqr_batch_4 = fp16_vec_L2sqr_batch_4_avx;
bf16_vec_L2sqr_batch_4 = bf16_vec_L2sqr_batch_4_avx;

fvec_L2sqr_ny_nearest = fvec_L2sqr_ny_nearest_avx;
simd_type = "AVX2";
support_pq_fast_scan = true;
} else if (use_sse4_2 && cpu_support_sse4_2()) {
Expand Down
34 changes: 34 additions & 0 deletions thirdparty/faiss/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,36 @@ void all_inner_product(
}
}

void exhaustive_L2sqr_nearest_imp(
const float* __restrict x,
const float* __restrict y,
size_t d,
size_t nx,
size_t ny,
float* vals,
int64_t* ids) {
const size_t ny_batch_size = 256;
auto sub_dis = std::make_unique<float[]>(std::min(ny_batch_size, ny));
for (int64_t i = 0; i < nx; i++) {
const float* x_i = x + i * d;
size_t nearest_idx = 0;
float min_dis = HUGE_VALF;
// compute distances
for (auto j = 0; j < ny; j += ny_batch_size) {
const float* y_j = y + j * d;
const size_t y_j_n = std::min(ny_batch_size, ny - j);
auto batch_nearest_id =
fvec_L2sqr_ny_nearest(sub_dis.get(), x_i, y_j, d, y_j_n);
if (sub_dis[batch_nearest_id] < min_dis) {
nearest_idx = batch_nearest_id + j;
min_dis = sub_dis[batch_nearest_id];
}
}
ids[i] = nearest_idx;
vals[i] = min_dis;
}
}

void knn_L2sqr(
const float* x,
const float* y,
Expand All @@ -990,6 +1020,10 @@ void knn_L2sqr(
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
return;
}
if (k == 1 && sel == nullptr) {
exhaustive_L2sqr_nearest_imp(x, y, d, nx, ny, vals, ids);
return;
}
// // todo aguzhva: this is disabled for knowhere, because it requires
// // some dynamic kernel dispatching.
// if (k == 1) {
Expand Down
9 changes: 9 additions & 0 deletions thirdparty/faiss/faiss/utils/distances.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ void pairwise_indexed_inner_product(
const int64_t* iy,
float* dis);

void exhaustive_L2sqr_nearest_imp(
const float* __restrict x,
const float* __restrict y,
size_t d,
size_t nx,
size_t ny,
float* vals,
int64_t* ids);

/***************************************************************************
* KNN functions
***************************************************************************/
Expand Down

0 comments on commit 5a1f166

Please sign in to comment.