Skip to content

Commit

Permalink
Add a batch interface for getDistanceByLabel (#337)
Browse files Browse the repository at this point in the history
Signed-off-by: zourunxin.zrx <zourunxin.zrx@oceanbase.com>
  • Loading branch information
Carrot-77 authored Jan 24, 2025
1 parent 9ad9cc5 commit 11f89df
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 0 deletions.
13 changes: 13 additions & 0 deletions include/vsag/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,19 @@ class Index {
throw std::runtime_error("Index doesn't support get distance by id");
};

/**
* @brief Calculate the distance between the query and the vector of the given ID for batch.
*
* @param query is the embedding of query
* @param ids is the unique identifier of the vector to be calculated in the index.
* @param count is the count of ids
* @return result is valid distance of input ids. '-1' indicates an invalid distance.
*/
virtual tl::expected<DatasetPtr, Error>
CalDistanceById(const float* query, const int64_t* ids, int64_t count) const {
throw std::runtime_error("Index doesn't support get distance by id");
};

/**
* @brief Checks if the specified feature is supported by the index.
*
Expand Down
6 changes: 6 additions & 0 deletions src/algorithm/hnswlib/algorithm_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
#include "space_interface.h"
#include "stream_reader.h"
#include "typing.h"
#include "vsag/dataset.h"
#include "vsag/errors.h"
#include "vsag/expected.hpp"

namespace hnswlib {

Expand Down Expand Up @@ -66,6 +69,9 @@ class AlgorithmInterface {
virtual float
getDistanceByLabel(LabelType label, const void* data_point) = 0;

virtual tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(const int64_t* ids, const void* data_point, int64_t count) = 0;

virtual const float*
getDataByLabel(LabelType label) const = 0;

Expand Down
28 changes: 28 additions & 0 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,34 @@ HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) {
return dist;
}

tl::expected<vsag::DatasetPtr, vsag::Error>
HierarchicalNSW::getBatchDistanceByLabel(const int64_t* ids,
const void* data_point,
int64_t count) {
std::shared_lock lock_table(label_lookup_lock_);
int64_t valid_cnt = 0;
auto result = vsag::Dataset::Make();
result->Owner(true, allocator_);
auto* distances = (float*)allocator_->Allocate(sizeof(float) * count);
result->Distances(distances);
std::shared_ptr<float[]> normalize_query;
normalizeVector(data_point, normalize_query);
for (int i = 0; i < count; i++) {
auto search = label_lookup_.find(ids[i]);
if (search == label_lookup_.end()) {
distances[i] = -1;
} else {
InnerIdType internal_id = search->second;
float dist =
fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_);
distances[i] = dist;
valid_cnt++;
}
}
result->NumElements(count);
return std::move(result);
}

bool
HierarchicalNSW::isValidLabel(LabelType label) {
std::shared_lock lock_table(label_lookup_lock_);
Expand Down
4 changes: 4 additions & 0 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "algorithm_interface.h"
#include "block_manager.h"
#include "visited_list_pool.h"
#include "vsag/dataset.h"
namespace hnswlib {
using InnerIdType = vsag::InnerIdType;
using linklistsizeint = unsigned int;
Expand Down Expand Up @@ -146,6 +147,9 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
float
getDistanceByLabel(LabelType label, const void* data_point) override;

tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(const int64_t* ids, const void* data_point, int64_t count) override;

bool
isValidLabel(LabelType label) override;

Expand Down
24 changes: 24 additions & 0 deletions src/algorithm/hnswlib/hnswalg_static.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,30 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
return dist;
}

tl::expected<vsag::DatasetPtr, vsag::Error>
getBatchDistanceByLabel(const int64_t* ids, const void* data_point, int64_t count) override {
std::unique_lock<std::mutex> lock_table(label_lookup_lock);
int64_t valid_cnt = 0;
auto result = vsag::Dataset::Make();
result->Owner(true, allocator_);
auto* distances = (float*)allocator_->Allocate(sizeof(float) * count);
result->Distances(distances);
for (int i = 0; i < count; i++) {
auto search = label_lookup_.find(ids[i]);
if (search == label_lookup_.end()) {
distances[i] = -1;
} else {
InnerIdType internal_id = search->second;
float dist =
fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_);
distances[i] = dist;
valid_cnt++;
}
}
result->NumElements(valid_cnt);
return std::move(result);
}

void
copyDataByLabel(LabelType label, void* data_point) override {
std::unique_lock lock_table(label_lookup_lock);
Expand Down
5 changes: 5 additions & 0 deletions src/index/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class HNSW : public Index {
SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector));
};

virtual tl::expected<DatasetPtr, Error>
CalDistanceById(const float* vector, const int64_t* ids, int64_t count) const override {
SAFE_CALL(return alg_hnsw_->getBatchDistanceByLabel(ids, vector, count));
};

[[nodiscard]] bool
CheckFeature(IndexFeature feature) const override;

Expand Down
40 changes: 40 additions & 0 deletions tests/test_hnsw_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,46 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Id", "[ft][hn
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Batch Calc Dis Id", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
TestBatchCalcDistanceById(index, dataset);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex,
"static HNSW Batch Calc Dis Id",
"[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2");
auto use_static = GENERATE(true);
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
if (dim % 4 != 0) {
dim = ((dim / 4) + 1) * 4;
}
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim, use_static);
auto index = TestFactory(name, param, true);
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
TestBatchCalcDistanceById(index, dataset);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Vector", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down
24 changes: 24 additions & 0 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,30 @@ TestIndex::TestCalcDistanceById(const IndexPtr& index, const TestDatasetPtr& dat
}
}

void
TestIndex::TestBatchCalcDistanceById(const IndexPtr& index,
const TestDatasetPtr& dataset,
float error) {
auto queries = dataset->query_;
auto query_count = queries->GetNumElements();
auto dim = queries->GetDim();
auto gts = dataset->ground_truth_;
auto gt_topK = dataset->top_k;
for (auto i = 0; i < query_count; ++i) {
auto query = vsag::Dataset::Make();
query->NumElements(1)
->Dim(dim)
->Float32Vectors(queries->GetFloat32Vectors() + i * dim)
->Owner(false);
auto result = index->CalDistanceById(
query->GetFloat32Vectors(), gts->GetIds() + (i * gt_topK), gt_topK);
for (auto j = 0; j < gt_topK; ++j) {
REQUIRE(std::abs(gts->GetDistances()[i * gt_topK + j] -
result.value()->GetDistances()[j]) < error);
}
}
}

void
TestIndex::TestSerializeFile(const IndexPtr& index_from,
const IndexPtr& index_to,
Expand Down
5 changes: 5 additions & 0 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ class TestIndex {
static void
TestCalcDistanceById(const IndexPtr& index, const TestDatasetPtr& dataset, float error = 1e-5);

static void
TestBatchCalcDistanceById(const IndexPtr& index,
const TestDatasetPtr& dataset,
float error = 1e-5);

static void
TestSerializeFile(const IndexPtr& index_from,
const IndexPtr& index_to,
Expand Down

0 comments on commit 11f89df

Please sign in to comment.