Skip to content

Commit

Permalink
Update ChainInformation
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Jul 27, 2022
1 parent 2566db4 commit 8c2e63c
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 75 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr <p.aschermayr@gmail.com>"]
version = "0.1.5"
version = "0.1.6"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
44 changes: 34 additions & 10 deletions src/sampling/chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,31 @@ function mergediagnostics(paramdiagnostic, chainparamdiagnostic, paramquantiles)
)
end

############################################################################################

"""
$(SIGNATURES)
Check if any parameter has been stuck at each iteration in any chain, in which case chainsummary will skip computations.
# Examples
```julia
```
"""
function is_stuck(arr3D::AbstractArray)
# Loop through chains and parameter to check if first parameter is equal to all samples == chain stuck
for Nchains in Base.OneTo( size(arr3D,3) )
for Nparams in Base.OneTo( size(arr3D,2) )
_benchmark = arr3D[begin,Nparams,Nchains]
stuck = all(val -> val == _benchmark, @view( arr3D[:,Nparams,Nchains] ))
if stuck
return true, (Nparams, Nchains)
end
end
end
return false, (0,0)
end

############################################################################################
"""
$(SIGNATURES)
Expand All @@ -113,25 +138,24 @@ Return summary for trace parameter chains. `Model` defines flattening type of pa
"""
function chainsummary(
trace::Trace,
model::ModelWrapper,
sym::S,
transform::TraceTransform,
backend, #i.e., Val(:text), or Val(:latex)
burnin::Integer,
thinning::Integer,
printdefault::PrintDefault=PrintDefault();
kwargs...,
) where {S<:Union{Symbol,NTuple{k,Symbol} where k}}
## Assign utility values
@unpack Ndigits, quantiles = printdefault
@unpack default = model.info.reconstruct
@unpack progress = trace.info
tagged = Tagged(model, sym)
@unpack tagged, paramnames = transform
Nparams = length(tagged)
paramnames = ModelWrappers.paramnames(
default, tagged.info.constraint, subset(model.val, tagged.parameter)
)
## Flatten parameter to 3D array
computingtime = progress.enabled ? (progress.tlast - progress.tinit) : NaN
arr3D = trace_to_3DArray(trace, model, tagged, burnin, thinning)
arr3D = trace_to_3DArray(trace, transform)
## Check if any MCMC sampler was stuck in any chain, in which case chainsummary will be skipped
stuck, paramchain = is_stuck(arr3D)
if stuck
println("Chain is first stuck in (Nparam, Nchain) = ", paramchain, " - skipping chainsummary.")
end
## Compute summary statistics
#!NOTE If more than 1 chain used, can use cross-chain diagnostics
if trace.info.sampling.Nchains > 1
Expand Down
33 changes: 8 additions & 25 deletions src/sampling/diagnostics.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
############################################################################################
"""
$(SIGNATURES)
Obtain parameter diagnostics from trace at chain `chain`, excluding first `burnin` samples.
# Examples
```julia
```
"""
function get_chaindiagnostics(
trace::Trace, chain::Integer, Nalgorithm::Integer, burnin::Integer, thinning::Integer
)
return @view(trace.diagnostics[chain][Nalgorithm][(1 + burnin):thinning:end])
end

############################################################################################
"""
$(SIGNATURES)
Expand All @@ -27,46 +11,45 @@ Return summary for trace parameter diagnostics, `backend` may be Val(:text), or
function diagnosticssummary(
trace::Trace,
algorithmᵛ::SMC,
transform::TraceTransform,
backend::Nothing, #i.e., Val(:text), or Val(:latex)
burnin::Integer,
thinning::Integer,
printdefault::PrintDefault=PrintDefault();
kwargs...,
)
## Assign utility variables
@unpack effective_iterations = transform
@unpack Ndigits, quantiles = printdefault
@unpack Nchains, Nalgorithms, burnin = trace.info.sampling
## Print diagnostics for each sampler for each chain
println(
"#####################################################################################",
)
return results(
@view(trace.diagnostics[(1 + burnin):thinning:end]), algorithmᵛ, Ndigits, quantiles
@view(trace.diagnostics[effective_iterations]), algorithmᵛ, Ndigits, quantiles
)
end

function diagnosticssummary(
trace::Trace,
algorithmᵛ::AbstractVector,
transform::TraceTransform,
backend::Nothing, #i.e., Val(:text), or Val(:latex)
burnin::Integer,
thinning::Integer,
printdefault::PrintDefault=PrintDefault();
kwargs...,
)
## Assign utility variables
@unpack Ndigits, quantiles = printdefault
@unpack Nchains, Nalgorithms = trace.info.sampling
@unpack chains, algorithms, effective_iterations = transform
## Print diagnostics for each sampler for each chain
for Nalgorithm in Base.OneTo(Nalgorithms)
for Nalgorithm in algorithms
println(
"#####################################################################################",
)
for Nchain in Base.OneTo(Nchains)
for Nchain in chains
println("########################################## Chain ", Nchain, ":")
println(Base.nameof(typeof(algorithmᵛ[Nchain][Nalgorithm])), " Diagnostics: ")
results(
get_chaindiagnostics(trace, Nchain, Nalgorithm, burnin, thinning),
get_chaindiagnostics(trace.diagnostics, Nchain, Nalgorithm, effective_iterations),
algorithmᵛ[Nchain][Nalgorithm],
Ndigits,
quantiles,
Expand Down
165 changes: 140 additions & 25 deletions src/sampling/inference.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,57 @@
############################################################################################
################################################################################
struct TraceTransform{T<:Tagged, P}
"Contains parameter where output information is printed."
tagged :: T
"Parameter names based on tagged model parameter."
paramnames :: P
"Chain indices that are used for output diagnostics."
chains :: Vector{Int64}
"Algorithm indices that are used for output diagnostics."
algorithms :: Vector{Int64}
"Number of burnin steps before output diagnostics are taken."
burnin :: Int64
"Number of steps that are set between 2 consecutive samples."
thinning :: Int64
"Maximum number of iterations to be collected for each chain."
maxiterations :: Int64
"StepRange for indices of effective samples"
effective_iterations :: StepRange{Int64, Int64}
function TraceTransform(
tagged::T,
paramnames::P,
chains::Vector{Int64},
algorithms::Vector{Int64},
burnin::Int64,
thinning::Int64,
maxiterations::Int64
) where {
T<:Tagged,P
}
ArgCheck.@argcheck maxiterations >= burnin >= 0
ArgCheck.@argcheck thinning > 0
ArgCheck.@argcheck maxiterations > 0
#Assign indices for subsetting trace
effective_iterations = (burnin+1):thinning:maxiterations
return new{T,P}(tagged, paramnames, chains, algorithms, burnin, thinning, maxiterations, effective_iterations)
end
end
function TraceTransform(trace::Trace, model::ModelWrapper)
tagged = Tagged(model, trace.info.sampling.printedparam.printed)
paramnames = ModelWrappers.paramnames(
tagged.info.reconstruct.default, tagged.info.constraint, subset(model.val, tagged.parameter)
)
return TraceTransform(
tagged,
paramnames,
collect(Base.OneTo(trace.info.sampling.Nchains)),
collect(Base.OneTo(trace.info.sampling.Nalgorithms)),
trace.info.sampling.burnin,
trace.info.sampling.thinning,
trace.info.sampling.iterations
)
end

################################################################################
"""
$(SIGNATURES)
Change trace.val to 3d Array that is consistent with MCMCCHains dimensons. First dimension is iterations, second number of parameter, third number of chains.
Expand All @@ -10,27 +63,22 @@ Change trace.val to 3d Array that is consistent with MCMCCHains dimensons. First
"""
function trace_to_3DArray(
trace::Trace,
model::ModelWrapper,
tagged::Tagged,
burnin::Integer,
thinning::Integer
transform::TraceTransform
)
## Get trace information
Nparams = length(tagged)
@unpack Nchains, iterations = trace.info.sampling
effective_iterations = (burnin+1):thinning:iterations
@unpack tagged, chains, effective_iterations = transform
## Preallocate array
mcmcchain = zeros(length(effective_iterations), Nparams, Nchains)
mcmcchain = zeros(length(effective_iterations), length(tagged), length(chains))
## Flatten corresponding parameter
Threads.@threads for chain in Base.OneTo(Nchains)
for (iter, index) in enumerate(effective_iterations)
mcmcchain[iter, :, chain] .= flatten(tagged.info.reconstruct, subset(trace.val[chain][index], tagged.parameter))
Threads.@threads for (idx, chain) in collect(enumerate(chains))
for (iter0, iterburnin) in enumerate(effective_iterations)
mcmcchain[iter0, :, idx] .= flatten(tagged.info.reconstruct, subset(trace.val[chain][iterburnin], tagged.parameter))
end
end
## Return MCMCChain
return mcmcchain
end
############################################################################################

"""
$(SIGNATURES)
Change trace.val to 3d Array and return Posterior mean as NamedTuple and as Vector
Expand All @@ -42,28 +90,95 @@ Change trace.val to 3d Array and return Posterior mean as NamedTuple and as Vect
"""
function trace_to_posteriormean(
mod_array::AbstractArray,
model::ModelWrapper,
tagged::Tagged,
burnin::Integer,
thinning::Integer
transform::TraceTransform
)
@unpack tagged = transform
mod_array_mean = map(iter -> mean(view(mod_array, :, iter, :)), Base.OneTo(size(mod_array, 2)))
mod_nt_mean = ModelWrappers.unflatten(model, tagged, mod_array_mean)
mod_nt_mean = ModelWrappers.unflatten(tagged.info.reconstruct, mod_array_mean)
return mod_array_mean, mod_nt_mean
end
function trace_to_posteriormean(
trace::Trace,
model::ModelWrapper,
tagged::Tagged,
burnin::Integer,
thinning::Integer
transform::TraceTransform
)
return trace_to_posteriormean(
trace_to_3DArray(trace, model, tagged, burnin, thinning),
model, tagged, burnin, thinning
trace_to_3DArray(trace, transform), transform
)
end

############################################################################################
"""
$(SIGNATURES)
Return a view of a specific index 'chain' for Vector of Parameter chains of NamedTuples with index 'effective_iterations'.
# Examples
```julia
```
"""
function get_chainvals(val::F, chain::Integer, effective_iterations::StepRange{Int64}) where {F<:Vector{<:Vector{<:NamedTuple}} }
return @view(val[chain][effective_iterations])
end

function get_chainvals(trace::Trace, transform::TraceTransform)
@unpack tagged, chains, effective_iterations = transform
ArgCheck.@argcheck length(chains) <= length(trace.val)
return [map(x -> subset(x, tagged.parameter), get_chainvals(trace.val, chain, effective_iterations)) for chain in chains]
end

"""
$(SIGNATURES)
Merge Vector of Parameter chains of NamedTuples into a single vector.
# Examples
```julia
```
"""
function merge_chainvals(trace::Trace, transform::TraceTransform)
return reduce(vcat, get_chainvals(trace, transform))
end

"""
$(SIGNATURES)
Flatten Vector of Parameter NamedTuples into a Matrix, where each row represents draws for a single parameter.
# Examples
```julia
```
"""
function flatten_chainvals(trace::Trace, transform::TraceTransform)
@unpack tagged = transform
return reduce(hcat, map(x -> flatten(tagged.info.reconstruct, x), merge_chainvals(trace, transform) ) )
end

################################################################################
"""
$(SIGNATURES)
Obtain parameter diagnostics from trace at chain `chain`, excluding first `burnin` samples.
# Examples
```julia
```
"""
function get_chaindiagnostics(diagnostics, chain::Integer, Nalgorithm::Integer, effective_iterations::StepRange{Int64})
return @view(diagnostics[chain][Nalgorithm][effective_iterations])
end

function get_chaindiagnostics(trace::Trace, transform::TraceTransform)
@unpack chains, algorithms, effective_iterations = transform
return [ map(algorithm -> get_chaindiagnostics(trace.diagnostics, chain, algorithm, effective_iterations), algorithms) for chain in chains]
end

############################################################################################
#export
export trace_to_3DArray, trace_to_posteriormean
export
TraceTransform,
trace_to_3DArray,
trace_to_posteriormean,
get_chainvals,
get_chaindiagnostics,
merge_chainvals,
flatten_chainvals
Loading

2 comments on commit 8c2e63c

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/65099

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.6 -m "<description of version>" 8c2e63c54e3fb8d7adaaa98f51267653dbaab83a
git push origin v0.1.6

Please sign in to comment.