Skip to content

Commit

Permalink
Improve fast_bernouilli
Browse files Browse the repository at this point in the history
  • Loading branch information
plietar committed Feb 28, 2024
1 parent 1b2eb25 commit b3c667a
Showing 1 changed file with 26 additions and 10 deletions.
36 changes: 26 additions & 10 deletions inst/include/IterableBitset.h
Original file line number Diff line number Diff line change
Expand Up @@ -581,31 +581,47 @@ inline void bitset_choose_internal(
}
}

//' A struct to generate the gap lengths between unsuccessful bernouilli trials.
//' This allows us to skip over many bits when sampling from a bitset.
//'
//' https://searchfox.org/mozilla-central/rev/aff9f084/mfbt/FastBernoulliTrial.h
struct fast_bernouilli {
fast_bernouilli(double probability) : probability(probability) {
double probability_log = log(1 - probability);
if (probability_log == 0.0) {
probability = 0.;
} else {
inverse_log = 1 / probability_log;
// 0 and 1 probabilities are handled as special cases in skip_count.
if (0 < probability && probability < 1) {
double probability_log = log(1 - probability);
// Probabilities smaller than 2^-53 could end up with a
// `probability_log` rounded to zero, which we can't inverse. Treat
// these probabilties the same as 0.
if (probability_log == 0.0) {
probability = 0.;
} else {
inverse_log = 1 / probability_log;
}
}
}

//' Get the number of subsequent unsuccessful trials, until the next
//' successful one.
uint64_t skip_count() {
//'
//' The returned value is capped at SIZE_MAX, including when probability is
//' zero. This is only a problem if trying to skip over a bit vector larger
//' than that.
size_t skip_count() {
if (probability == 1) {
return 0;
} else if (probability == 0.) {
return UINT64_MAX;
} else if (probability == 0) {
return SIZE_MAX;
}

double x = R::runif(0.0, 1.0);
double skip_count = floor(log(x) * inverse_log);
if (skip_count < double(UINT64_MAX)) {
if (skip_count < double(SIZE_MAX)) {
return skip_count;
} else {
return UINT64_MAX;
// For very small probability, the skip count can end up being very large
// and exceeding SIZE_MAX. Returning the double directly would be UB.
return SIZE_MAX;
}
}

Expand Down

0 comments on commit b3c667a

Please sign in to comment.