diff --git a/c++/pod5_format/signal_compression.cpp b/c++/pod5_format/signal_compression.cpp index 98f40d0..69e0d3a 100644 --- a/c++/pod5_format/signal_compression.cpp +++ b/c++/pod5_format/signal_compression.cpp @@ -81,8 +81,10 @@ POD5_FORMAT_EXPORT arrow::Status decompress_signal( ")"); } + auto allocation_padding = svb16::decode_input_buffer_padding_byte_count(); ARROW_ASSIGN_OR_RAISE( - auto intermediate, arrow::AllocateResizableBuffer(decompressed_zstd_size, pool)); + auto intermediate, + arrow::AllocateResizableBuffer(decompressed_zstd_size + allocation_padding, pool)); size_t const decompress_res = ZSTD_decompress( intermediate->mutable_data(), intermediate->size(), @@ -101,11 +103,8 @@ POD5_FORMAT_EXPORT arrow::Status decompress_signal( static constexpr bool UseDelta = true; static constexpr bool UseZigzag = true; auto consumed_count = svb16::decode( - reinterpret_cast(destination.data()), - intermediate->data(), - destination.size()); - - if (consumed_count != (std::size_t)intermediate->size()) { + destination, gsl::make_span(intermediate->data(), intermediate->size())); + if ((consumed_count + allocation_padding) != (std::size_t)intermediate->size()) { return pod5::Status::Invalid("Remaining data at end of signal buffer"); } diff --git a/c++/pod5_format/svb16/decode.hpp b/c++/pod5_format/svb16/decode.hpp index 0869aca..c8c221d 100644 --- a/c++/pod5_format/svb16/decode.hpp +++ b/c++/pod5_format/svb16/decode.hpp @@ -10,17 +10,30 @@ namespace svb16 { +// Required extra space after readable buffers passed in. +// +// Require 1 128 bit buffer beyond the end of all input readable buffers. +inline std::size_t decode_input_buffer_padding_byte_count() +{ +#ifdef SVB16_X64 + return sizeof(__m128i); +#else + return 0; +#endif +} + template -size_t decode(Int16T * out, uint8_t const * SVB_RESTRICT in, uint32_t count, Int16T prev = 0) +size_t decode(gsl::span out, gsl::span in, Int16T prev = 0) { - auto const keys = in; - auto const data = keys + ::svb16_key_length(count); + auto keys_length = ::svb16_key_length(out.size()); + auto const keys = in.subspan(0, keys_length); + auto const data = in.subspan(keys_length); #ifdef SVB16_X64 if (has_sse4_1()) { - return decode_sse(out, keys, data, count, prev) - in; + return decode_sse(out, keys, data, prev) - in.begin(); } #endif - return decode_scalar(out, keys, data, count, prev) - in; + return decode_scalar(out, keys, data, prev) - in.begin(); } } // namespace svb16 diff --git a/c++/pod5_format/svb16/decode_scalar.hpp b/c++/pod5_format/svb16/decode_scalar.hpp index 15b3798..f7d0aba 100644 --- a/c++/pod5_format/svb16/decode_scalar.hpp +++ b/c++/pod5_format/svb16/decode_scalar.hpp @@ -2,6 +2,9 @@ #include "common.hpp" +#include + +#include #include #include #include @@ -13,9 +16,8 @@ inline uint16_t zigzag_decode(uint16_t val) return (val >> 1) ^ static_cast(0 - (val & 1)); } -inline uint16_t decode_data(uint8_t const * SVB_RESTRICT * dataPtrPtr, uint8_t code) +inline uint16_t decode_data(gsl::span::iterator & dataPtr, uint8_t code) { - uint8_t const * dataPtr = *dataPtrPtr; uint16_t val; if (code == 0) { // 1 byte @@ -27,23 +29,26 @@ inline uint16_t decode_data(uint8_t const * SVB_RESTRICT * dataPtrPtr, uint8_t c dataPtr += 2; } - *dataPtrPtr = dataPtr; return val; } } // namespace detail template uint8_t const * decode_scalar( - Int16T * out, - uint8_t const * SVB_RESTRICT keys, - uint8_t const * SVB_RESTRICT data, - uint32_t count, + gsl::span out_span, + gsl::span keys_span, + gsl::span data_span, Int16T prev = 0) { + auto const count = out_span.size(); if (count == 0) { - return data; + return data_span.begin(); } + auto out = out_span.begin(); + auto keys = keys_span.begin(); + auto data = data_span.begin(); + uint8_t shift = 0; // cycles 0 through 7 then resets uint8_t key_byte = *keys++; // need to do the arithmetic in unsigned space so it wraps @@ -53,7 +58,7 @@ uint8_t const * decode_scalar( shift = 0; key_byte = *keys++; } - uint16_t value = detail::decode_data(&data, (key_byte >> shift) & 0x01); + uint16_t value = detail::decode_data(data, (key_byte >> shift) & 0x01); SVB16_IF_CONSTEXPR(UseZigzag) { value = detail::zigzag_decode(value); } SVB16_IF_CONSTEXPR(UseDelta) { @@ -62,6 +67,10 @@ uint8_t const * decode_scalar( } *out++ = static_cast(value); } + + assert(out == out_span.end()); + assert(keys == keys_span.end()); + assert(data <= data_span.end()); return data; } diff --git a/c++/pod5_format/svb16/decode_x64.hpp b/c++/pod5_format/svb16/decode_x64.hpp index b885e0d..16d29f0 100644 --- a/c++/pod5_format/svb16/decode_x64.hpp +++ b/c++/pod5_format/svb16/decode_x64.hpp @@ -6,6 +6,8 @@ #include "shuffle_tables.hpp" #include "svb16.h" // svb16_key_length +#include + #include #include @@ -71,10 +73,9 @@ template template [[gnu::target("sse4.1")]] uint8_t const * decode_sse( - Int16T * out, - uint8_t const * SVB_RESTRICT keys, - uint8_t const * SVB_RESTRICT data, - uint32_t count, + gsl::span out_span, + gsl::span keys_span, + gsl::span data_span, Int16T prev = 0) { auto store_8 = [](Int16T * to, __m128i value, __m128i * prev) { @@ -83,6 +84,11 @@ template // this code treats all input as uint16_t (except the zigzag code, which treats it as int16_t) // this isn't a problem, as the scalar code does the same + auto out = out_span.begin(); + auto const count = out_span.size(); + auto keys_it = keys_span.begin(); + auto data = data_span.begin(); + // handle blocks of 32 values if (count >= 64) { size_t const key_bytes = count / 8; @@ -91,7 +97,7 @@ template SVB16_IF_CONSTEXPR(UseDelta) { prev_reg = _mm_set1_epi16(prev); } int64_t offset = -static_cast(key_bytes) / 8 + 1; // 8 -> 4? - uint64_t const * keyPtr64 = reinterpret_cast(keys) - offset; + uint64_t const * keyPtr64 = reinterpret_cast(keys_it) - offset; uint64_t nextkeys; memcpy(&nextkeys, keyPtr64 + offset, sizeof(nextkeys)); @@ -153,6 +159,14 @@ template keys >>= 16; data_reg = detail::unpack((keys & 0x00FF), &data); store_8(out + 48, data_reg, &prev_reg); + + // Note we load at least sizeof(__m128i) bytes from the end of data + // here, need to ensure that is available to read. + // + // But we might not use it all depending on the unpacking. + // + // This is ok due to `decode_input_buffer_padding_byte_count` enuring + // extra space on the input buffer. data_reg = detail::unpack((keys & 0xFF00) >> 8, &data); store_8(out + 56, data_reg, &prev_reg); @@ -183,8 +197,9 @@ template data_reg = _mm_cvtepu8_epi16( _mm_lddqu_si128(reinterpret_cast<__m128i const *>(data + 48))); store_8(out + 48, data_reg, &prev_reg); + // Only load the first 8 bytes here, otherwise we may run off the end of the buffer data_reg = _mm_cvtepu8_epi16( - _mm_lddqu_si128(reinterpret_cast<__m128i const *>(data + 56))); + _mm_loadl_epi64(reinterpret_cast<__m128i const *>(data + 56))); store_8(out + 56, data_reg, &prev_reg); out += 64; data += 64; @@ -218,10 +233,21 @@ template } prev = out[-1]; - keys += key_bytes - (key_bytes & 7); + keys_it += key_bytes - (key_bytes & 7); } - return decode_scalar(out, keys, data, count & 63, prev); + assert(out <= out_span.end()); + assert(keys_it <= keys_span.end()); + assert(data <= data_span.end()); + + auto out_scalar_span = gsl::make_span(out, out_span.end()); + assert(out_scalar_span.size() == (count & 63)); + + auto keys_scalar_span = gsl::make_span(keys_it, keys_span.end()); + auto data_scalar_span = gsl::make_span(data, data_span.end()); + + return decode_scalar( + out_scalar_span, keys_scalar_span, data_scalar_span, prev); } #endif // SVB16_X64 diff --git a/c++/test/svb16_scalar_tests.cpp b/c++/test/svb16_scalar_tests.cpp index d9fe1e2..a8b4a5f 100644 --- a/c++/test/svb16_scalar_tests.cpp +++ b/c++/test/svb16_scalar_tests.cpp @@ -31,11 +31,12 @@ void test_scalar_encode_scalar_decode() CHECK(encoded_count <= svb16_max_encoded_length(data.size())); std::vector decoded(DATA_COUNT); + auto const encoded_span = gsl::make_span(encoded); + auto const key_length = svb16_key_length(data.size()); auto const consumed = svb16::decode_scalar( - decoded.data(), - encoded.data(), - encoded.data() + svb16_key_length(data.size()), - DATA_COUNT) + gsl::make_span(decoded), + encoded_span.subspan(0, key_length), + encoded_span.subspan(key_length)) - encoded.data(); CHECK(consumed == encoded_count); diff --git a/c++/test/svb16_x64_tests.cpp b/c++/test/svb16_x64_tests.cpp index 27d239e..2610d09 100644 --- a/c++/test/svb16_x64_tests.cpp +++ b/c++/test/svb16_x64_tests.cpp @@ -45,11 +45,12 @@ void test_sse_encode_scalar_decode() CHECK(encoded == encoded_scalar); std::vector decoded(DATA_COUNT); + auto const encoded_span = gsl::make_span(encoded); + auto const key_length = svb16_key_length(data.size()); auto const consumed = svb16::decode_sse( - decoded.data(), - encoded.data(), - encoded.data() + svb16_key_length(data.size()), - DATA_COUNT) + gsl::make_span(decoded), + encoded_span.subspan(0, key_length), + encoded_span.subspan(key_length)) - encoded.data(); CHECK(consumed == encoded_count);