From 809e0ca7c22e6e2ea4b85fc07de7dd94caa50129 Mon Sep 17 00:00:00 2001 From: AlexanderSaydakov Date: Thu, 22 Feb 2024 15:26:53 -0800 Subject: [PATCH] use uint32_t weight for tdigest --- tdigest/include/tdigest.hpp | 11 +++++++---- tdigest/include/tdigest_impl.hpp | 8 ++++---- tdigest/test/tdigest_serialize_for_java.cpp | 10 ++++++++++ 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tdigest/include/tdigest.hpp b/tdigest/include/tdigest.hpp index 0b7c54a3..b86e4eca 100644 --- a/tdigest/include/tdigest.hpp +++ b/tdigest/include/tdigest.hpp @@ -74,7 +74,8 @@ using tdigest_double = tdigest; */ template class tdigest { - static_assert(std::is_floating_point::value, "Floating-point type expected"); + // exclude long double by not using std::is_floating_point + static_assert(std::is_same::value || std::is_same::value, "Either double or float type expected"); static_assert(std::numeric_limits::is_iec559, "IEEE 754 compatibility required"); public: using value_type = T; @@ -84,18 +85,20 @@ class tdigest { static const bool USE_TWO_LEVEL_COMPRESSION = true; static const bool USE_WEIGHT_LIMIT = true; + using W = typename std::conditional::value, uint64_t, uint32_t>::type; + class centroid { public: - centroid(T value, uint64_t weight): mean_(value), weight_(weight) {} + centroid(T value, W weight): mean_(value), weight_(weight) {} void add(const centroid& other) { weight_ += other.weight_; mean_ += (other.mean_ - mean_) * other.weight_ / weight_; } T get_mean() const { return mean_; } - uint64_t get_weight() const { return weight_; } + T get_weight() const { return weight_; } private: T mean_; - uint64_t weight_; + W weight_; }; using vector_centroid = std::vector::template rebind_alloc>; using vector_bytes = std::vector::template rebind_alloc>; diff --git a/tdigest/include/tdigest_impl.hpp b/tdigest/include/tdigest_impl.hpp index 7656e60e..cdde7332 100644 --- a/tdigest/include/tdigest_impl.hpp +++ b/tdigest/include/tdigest_impl.hpp @@ -445,7 +445,7 @@ tdigest tdigest::deserialize_compat(std::istream& is, const A& alloc vector_centroid centroids(num_centroids, centroid(0, 0), allocator); uint64_t total_weight = 0; for (auto& c: centroids) { - const uint64_t weight = static_cast(read_big_endian(is)); + const W weight = static_cast(read_big_endian(is)); const auto mean = read_big_endian(is); c = centroid(mean, weight); total_weight += weight; @@ -463,7 +463,7 @@ tdigest tdigest::deserialize_compat(std::istream& is, const A& alloc vector_centroid centroids(num_centroids, centroid(0, 0), allocator); uint64_t total_weight = 0; for (auto& c: centroids) { - const uint64_t weight = static_cast(read_big_endian(is)); + const W weight = static_cast(read_big_endian(is)); const auto mean = read_big_endian(is); c = centroid(mean, weight); total_weight += weight; @@ -507,7 +507,7 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size, double mean; ptr += copy_from_mem(ptr, mean); mean = byteswap(mean); - c = centroid(mean, static_cast(weight)); + c = centroid(mean, static_cast(weight)); total_weight += static_cast(weight); } return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator); @@ -539,7 +539,7 @@ tdigest tdigest::deserialize_compat(const void* bytes, size_t size, float mean; ptr += copy_from_mem(ptr, mean); mean = byteswap(mean); - c = centroid(mean, static_cast(weight)); + c = centroid(mean, static_cast(weight)); total_weight += static_cast(weight); } return tdigest(false, k, min, max, std::move(centroids), total_weight, allocator); diff --git a/tdigest/test/tdigest_serialize_for_java.cpp b/tdigest/test/tdigest_serialize_for_java.cpp index ed0e4f8c..1f3c1fb1 100644 --- a/tdigest/test/tdigest_serialize_for_java.cpp +++ b/tdigest/test/tdigest_serialize_for_java.cpp @@ -34,4 +34,14 @@ TEST_CASE("tdigest double generate", "[serialize_for_java]") { } } +TEST_CASE("tdigest float generate", "[serialize_for_java]") { + const unsigned n_arr[] = {0, 1, 10, 100, 1000, 10000, 100000, 1000000}; + for (const unsigned n: n_arr) { + tdigest_float td(100); + for (unsigned i = 1; i <= n; ++i) td.update(i); + std::ofstream os("tdigest_float_n" + std::to_string(n) + "_cpp.sk", std::ios::binary); + td.serialize(os); + } +} + } /* namespace datasketches */