From 7a2f516d9a7a62c15676ac450f7b4a58320481f4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 17 Mar 2020 13:49:28 +0000 Subject: [PATCH] Make Bernoulli produce Bools (#1079) * Bernoulli returns bool closes #1068 * fix bernoulli eltype * fix cdf and ccdf --- Project.toml | 5 ++--- src/univariate/discrete/bernoulli.jl | 10 ++++++---- test/bernoulli.jl | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index a5a44c0e3c..1d6320773b 100644 --- a/Project.toml +++ b/Project.toml @@ -29,11 +29,10 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON", - "StaticArrays", "HypothesisTests", "Test"] +test = ["Calculus", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON", "StaticArrays", "HypothesisTests", "Test"] diff --git a/src/univariate/discrete/bernoulli.jl b/src/univariate/discrete/bernoulli.jl index 713b6cd7bc..f39232c911 100644 --- a/src/univariate/discrete/bernoulli.jl +++ b/src/univariate/discrete/bernoulli.jl @@ -38,7 +38,9 @@ end Bernoulli(p::Integer) = Bernoulli(float(p)) Bernoulli() = Bernoulli(0.5, check_args=false) -@distr_support Bernoulli 0 1 +@distr_support Bernoulli false true + +Base.eltype(::Type{<:Bernoulli}) = Bool #### Conversions convert(::Type{Bernoulli{T}}, p::Real) where {T<:Real} = Bernoulli(T(p)) @@ -83,11 +85,11 @@ pdf(d::Bernoulli, x::Bool) = x ? succprob(d) : failprob(d) pdf(d::Bernoulli, x::Int) = x == 0 ? failprob(d) : x == 1 ? succprob(d) : zero(d.p) -cdf(d::Bernoulli, x::Bool) = x ? failprob(d) : one(d.p) +cdf(d::Bernoulli, x::Bool) = x ? one(d.p) : failprob(d) cdf(d::Bernoulli, x::Int) = x < 0 ? zero(d.p) : x < 1 ? failprob(d) : one(d.p) -ccdf(d::Bernoulli, x::Bool) = x ? succprob(d) : one(d.p) +ccdf(d::Bernoulli, x::Bool) = x ? zero(d.p) : succprob(d) ccdf(d::Bernoulli, x::Int) = x < 0 ? one(d.p) : x < 1 ? succprob(d) : zero(d.p) @@ -104,7 +106,7 @@ cf(d::Bernoulli, t::Real) = failprob(d) + succprob(d) * cis(t) #### Sampling -rand(rng::AbstractRNG, d::Bernoulli) = rand(rng) <= succprob(d) ? 1 : 0 +rand(rng::AbstractRNG, d::Bernoulli) = rand(rng) <= succprob(d) #### MLE fitting diff --git a/test/bernoulli.jl b/test/bernoulli.jl index 276a442ac2..e9961f8d39 100644 --- a/test/bernoulli.jl +++ b/test/bernoulli.jl @@ -1,5 +1,5 @@ using Distributions using Test, Random -@test typeof(rand(Bernoulli())) == Int -@test typeof(rand(Bernoulli(), 10)) == Vector{Int} +@test rand(Bernoulli()) isa Bool +@test rand(Bernoulli(), 10) isa Vector{Bool}