Skip to content

Commit

Permalink
Merge branch 'feature/fix_loading_bad_memory' into 'master'
Browse files Browse the repository at this point in the history
Fix loading bad memory on some platforms

See merge request minknow/pod5-file-format!221
  • Loading branch information
0x55555555 committed Apr 6, 2023
2 parents 73617c6 + 48b9918 commit 2bf1706
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 36 deletions.
11 changes: 5 additions & 6 deletions c++/pod5_format/signal_compression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ POD5_FORMAT_EXPORT arrow::Status decompress_signal(
")");
}

auto allocation_padding = svb16::decode_input_buffer_padding_byte_count();
ARROW_ASSIGN_OR_RAISE(
auto intermediate, arrow::AllocateResizableBuffer(decompressed_zstd_size, pool));
auto intermediate,
arrow::AllocateResizableBuffer(decompressed_zstd_size + allocation_padding, pool));
size_t const decompress_res = ZSTD_decompress(
intermediate->mutable_data(),
intermediate->size(),
Expand All @@ -101,11 +103,8 @@ POD5_FORMAT_EXPORT arrow::Status decompress_signal(
static constexpr bool UseDelta = true;
static constexpr bool UseZigzag = true;
auto consumed_count = svb16::decode<SampleType, UseDelta, UseZigzag>(
reinterpret_cast<SampleType *>(destination.data()),
intermediate->data(),
destination.size());

if (consumed_count != (std::size_t)intermediate->size()) {
destination, gsl::make_span(intermediate->data(), intermediate->size()));
if ((consumed_count + allocation_padding) != (std::size_t)intermediate->size()) {
return pod5::Status::Invalid("Remaining data at end of signal buffer");
}

Expand Down
23 changes: 18 additions & 5 deletions c++/pod5_format/svb16/decode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,30 @@

namespace svb16 {

// Required extra space after readable buffers passed in.
//
// Require 1 128 bit buffer beyond the end of all input readable buffers.
inline std::size_t decode_input_buffer_padding_byte_count()
{
#ifdef SVB16_X64
return sizeof(__m128i);
#else
return 0;
#endif
}

template <typename Int16T, bool UseDelta, bool UseZigzag>
size_t decode(Int16T * out, uint8_t const * SVB_RESTRICT in, uint32_t count, Int16T prev = 0)
size_t decode(gsl::span<Int16T> out, gsl::span<uint8_t const> in, Int16T prev = 0)
{
auto const keys = in;
auto const data = keys + ::svb16_key_length(count);
auto keys_length = ::svb16_key_length(out.size());
auto const keys = in.subspan(0, keys_length);
auto const data = in.subspan(keys_length);
#ifdef SVB16_X64
if (has_sse4_1()) {
return decode_sse<Int16T, UseDelta, UseZigzag>(out, keys, data, count, prev) - in;
return decode_sse<Int16T, UseDelta, UseZigzag>(out, keys, data, prev) - in.begin();
}
#endif
return decode_scalar<Int16T, UseDelta, UseZigzag>(out, keys, data, count, prev) - in;
return decode_scalar<Int16T, UseDelta, UseZigzag>(out, keys, data, prev) - in.begin();
}

} // namespace svb16
27 changes: 18 additions & 9 deletions c++/pod5_format/svb16/decode_scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include "common.hpp"

#include <gsl/gsl-lite.hpp>

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstring>
Expand All @@ -13,9 +16,8 @@ inline uint16_t zigzag_decode(uint16_t val)
return (val >> 1) ^ static_cast<uint16_t>(0 - (val & 1));
}

inline uint16_t decode_data(uint8_t const * SVB_RESTRICT * dataPtrPtr, uint8_t code)
inline uint16_t decode_data(gsl::span<uint8_t const>::iterator & dataPtr, uint8_t code)
{
uint8_t const * dataPtr = *dataPtrPtr;
uint16_t val;

if (code == 0) { // 1 byte
Expand All @@ -27,23 +29,26 @@ inline uint16_t decode_data(uint8_t const * SVB_RESTRICT * dataPtrPtr, uint8_t c
dataPtr += 2;
}

*dataPtrPtr = dataPtr;
return val;
}
} // namespace detail

template <typename Int16T, bool UseDelta, bool UseZigzag>
uint8_t const * decode_scalar(
Int16T * out,
uint8_t const * SVB_RESTRICT keys,
uint8_t const * SVB_RESTRICT data,
uint32_t count,
gsl::span<Int16T> out_span,
gsl::span<uint8_t const> keys_span,
gsl::span<uint8_t const> data_span,
Int16T prev = 0)
{
auto const count = out_span.size();
if (count == 0) {
return data;
return data_span.begin();
}

auto out = out_span.begin();
auto keys = keys_span.begin();
auto data = data_span.begin();

uint8_t shift = 0; // cycles 0 through 7 then resets
uint8_t key_byte = *keys++;
// need to do the arithmetic in unsigned space so it wraps
Expand All @@ -53,7 +58,7 @@ uint8_t const * decode_scalar(
shift = 0;
key_byte = *keys++;
}
uint16_t value = detail::decode_data(&data, (key_byte >> shift) & 0x01);
uint16_t value = detail::decode_data(data, (key_byte >> shift) & 0x01);
SVB16_IF_CONSTEXPR(UseZigzag) { value = detail::zigzag_decode(value); }
SVB16_IF_CONSTEXPR(UseDelta)
{
Expand All @@ -62,6 +67,10 @@ uint8_t const * decode_scalar(
}
*out++ = static_cast<Int16T>(value);
}

assert(out == out_span.end());
assert(keys == keys_span.end());
assert(data <= data_span.end());
return data;
}

Expand Down
42 changes: 34 additions & 8 deletions c++/pod5_format/svb16/decode_x64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "shuffle_tables.hpp"
#include "svb16.h" // svb16_key_length

#include <gsl/gsl-lite.hpp>

#include <cstddef>
#include <cstdint>

Expand Down Expand Up @@ -71,10 +73,9 @@ template <typename Int16T, bool UseDelta, bool UseZigzag>

template <typename Int16T, bool UseDelta, bool UseZigzag>
[[gnu::target("sse4.1")]] uint8_t const * decode_sse(
Int16T * out,
uint8_t const * SVB_RESTRICT keys,
uint8_t const * SVB_RESTRICT data,
uint32_t count,
gsl::span<Int16T> out_span,
gsl::span<uint8_t const> keys_span,
gsl::span<uint8_t const> data_span,
Int16T prev = 0)
{
auto store_8 = [](Int16T * to, __m128i value, __m128i * prev) {
Expand All @@ -83,6 +84,11 @@ template <typename Int16T, bool UseDelta, bool UseZigzag>
// this code treats all input as uint16_t (except the zigzag code, which treats it as int16_t)
// this isn't a problem, as the scalar code does the same

auto out = out_span.begin();
auto const count = out_span.size();
auto keys_it = keys_span.begin();
auto data = data_span.begin();

// handle blocks of 32 values
if (count >= 64) {
size_t const key_bytes = count / 8;
Expand All @@ -91,7 +97,7 @@ template <typename Int16T, bool UseDelta, bool UseZigzag>
SVB16_IF_CONSTEXPR(UseDelta) { prev_reg = _mm_set1_epi16(prev); }

int64_t offset = -static_cast<int64_t>(key_bytes) / 8 + 1; // 8 -> 4?
uint64_t const * keyPtr64 = reinterpret_cast<uint64_t const *>(keys) - offset;
uint64_t const * keyPtr64 = reinterpret_cast<uint64_t const *>(keys_it) - offset;
uint64_t nextkeys;
memcpy(&nextkeys, keyPtr64 + offset, sizeof(nextkeys));

Expand Down Expand Up @@ -153,6 +159,14 @@ template <typename Int16T, bool UseDelta, bool UseZigzag>
keys >>= 16;
data_reg = detail::unpack((keys & 0x00FF), &data);
store_8(out + 48, data_reg, &prev_reg);

// Note we load at least sizeof(__m128i) bytes from the end of data
// here, need to ensure that is available to read.
//
// But we might not use it all depending on the unpacking.
//
// This is ok due to `decode_input_buffer_padding_byte_count` enuring
// extra space on the input buffer.
data_reg = detail::unpack((keys & 0xFF00) >> 8, &data);
store_8(out + 56, data_reg, &prev_reg);

Expand Down Expand Up @@ -183,8 +197,9 @@ template <typename Int16T, bool UseDelta, bool UseZigzag>
data_reg = _mm_cvtepu8_epi16(
_mm_lddqu_si128(reinterpret_cast<__m128i const *>(data + 48)));
store_8(out + 48, data_reg, &prev_reg);
// Only load the first 8 bytes here, otherwise we may run off the end of the buffer
data_reg = _mm_cvtepu8_epi16(
_mm_lddqu_si128(reinterpret_cast<__m128i const *>(data + 56)));
_mm_loadl_epi64(reinterpret_cast<__m128i const *>(data + 56)));
store_8(out + 56, data_reg, &prev_reg);
out += 64;
data += 64;
Expand Down Expand Up @@ -218,10 +233,21 @@ template <typename Int16T, bool UseDelta, bool UseZigzag>
}
prev = out[-1];

keys += key_bytes - (key_bytes & 7);
keys_it += key_bytes - (key_bytes & 7);
}

return decode_scalar<Int16T, UseDelta, UseZigzag>(out, keys, data, count & 63, prev);
assert(out <= out_span.end());
assert(keys_it <= keys_span.end());
assert(data <= data_span.end());

auto out_scalar_span = gsl::make_span(out, out_span.end());
assert(out_scalar_span.size() == (count & 63));

auto keys_scalar_span = gsl::make_span(keys_it, keys_span.end());
auto data_scalar_span = gsl::make_span(data, data_span.end());

return decode_scalar<Int16T, UseDelta, UseZigzag>(
out_scalar_span, keys_scalar_span, data_scalar_span, prev);
}

#endif // SVB16_X64
Expand Down
9 changes: 5 additions & 4 deletions c++/test/svb16_scalar_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ void test_scalar_encode_scalar_decode()
CHECK(encoded_count <= svb16_max_encoded_length(data.size()));

std::vector<Int16T> decoded(DATA_COUNT);
auto const encoded_span = gsl::make_span(encoded);
auto const key_length = svb16_key_length(data.size());
auto const consumed = svb16::decode_scalar<Int16T, UseDelta, UseZigzag>(
decoded.data(),
encoded.data(),
encoded.data() + svb16_key_length(data.size()),
DATA_COUNT)
gsl::make_span(decoded),
encoded_span.subspan(0, key_length),
encoded_span.subspan(key_length))
- encoded.data();

CHECK(consumed == encoded_count);
Expand Down
9 changes: 5 additions & 4 deletions c++/test/svb16_x64_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ void test_sse_encode_scalar_decode()
CHECK(encoded == encoded_scalar);

std::vector<Int16T> decoded(DATA_COUNT);
auto const encoded_span = gsl::make_span(encoded);
auto const key_length = svb16_key_length(data.size());
auto const consumed = svb16::decode_sse<Int16T, UseDelta, UseZigzag>(
decoded.data(),
encoded.data(),
encoded.data() + svb16_key_length(data.size()),
DATA_COUNT)
gsl::make_span(decoded),
encoded_span.subspan(0, key_length),
encoded_span.subspan(key_length))
- encoded.data();

CHECK(consumed == encoded_count);
Expand Down

0 comments on commit 2bf1706

Please sign in to comment.