Skip to content

Commit

Permalink
use uint32_t weight for tdigest<float>
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSaydakov committed Feb 22, 2024
1 parent 7f0c235 commit 809e0ca
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
11 changes: 7 additions & 4 deletions tdigest/include/tdigest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ using tdigest_double = tdigest<double>;
*/
template <typename T, typename Allocator>
class tdigest {
static_assert(std::is_floating_point<T>::value, "Floating-point type expected");
// exclude long double by not using std::is_floating_point
static_assert(std::is_same<T, double>::value || std::is_same<T, float>::value, "Either double or float type expected");
static_assert(std::numeric_limits<T>::is_iec559, "IEEE 754 compatibility required");
public:
using value_type = T;
Expand All @@ -84,18 +85,20 @@ class tdigest {
static const bool USE_TWO_LEVEL_COMPRESSION = true;
static const bool USE_WEIGHT_LIMIT = true;

using W = typename std::conditional<std::is_same<T, double>::value, uint64_t, uint32_t>::type;

class centroid {
public:
centroid(T value, uint64_t weight): mean_(value), weight_(weight) {}
centroid(T value, W weight): mean_(value), weight_(weight) {}
void add(const centroid& other) {
weight_ += other.weight_;
mean_ += (other.mean_ - mean_) * other.weight_ / weight_;
}
T get_mean() const { return mean_; }
uint64_t get_weight() const { return weight_; }
T get_weight() const { return weight_; }
private:
T mean_;
uint64_t weight_;
W weight_;
};
using vector_centroid = std::vector<centroid, typename std::allocator_traits<Allocator>::template rebind_alloc<centroid>>;
using vector_bytes = std::vector<uint8_t, typename std::allocator_traits<Allocator>::template rebind_alloc<uint8_t>>;
Expand Down
8 changes: 4 additions & 4 deletions tdigest/include/tdigest_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
const uint64_t weight = static_cast<uint64_t>(read_big_endian<double>(is));
const W weight = static_cast<W>(read_big_endian<double>(is));
const auto mean = read_big_endian<double>(is);
c = centroid(mean, weight);
total_weight += weight;
Expand All @@ -463,7 +463,7 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& alloc
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
uint64_t total_weight = 0;
for (auto& c: centroids) {
const uint64_t weight = static_cast<uint64_t>(read_big_endian<float>(is));
const W weight = static_cast<W>(read_big_endian<float>(is));
const auto mean = read_big_endian<float>(is);
c = centroid(mean, weight);
total_weight += weight;
Expand Down Expand Up @@ -507,7 +507,7 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
double mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
c = centroid(mean, static_cast<uint64_t>(weight));
c = centroid(mean, static_cast<W>(weight));
total_weight += static_cast<uint64_t>(weight);
}
return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator);
Expand Down Expand Up @@ -539,7 +539,7 @@ tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t size,
float mean;
ptr += copy_from_mem(ptr, mean);
mean = byteswap(mean);
c = centroid(mean, static_cast<uint64_t>(weight));
c = centroid(mean, static_cast<W>(weight));
total_weight += static_cast<uint64_t>(weight);
}
return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator);
Expand Down
10 changes: 10 additions & 0 deletions tdigest/test/tdigest_serialize_for_java.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,14 @@ TEST_CASE("tdigest double generate", "[serialize_for_java]") {
}
}

TEST_CASE("tdigest float generate", "[serialize_for_java]") {
const unsigned n_arr[] = {0, 1, 10, 100, 1000, 10000, 100000, 1000000};
for (const unsigned n: n_arr) {
tdigest_float td(100);
for (unsigned i = 1; i <= n; ++i) td.update(i);
std::ofstream os("tdigest_float_n" + std::to_string(n) + "_cpp.sk", std::ios::binary);
td.serialize(os);
}
}

} /* namespace datasketches */

0 comments on commit 809e0ca

Please sign in to comment.