Skip to content

Commit

Permalink
change the implementation of the pyramid from being index-based to da…
Browse files Browse the repository at this point in the history
…tacell-based

Signed-off-by: jinjiabao.jjb <jinjiabao.jjb@antgroup.com>
  • Loading branch information
jinjiabao.jjb committed Feb 2, 2025
1 parent 11f89df commit 29c0334
Show file tree
Hide file tree
Showing 16 changed files with 610 additions and 390 deletions.
11 changes: 6 additions & 5 deletions include/vsag/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ extern const char* const DISKANN_PARAMETER_USE_OPQ;
extern const char* const DISKANN_PARAMETER_USE_ASYNC_IO;
extern const char* const DISKANN_PARAMETER_USE_BSA;
extern const char* const DISKANN_PARAMETER_GRAPH_TYPE;
extern const char* const DISKANN_PARAMETER_ALPHA;
extern const char* const DISKANN_PARAMETER_GRAPH_ITER_TURN;
extern const char* const DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE;
extern const char* const ODESCENT_PARAMETER_ALPHA;
extern const char* const ODESCENT_PARAMETER_GRAPH_ITER_TURN;
extern const char* const ODESCENT_PARAMETER_NEIGHBOR_SAMPLE_RATE;
extern const char* const DISKANN_GRAPH_TYPE_VAMANA;
extern const char* const DISKANN_GRAPH_TYPE_ODESCENT;
extern const char* const GRAPH_TYPE_ODESCENT;

extern const char* const DISKANN_PARAMETER_BEAM_SEARCH;
extern const char* const DISKANN_PARAMETER_IO_LIMIT;
Expand All @@ -79,9 +79,10 @@ extern const char* const HNSW_PARAMETER_CONSTRUCTION;
extern const char* const HNSW_PARAMETER_USE_STATIC;
extern const char* const HNSW_PARAMETER_REVERSED_EDGES;

extern const char* const PYRAMID_PARAMETER_BASE_CODES;

extern const char* const INDEX_PARAM;

extern const char* const PYRAMID_PARAMETER_SUBINDEX_TYPE;
extern const char PART_SLASH;

// statstic key
Expand Down
11 changes: 5 additions & 6 deletions src/constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,21 @@ const char* const DISKANN_PARAMETER_IO_LIMIT = "io_limit";
const char* const DISKANN_PARAMETER_EF_SEARCH = "ef_search";
const char* const DISKANN_PARAMETER_REORDER = "use_reorder";
const char* const DISKANN_PARAMETER_GRAPH_TYPE = "graph_type";
const char* const DISKANN_PARAMETER_ALPHA = "alpha";
const char* const DISKANN_PARAMETER_GRAPH_ITER_TURN = "graph_iter_turn";
const char* const DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE = "neighbor_sample_rate";
const char* const ODESCENT_PARAMETER_ALPHA = "alpha";
const char* const ODESCENT_PARAMETER_GRAPH_ITER_TURN = "graph_iter_turn";
const char* const ODESCENT_PARAMETER_NEIGHBOR_SAMPLE_RATE = "neighbor_sample_rate";

const char* const DISKANN_GRAPH_TYPE_VAMANA = "vamana";
const char* const DISKANN_GRAPH_TYPE_ODESCENT = "odescent";
const char* const GRAPH_TYPE_ODESCENT = "odescent";

const char* const HNSW_PARAMETER_EF_RUNTIME = "ef_search";
const char* const HNSW_PARAMETER_M = "max_degree";
const char* const HNSW_PARAMETER_CONSTRUCTION = "ef_construction";
const char* const HNSW_PARAMETER_USE_STATIC = "use_static";
const char* const HNSW_PARAMETER_REVERSED_EDGES = "use_reversed_edges";

const char* const PYRAMID_PARAMETER_BASE_CODES = "base_codes";
const char* const INDEX_PARAM = "index_param";

const char* const PYRAMID_PARAMETER_SUBINDEX_TYPE = "sub_index_type";
const char PART_SLASH = '/';

// statstic key
Expand Down
4 changes: 2 additions & 2 deletions src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ Engine::CreateIndex(const std::string& origin_name, const std::string& parameter
CHECK_ARGUMENT(parsed_params.contains(INDEX_PARAM),
fmt::format("parameters must contains {}", INDEX_PARAM));
auto& pyramid_param_obj = parsed_params[INDEX_PARAM];
auto pyramid_params =
PyramidParameters::FromJson(pyramid_param_obj, index_common_params);
PyramidParameters pyramid_params;
pyramid_params.FromJson(pyramid_param_obj);
logger::debug("created a pyramid index");
return std::make_shared<Pyramid>(pyramid_params, index_common_params);
} else {
Expand Down
11 changes: 8 additions & 3 deletions src/impl/basic_searcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,14 @@ BasicSearcher::Search(const GraphInterfacePtr& graph_data_cell,
}
}
}

while (top_candidates.size() > inner_search_param.topk_) {
top_candidates.pop();
if (inner_search_param.topk_ > 0) {
while (top_candidates.size() > inner_search_param.topk_) {
top_candidates.pop();
}
} else if (inner_search_param.radius_ > 0) {
while (top_candidates.top().first > inner_search_param.radius_ + 2e-6) {
top_candidates.pop();
}
}

return top_candidates;
Expand Down
4 changes: 2 additions & 2 deletions src/impl/basic_searcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace vsag {

class InnerSearchParam {
public:
int topk_{0};
float radius_{0.0f};
int64_t topk_{0};
float radius_{-1.0f};
InnerIdType ep_{0};
uint64_t ef_{10};
BaseFilterFunctor* is_id_allowed_{nullptr};
Expand Down
7 changes: 2 additions & 5 deletions src/impl/odescent_graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@ class LinearCongruentialGenerator {
};

bool
ODescent::Build(const uint32_t* valid_ids, int64_t data_num) {
if (is_build_) {
return false;
}
is_build_ = true;
ODescent::Build(const uint32_t* valid_ids, uint64_t data_num) {
valid_ids_ = valid_ids;
if (valid_ids_ != nullptr) {
data_num_ = data_num;
Expand Down Expand Up @@ -125,6 +121,7 @@ ODescent::init_graph() {
for (int64_t i = start; i < end; ++i) {
UnorderedSet<uint32_t> ids_set(allocator_);
ids_set.insert(i);
graph[i].neighbors.clear();
graph[i].neighbors.reserve(max_degree_);
int64_t max_neighbors = std::min(data_num_ - 1, max_degree_);
for (int j = 0; j < max_neighbors; ++j) {
Expand Down
4 changes: 2 additions & 2 deletions src/impl/odescent_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct Linklist {
};
class ODescent {
public:
ODescent(int64_t max_degree,
ODescent(uint64_t max_degree,
float alpha,
int64_t turn,
float sample_rate,
Expand All @@ -94,7 +94,7 @@ class ODescent {
}

bool
Build(const uint32_t* valid_ids = nullptr, int64_t data_num = 0);
Build(const uint32_t* valid_ids = nullptr, uint64_t data_num = 0);

void
SaveGraph(std::stringstream& out);
Expand Down
2 changes: 1 addition & 1 deletion src/index/diskann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ DiskANN::build(const DatasetPtr& base) {
fmt::format("base.num_elements({}) must be greater than 1", data_num));

std::vector<size_t> failed_locs;
if (diskann_params_.graph_type == DISKANN_GRAPH_TYPE_ODESCENT) {
if (diskann_params_.graph_type == GRAPH_TYPE_ODESCENT) {
SlowTaskTimer t("odescent build full (graph)");
FlattenDataCellParamPtr flatten_param =
std::make_shared<vsag::FlattenDataCellParameter>();
Expand Down
22 changes: 11 additions & 11 deletions src/index/diskann_zparameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,36 +110,36 @@ DiskannParameters::FromJson(JsonType& diskann_param_obj, IndexCommonParam index_
fmt::format("ef_construction({}) must in range[$max_degree({}), 64]",
obj.ef_construction,
obj.max_degree));
} else if (obj.graph_type == DISKANN_GRAPH_TYPE_ODESCENT) {
} else if (obj.graph_type == GRAPH_TYPE_ODESCENT) {
// set obj.alpha
if (diskann_param_obj.contains(DISKANN_PARAMETER_ALPHA)) {
obj.alpha = diskann_param_obj[DISKANN_PARAMETER_ALPHA];
if (diskann_param_obj.contains(ODESCENT_PARAMETER_ALPHA)) {
obj.alpha = diskann_param_obj[ODESCENT_PARAMETER_ALPHA];
CHECK_ARGUMENT(
(obj.alpha >= 1.0 && obj.alpha <= 2.0),
fmt::format(
"{} must in range[1.0, 2.0], now is {}", DISKANN_PARAMETER_ALPHA, obj.alpha));
"{} must in range[1.0, 2.0], now is {}", ODESCENT_PARAMETER_ALPHA, obj.alpha));
}
// set obj.turn
if (diskann_param_obj.contains(DISKANN_PARAMETER_GRAPH_ITER_TURN)) {
obj.turn = diskann_param_obj[DISKANN_PARAMETER_GRAPH_ITER_TURN];
if (diskann_param_obj.contains(ODESCENT_PARAMETER_GRAPH_ITER_TURN)) {
obj.turn = diskann_param_obj[ODESCENT_PARAMETER_GRAPH_ITER_TURN];
CHECK_ARGUMENT((obj.turn > 0),
fmt::format("{} must be greater than 0, now is {}",
DISKANN_PARAMETER_GRAPH_ITER_TURN,
ODESCENT_PARAMETER_GRAPH_ITER_TURN,
obj.turn));
}
// set obj.sample_rate
if (diskann_param_obj.contains(DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE)) {
obj.sample_rate = diskann_param_obj[DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE];
if (diskann_param_obj.contains(ODESCENT_PARAMETER_NEIGHBOR_SAMPLE_RATE)) {
obj.sample_rate = diskann_param_obj[ODESCENT_PARAMETER_NEIGHBOR_SAMPLE_RATE];
CHECK_ARGUMENT((obj.sample_rate > 0.05 && obj.sample_rate < 0.5),
fmt::format("{} must in range[0.05, 0.5], now is {}",
DISKANN_PARAMETER_NEIGHBOR_SAMPLE_RATE,
ODESCENT_PARAMETER_NEIGHBOR_SAMPLE_RATE,
obj.sample_rate));
}
} else {
throw std::invalid_argument(fmt::format("parameters[{}] must in [{}, {}], now is {}",
DISKANN_PARAMETER_GRAPH_TYPE,
DISKANN_GRAPH_TYPE_VAMANA,
DISKANN_GRAPH_TYPE_ODESCENT,
GRAPH_TYPE_ODESCENT,
obj.graph_type));
}
return obj;
Expand Down
Loading

0 comments on commit 29c0334

Please sign in to comment.