Skip to content

Commit

Permalink
Add Hamming distance to cuvs CAGRA
Browse files Browse the repository at this point in the history
Signed-off-by: Mickael Ide <mide@nvidia.com>
  • Loading branch information
lowener committed Feb 7, 2025
1 parent 00c4ca9 commit 35bea55
Show file tree
Hide file tree
Showing 14 changed files with 184 additions and 117 deletions.
1 change: 1 addition & 0 deletions include/knowhere/index/index_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
{IndexEnum::INDEX_GPU_IVFFLAT, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_GPU_IVFPQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_GPU_CAGRA, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_GPU_CAGRA, VecType::VECTOR_INT8},

// hnsw
{IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT},
Expand Down
2 changes: 1 addition & 1 deletion src/common/cuvs/integration/brute_force_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
#include "common/cuvs/proto/cuvs_index_kind.hpp"

namespace cuvs_knowhere {
template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::brute_force>;
template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::brute_force, float>;
} // namespace cuvs_knowhere
3 changes: 2 additions & 1 deletion src/common/cuvs/integration/cagra_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
#include "common/cuvs/proto/cuvs_index_kind.hpp"

namespace cuvs_knowhere {
template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::cagra>;
template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::cagra, float>;
template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::cagra, std::uint8_t>;
} // namespace cuvs_knowhere
165 changes: 87 additions & 78 deletions src/common/cuvs/integration/cuvs_knowhere_index.cuh

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions src/common/cuvs/integration/cuvs_knowhere_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
#include "common/cuvs/proto/cuvs_index_kind.hpp"
namespace cuvs_knowhere {

template <cuvs_proto::cuvs_index_kind IndexKind>
template <cuvs_proto::cuvs_index_kind IndexKind, typename DataType>
struct cuvs_knowhere_index {
auto static constexpr index_kind = IndexKind;

using data_type = cuvs_data_t<index_kind>;
using data_type = DataType;
using indexing_type = cuvs_indexing_t<index_kind>;
using input_indexing_type = cuvs_input_indexing_t<index_kind>;

Expand All @@ -45,7 +45,7 @@ struct cuvs_knowhere_index {
dim() const;
void
train(cuvs_knowhere_config const&, data_type const*, knowhere_indexing_type, knowhere_indexing_type);
std::tuple<knowhere_indexing_type*, knowhere_data_type*>
std::tuple<knowhere_indexing_type*, knowhere_distance_type*>
search(cuvs_knowhere_config const& config, data_type const* data, knowhere_indexing_type row_count,
knowhere_indexing_type feature_count, knowhere_bitset_data_type const* bitset_data = nullptr,
knowhere_bitset_indexing_type bitset_byte_size = knowhere_bitset_indexing_type{},
Expand All @@ -58,7 +58,7 @@ struct cuvs_knowhere_index {
serialize(std::ostream& os) const;
void
serialize_to_hnswlib(std::ostream& os) const;
static cuvs_knowhere_index<IndexKind>
static cuvs_knowhere_index<IndexKind, DataType>
deserialize(std::istream& is);
void
synchronize(bool is_without_mempool = false) const;
Expand All @@ -73,9 +73,10 @@ struct cuvs_knowhere_index {
}
};

extern template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::brute_force>;
extern template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::ivf_flat>;
extern template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::ivf_pq>;
extern template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::cagra>;
extern template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::brute_force, float>;
extern template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::ivf_flat, float>;
extern template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::ivf_pq, float>;
extern template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::cagra, float>;
extern template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::cagra, uint8_t>;

} // namespace cuvs_knowhere
2 changes: 1 addition & 1 deletion src/common/cuvs/integration/ivf_flat_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
#include "common/cuvs/proto/cuvs_index_kind.hpp"

namespace cuvs_knowhere {
template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::ivf_flat>;
template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::ivf_flat, float>;
} // namespace cuvs_knowhere
2 changes: 1 addition & 1 deletion src/common/cuvs/integration/ivf_pq_index.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
#include "common/cuvs/proto/cuvs_index_kind.hpp"

namespace cuvs_knowhere {
template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::ivf_pq>;
template struct cuvs_knowhere_index<cuvs_proto::cuvs_index_kind::ivf_pq, float>;
} // namespace cuvs_knowhere
6 changes: 1 addition & 5 deletions src/common/cuvs/integration/type_mappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

namespace cuvs_knowhere {

using knowhere_data_type = float;
using knowhere_distance_type = float;
using knowhere_indexing_type = std::int64_t;
using knowhere_bitset_data_type = std::uint8_t;
using knowhere_bitset_indexing_type = std::uint32_t;
Expand Down Expand Up @@ -61,12 +61,8 @@ struct cuvs_io_type_mapper<true, cuvs_proto::cuvs_index_kind::cagra> : std::true
using indexing_type = std::uint32_t;
using input_indexing_type = std::int64_t;
};

} // namespace detail

template <cuvs_proto::cuvs_index_kind IndexKind>
using cuvs_data_t = typename detail::cuvs_io_type_mapper<true, IndexKind>::data_type;

template <cuvs_proto::cuvs_index_kind IndexKind>
using cuvs_indexing_t = typename detail::cuvs_io_type_mapper<true, IndexKind>::indexing_type;

Expand Down
15 changes: 11 additions & 4 deletions src/common/cuvs/proto/cuvs_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,17 @@ struct cuvs_index {
cuvs::neighbors::cagra::search(res, search_params, underlying_index, queries, neighbors_tmp, distances_tmp,
filter);
}
if (refine_ratio > 1.0f) {
if (dataset.has_value()) {
bool do_refine_step = refine_ratio > 1.0f;
if (do_refine_step && !dataset.has_value()) {
RAFT_LOG_WARN("Refinement requested, but no dataset provided. Ignoring refinement request.");
do_refine_step = false;
}
if (do_refine_step && !std::is_same_v<T, float>) {
RAFT_LOG_WARN("Refinement requested, but only float are supported. Ignoring refinement request.");
do_refine_step = false;
}
if constexpr (std::is_same_v<T, float>) {
if (do_refine_step) {
if constexpr (std::is_same_v<IdxT, InputIdxT>) {
cuvs::neighbors::refine(res, *dataset, queries, raft::make_const_mdspan(neighbors_tmp), neighbors,
distances, underlying_index.metric());
Expand All @@ -214,8 +223,6 @@ struct cuvs_index {
InputIdxT(distances.extent(1))),
underlying_index.metric());
}
} else {
RAFT_LOG_WARN("Refinement requested, but no dataset provided. Ignoring refinement request.");
}
}
}
Expand Down
20 changes: 10 additions & 10 deletions src/index/gpu_cuvs/gpu_cuvs.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ struct GpuCuvsIndexNode : public IndexNode {
Status
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
auto result = Status::success;
auto raft_cfg = cuvs_knowhere::cuvs_knowhere_config{};
auto cuvs_cfg = cuvs_knowhere::cuvs_knowhere_config{};
try {
raft_cfg = to_cuvs_knowhere_config(static_cast<const knowhere_config_type&>(*cfg));
cuvs_cfg = to_cuvs_knowhere_config(static_cast<const knowhere_config_type&>(*cfg));
} catch (const std::exception& e) {
LOG_KNOWHERE_ERROR_ << e.what();
result = Status::invalid_args;
Expand All @@ -90,9 +90,9 @@ struct GpuCuvsIndexNode : public IndexNode {
if (result == Status::success) {
auto rows = dataset->GetRows();
auto dim = dataset->GetDim();
auto const* data = reinterpret_cast<float const*>(dataset->GetTensor());
auto const* data = reinterpret_cast<const DataType*>(dataset->GetTensor());
try {
index_.train(raft_cfg, data, rows, dim);
index_.train(cuvs_cfg, data, rows, dim);
index_.synchronize(true);
} catch (const std::exception& e) {
LOG_KNOWHERE_ERROR_ << e.what();
Expand All @@ -110,10 +110,10 @@ struct GpuCuvsIndexNode : public IndexNode {
expected<DataSetPtr>
Search(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override {
auto result = Status::success;
auto raft_cfg = cuvs_knowhere::cuvs_knowhere_config{};
auto cuvs_cfg = cuvs_knowhere::cuvs_knowhere_config{};
auto err_msg = std::string{};
try {
raft_cfg = to_cuvs_knowhere_config(static_cast<const knowhere_config_type&>(*cfg));
cuvs_cfg = to_cuvs_knowhere_config(static_cast<const knowhere_config_type&>(*cfg));
} catch (const std::exception& e) {
err_msg = std::string{e.what()};
LOG_KNOWHERE_ERROR_ << e.what();
Expand All @@ -123,12 +123,12 @@ struct GpuCuvsIndexNode : public IndexNode {
try {
auto rows = dataset->GetRows();
auto dim = dataset->GetDim();
auto const* data = reinterpret_cast<float const*>(dataset->GetTensor());
auto const* data = reinterpret_cast<const DataType*>(dataset->GetTensor());
auto search_result =
index_.search(raft_cfg, data, rows, dim, bitset.data(), bitset.byte_size(), bitset.size());
index_.search(cuvs_cfg, data, rows, dim, bitset.data(), bitset.byte_size(), bitset.size());
std::this_thread::yield();
index_.synchronize();
return GenResultDataSet(rows, raft_cfg.k, std::get<0>(search_result), std::get<1>(search_result));
return GenResultDataSet(rows, cuvs_cfg.k, std::get<0>(search_result), std::get<1>(search_result));
} catch (const std::exception& e) {
err_msg = std::string{e.what()};
LOG_KNOWHERE_ERROR_ << e.what();
Expand Down Expand Up @@ -243,7 +243,7 @@ struct GpuCuvsIndexNode : public IndexNode {
}
}

using cuvs_knowhere_index_type = typename cuvs_knowhere::cuvs_knowhere_index<K>;
using cuvs_knowhere_index_type = typename cuvs_knowhere::cuvs_knowhere_index<K, DataType>;

protected:
cuvs_knowhere_index_type index_;
Expand Down
13 changes: 13 additions & 0 deletions src/index/gpu_cuvs/gpu_cuvs_cagra.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,17 @@ KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_CAGRA, GpuCuvsCagraHybridIndexNode
RAFT_CUDA_TRY(cudaGetDeviceCount(&count));
return count * cuda_concurrent_size_per_device;
}());

KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_CUVS_CAGRA, GpuCuvsCagraHybridIndexNode, bin1,
knowhere::feature::GPU | knowhere::feature::BINARY, []() {
int count;
RAFT_CUDA_TRY(cudaGetDeviceCount(&count));
return count * cuda_concurrent_size_per_device;
}());
KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_CAGRA, GpuCuvsCagraHybridIndexNode, bin1,
knowhere::feature::GPU | knowhere::feature::BINARY, []() {
int count;
RAFT_CUDA_TRY(cudaGetDeviceCount(&count));
return count * cuda_concurrent_size_per_device;
}());
} // namespace knowhere
5 changes: 3 additions & 2 deletions src/index/gpu_cuvs/gpu_cuvs_cagra_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ struct GpuCuvsCagraConfig : public BaseConfig {
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
constexpr std::array<std::string_view, 4> legal_metric_list{"L2", "IP", "COSINE", "HAMMING"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
std::string msg = "metric type " + metric + " not found or not supported, supported: [L2 IP COSINE]";
std::string msg =
"metric type " + metric + " not found or not supported, supported: [L2 IP COSINE HAMMING]";
return HandleError(err_msg, msg, Status::invalid_metric_type);
}
}
Expand Down
48 changes: 43 additions & 5 deletions tests/ut/test_gpu_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ TEST_CASE("Test All GPU Index", "[search]") {
return json;
};
};
auto hamming_gen = [](auto&& upstream_gen) {
return [upstream_gen]() {
knowhere::Json json = upstream_gen();
json[knowhere::meta::METRIC_TYPE] = knowhere::metric::HAMMING;
json[knowhere::indexparam::BUILD_ALGO] = "ITERATIVE";
return json;
};
};

SECTION("Test Gpu Index Search") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
Expand Down Expand Up @@ -263,7 +272,7 @@ TEST_CASE("Test All GPU Index", "[search]") {
REQUIRE(recall >= 0.8f);
}

SECTION("Test Gpu Index Search Metric") {
SECTION("Test Gpu Index Search Cosine Metric") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_CUVS_BRUTEFORCE, cosine_gen(bruteforce_gen)),
Expand All @@ -276,21 +285,50 @@ TEST_CASE("Test All GPU Index", "[search]") {
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
auto train_ds = GenDataSet(nb, dim, seed);
auto query_ds = GenDataSet(nq, dim, seed);
auto query_ds = GenDataSet(nq, dim, seed + 1);
REQUIRE(idx.Type() == name);
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.HasRawData(json[knowhere::meta::METRIC_TYPE]) ==
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(name, version, json));
auto results = idx.Search(train_ds, json, nullptr);
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, train_ds, json, nullptr);
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, json, nullptr);
auto ids = results.value()->GetIds();
auto dist = results.value()->GetDistance();
auto gt_ids = gt.value()->GetIds();
auto gt_dist = gt.value()->GetDistance();
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > 0.65f);
}

SECTION("Test Gpu Index Search Hamming Metric") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_CUVS_CAGRA, hamming_gen(cagra_gen))}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::bin1>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
nb = 1500; // Reduce dataset size to have less distance = 0 when testing query distance
auto train_ds = GenBinDataSet(nb, dim, seed);
auto query_ds = GenBinDataSet(nq, dim, seed + 1);
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);
REQUIRE(idx.Count() == nb);
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
auto gt = knowhere::BruteForce::Search<knowhere::bin1>(train_ds, query_ds, json, nullptr);
auto ids = results.value()->GetIds();
auto dist = results.value()->GetDistance();
auto gt_ids = gt.value()->GetIds();
auto gt_dist = gt.value()->GetDistance();
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > 0.8f);
recall = GetKNNRelativeRecall(*gt.value(), *results.value(), true);
REQUIRE(recall > 0.95f);
for (int i = 1; i < nq; ++i) {
CHECK(ids[i] == gt_ids[i]);
// Check query distance
CHECK(GetRelativeLoss(gt_dist[i], dist[i]) < 0.1f);
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ GetKNNRecall(const knowhere::DataSet& ground_truth, const std::vector<std::vecto
// Compare two ann-search results
// "ground_truth" here is just used as a baseline value for comparison. It is not real groundtruth and the knn
// results may be worse, we can call the compare results as "relative-recall".
// when the k-th distance of gt is worth, define the recall as 1.0f
// when the k-th distance of gt is worse, define the recall as 1.0f
// when the k-th distance of gt is better, define the recall as (intersection_count / size)
inline float
GetKNNRelativeRecall(const knowhere::DataSet& ground_truth, const knowhere::DataSet& result, bool dist_less_better) {
Expand Down

0 comments on commit 35bea55

Please sign in to comment.