Skip to content

Commit 3750734

Browse files
a theoretical fix for binary data type (zilliztech#676)
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
1 parent 0d4834f commit 3750734

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

src/index/flat/flat.cc

+6-4
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ class FlatIndexNode : public IndexNode {
117117
faiss::SearchParameters search_params;
118118
search_params.sel = id_selector;
119119

120-
index_->search(1, (const uint8_t*)x + index * dim / 8, k, cur_i_dis, cur_ids, &search_params);
120+
index_->search(1, (const uint8_t*)x + index * ((dim + 7) / 8), k, cur_i_dis, cur_ids,
121+
&search_params);
121122

122123
if (index_->metric_type == faiss::METRIC_Hamming) {
123124
for (int64_t j = 0; j < k; j++) {
@@ -189,7 +190,8 @@ class FlatIndexNode : public IndexNode {
189190
faiss::SearchParameters search_params;
190191
search_params.sel = id_selector;
191192

192-
index_->range_search(1, (const uint8_t*)xq + index * dim / 8, radius, &res, &search_params);
193+
index_->range_search(1, (const uint8_t*)xq + index * ((dim + 7) / 8), radius, &res,
194+
&search_params);
193195
}
194196
auto elem_cnt = res.lims[1];
195197
result_dist_array[index].resize(elem_cnt);
@@ -238,9 +240,9 @@ class FlatIndexNode : public IndexNode {
238240
if constexpr (std::is_same<IndexType, faiss::IndexBinaryFlat>::value) {
239241
uint8_t* data = nullptr;
240242
try {
241-
data = new uint8_t[rows * dim / 8];
243+
data = new uint8_t[rows * ((dim + 7) / 8)];
242244
for (int64_t i = 0; i < rows; i++) {
243-
index_->reconstruct(ids[i], data + i * dim / 8);
245+
index_->reconstruct(ids[i], data + i * ((dim + 7) / 8));
244246
}
245247
return GenResultDataSet(rows, dim, data);
246248
} catch (const std::exception& e) {

src/index/ivf/ivf.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSetPtr dataset, const Config
656656
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;
657657

658658
if constexpr (std::is_same<IndexType, faiss::IndexBinaryIVF>::value) {
659-
auto cur_data = (const uint8_t*)data + index * dim / 8;
659+
auto cur_data = (const uint8_t*)data + index * ((dim + 7) / 8);
660660

661661
int32_t* i_distances = reinterpret_cast<int32_t*>(distances.get());
662662

@@ -781,7 +781,7 @@ IvfIndexNode<DataType, IndexType>::RangeSearch(const DataSetPtr dataset, const C
781781
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;
782782

783783
if constexpr (std::is_same<IndexType, faiss::IndexBinaryIVF>::value) {
784-
auto cur_data = (const uint8_t*)xq + index * dim / 8;
784+
auto cur_data = (const uint8_t*)xq + index * ((dim + 7) / 8);
785785

786786
faiss::IVFSearchParameters ivf_search_params;
787787
ivf_search_params.nprobe = index_->nlist;
@@ -931,11 +931,11 @@ IvfIndexNode<DataType, IndexType>::GetVectorByIds(const DataSetPtr dataset) cons
931931
auto ids = dataset->GetIds();
932932

933933
try {
934-
auto data = std::make_unique<uint8_t[]>(dim * rows / 8);
934+
auto data = std::make_unique<uint8_t[]>(rows * ((dim + 7) / 8));
935935
for (int64_t i = 0; i < rows; i++) {
936936
int64_t id = ids[i];
937937
assert(id >= 0 && id < index_->ntotal);
938-
index_->reconstruct(id, data.get() + i * dim / 8);
938+
index_->reconstruct(id, data.get() + i * ((dim + 7) / 8));
939939
}
940940
return GenResultDataSet(rows, dim, std::move(data));
941941
} catch (const std::exception& e) {

0 commit comments

Comments
 (0)