Skip to content

Commit

Permalink
added serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSaydakov committed Jan 24, 2024
1 parent 1540be6 commit e98561c
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 36 deletions.
45 changes: 44 additions & 1 deletion tdigest/include/tdigest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class tdigest {
uint64_t weight_;
};
using vector_centroid = std::vector<centroid, typename std::allocator_traits<Allocator>::template rebind_alloc<centroid>>;
using vector_bytes = std::vector<uint8_t, typename std::allocator_traits<Allocator>::template rebind_alloc<uint8_t>>;

struct centroid_cmp {
centroid_cmp(bool reverse): reverse_(reverse) {}
Expand Down Expand Up @@ -176,11 +177,43 @@ class tdigest {
*/
string<Allocator> to_string(bool print_centroids = false) const;

/**
* This method serializes t-Digest into a given stream in a binary form
* @param os output stream
*/
void serialize(std::ostream& os) const;

/**
* This method serializes t-Digest as a vector of bytes.
* An optional header can be reserved in front of the sketch.
* It is an uninitialized space of a given size.
* @param header_size_bytes space to reserve in front of the sketch
* @return serialized sketch as a vector of bytes
*/
vector_bytes serialize(unsigned header_size_bytes = 0) const;

/**
* This method deserializes t-Digest from a given stream.
* @param is input stream
* @param allocator instance of an Allocator
* @return an instance of t-Digest
*/
static tdigest deserialize(std::istream& is, const Allocator& allocator = Allocator());

/**
* This method deserializes t-Digest from a given array of bytes.
* @param bytes pointer to the array of bytes
* @param size the size of the array
* @param allocator instance of an Allocator
* @return an instance of t-Digest
*/
static tdigest deserialize(const void* bytes, size_t size, const Allocator& allocator = Allocator());

private:
Allocator allocator_;
bool reverse_merge_;
uint16_t k_;
uint16_t internal_k_;
uint32_t merge_count_;
T min_;
T max_;
size_t centroids_capacity_;
Expand All @@ -190,6 +223,16 @@ class tdigest {
vector_centroid buffer_;
uint64_t buffered_weight_;

static const uint8_t PREAMBLE_LONGS_EMPTY = 1;
static const uint8_t PREAMBLE_LONGS_NON_EMPTY = 2;
static const uint8_t SERIAL_VERSION = 1;
static const uint8_t SKETCH_TYPE = 20;

enum flags { IS_EMPTY, REVERSE_MERGE };

// for deserialize
tdigest(bool reverse_merge, uint16_t k, T min, T max, vector_centroid&& centroids, uint64_t total_weight_, const Allocator& allocator);

void merge_new_values();
void merge_new_values(bool force, uint16_t k);
void merge_new_values(uint16_t k);
Expand Down
210 changes: 175 additions & 35 deletions tdigest/include/tdigest_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,46 +23,20 @@
#include <cmath>
#include <sstream>

#include "common_defs.hpp"
#include "memory_operations.hpp"

namespace datasketches {

template<typename T, typename A>
tdigest<T, A>::tdigest(uint16_t k, const A& allocator):
allocator_(allocator),
k_(k),
internal_k_(k),
merge_count_(0),
min_(std::numeric_limits<T>::infinity()),
max_(-std::numeric_limits<T>::infinity()),
centroids_capacity_(0),
centroids_(allocator),
total_weight_(0),
buffer_capacity_(0),
buffer_(allocator),
buffered_weight_(0)
{
if (k < 10) throw std::invalid_argument("k must be at least 10");
size_t fudge = 0;
if (USE_WEIGHT_LIMIT) {
fudge = 10;
if (k < 30) fudge +=20;
}
centroids_capacity_ = 2 * k_ + fudge;
buffer_capacity_ = 5 * centroids_capacity_;
double scale = std::max(1.0, static_cast<double>(buffer_capacity_) / centroids_capacity_ - 1.0);
if (!USE_TWO_LEVEL_COMPRESSION) scale = 1;
internal_k_ = std::ceil(std::sqrt(scale) * k_);
if (centroids_capacity_ < internal_k_ + fudge) {
centroids_capacity_ = internal_k_ + fudge;
}
if (buffer_capacity_ < 2 * centroids_capacity_) buffer_capacity_ = 2 * centroids_capacity_;
centroids_.reserve(centroids_capacity_);
buffer_.reserve(buffer_capacity_);
}
tdigest(false, k, std::numeric_limits<T>::infinity(), -std::numeric_limits<T>::infinity(), vector_centroid(allocator), 0, allocator)
{}

template<typename T, typename A>
void tdigest<T, A>::update(T value) {
// check for NaN
if (buffer_.size() >= buffer_capacity_ - centroids_.size() - 1) merge_new_values(); // - 1 for compatibility with Java
if (buffer_.size() >= buffer_capacity_ - centroids_.size()) merge_new_values();
buffer_.push_back(centroid(value, 1));
++buffered_weight_;
min_ = std::min(min_, value);
Expand Down Expand Up @@ -237,7 +211,6 @@ string<A> tdigest<T, A>::to_string(bool print_centroids) const {
os << " Buffer capacity : " << buffer_capacity_ << std::endl;
os << " Total Weight : " << total_weight_ << std::endl;
os << " Buffered Weight : " << buffered_weight_ << std::endl;
os << " Merge count : " << merge_count_ << std::endl;
if (!is_empty()) {
os << " Min : " << min_ << std::endl;
os << " Max : " << max_ << std::endl;
Expand Down Expand Up @@ -267,7 +240,7 @@ void tdigest<T, A>::merge_new_values(bool force, uint16_t k) {

template<typename T, typename A>
void tdigest<T, A>::merge_new_values(uint16_t k) {
const bool reverse = USE_ALTERNATING_SORT & (merge_count_ & 1);
const bool reverse = USE_ALTERNATING_SORT & reverse_merge_;
for (const auto& centroid: centroids_) buffer_.push_back(centroid);
centroids_.clear();
std::stable_sort(buffer_.begin(), buffer_.end(), centroid_cmp(reverse));
Expand Down Expand Up @@ -310,7 +283,7 @@ void tdigest<T, A>::merge_new_values(uint16_t k) {
min_ = std::min(min_, centroids_.front().get_mean());
max_ = std::max(max_, centroids_.back().get_mean());
}
++merge_count_;
reverse_merge_ = !reverse_merge_;
buffer_.clear();
buffered_weight_ = 0;
}
Expand All @@ -320,6 +293,173 @@ double tdigest<T, A>::weighted_average(double x1, double w1, double x2, double w
return (x1 * w1 + x2 * w2) / (w1 + w2);
}

template<typename T, typename A>
void tdigest<T, A>::serialize(std::ostream& os) const {
const_cast<tdigest*>(this)->merge_new_values(); // side effect
write(os, is_empty() ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_NON_EMPTY);
write(os, SERIAL_VERSION);
write(os, SKETCH_TYPE);
write(os, k_);
const uint8_t flags_byte(
(is_empty() ? 1 << flags::IS_EMPTY : 0) |
(reverse_merge_ ? 1 << flags::REVERSE_MERGE : 0)
);
write(os, flags_byte);
write<uint16_t>(os, 0); // unused

if (is_empty()) return;

write(os, static_cast<uint32_t>(centroids_.size()));
write<uint32_t>(os, 0); // unused

write(os, min_);
write(os, max_);
write(os, centroids_.data(), centroids_.size() * sizeof(centroid));
}

template<typename T, typename A>
auto tdigest<T, A>::serialize(unsigned header_size_bytes) const -> vector_bytes {
const_cast<tdigest*>(this)->merge_new_values(); // side effect
const uint8_t preamble_longs = is_empty() ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_NON_EMPTY;
const size_t size_bytes = preamble_longs * sizeof(uint64_t) + sizeof(T) * 2 + sizeof(centroid) * centroids_.size();
vector_bytes bytes(size_bytes, 0, allocator_);
uint8_t* ptr = bytes.data() + header_size_bytes;

*ptr++ = preamble_longs;
*ptr++ = SERIAL_VERSION;
*ptr++ = SKETCH_TYPE;
ptr += copy_to_mem(k_, ptr);
const uint8_t flags_byte(
(is_empty() ? 1 << flags::IS_EMPTY : 0) |
(reverse_merge_ ? 1 << flags::REVERSE_MERGE : 0)
);
*ptr++ = flags_byte;
ptr += 2; // unused
if (is_empty()) return bytes;

ptr += copy_to_mem(static_cast<uint32_t>(centroids_.size()), ptr);
ptr += 4; // unused

ptr += copy_to_mem(min_, ptr);
ptr += copy_to_mem(max_, ptr);
copy_to_mem(centroids_.data(), ptr, centroids_.size() * sizeof(centroid));
return bytes;
}

template<typename T, typename A>
tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, const A& allocator) {
const auto preamble_longs = read<uint8_t>(is);
const auto serial_version = read<uint8_t>(is);
const auto sketch_type = read<uint8_t>(is);
if (sketch_type != SKETCH_TYPE) {
throw std::invalid_argument("sketch type mismatch: expected " + std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type));
}
if (serial_version != SERIAL_VERSION) {
throw std::invalid_argument("serial version mismatch: expected " + std::to_string(SERIAL_VERSION) + ", actual " + std::to_string(serial_version));
}
const auto k = read<uint16_t>(is);
const auto flags_byte = read<uint8_t>(is);
const bool is_empty = flags_byte & (1 << flags::IS_EMPTY);
const uint8_t expected_preamble_longs = is_empty ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_NON_EMPTY;
if (preamble_longs != expected_preamble_longs) {
throw std::invalid_argument("preamble longs mismatch: expected " + std::to_string(expected_preamble_longs) + ", actual " + std::to_string(preamble_longs));
}
read<uint16_t>(is); // unused

if (is_empty) return tdigest(k, allocator);

const auto num_centroids = read<uint32_t>(is);
read<uint32_t>(is); // unused

const T min = read<T>(is);
const T max = read<T>(is);
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
read(is, centroids.data(), num_centroids * sizeof(centroid));
uint64_t total_weight = 0;
for (const auto& c: centroids) total_weight += c.get_weight();
const bool reverse_merge = flags_byte & (1 << flags::REVERSE_MERGE);
return tdigest(reverse_merge, k, min, max, std::move(centroids), total_weight, allocator);
}

template<typename T, typename A>
tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, size_t size, const A& allocator) {
ensure_minimum_memory(size, 8);
const char* ptr = static_cast<const char*>(bytes);
const char* end_ptr = static_cast<const char*>(bytes) + size;

const uint8_t preamble_longs = *ptr++;
const uint8_t serial_version = *ptr++;
const uint8_t sketch_type = *ptr++;
if (sketch_type != SKETCH_TYPE) {
throw std::invalid_argument("sketch type mismatch: expected " + std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type));
}
if (serial_version != SERIAL_VERSION) {
throw std::invalid_argument("serial version mismatch: expected " + std::to_string(SERIAL_VERSION) + ", actual " + std::to_string(serial_version));
}
uint16_t k;
ptr += copy_from_mem(ptr, k);
const uint8_t flags_byte = *ptr++;
const bool is_empty = flags_byte & (1 << flags::IS_EMPTY);
const uint8_t expected_preamble_longs = is_empty ? PREAMBLE_LONGS_EMPTY : PREAMBLE_LONGS_NON_EMPTY;
if (preamble_longs != expected_preamble_longs) {
throw std::invalid_argument("preamble longs mismatch: expected " + std::to_string(expected_preamble_longs) + ", actual " + std::to_string(preamble_longs));
}
ptr += 2; // unused

if (is_empty) return tdigest(k, allocator);

ensure_minimum_memory(end_ptr - ptr, 8);
uint32_t num_centroids;
ptr += copy_from_mem(ptr, num_centroids);
ptr += 4; // unused

ensure_minimum_memory(end_ptr - ptr, sizeof(T) * 2 + sizeof(centroid) * num_centroids);
T min;
ptr += copy_from_mem(ptr, min);
T max;
ptr += copy_from_mem(ptr, max);
vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
copy_from_mem(ptr, centroids.data(), sizeof(centroid) * num_centroids);
uint64_t total_weight = 0;
for (const auto& c: centroids) total_weight += c.get_weight();
const bool reverse_merge = flags_byte & (1 << flags::REVERSE_MERGE);
return tdigest(reverse_merge, k, min, max, std::move(centroids), total_weight, allocator);
}

template<typename T, typename A>
tdigest<T, A>::tdigest(bool reverse_merge, uint16_t k, T min, T max, vector_centroid&& centroids, uint64_t total_weight, const A& allocator):
allocator_(allocator),
reverse_merge_(reverse_merge),
k_(k),
internal_k_(k),
min_(min),
max_(max),
centroids_capacity_(0),
centroids_(std::move(centroids)),
total_weight_(total_weight),
buffer_capacity_(0),
buffer_(allocator),
buffered_weight_(0)
{
if (k < 10) throw std::invalid_argument("k must be at least 10");
size_t fudge = 0;
if (USE_WEIGHT_LIMIT) {
fudge = 10;
if (k < 30) fudge +=20;
}
centroids_capacity_ = 2 * k_ + fudge;
buffer_capacity_ = 5 * centroids_capacity_;
double scale = std::max(1.0, static_cast<double>(buffer_capacity_) / centroids_capacity_ - 1.0);
if (!USE_TWO_LEVEL_COMPRESSION) scale = 1;
internal_k_ = std::ceil(std::sqrt(scale) * k_);
if (centroids_capacity_ < internal_k_ + fudge) {
centroids_capacity_ = internal_k_ + fudge;
}
if (buffer_capacity_ < 2 * centroids_capacity_) buffer_capacity_ = 2 * centroids_capacity_;
centroids_.reserve(centroids_capacity_);
buffer_.reserve(buffer_capacity_);
}

} /* namespace datasketches */

#endif // _TDIGEST_IMPL_HPP_
Loading

0 comments on commit e98561c

Please sign in to comment.