From 700778fad2792f56d04711df518bae63100c6ced Mon Sep 17 00:00:00 2001
From: Patrick Aschermayr
Date: Wed, 12 Jul 2023 23:41:21 +0200
Subject: [PATCH] 20230712
---
Project.toml | 4 ++--
src/Baytes.jl | 2 ++
src/sampling/chain.jl | 2 +-
src/sampling/inference.jl | 21 ++++++++++-----------
4 files changed, 15 insertions(+), 14 deletions(-)
diff --git a/Project.toml b/Project.toml
index d739a52..4ff63ee 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr "]
-version = "0.3.6"
+version = "0.3.7"
[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -31,7 +31,7 @@ BaytesSMC = "0.3"
DocStringExtensions = "0.8, 0.9"
JLD2 = "0.4"
MCMCDiagnosticTools = "0.3"
-ModelWrappers = "0.4"
+ModelWrappers = "0.5"
PrettyTables = "2"
ProgressMeter = "1.7"
SimpleUnPack = "1"
diff --git a/src/Baytes.jl b/src/Baytes.jl
index a85ff2c..2a08fb8 100644
--- a/src/Baytes.jl
+++ b/src/Baytes.jl
@@ -48,6 +48,8 @@ using ModelWrappers:
ModelWrapper,
Tagged,
Objective,
+ length_constrained,
+ length_unconstrained,
#= DiffObjective,
AbstractDifferentiableTune,
ℓObjectiveResult,
diff --git a/src/sampling/chain.jl b/src/sampling/chain.jl
index a36cbe7..7cd93c6 100644
--- a/src/sampling/chain.jl
+++ b/src/sampling/chain.jl
@@ -141,7 +141,7 @@ function chainsummary(
@unpack Ndigits, quantiles = printdefault
@unpack progress = trace.summary
@unpack tagged, paramnames = transform
- Nparams = length(tagged)
+ Nparams = length_constrained(tagged)
Nchains = length(transform.chains)
## Flatten parameter to 3D array
computingtime = progress.enabled ? (progress.tlast - progress.tinit) : NaN
diff --git a/src/sampling/inference.jl b/src/sampling/inference.jl
index dce8ab5..27826ce 100644
--- a/src/sampling/inference.jl
+++ b/src/sampling/inference.jl
@@ -140,7 +140,7 @@ function trace_to_3DArray(
## Get trace information
@unpack tagged, chains, effective_iterations = transform
## Preallocate array
- mcmcchain = zeros(length(effective_iterations), length(chains), length(tagged))
+ mcmcchain = zeros(length(effective_iterations), length(chains), length_constrained(tagged))
## Flatten corresponding parameter
#!NOTE: Commented-out loop below is threadsave, but chain is not flattened in correct order, which might be troublesome for MCMC chain analysis.
# Threads.@threads for (idx, chain) in collect(enumerate(chains))
@@ -171,7 +171,7 @@ function trace_to_3DArrayᵤ(
## Get trace information
@unpack tagged, chains, effective_iterations = transform
## Preallocate array
- mcmcchain = zeros(length(effective_iterations), length(chains), length(tagged))
+ mcmcchain = zeros(length(effective_iterations), length(chains), length_unconstrained(tagged))
## Flatten corresponding parameter
#!NOTE: This is threadsave, but chain is not flattened in correct order, which might be troublesome for MCMC chain analysis.
# Threads.@threads for (idx, chain) in collect(enumerate(chains))
@@ -179,9 +179,8 @@ function trace_to_3DArrayᵤ(
for (iter0, iterburnin) in enumerate(effective_iterations)
# mcmcchain[iter0, :, idx] .=
mcmcchain[iter0, idx, :] .=
- flatten(tagged.info.reconstruct,
- unconstrain(tagged.info.transform, subset(trace.val[chain][iterburnin], tagged.parameter) )
- )
+# flatten(tagged.info.reconstruct, unconstrain(tagged.info.transform, subset(trace.val[chain][iterburnin], tagged.parameter) ) )
+ ModelWrappers.unconstrain_flatten(tagged.info, subset(trace.val[chain][iterburnin], tagged.parameter))
end
end
## Return MCMCChain
@@ -205,7 +204,7 @@ function trace_to_2DArray(
## Get trace information
@unpack tagged, chains, effective_iterations = transform
## Preallocate array
- mcmcchain = zeros(length(effective_iterations) * length(chains), length(tagged))
+ mcmcchain = zeros(length(effective_iterations) * length(chains), length_constrained(tagged))
## Flatten corresponding parameter
#!NOTE: This is threadsave, but chain is not flattened in correct order, which might be troublesome for MCMC chain analysis.
# Threads.@threads for (idx, chain) in collect(enumerate(chains))
@@ -235,7 +234,7 @@ function trace_to_2DArrayᵤ(
## Get trace information
@unpack tagged, chains, effective_iterations = transform
## Preallocate array
- mcmcchain = zeros(length(effective_iterations) * length(chains), length(tagged))
+ mcmcchain = zeros(length(effective_iterations) * length(chains), length_unconstrained(tagged))
## Flatten corresponding parameter
#!NOTE: This is threadsave, but chain is not flattened in correct ordered, which might be troublesome for MCMC chain analysis.
# Threads.@threads for (idx, chain) in collect(enumerate(chains))
@@ -243,9 +242,9 @@ function trace_to_2DArrayᵤ(
for (idx, chain) in collect(enumerate(chains))
for (iter0, iterburnin) in enumerate(effective_iterations)
iter += 1
- mcmcchain[iter, :] .= flatten(tagged.info.reconstruct,
- unconstrain(tagged.info.transform, subset(trace.val[chain][iterburnin], tagged.parameter) )
- )
+ mcmcchain[iter, :] .=
+# flatten(tagged.info.reconstruct, unconstrain(tagged.info.transform, subset(trace.val[chain][iterburnin], tagged.parameter) ) )
+ ModelWrappers.unconstrain_flatten(tagged.info, subset(trace.val[chain][iterburnin], tagged.parameter))
end
end
## Return MCMCChain
@@ -329,7 +328,7 @@ function flatten_chainvals(
## Get trace information
@unpack tagged, chains, effective_iterations = transform
## Preallocate array
- mcmcchain = [ [ zeros(tagged.info.reconstruct.default.output, length(tagged)) for _ in eachindex(effective_iterations) ] for _ in eachindex(chains) ]
+ mcmcchain = [ [ zeros(tagged.info.reconstruct.default.output, length_constrained(tagged)) for _ in eachindex(effective_iterations) ] for _ in eachindex(chains) ]
## Flatten corresponding parameter
#!NOTE: This is threadsave, but chain is not flattened in correct order, which might be troublesome for MCMC chain analysis, hence we opt out of it.
# Threads.@threads for (idx, chain) in collect(enumerate(chains))