Skip to content

Commit 5699751

Browse files
authored
Add trace span in bruteforce search (zilliztech#370)
Signed-off-by: Yudong Cai <yudong.cai@zilliz.com>
1 parent e9574eb commit 5699751

File tree

2 files changed

+93
-6
lines changed

2 files changed

+93
-6
lines changed

src/common/comp/brute_force.cc

+88-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
#include "knowhere/sparse_utils.h"
2727
#include "knowhere/utils.h"
2828

29+
#ifdef NOT_COMPILE_FOR_SWIG
30+
#include "knowhere/tracer.h"
31+
#endif
32+
2933
namespace knowhere {
3034

3135
/* knowhere wrapper API to call faiss brute force search for all metric types */
@@ -47,13 +51,25 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
4751

4852
auto xq = query->GetTensor();
4953
auto nq = query->GetRows();
54+
5055
BruteForceConfig cfg;
5156
std::string msg;
5257
auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg);
5358
if (status != Status::success) {
5459
return expected<DataSetPtr>::Err(status, msg);
5560
}
5661

62+
#ifdef NOT_COMPILE_FOR_SWIG
63+
std::shared_ptr<tracer::trace::Span> span = nullptr;
64+
if (cfg.trace_id.has_value()) {
65+
auto ctx = tracer::TraceContext{(uint8_t*)cfg.trace_id.value().c_str(), (uint8_t*)cfg.span_id.value().c_str(),
66+
(uint8_t)cfg.trace_flags.value()};
67+
span = tracer::StartSpan("knowhere bf search", &ctx);
68+
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
69+
span->SetAttribute(meta::TOPK, cfg.k.value());
70+
}
71+
#endif
72+
5773
std::string metric_str = cfg.metric_type.value();
5874
auto result = Str2FaissMetricType(metric_str);
5975
if (result.error() != Status::success) {
@@ -133,7 +149,15 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
133149
if (ret != Status::success) {
134150
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
135151
}
136-
return GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release());
152+
auto res = GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release());
153+
154+
#ifdef NOT_COMPILE_FOR_SWIG
155+
if (cfg.trace_id.has_value()) {
156+
span->End();
157+
}
158+
#endif
159+
160+
return res;
137161
}
138162

139163
template <typename DataType>
@@ -156,6 +180,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
156180
BruteForceConfig cfg;
157181
RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH));
158182

183+
#ifdef NOT_COMPILE_FOR_SWIG
184+
std::shared_ptr<tracer::trace::Span> span = nullptr;
185+
if (cfg.trace_id.has_value()) {
186+
auto ctx = tracer::TraceContext{(uint8_t*)cfg.trace_id.value().c_str(), (uint8_t*)cfg.span_id.value().c_str(),
187+
(uint8_t)cfg.trace_flags.value()};
188+
span = tracer::StartSpan("knowhere bf search with buf", &ctx);
189+
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
190+
span->SetAttribute(meta::TOPK, cfg.k.value());
191+
}
192+
#endif
193+
159194
std::string metric_str = cfg.metric_type.value();
160195
auto result = Str2FaissMetricType(cfg.metric_type.value());
161196
if (result.error() != Status::success) {
@@ -232,6 +267,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
232267
}));
233268
}
234269
RETURN_IF_ERROR(WaitAllSuccess(futs));
270+
271+
#ifdef NOT_COMPILE_FOR_SWIG
272+
if (cfg.trace_id.has_value()) {
273+
span->End();
274+
}
275+
#endif
276+
235277
return Status::success;
236278
}
237279

@@ -261,6 +303,21 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
261303
return expected<DataSetPtr>::Err(status, std::move(msg));
262304
}
263305

306+
#ifdef NOT_COMPILE_FOR_SWIG
307+
std::shared_ptr<tracer::trace::Span> span = nullptr;
308+
if (cfg.trace_id.has_value()) {
309+
auto ctx = tracer::TraceContext{(uint8_t*)cfg.trace_id.value().c_str(), (uint8_t*)cfg.span_id.value().c_str(),
310+
(uint8_t)cfg.trace_flags.value()};
311+
span = tracer::StartSpan("knowhere bf range search", &ctx);
312+
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
313+
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
314+
span->SetAttribute(meta::RADIUS, cfg.radius.value());
315+
if (cfg.range_filter.value() != defaultRangeFilter) {
316+
span->SetAttribute(meta::RANGE_FILTER, cfg.range_filter.value());
317+
}
318+
}
319+
#endif
320+
264321
std::string metric_str = cfg.metric_type.value();
265322
auto result = Str2FaissMetricType(metric_str);
266323
if (result.error() != Status::success) {
@@ -351,7 +408,15 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
351408
float* distances = nullptr;
352409
size_t* lims = nullptr;
353410
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims);
354-
return GenResultDataSet(nq, ids, distances, lims);
411+
auto res = GenResultDataSet(nq, ids, distances, lims);
412+
413+
#ifdef NOT_COMPILE_FOR_SWIG
414+
if (cfg.trace_id.has_value()) {
415+
span->End();
416+
}
417+
#endif
418+
419+
return res;
355420
}
356421

357422
Status
@@ -430,12 +495,32 @@ BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_d
430495
if (status != Status::success) {
431496
return expected<DataSetPtr>::Err(status, msg);
432497
}
498+
499+
#ifdef NOT_COMPILE_FOR_SWIG
500+
std::shared_ptr<tracer::trace::Span> span = nullptr;
501+
if (cfg.trace_id.has_value()) {
502+
auto ctx = tracer::TraceContext{(uint8_t*)cfg.trace_id.value().c_str(), (uint8_t*)cfg.span_id.value().c_str(),
503+
(uint8_t)cfg.trace_flags.value()};
504+
span = tracer::StartSpan("knowhere bf search with buf", &ctx);
505+
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
506+
span->SetAttribute(meta::TOPK, cfg.k.value());
507+
}
508+
#endif
509+
433510
int topk = cfg.k.value();
434511
auto labels = std::make_unique<sparse::label_t[]>(nq * topk);
435512
auto distances = std::make_unique<float[]>(nq * topk);
436513

437514
SearchSparseWithBuf(base_dataset, query_dataset, labels.get(), distances.get(), config, bitset);
438-
return GenResultDataSet(nq, topk, labels.release(), distances.release());
515+
auto res = GenResultDataSet(nq, topk, labels.release(), distances.release());
516+
517+
#ifdef NOT_COMPILE_FOR_SWIG
518+
if (cfg.trace_id.has_value()) {
519+
span->End();
520+
}
521+
#endif
522+
523+
return res;
439524
}
440525

441526
} // namespace knowhere

src/common/tracer.cc

+5-3
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,11 @@ AddEvent(const std::string& event_label) {
120120

121121
bool
122122
isEmptyID(const uint8_t* id, int length) {
123-
for (int i = 0; i < length; i++) {
124-
if (id[i] != 0) {
125-
return false;
123+
if (id != nullptr) {
124+
for (int i = 0; i < length; i++) {
125+
if (id[i] != 0) {
126+
return false;
127+
}
126128
}
127129
}
128130
return true;

0 commit comments

Comments
 (0)