Skip to content

Commit

Permalink
update chainoutput
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Oct 10, 2022
1 parent 9e00dd1 commit c067ad5
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr <p.aschermayr@gmail.com>"]
version = "0.1.15"
version = "0.1.16"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
86 changes: 70 additions & 16 deletions src/sampling/chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ end
############################################################################################
"""
$(SIGNATURES)
Return summary for trace parameter chains. `Model` defines flattening type of parameter,
`sym` defines parameter to be flattened, `backend` may be Val(:text), or Val(:latex).
Return summary for trace parameter chains. 'printdefault' defines quantiles and number of digits for printing.
# Examples
```julia
Expand All @@ -139,10 +138,8 @@ Return summary for trace parameter chains. `Model` defines flattening type of pa
function chainsummary(
trace::Trace,
transform::TraceTransform,
backend, #i.e., Val(:text), or Val(:latex)
printdefault::PrintDefault=PrintDefault();
kwargs...,
) where {S<:Union{Symbol,NTuple{k,Symbol} where k}}
printdefault::PrintDefault=PrintDefault()
)
## Assign utility values
@unpack Ndigits, quantiles = printdefault
@unpack progress = trace.info
Expand All @@ -155,11 +152,7 @@ function chainsummary(
## Check if any MCMC sampler was stuck in any chain, in which case chainsummary will be skipped
stuck, paramchain = is_stuck(arr3D)
if stuck
println(
"#####################################################################################",
)
println("Chain is first stuck in (Nparam, Nchain) = ", paramchain, " - skipping chainsummary.")
return nothing
return stuck, paramchain, nothing
end
## Compute summary statistics
#!NOTE If more than 1 chain used, can use cross-chain diagnostics
Expand All @@ -181,14 +174,75 @@ function chainsummary(
_reconstruct = ModelWrappers.ReConstructor(diag)
diag_flattened = flatten(_reconstruct, diag)
table = round.(reshape(diag_flattened, Nstats, Nparams)'; digits=Ndigits)
## Print table
PrettyTables.pretty_table(
table, backend=backend, header=tablenames, row_labels=paramnames, kwargs...
)
## Return table arguments
return table, tablenames, paramnames
end

"""
$(SIGNATURES)
Print summary for trace parameter chains. `backend` may be Val(:text), or Val(:latex).
# Examples
```julia
```
"""
function printchainsummary(
trace::Trace,
transform::TraceTransform,
backend, #i.e., Val(:text), or Val(:latex)
printdefault::PrintDefault=PrintDefault();
kwargs...,
)
table, tablenames, paramnames = chainsummary(trace, transform, printdefault)
if table isa Bool
println(
"#####################################################################################",
)
println("Chain is first stuck in (Nparam, Nchain) = ", tablenames, " - skipping chainsummary.")
return nothing, nothing, nothing
## Print table
else
PrettyTables.pretty_table(
table, backend=backend, header=tablenames, row_labels=paramnames, kwargs...
)
end
end

"""
$(SIGNATURES)
Add a ModelWrapper struct 'model' as a function argument to print model.val as "true" parameter in table.
# Examples
```julia
```
"""
function printchainsummary(
model::ModelWrapper,
trace::Trace,
transform::TraceTransform,
backend, #i.e., Val(:text), or Val(:latex)
printdefault::PrintDefault=PrintDefault();
kwargs...,
)
table, tablenames, paramnames = chainsummary(trace, transform, printdefault)
if table isa Bool
println(
"#####################################################################################",
)
println("Chain is first stuck in (Nparam, Nchain) = ", tablenames, " - skipping chainsummary.")
return nothing, nothing, nothing
## Print table
else
#Obtain true parameter from model
θ_true = round.(flatten(model, transform.tagged); digits = printdefault.Ndigits)
PrettyTables.pretty_table(
hcat(θ_true, table), backend=backend, header=vcat("True", tablenames), row_labels=paramnames, kwargs...
)
end
end

############################################################################################
#export
export chainsummary
export chainsummary, printchainsummary
6 changes: 3 additions & 3 deletions src/sampling/diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Return summary for trace parameter diagnostics, `backend` may be Val(:text), or
```
"""
function diagnosticssummary(
function printdiagnosticssummary(
trace::Trace,
algorithmᵛ::SMC,
transform::TraceTransform,
Expand All @@ -29,7 +29,7 @@ function diagnosticssummary(
)
end

function diagnosticssummary(
function printdiagnosticssummary(
trace::Trace,
algorithmᵛ::AbstractVector,
transform::TraceTransform,
Expand Down Expand Up @@ -60,4 +60,4 @@ end

############################################################################################
#export
export diagnosticssummary
export printdiagnosticssummary
4 changes: 2 additions & 2 deletions src/sampling/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ function summary(
printdefault::PrintDefault=PrintDefault(),
) where {S<:Union{Symbol,NTuple{k,Symbol} where k}}
## Print Diagnostics summary
diagnosticssummary(trace, algorithmᵛ, transform, nothing, printdefault)
printdiagnosticssummary(trace, algorithmᵛ, transform, nothing, printdefault)
## Print Chain summary
chainsummary(trace, transform, Val(:text), printdefault)
printchainsummary(trace, transform, Val(:text), printdefault)
## Return
return nothing
end
Expand Down
42 changes: 21 additions & 21 deletions src/sampling/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,34 +238,34 @@ end
############################################################################################
"""
$(SIGNATURES)
Safe `trace`, `model` and `algorithm` to current working directory.
Save `trace`, `model` and `algorithm` to current working directory with name 'name'.
# Examples
```julia
```
"""
function savetrace(trace::Trace, model::ModelWrapper, algorithm)
@unpack iterations, burnin, Nchains = trace.info.sampling
function savetrace(trace::Trace, model::ModelWrapper, algorithm,
name = join((
Base.nameof(typeof(model.id)),
"_",
Base.nameof(typeof(algorithm)),
"_",
Dates.today(),
"_H",
Dates.hour(Dates.now()),
"M",
Dates.minute(Dates.now()),
"_Nchains",
trace.info.sampling.Nchains,
"_Iter",
trace.info.sampling.iterations,
"_Burnin",
trace.info.sampling.burnin,
))
)
JLD2.jldsave(
join((
Base.nameof(typeof(model.id)),
"_",
Base.nameof(typeof(algorithm)),
"_",
Dates.today(),
"_H",
Dates.hour(Dates.now()),
"M",
Dates.minute(Dates.now()),
"_Nchains",
Nchains,
"_Iter",
iterations,
"_Burnin",
burnin,
".jld2",
));
join((name, ".jld2"));
trace=trace,
model=model,
algorithm=algorithm,
Expand Down

2 comments on commit c067ad5

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/69832

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.16 -m "<description of version>" c067ad58ce788ebcba0e9f9a87d6529d0a85544a
git push origin v0.1.16

Please sign in to comment.