Skip to content

Commit

Permalink
Add LatchingBitSet.
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyastolfi committed Feb 26, 2024
1 parent 514519d commit 8bf2134
Show file tree
Hide file tree
Showing 3 changed files with 383 additions and 0 deletions.
164 changes: 164 additions & 0 deletions src/llfs/latching_bit_set.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//#=##=##=#==#=#==#===#+==#+==========+==+=+=+=+=+=++=+++=+++++=-++++=-+++++++++++
//
// Part of the LLFS Project, under Apache License v2.0.
// See https://www.apache.org/licenses/LICENSE-2.0 for license information.
// SPDX short identifier: Apache-2.0
//
//+++++++++++-+-+--+----- --- -- - - - -

#include <llfs/latching_bit_set.hpp>
//

namespace llfs {

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
/*explicit*/ LatchingBitSet::LatchingBitSet(usize n) noexcept : upper_bound_{n}
{
usize total_size = 0;
usize level_size = this->upper_bound_;

this->start_of_level_.emplace_back(0);

batt::SmallVec<u64, 12> last_word;

for (;;) {
const usize leftover_bits = level_size % 64;
if (leftover_bits == 0) {
last_word.emplace_back(0);
} else {
last_word.emplace_back(~((u64{1} << leftover_bits) - 1));
}

level_size = (level_size + 63) / 64;
total_size += level_size;
if (level_size <= 1) {
break;
}
this->start_of_level_.emplace_back(this->start_of_level_.back() + level_size);
}

if (total_size > 0) {
this->data_.reset(new u64[total_size]);
std::memset(this->data_.get(), 0, sizeof(u64) * total_size);

// Set any leftover bits at the end of the last word in each level to 1, to simplify the code
// to check for a block containing all 1's.
//
for (usize i = 0; i < this->start_of_level_.size() - 1; ++i) {
const usize last_word_offset = this->start_of_level_[i + 1] - 1;
this->data_.get()[last_word_offset] = last_word[i];
}
this->data_.get()[this->start_of_level_.back()] = last_word.back();
}
}

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
usize LatchingBitSet::upper_bound() const noexcept
{
return this->upper_bound_;
}

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
usize LatchingBitSet::first_missing() const noexcept
{
if (this->is_full()) {
return this->upper_bound();
}

usize depth = this->start_of_level_.size();
usize lower_bound = 0;
while (depth > 0) {
--depth;

const usize index = this->start_of_level_[depth] + lower_bound;

// Even if the set wasn't full
//
if (depth + 1 < this->start_of_level_.size() && index >= this->start_of_level_[depth + 1]) {
return this->upper_bound();
}

const u64 value = this->data()[index].load();
i32 first_zero_bit = [&] {
if (value == 0) {
return 0;
}
return __builtin_ctzll(~value);
}();

lower_bound = lower_bound * 64 + first_zero_bit;
}

return lower_bound;
}

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
bool LatchingBitSet::contains(usize i) noexcept
{
BATT_CHECK_LT(i, this->upper_bound());

const usize word_i = i / 64;
const usize bit_i = i % 64;
const u64 mask = u64{1} << bit_i;

return (this->data()[word_i].load() & mask) != 0;
}

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
bool LatchingBitSet::insert(usize i) noexcept
{
BATT_CHECK_LT(i, this->upper_bound());

bool changed = false;

usize depth = 0;

for (;;) {
const usize word_i = i / 64;
const usize bit_i = i % 64;
const u64 mask = u64{1} << bit_i;

const u64 old_value = this->data()[word_i + this->start_of_level_[depth]].fetch_or(mask);
const u64 new_value = old_value | mask;

changed = changed || (old_value != new_value);

if (!changed || batt::bit_count(new_value) < 64) {
break;
}

++depth;
if (depth == this->start_of_level_.size()) {
break;
}
i = word_i;
}

return changed;
}

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
bool LatchingBitSet::is_full() const noexcept
{
if (this->upper_bound() == 0u) {
return true;
}
return this->data()[this->start_of_level_.back()].load() == ~u64{0};
}

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
std::atomic<u64>* LatchingBitSet::data() const noexcept
{
static_assert(sizeof(std::atomic<u64>) == sizeof(u64));

return reinterpret_cast<std::atomic<u64>*>(this->data_.get());
}

} //namespace llfs
89 changes: 89 additions & 0 deletions src/llfs/latching_bit_set.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//#=##=##=#==#=#==#===#+==#+==========+==+=+=+=+=+=++=+++=+++++=-++++=-+++++++++++
//
// Part of the LLFS Project, under Apache License v2.0.
// See https://www.apache.org/licenses/LICENSE-2.0 for license information.
// SPDX short identifier: Apache-2.0
//
//+++++++++++-+-+--+----- --- -- - - - -

#pragma once
#ifndef LLFS_LATCHING_BIT_SET_HPP
#define LLFS_LATCHING_BIT_SET_HPP

#include <llfs/config.hpp>
//
#include <llfs/int_types.hpp>

#include <batteries/math.hpp>
#include <batteries/small_vec.hpp>

namespace llfs {

/** \brief A lock-free concurrent set of integers with a fixed upper bound, which supports insertion
* but not erasure of items.
*/
class LatchingBitSet
{
public:
/** \brief Creates a new set with space for `n_bits` bits, allowing integers from 0 to n_bits-1 to
* be inserted.
*/
explicit LatchingBitSet(usize n_bits) noexcept;

//+++++++++++-+-+--+----- --- -- - - - -

/** \brief Returns one past the maximum integer that can be inserted into the set.
*
* Note: this is a property of the container that is fixed at construction time; it does not
* report the current actual largest value in the set.
*/
usize upper_bound() const noexcept;

/** \brief Returns the index of the first unset (0) bit.
*/
usize first_missing() const noexcept;

/** \brief Returns true iff the set contains the given integer.
*
* \param i The integer to test for; MUST be less than this->upper_bound(), or we PANIC.
*/
bool contains(usize i) noexcept;

/** \brief Inserts the passed integer `i`, returning true iff it was not previously contained by
* the set.
*/
bool insert(usize i) noexcept;

/** \brief Returns true iff all integers from 0 to this->upper_bound() - 1 (inclusive) are present
* in the set.
*/
bool is_full() const noexcept;

//+++++++++++-+-+--+----- --- -- - - - -
private:
/** \brief Returns the backing storage memory as an array of std::atomic<u64>.
*/
std::atomic<u64>* data() const noexcept;

//+++++++++++-+-+--+----- --- -- - - - -

/** \brief The number of bits stored in this set.
*/
usize upper_bound_;

/** \brief Precomputed/cached starts of the levels of the data structure, in index offsets into
* this->data_; level 0 is the flat array of all per-integer bits, the next level is a 1/64-sized
* summary of those (where a 1 bit means all 64 corresponding bits from the previous level are
* also 1; 0 otherwise); the last level is always a single u64.
*/
batt::SmallVec<usize, 12> start_of_level_;

/** \brief The backing memory for the bit set. This is allocated as u64 and then cast to
* std::atomic<u64> so we can speed up initialization by just memset-ing the array to all zeros.
*/
std::unique_ptr<u64[]> data_;
};

} //namespace llfs

#endif // LLFS_LATCHING_BIT_SET_HPP
130 changes: 130 additions & 0 deletions src/llfs/latching_bit_set.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
//#=##=##=#==#=#==#===#+==#+==========+==+=+=+=+=+=++=+++=+++++=-++++=-+++++++++++
//
// Part of the LLFS Project, under Apache License v2.0.
// See https://www.apache.org/licenses/LICENSE-2.0 for license information.
// SPDX short identifier: Apache-2.0
//
//+++++++++++-+-+--+----- --- -- - - - -

#include <llfs/latching_bit_set.hpp>
//
#include <llfs/latching_bit_set.hpp>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <vector>

namespace {

using namespace llfs::int_types;

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
TEST(LatchingBitSetTest, Size0)
{
llfs::LatchingBitSet s{0};

EXPECT_EQ(s.upper_bound(), 0u);
EXPECT_EQ(s.first_missing(), 0u);
EXPECT_TRUE(s.is_full());
}

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
TEST(LatchingBitSetTest, Size1)
{
llfs::LatchingBitSet s{1};

EXPECT_EQ(s.upper_bound(), 1u);
EXPECT_EQ(s.first_missing(), 0u);
EXPECT_FALSE(s.is_full());

EXPECT_DEATH(s.insert(1), ".*[Aa]ssert.*fail.*i.*<.*size.*");

EXPECT_TRUE(s.insert(0));
EXPECT_FALSE(s.insert(0));

EXPECT_EQ(s.first_missing(), 1u);
EXPECT_TRUE(s.is_full());
}

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
TEST(LatchingBitSetTest, Size40)
{
llfs::LatchingBitSet s{40};

EXPECT_EQ(s.upper_bound(), 40u);
EXPECT_EQ(s.first_missing(), 0u);
EXPECT_FALSE(s.is_full());

EXPECT_TRUE(s.insert(0));
EXPECT_EQ(s.first_missing(), 1u);

EXPECT_TRUE(s.insert(10));
EXPECT_EQ(s.first_missing(), 1u);

for (usize i = 1; i < 9; ++i) {
for (usize j = 0; j < i; ++j) {
EXPECT_TRUE(s.contains(j));
EXPECT_FALSE(s.insert(j));
}
EXPECT_FALSE(s.contains(i));
EXPECT_TRUE(s.insert(i));
EXPECT_EQ(s.first_missing(), i + 1);
EXPECT_FALSE(s.is_full());
}

EXPECT_TRUE(s.insert(9));
EXPECT_EQ(s.first_missing(), 11u);
EXPECT_FALSE(s.is_full());

for (usize i = 11; i < 39; ++i) {
for (usize j = 0; j < i; ++j) {
EXPECT_FALSE(s.insert(j));
}
EXPECT_TRUE(s.insert(i));
EXPECT_EQ(s.first_missing(), i + 1);
EXPECT_FALSE(s.is_full());
}

EXPECT_TRUE(s.insert(39));
EXPECT_EQ(s.first_missing(), 40);
EXPECT_TRUE(s.is_full());
}

//==#==========+==+=+=++=+++++++++++-+-+--+----- --- -- - - - -
//
TEST(LatchingBitSetTest, MultiLevelTest)
{
for (usize n : std::vector<usize>{
64, // one level, exactly full
277, // a prime number
64 * 64, // two full levels
7777, // three levels, not full
64 * 64 * 64, // three exactly full levels
}) {
llfs::LatchingBitSet s{n};

EXPECT_EQ(s.upper_bound(), n);

for (usize i = 0; i < s.upper_bound(); ++i) {
for (usize j = i - std::min<usize>(i, 100); j < i; ++j) {
EXPECT_TRUE(s.contains(j));
EXPECT_FALSE(s.insert(j));
}

EXPECT_EQ(s.first_missing(), i);
EXPECT_TRUE(s.insert(i));
EXPECT_EQ(s.first_missing(), i + 1);
if (i + 1 < s.upper_bound()) {
EXPECT_FALSE(s.is_full());
} else {
EXPECT_TRUE(s.is_full());
}
}
}
}

} // namespace

0 comments on commit 8bf2134

Please sign in to comment.