Skip to content

Commit

Permalink
Optimize the bitset sampling algorithm.
Browse files Browse the repository at this point in the history
The existing bitset sampling implementation works by using a binomial
distribution to decide how many bits to keep, randomly chooses the
indices of those bits, sorts the vector of indices and finally iterates
over all the bits one by one to clear those not contained in the vector.

This can be very inefficient, in particular when sampling over large
bit sets with a very low sampling rate. In that case, the list of
indices to keep is roughly as large as the bitset itself, and sorting it
requires O(nlog(n)) time, which ends up being significant.

Additionally, walking over every single bit, set or not, to be cleared
or not, is pretty inefficient as well.

This commit optimizes the implementation through a few methods:
- Instead of sampling and sorting indices to keep, it randomly samples
  the size of the gaps between two succesful bernouilli trials.
  This was inspired by the [FastBernoulliTrial] class.
- When the sampling rate is higher than 1/2, it flips the sampling logic
  and uses the gaps between two unsuccessful trials, minimizing the
  number of loop iterations.
- Finally, in order to take full advantage of the gap lengths, it is
  able to quickly scan through the bitset to skip a given number of set
  bits, calling popcnt once per words rather than looking at each bit,
  and can clear entire ranges of bits at once by overwriting entire
  words, rather than masking bits one by one.

This implementation is faster than the previous one for the entire
parameter space. The difference is most drastic for very low sampling
rates where the new implementation is more than two orders of magnitude
faster.

[FastBernoulliTrial]: https://searchfox.org/mozilla-central/rev/a6d25de0c706dbc072407ed5d339aaed1cab43b7/mfbt/FastBernoulliTrial.h
  • Loading branch information
plietar committed Feb 28, 2024
1 parent e59b4f1 commit d881329
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 41 deletions.
205 changes: 165 additions & 40 deletions inst/include/IterableBitset.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class IterableBitset {
const_iterator end() const;
const_iterator cend() const;
void erase(size_t);
void erase(size_t start, size_t end);
const_iterator find(size_t) const;
template<class InputIterator>
void insert(InputIterator, InputIterator);
Expand All @@ -98,6 +99,7 @@ class IterableBitset {
bool empty() const;
void extend(size_t);
void shrink(const std::vector<size_t>&);
size_t next_position(size_t start, size_t n) const;
};


Expand Down Expand Up @@ -131,23 +133,37 @@ inline size_t popcount(uint64_t x) {
#endif
}

//' @title find the next set bit
//' @description given the current element p,
//' return the integer represented by the next set bit in the bitmap
//' Find the nth set bit in a 64bit integer.
//'
//' Returns the index of the bit, or 64 if there are not enough bits set.
inline size_t find_bit(uint64_t x, size_t n) {
if (n >= 64) {
return 64;
}

for (size_t i = 0; i < n; i++) {
x &= x - 1;
}
return ctz(x);
}

//' Find the n-th set bit, starting at position p.
//'
//' Returns the index of the bit, ot max_n if there are not enough bits set.
template<class A>
inline size_t next_position(const std::vector<A>& bitmap, size_t num_bits, size_t max_n, size_t p) {
++p;
auto bucket = p / num_bits;
auto excess = p % num_bits;
A bitset = bitmap.at(bucket) >> excess;
inline size_t IterableBitset<A>::next_position(size_t p, size_t n) const {
size_t bucket = p / num_bits;
size_t excess = p % num_bits;

while(bitset == 0 && bucket + 1 < bitmap.size()) {
bitset = bitmap.at(++bucket);
A bitset = bitmap[bucket] >> excess;
while (n >= popcount(bitset) && bucket + 1 < bitmap.size()) {
n -= popcount(bitset);
bucket += 1;
bitset = bitmap[bucket];
excess = 0;
}

auto lsb = bitset & -bitset;
auto r = ctz(lsb);
auto r = find_bit(bitset, n);
return std::min(bucket * num_bits + excess + r, max_n);
}

Expand All @@ -158,8 +174,8 @@ inline IterableBitset<A>::const_iterator::const_iterator(

template<class A>
inline IterableBitset<A>::const_iterator::const_iterator(
const IterableBitset& index) : index(index), p(static_cast<size_t>(-1)) {
p = next_position(index.bitmap, index.num_bits, index.max_n, p);
const IterableBitset& index) : index(index) {
p = index.next_position(0, 0);
}

template<class A>
Expand All @@ -176,7 +192,7 @@ inline bool IterableBitset<A>::const_iterator::operator !=(

template<class A>
inline typename IterableBitset<A>::const_iterator& IterableBitset<A>::const_iterator::operator ++() {
p = next_position(index.bitmap, index.num_bits, index.max_n, p);
p = index.next_position(p + 1, 0);
return *this;
}

Expand Down Expand Up @@ -357,6 +373,70 @@ inline void IterableBitset<A>::erase(size_t v) {
}
}

//' @title Erase all values in a given range.
//' @description Bits at indices [start, end) are set to zero.
template<class A>
inline void IterableBitset<A>::erase(size_t start, size_t end) {
// In the general case, bits to erase are split into three regions, a
// prefix, a middle part and a postfix. The middle region is always aligned
// on word boundaries.
//
// Consider the following bitset, stored using 4-bit words:
//
// abcd efgh ijkl mnop
//
// Erasing the range [2, 14) requires clearing out bit c to n, inclusive.
// The prefix is [cd] and the suffix [mn]. The middle section is [efghijkl].
//
// The middle section can be erased by overwriting the entire word with zeros.
// The prefix and suffix parts must be cleared by applying a mask over the
// existing bits.
//
// There are however a few special cases:
// - The range could be empty, in which case nothing needs to be done.
// - The range falls within a single word. In the example above that could
// be [5, 6), ie. [fg]. A single mask needs to be constructed, covering
// only the relevant bits.
// - The middle region ends on a word boundary, in which case there is no
// postfix to erase.
//
// Anytime bits are cleared, whether by overwriting a word or using a mask,
// the bitset's size must be updated accordingly, using popcount to find out
// how many bits have actually been cleared.

if (start == end) {
return;
} else if (start / num_bits == end / num_bits) {
auto mask =
(static_cast<A>(1) << (end % num_bits)) -
(static_cast<A>(1) << (start % num_bits));
n -= popcount(bitmap[start / num_bits] & mask);
bitmap[start / num_bits] &= ~mask;
} else {
// Clear the prefix part, using a mask to preserve bits that are before it.
auto mask = -(static_cast<A>(1) << (start % num_bits));
n -= popcount(bitmap[start / num_bits] & mask);
bitmap[start / num_bits] &= ~mask;

start = (start + num_bits - 1) / num_bits * num_bits;

// Now clear the middle chunk. No masking needed since entire words are
// being cleared.
for (; start + num_bits <= end; start += num_bits) {
n -= popcount(bitmap[start / num_bits]);
bitmap[start / num_bits] = 0;
}
start = (end / num_bits) * num_bits;

// Finally clear the suffix, if applicable, using a mask again.
if (start < end) {
mask = (static_cast<A>(1) << (end % num_bits)) - 1;
n -= popcount(bitmap[end / num_bits] & mask);
bitmap[end / num_bits] &= ~mask;
}
}
}

//' @title find an element in the bitset
//' @description checks if the bit for `v` is set
template<class A>
Expand Down Expand Up @@ -501,33 +581,78 @@ inline void bitset_choose_internal(
}
}

//' @title sample the bitset
//' @description retain a subset of values contained in this bitset,
//' where each element has probability 'rate' to remain.
//' This function modifies the bitset.
struct fast_bernouilli {
fast_bernouilli(double probability) : probability(probability) {
double probability_log = log(1 - probability);
if (probability_log == 0.0) {
probability = 0.;
} else {
inverse_log = 1 / probability_log;
}
}

//' Get the number of subsequent unsuccessful trials, until the next
//' successful one.
uint64_t skip_count() {
if (probability == 1) {
return 0;

Check warning on line 598 in inst/include/IterableBitset.h

View check run for this annotation

Codecov / codecov/patch

inst/include/IterableBitset.h#L598

Added line #L598 was not covered by tests
} else if (probability == 0.) {
return UINT64_MAX;
}

double x = R::runif(0.0, 1.0);
double skip_count = floor(log(x) * inverse_log);
if (skip_count < double(UINT64_MAX)) {
return skip_count;
} else {
return UINT64_MAX;

Check warning on line 608 in inst/include/IterableBitset.h

View check run for this annotation

Codecov / codecov/patch

inst/include/IterableBitset.h#L608

Added line #L608 was not covered by tests
}
}

private:
double probability;
double inverse_log;
};

//' Sample values from the bitset.
//'
//' Each value contained in the bitset is retained with an equal probability
//' 'rate'. This function modifies the bitset in-place.
//'
//' Rather than performing a bernouilli trial for every member of the bitset,
//' this function is implemented by generating the lengths of the gaps between
//' two successful trials. This allows it to efficiently skip from one positive
//' trial to the next.
//'
//' This technique comes from the FastBernoulliTrial class in Firefox:
//' https://searchfox.org/mozilla-central/rev/aff9f084/mfbt/FastBernoulliTrial.h
//'
//' As an additional optimization, we flip the behaviour and sampling rate in
//' order to maximize the lengths, depending on whether the rate was smaller or
//' greater than 1/2.
template<class A>
inline void bitset_sample_internal(
IterableBitset<A>& b,
const double rate
){
auto to_remove = Rcpp::sample(
b.size(),
Rcpp::rbinom(1, b.size(), 1 - std::min(rate, 1.))[0],
false, // replacement
R_NilValue, // evenly distributed
false // one based
);
std::sort(to_remove.begin(), to_remove.end());
auto bitset_i = 0;
auto bitset_it = b.cbegin();
for (auto i : to_remove) {
while(bitset_i != i) {
++bitset_i;
++bitset_it;
}
b.erase(*bitset_it);
++bitset_i;
++bitset_it;
IterableBitset<A>& b,
const double rate
){
if (rate < 0.5) {
fast_bernouilli bernouilli(rate);
size_t i = 0;
while (i < b.max_size()) {
size_t next = b.next_position(i, bernouilli.skip_count());
b.erase(i, next);
i = next + 1;
}
} else {
fast_bernouilli bernouilli(1 - rate);
size_t i = 0;
while (i < b.max_size()) {
size_t next = b.next_position(i, bernouilli.skip_count());
if (next < b.max_size()) {
b.erase(next);
}
i = next + 1;
}
}
}

Expand Down
45 changes: 44 additions & 1 deletion tests/testthat/test-bitset.R
Original file line number Diff line number Diff line change
Expand Up @@ -314,4 +314,47 @@ test_that("bitset filtering works when given empty bitset", {
f <- Bitset$new(10)
expect_equal(filter_bitset(b, f)$size(), 0)
expect_equal(filter_bitset(b, integer(0))$size(), 0)
})
})

test_that("bitset sampling with extremes is correct", {
size <- 100
expect_equal(Bitset$new(size)$not()$sample(0)$size(), 0)
expect_equal(Bitset$new(size)$not()$sample(1)$size(), size)
})

test_that("bitset is evenly sampled", {
set.seed(123)

threshold <- 0.05
rates <- c(0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99)
size <- 100
N <- 1000

for (rate in rates) {
freqs <- rep(0, size)
for (i in seq(N)) {
b <- Bitset$new(size)$not()$sample(rate)
xs <- b$to_vector()
freqs[xs] <- freqs[xs] + 1
}
p <- t.test(freqs, mu=rate*N)$p.value
expect_gt(p, threshold)
}
})

test_that("bitset sampling has correctly distributed size", {
set.seed(123)

threshold <- 0.05
rates <- c(0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99)
size <- 100
N <- 1000

for (rate in rates) {
data <- sapply(1:N, function(i) {
Bitset$new(size)$not()$sample(rate)$size()
})
p <- t.test(data, mu=rate*size)$p.value
expect_gt(p, threshold)
}
})

0 comments on commit d881329

Please sign in to comment.