diff --git a/tdigest/include/tdigest.hpp b/tdigest/include/tdigest.hpp index 70c737ee..357f203b 100644 --- a/tdigest/include/tdigest.hpp +++ b/tdigest/include/tdigest.hpp @@ -96,6 +96,7 @@ class tdigest { uint64_t weight_; }; using vector_centroid = std::vector::template rebind_alloc>; + using vector_bytes = std::vector::template rebind_alloc>; struct centroid_cmp { centroid_cmp(bool reverse): reverse_(reverse) {} @@ -176,11 +177,43 @@ class tdigest { */ string to_string(bool print_centroids = false) const; + /** + * This method serializes t-Digest into a given stream in a binary form + * @param os output stream + */ + void serialize(std::ostream& os) const; + + /** + * This method serializes t-Digest as a vector of bytes. + * An optional header can be reserved in front of the sketch. + * It is an uninitialized space of a given size. + * @param header_size_bytes space to reserve in front of the sketch + * @return serialized sketch as a vector of bytes + */ + vector_bytes serialize(unsigned header_size_bytes = 0) const; + + /** + * This method deserializes t-Digest from a given stream. + * @param is input stream + * @param allocator instance of an Allocator + * @return an instance of t-Digest + */ + static tdigest deserialize(std::istream& is, const Allocator& allocator = Allocator()); + + /** + * This method deserializes t-Digest from a given array of bytes. + * @param bytes pointer to the array of bytes + * @param size the size of the array + * @param allocator instance of an Allocator + * @return an instance of t-Digest + */ + static tdigest deserialize(const void* bytes, size_t size, const Allocator& allocator = Allocator()); + private: Allocator allocator_; + bool reverse_merge_; uint16_t k_; uint16_t internal_k_; - uint32_t merge_count_; T min_; T max_; size_t centroids_capacity_; @@ -190,6 +223,16 @@ class tdigest { vector_centroid buffer_; uint64_t buffered_weight_; + static const uint8_t PREAMBLE_LONGS_EMPTY = 1; + static const uint8_t PREAMBLE_LONGS_NON_EMPTY = 2; + static const uint8_t SERIAL_VERSION = 1; + static const uint8_t SKETCH_TYPE = 20; + + enum flags { IS_EMPTY, REVERSE_MERGE }; + + // for deserialize + tdigest(bool reverse_merge, uint16_t k, T min, T max, vector_centroid&& centroids, uint64_t total_weight_, const Allocator& allocator); + void merge_new_values(); void merge_new_values(bool force, uint16_t k); void merge_new_values(uint16_t k); diff --git a/tdigest/include/tdigest_impl.hpp b/tdigest/include/tdigest_impl.hpp index aeb6157a..c7128ce9 100644 --- a/tdigest/include/tdigest_impl.hpp +++ b/tdigest/include/tdigest_impl.hpp @@ -23,46 +23,20 @@ #include #include +#include "common_defs.hpp" +#include "memory_operations.hpp" + namespace datasketches { template tdigest::tdigest(uint16_t k, const A& allocator): -allocator_(allocator), -k_(k), -internal_k_(k), -merge_count_(0), -min_(std::numeric_limits::infinity()), -max_(-std::numeric_limits::infinity()), -centroids_capacity_(0), -centroids_(allocator), -total_weight_(0), -buffer_capacity_(0), -buffer_(allocator), -buffered_weight_(0) -{ - if (k < 10) throw std::invalid_argument("k must be at least 10"); - size_t fudge = 0; - if (USE_WEIGHT_LIMIT) { - fudge = 10; - if (k < 30) fudge +=20; - } - centroids_capacity_ = 2 * k_ + fudge; - buffer_capacity_ = 5 * centroids_capacity_; - double scale = std::max(1.0, static_cast(buffer_capacity_) / centroids_capacity_ - 1.0); - if (!USE_TWO_LEVEL_COMPRESSION) scale = 1; - internal_k_ = std::ceil(std::sqrt(scale) * k_); - if (centroids_capacity_ < internal_k_ + fudge) { - centroids_capacity_ = internal_k_ + fudge; - } - if (buffer_capacity_ < 2 * centroids_capacity_) buffer_capacity_ = 2 * centroids_capacity_; - centroids_.reserve(centroids_capacity_); - buffer_.reserve(buffer_capacity_); -} +tdigest(false, k, std::numeric_limits::infinity(), -std::numeric_limits::infinity(), vector_centroid(allocator), 0, allocator) +{} template void tdigest::update(T value) { // check for NaN - if (buffer_.size() >= buffer_capacity_ - centroids_.size() - 1) merge_new_values(); // - 1 for compatibility with Java + if (buffer_.size() >= buffer_capacity_ - centroids_.size()) merge_new_values(); buffer_.push_back(centroid(value, 1)); ++buffered_weight_; min_ = std::min(min_, value); @@ -237,7 +211,6 @@ string tdigest::to_string(bool print_centroids) const { os << " Buffer capacity : " << buffer_capacity_ << std::endl; os << " Total Weight : " << total_weight_ << std::endl; os << " Buffered Weight : " << buffered_weight_ << std::endl; - os << " Merge count : " << merge_count_ << std::endl; if (!is_empty()) { os << " Min : " << min_ << std::endl; os << " Max : " << max_ << std::endl; @@ -267,7 +240,7 @@ void tdigest::merge_new_values(bool force, uint16_t k) { template void tdigest::merge_new_values(uint16_t k) { - const bool reverse = USE_ALTERNATING_SORT & (merge_count_ & 1); + const bool reverse = USE_ALTERNATING_SORT & reverse_merge_; for (const auto& centroid: centroids_) buffer_.push_back(centroid); centroids_.clear(); std::stable_sort(buffer_.begin(), buffer_.end(), centroid_cmp(reverse)); @@ -310,7 +283,7 @@ void tdigest::merge_new_values(uint16_t k) { min_ = std::min(min_, centroids_.front().get_mean()); max_ = std::max(max_, centroids_.back().get_mean()); } - ++merge_count_; + reverse_merge_ = !reverse_merge_; buffer_.clear(); buffered_weight_ = 0; } @@ -320,6 +293,173 @@ double tdigest::weighted_average(double x1, double w1, double x2, double w return (x1 * w1 + x2 * w2) / (w1 + w2); } +template +void tdigest::serialize(std::ostream& os) const { + const_cast(this)->merge_new_values(); // side effect + write(os, is_empty() ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_NON_EMPTY); + write(os, SERIAL_VERSION); + write(os, SKETCH_TYPE); + write(os, k_); + const uint8_t flags_byte( + (is_empty() ? 1 << flags::IS_EMPTY : 0) | + (reverse_merge_ ? 1 << flags::REVERSE_MERGE : 0) + ); + write(os, flags_byte); + write(os, 0); // unused + + if (is_empty()) return; + + write(os, static_cast(centroids_.size())); + write(os, 0); // unused + + write(os, min_); + write(os, max_); + write(os, centroids_.data(), centroids_.size() * sizeof(centroid)); +} + +template +auto tdigest::serialize(unsigned header_size_bytes) const -> vector_bytes { + const_cast(this)->merge_new_values(); // side effect + const uint8_t preamble_longs = is_empty() ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_NON_EMPTY; + const size_t size_bytes = preamble_longs * sizeof(uint64_t) + sizeof(T) * 2 + sizeof(centroid) * centroids_.size(); + vector_bytes bytes(size_bytes, 0, allocator_); + uint8_t* ptr = bytes.data() + header_size_bytes; + + *ptr++ = preamble_longs; + *ptr++ = SERIAL_VERSION; + *ptr++ = SKETCH_TYPE; + ptr += copy_to_mem(k_, ptr); + const uint8_t flags_byte( + (is_empty() ? 1 << flags::IS_EMPTY : 0) | + (reverse_merge_ ? 1 << flags::REVERSE_MERGE : 0) + ); + *ptr++ = flags_byte; + ptr += 2; // unused + if (is_empty()) return bytes; + + ptr += copy_to_mem(static_cast(centroids_.size()), ptr); + ptr += 4; // unused + + ptr += copy_to_mem(min_, ptr); + ptr += copy_to_mem(max_, ptr); + copy_to_mem(centroids_.data(), ptr, centroids_.size() * sizeof(centroid)); + return bytes; +} + +template +tdigest tdigest::deserialize(std::istream& is, const A& allocator) { + const auto preamble_longs = read(is); + const auto serial_version = read(is); + const auto sketch_type = read(is); + if (sketch_type != SKETCH_TYPE) { + throw std::invalid_argument("sketch type mismatch: expected " + std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type)); + } + if (serial_version != SERIAL_VERSION) { + throw std::invalid_argument("serial version mismatch: expected " + std::to_string(SERIAL_VERSION) + ", actual " + std::to_string(serial_version)); + } + const auto k = read(is); + const auto flags_byte = read(is); + const bool is_empty = flags_byte & (1 << flags::IS_EMPTY); + const uint8_t expected_preamble_longs = is_empty ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_NON_EMPTY; + if (preamble_longs != expected_preamble_longs) { + throw std::invalid_argument("preamble longs mismatch: expected " + std::to_string(expected_preamble_longs) + ", actual " + std::to_string(preamble_longs)); + } + read(is); // unused + + if (is_empty) return tdigest(k, allocator); + + const auto num_centroids = read(is); + read(is); // unused + + const T min = read(is); + const T max = read(is); + vector_centroid centroids(num_centroids, centroid(0, 0), allocator); + read(is, centroids.data(), num_centroids * sizeof(centroid)); + uint64_t total_weight = 0; + for (const auto& c: centroids) total_weight += c.get_weight(); + const bool reverse_merge = flags_byte & (1 << flags::REVERSE_MERGE); + return tdigest(reverse_merge, k, min, max, std::move(centroids), total_weight, allocator); +} + +template +tdigest tdigest::deserialize(const void* bytes, size_t size, const A& allocator) { + ensure_minimum_memory(size, 8); + const char* ptr = static_cast(bytes); + const char* end_ptr = static_cast(bytes) + size; + + const uint8_t preamble_longs = *ptr++; + const uint8_t serial_version = *ptr++; + const uint8_t sketch_type = *ptr++; + if (sketch_type != SKETCH_TYPE) { + throw std::invalid_argument("sketch type mismatch: expected " + std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type)); + } + if (serial_version != SERIAL_VERSION) { + throw std::invalid_argument("serial version mismatch: expected " + std::to_string(SERIAL_VERSION) + ", actual " + std::to_string(serial_version)); + } + uint16_t k; + ptr += copy_from_mem(ptr, k); + const uint8_t flags_byte = *ptr++; + const bool is_empty = flags_byte & (1 << flags::IS_EMPTY); + const uint8_t expected_preamble_longs = is_empty ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_NON_EMPTY; + if (preamble_longs != expected_preamble_longs) { + throw std::invalid_argument("preamble longs mismatch: expected " + std::to_string(expected_preamble_longs) + ", actual " + std::to_string(preamble_longs)); + } + ptr += 2; // unused + + if (is_empty) return tdigest(k, allocator); + + ensure_minimum_memory(end_ptr - ptr, 8); + uint32_t num_centroids; + ptr += copy_from_mem(ptr, num_centroids); + ptr += 4; // unused + + ensure_minimum_memory(end_ptr - ptr, sizeof(T) * 2 + sizeof(centroid) * num_centroids); + T min; + ptr += copy_from_mem(ptr, min); + T max; + ptr += copy_from_mem(ptr, max); + vector_centroid centroids(num_centroids, centroid(0, 0), allocator); + copy_from_mem(ptr, centroids.data(), sizeof(centroid) * num_centroids); + uint64_t total_weight = 0; + for (const auto& c: centroids) total_weight += c.get_weight(); + const bool reverse_merge = flags_byte & (1 << flags::REVERSE_MERGE); + return tdigest(reverse_merge, k, min, max, std::move(centroids), total_weight, allocator); +} + +template +tdigest::tdigest(bool reverse_merge, uint16_t k, T min, T max, vector_centroid&& centroids, uint64_t total_weight, const A& allocator): +allocator_(allocator), +reverse_merge_(reverse_merge), +k_(k), +internal_k_(k), +min_(min), +max_(max), +centroids_capacity_(0), +centroids_(std::move(centroids)), +total_weight_(total_weight), +buffer_capacity_(0), +buffer_(allocator), +buffered_weight_(0) +{ + if (k < 10) throw std::invalid_argument("k must be at least 10"); + size_t fudge = 0; + if (USE_WEIGHT_LIMIT) { + fudge = 10; + if (k < 30) fudge +=20; + } + centroids_capacity_ = 2 * k_ + fudge; + buffer_capacity_ = 5 * centroids_capacity_; + double scale = std::max(1.0, static_cast(buffer_capacity_) / centroids_capacity_ - 1.0); + if (!USE_TWO_LEVEL_COMPRESSION) scale = 1; + internal_k_ = std::ceil(std::sqrt(scale) * k_); + if (centroids_capacity_ < internal_k_ + fudge) { + centroids_capacity_ = internal_k_ + fudge; + } + if (buffer_capacity_ < 2 * centroids_capacity_) buffer_capacity_ = 2 * centroids_capacity_; + centroids_.reserve(centroids_capacity_); + buffer_.reserve(buffer_capacity_); +} + } /* namespace datasketches */ #endif // _TDIGEST_IMPL_HPP_ diff --git a/tdigest/test/tdigest_test.cpp b/tdigest/test/tdigest_test.cpp index 678afca2..b1627d58 100644 --- a/tdigest/test/tdigest_test.cpp +++ b/tdigest/test/tdigest_test.cpp @@ -155,4 +155,78 @@ TEST_CASE("merge large", "[tdigest]") { REQUIRE(td1.get_rank(n) == 1); } +TEST_CASE("serialize deserialize stream empty", "[tdigest]") { + tdigest td(100); + std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); + td.serialize(s); + auto deserialized_td = tdigest::deserialize(s); + REQUIRE(td.get_k() == deserialized_td.get_k()); + REQUIRE(td.get_total_weight() == deserialized_td.get_total_weight()); + REQUIRE(td.is_empty() == deserialized_td.is_empty()); +} + +TEST_CASE("serialize deserialize stream non empty", "[tdigest]") { + tdigest td(100); + for (int i = 0; i < 1000; ++i) td.update(i); + std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); + td.serialize(s); + auto deserialized_td = tdigest::deserialize(s); + REQUIRE(td.get_k() == deserialized_td.get_k()); + REQUIRE(td.get_total_weight() == deserialized_td.get_total_weight()); + REQUIRE(td.is_empty() == deserialized_td.is_empty()); + REQUIRE(td.get_min_value() == deserialized_td.get_min_value()); + REQUIRE(td.get_max_value() == deserialized_td.get_max_value()); + REQUIRE(td.get_rank(500) == deserialized_td.get_rank(500)); + REQUIRE(td.get_quantile(0.5) == deserialized_td.get_quantile(0.5)); +} + +TEST_CASE("serialize deserialize bytes empty", "[tdigest]") { + tdigest td(100); + auto bytes = td.serialize(); + auto deserialized_td = tdigest::deserialize(bytes.data(), bytes.size()); + REQUIRE(td.get_k() == deserialized_td.get_k()); + REQUIRE(td.get_total_weight() == deserialized_td.get_total_weight()); + REQUIRE(td.is_empty() == deserialized_td.is_empty()); +} + +TEST_CASE("serialize deserialize bytes non empty", "[tdigest]") { + tdigest td(100); + for (int i = 0; i < 1000; ++i) td.update(i); + auto bytes = td.serialize(); + auto deserialized_td = tdigest::deserialize(bytes.data(), bytes.size()); + REQUIRE(td.get_k() == deserialized_td.get_k()); + REQUIRE(td.get_total_weight() == deserialized_td.get_total_weight()); + REQUIRE(td.is_empty() == deserialized_td.is_empty()); + REQUIRE(td.get_min_value() == deserialized_td.get_min_value()); + REQUIRE(td.get_max_value() == deserialized_td.get_max_value()); + REQUIRE(td.get_rank(500) == deserialized_td.get_rank(500)); + REQUIRE(td.get_quantile(0.5) == deserialized_td.get_quantile(0.5)); +} + +TEST_CASE("serialize deserialize steam and bytes equivalence", "[tdigest]") { + tdigest td(100); + for (int i = 0; i < 1000; ++i) td.update(i); + std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); + td.serialize(s); + auto bytes = td.serialize(); + + REQUIRE(bytes.size() == static_cast(s.tellp())); + for (size_t i = 0; i < bytes.size(); ++i) { + REQUIRE(((char*)bytes.data())[i] == (char)s.get()); + } + + s.seekg(0); // rewind + auto deserialized_td1 = tdigest::deserialize(s); + auto deserialized_td2 = tdigest::deserialize(bytes.data(), bytes.size()); + REQUIRE(bytes.size() == static_cast(s.tellg())); + + REQUIRE(deserialized_td1.get_k() == deserialized_td2.get_k()); + REQUIRE(deserialized_td1.get_total_weight() == deserialized_td2.get_total_weight()); + REQUIRE(deserialized_td1.is_empty() == deserialized_td2.is_empty()); + REQUIRE(deserialized_td1.get_min_value() == deserialized_td2.get_min_value()); + REQUIRE(deserialized_td1.get_max_value() == deserialized_td2.get_max_value()); + REQUIRE(deserialized_td1.get_rank(500) == deserialized_td2.get_rank(500)); + REQUIRE(deserialized_td1.get_quantile(0.5) == deserialized_td2.get_quantile(0.5)); +} + } /* namespace datasketches */