diff --git a/README.md b/README.md index 11dafa8f0..1182d7898 100644 --- a/README.md +++ b/README.md @@ -1471,6 +1471,7 @@ $probabilities = ['a' => 0.3, 'b' => 0.2, 'c' => 0.5]; // probabilities for cate $categorical = new Discrete\Categorical($k, $probabilities); $pmf_a = $categorical->pmf('a'); $mode = $categorical->mode(); +$random = $categorical->rand(); // returns 'a' or 'b' or 'c' // Geometric distribution (failures before the first success) $p = 0.5; // success probability @@ -1549,6 +1550,7 @@ $cdf = $uniform->cdf($k); $μ = $uniform->mean(); $median = $uniform->median(); $σ² = $uniform->variance(); +$random = $uniform->rand(); // Zipf distribution $k = 2; // rank diff --git a/src/Probability/Distribution/Discrete/Categorical.php b/src/Probability/Distribution/Discrete/Categorical.php index 4c2c2cf2d..a3b1ce8de 100644 --- a/src/Probability/Distribution/Discrete/Categorical.php +++ b/src/Probability/Distribution/Discrete/Categorical.php @@ -27,6 +27,14 @@ class Categorical extends Discrete */ private $probabilities; + /** + * @var array|null + * Cached CDF when pmf sorted from most probable category + * to least probable category. + * This is only useful for repeated sampling using Categorical::rand() + */ + private $sorted_cdf = null; + /** * Distribution constructor * @@ -123,4 +131,45 @@ public function __get(string $name) throw new Exception\BadDataException("$name is not a valid gettable parameter"); } } + + /** + * Sample a random category and return its key + * + * @return int|string + */ + public function rand() + { + // calculate sorted cdf or use cached array + if (is_null($this->sorted_cdf)) { + // sort probabilities in descending order + $sorted_probabilities = $this->probabilities; // copy as arsort works in place + arsort($sorted_probabilities, SORT_NUMERIC); + + // calculate cdf + $cdf = []; + $sum = 0.0; + foreach ($sorted_probabilities as $category => $pᵢ) { + $sum += $pᵢ; + $cdf[$category] = $sum; + } + + $this->sorted_cdf = $cdf; + } + + $rand = \random_int(0, \PHP_INT_MAX) / \PHP_INT_MAX; // [0, 1] + + // find first element in sorted cdf that is larger than $rand + // for large arrays, performance could be improved by using binary search instead + // also possible with array_find_key in PHP >=8.4 + foreach ($this->sorted_cdf as $category => $v) { + if ($v >= $rand) { + return $category; + } + } + + // should only end up here if due to rounding errors the sum of probabilities + // is less than 1.0 and the generated random value is larger than the sum + // should be very unlikely, but possible + return array_key_last($this->sorted_cdf); + } } diff --git a/src/Probability/Distribution/Discrete/Uniform.php b/src/Probability/Distribution/Discrete/Uniform.php index f1eabf71d..86cebbf8e 100644 --- a/src/Probability/Distribution/Discrete/Uniform.php +++ b/src/Probability/Distribution/Discrete/Uniform.php @@ -154,4 +154,14 @@ public function variance(): float return (($b - $a + 1) ** 2 - 1) / 12; } + + /** + * Random number sampled from the distribution + * + * @return int + */ + public function rand(): int + { + return \random_int($this->a, $this->b); + } } diff --git a/tests/Probability/Distribution/Discrete/CategoricalTest.php b/tests/Probability/Distribution/Discrete/CategoricalTest.php index a54d48583..0983ce2ac 100644 --- a/tests/Probability/Distribution/Discrete/CategoricalTest.php +++ b/tests/Probability/Distribution/Discrete/CategoricalTest.php @@ -237,4 +237,39 @@ public function testGetException() // When $does_not_exist = $categorical->does_not_exist; } + + + /** + * @test rand + */ + public function testRand() + { + // Given + $k = 3; + $probabilities = ['a' => 0.2, 'b' => 0.5, 'c' => 0.3]; + $categorical = new Categorical($k, $probabilities); + + // When + $rand = $categorical->rand(); + + // Then + $this->assertContains($rand, ['a', 'b', 'c']); + } + + /** + * @test rand with certainty + */ + public function testRandCertain() + { + // Given + $k = 3; + $probabilities = ['a' => 0.0, 'b' => 1.0, 'c' => 0.0]; + $categorical = new Categorical($k, $probabilities); + + // When + $rand = $categorical->rand(); + + // Then + $this->assertEquals('b', $rand); + } } diff --git a/tests/Probability/Distribution/Discrete/UniformTest.php b/tests/Probability/Distribution/Discrete/UniformTest.php index ffafbdaa5..b881865ee 100644 --- a/tests/Probability/Distribution/Discrete/UniformTest.php +++ b/tests/Probability/Distribution/Discrete/UniformTest.php @@ -175,4 +175,23 @@ public function dataProviderForVariance(): array [2, 4, 0.66666666666667], ]; } + + /** + * @test rand + */ + public function testRand() + { + // Given + $a = 10; + $b = 11; + $uniform = new Uniform($a, $b); + + // When + $random = $uniform->rand(); + + // Then + $this->assertTrue(\is_numeric($random)); + $this->assertTrue($a <= $random); + $this->assertTrue($random <= $b); + } }