26
26
#include " knowhere/sparse_utils.h"
27
27
#include " knowhere/utils.h"
28
28
29
+ #ifdef NOT_COMPILE_FOR_SWIG
30
+ #include " knowhere/tracer.h"
31
+ #endif
32
+
29
33
namespace knowhere {
30
34
31
35
/* knowhere wrapper API to call faiss brute force search for all metric types */
@@ -47,13 +51,25 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
47
51
48
52
auto xq = query->GetTensor ();
49
53
auto nq = query->GetRows ();
54
+
50
55
BruteForceConfig cfg;
51
56
std::string msg;
52
57
auto status = Config::Load (cfg, config, knowhere::SEARCH, &msg);
53
58
if (status != Status::success) {
54
59
return expected<DataSetPtr>::Err (status, msg);
55
60
}
56
61
62
+ #ifdef NOT_COMPILE_FOR_SWIG
63
+ std::shared_ptr<tracer::trace::Span> span = nullptr ;
64
+ if (cfg.trace_id .has_value ()) {
65
+ auto ctx = tracer::TraceContext{(uint8_t *)cfg.trace_id .value ().c_str (), (uint8_t *)cfg.span_id .value ().c_str (),
66
+ (uint8_t )cfg.trace_flags .value ()};
67
+ span = tracer::StartSpan (" knowhere bf search" , &ctx);
68
+ span->SetAttribute (meta::METRIC_TYPE, cfg.metric_type .value ());
69
+ span->SetAttribute (meta::TOPK, cfg.k .value ());
70
+ }
71
+ #endif
72
+
57
73
std::string metric_str = cfg.metric_type .value ();
58
74
auto result = Str2FaissMetricType (metric_str);
59
75
if (result.error () != Status::success) {
@@ -133,7 +149,15 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
133
149
if (ret != Status::success) {
134
150
return expected<DataSetPtr>::Err (ret, " failed to brute force search" );
135
151
}
136
- return GenResultDataSet (nq, cfg.k .value (), labels.release (), distances.release ());
152
+ auto res = GenResultDataSet (nq, cfg.k .value (), labels.release (), distances.release ());
153
+
154
+ #ifdef NOT_COMPILE_FOR_SWIG
155
+ if (cfg.trace_id .has_value ()) {
156
+ span->End ();
157
+ }
158
+ #endif
159
+
160
+ return res;
137
161
}
138
162
139
163
template <typename DataType>
@@ -156,6 +180,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
156
180
BruteForceConfig cfg;
157
181
RETURN_IF_ERROR (Config::Load (cfg, config, knowhere::SEARCH));
158
182
183
+ #ifdef NOT_COMPILE_FOR_SWIG
184
+ std::shared_ptr<tracer::trace::Span> span = nullptr ;
185
+ if (cfg.trace_id .has_value ()) {
186
+ auto ctx = tracer::TraceContext{(uint8_t *)cfg.trace_id .value ().c_str (), (uint8_t *)cfg.span_id .value ().c_str (),
187
+ (uint8_t )cfg.trace_flags .value ()};
188
+ span = tracer::StartSpan (" knowhere bf search with buf" , &ctx);
189
+ span->SetAttribute (meta::METRIC_TYPE, cfg.metric_type .value ());
190
+ span->SetAttribute (meta::TOPK, cfg.k .value ());
191
+ }
192
+ #endif
193
+
159
194
std::string metric_str = cfg.metric_type .value ();
160
195
auto result = Str2FaissMetricType (cfg.metric_type .value ());
161
196
if (result.error () != Status::success) {
@@ -232,6 +267,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
232
267
}));
233
268
}
234
269
RETURN_IF_ERROR (WaitAllSuccess (futs));
270
+
271
+ #ifdef NOT_COMPILE_FOR_SWIG
272
+ if (cfg.trace_id .has_value ()) {
273
+ span->End ();
274
+ }
275
+ #endif
276
+
235
277
return Status::success;
236
278
}
237
279
@@ -261,6 +303,21 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
261
303
return expected<DataSetPtr>::Err (status, std::move (msg));
262
304
}
263
305
306
+ #ifdef NOT_COMPILE_FOR_SWIG
307
+ std::shared_ptr<tracer::trace::Span> span = nullptr ;
308
+ if (cfg.trace_id .has_value ()) {
309
+ auto ctx = tracer::TraceContext{(uint8_t *)cfg.trace_id .value ().c_str (), (uint8_t *)cfg.span_id .value ().c_str (),
310
+ (uint8_t )cfg.trace_flags .value ()};
311
+ span = tracer::StartSpan (" knowhere bf range search" , &ctx);
312
+ span->SetAttribute (meta::METRIC_TYPE, cfg.metric_type .value ());
313
+ span->SetAttribute (meta::METRIC_TYPE, cfg.metric_type .value ());
314
+ span->SetAttribute (meta::RADIUS, cfg.radius .value ());
315
+ if (cfg.range_filter .value () != defaultRangeFilter) {
316
+ span->SetAttribute (meta::RANGE_FILTER, cfg.range_filter .value ());
317
+ }
318
+ }
319
+ #endif
320
+
264
321
std::string metric_str = cfg.metric_type .value ();
265
322
auto result = Str2FaissMetricType (metric_str);
266
323
if (result.error () != Status::success) {
@@ -351,7 +408,15 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
351
408
float * distances = nullptr ;
352
409
size_t * lims = nullptr ;
353
410
GetRangeSearchResult (result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims);
354
- return GenResultDataSet (nq, ids, distances, lims);
411
+ auto res = GenResultDataSet (nq, ids, distances, lims);
412
+
413
+ #ifdef NOT_COMPILE_FOR_SWIG
414
+ if (cfg.trace_id .has_value ()) {
415
+ span->End ();
416
+ }
417
+ #endif
418
+
419
+ return res;
355
420
}
356
421
357
422
Status
@@ -430,12 +495,32 @@ BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_d
430
495
if (status != Status::success) {
431
496
return expected<DataSetPtr>::Err (status, msg);
432
497
}
498
+
499
+ #ifdef NOT_COMPILE_FOR_SWIG
500
+ std::shared_ptr<tracer::trace::Span> span = nullptr ;
501
+ if (cfg.trace_id .has_value ()) {
502
+ auto ctx = tracer::TraceContext{(uint8_t *)cfg.trace_id .value ().c_str (), (uint8_t *)cfg.span_id .value ().c_str (),
503
+ (uint8_t )cfg.trace_flags .value ()};
504
+ span = tracer::StartSpan (" knowhere bf search with buf" , &ctx);
505
+ span->SetAttribute (meta::METRIC_TYPE, cfg.metric_type .value ());
506
+ span->SetAttribute (meta::TOPK, cfg.k .value ());
507
+ }
508
+ #endif
509
+
433
510
int topk = cfg.k .value ();
434
511
auto labels = std::make_unique<sparse::label_t []>(nq * topk);
435
512
auto distances = std::make_unique<float []>(nq * topk);
436
513
437
514
SearchSparseWithBuf (base_dataset, query_dataset, labels.get (), distances.get (), config, bitset);
438
- return GenResultDataSet (nq, topk, labels.release (), distances.release ());
515
+ auto res = GenResultDataSet (nq, topk, labels.release (), distances.release ());
516
+
517
+ #ifdef NOT_COMPILE_FOR_SWIG
518
+ if (cfg.trace_id .has_value ()) {
519
+ span->End ();
520
+ }
521
+ #endif
522
+
523
+ return res;
439
524
}
440
525
441
526
} // namespace knowhere
0 commit comments