From b3c667a3cdb8a6bca6491b87fc82ebd110c5c852 Mon Sep 17 00:00:00 2001 From: Paul Lietar Date: Wed, 28 Feb 2024 17:48:01 +0000 Subject: [PATCH] Improve fast_bernouilli --- inst/include/IterableBitset.h | 36 +++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/inst/include/IterableBitset.h b/inst/include/IterableBitset.h index 08b9964..c1eb738 100644 --- a/inst/include/IterableBitset.h +++ b/inst/include/IterableBitset.h @@ -581,31 +581,47 @@ inline void bitset_choose_internal( } } +//' A struct to generate the gap lengths between unsuccessful bernouilli trials. +//' This allows us to skip over many bits when sampling from a bitset. +//' +//' https://searchfox.org/mozilla-central/rev/aff9f084/mfbt/FastBernoulliTrial.h struct fast_bernouilli { fast_bernouilli(double probability) : probability(probability) { - double probability_log = log(1 - probability); - if (probability_log == 0.0) { - probability = 0.; - } else { - inverse_log = 1 / probability_log; + // 0 and 1 probabilities are handled as special cases in skip_count. + if (0 < probability && probability < 1) { + double probability_log = log(1 - probability); + // Probabilities smaller than 2^-53 could end up with a + // `probability_log` rounded to zero, which we can't inverse. Treat + // these probabilties the same as 0. + if (probability_log == 0.0) { + probability = 0.; + } else { + inverse_log = 1 / probability_log; + } } } //' Get the number of subsequent unsuccessful trials, until the next //' successful one. - uint64_t skip_count() { + //' + //' The returned value is capped at SIZE_MAX, including when probability is + //' zero. This is only a problem if trying to skip over a bit vector larger + //' than that. + size_t skip_count() { if (probability == 1) { return 0; - } else if (probability == 0.) { - return UINT64_MAX; + } else if (probability == 0) { + return SIZE_MAX; } double x = R::runif(0.0, 1.0); double skip_count = floor(log(x) * inverse_log); - if (skip_count < double(UINT64_MAX)) { + if (skip_count < double(SIZE_MAX)) { return skip_count; } else { - return UINT64_MAX; + // For very small probability, the skip count can end up being very large + // and exceeding SIZE_MAX. Returning the double directly would be UB. + return SIZE_MAX; } }