From 700778fad2792f56d04711df518bae63100c6ced Mon Sep 17 00:00:00 2001 From: Patrick Aschermayr Date: Wed, 12 Jul 2023 23:41:21 +0200 Subject: [PATCH] 20230712 --- Project.toml | 4 ++-- src/Baytes.jl | 2 ++ src/sampling/chain.jl | 2 +- src/sampling/inference.jl | 21 ++++++++++----------- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index d739a52..4ff63ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Baytes" uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f" authors = ["Patrick Aschermayr "] -version = "0.3.6" +version = "0.3.7" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -31,7 +31,7 @@ BaytesSMC = "0.3" DocStringExtensions = "0.8, 0.9" JLD2 = "0.4" MCMCDiagnosticTools = "0.3" -ModelWrappers = "0.4" +ModelWrappers = "0.5" PrettyTables = "2" ProgressMeter = "1.7" SimpleUnPack = "1" diff --git a/src/Baytes.jl b/src/Baytes.jl index a85ff2c..2a08fb8 100644 --- a/src/Baytes.jl +++ b/src/Baytes.jl @@ -48,6 +48,8 @@ using ModelWrappers: ModelWrapper, Tagged, Objective, + length_constrained, + length_unconstrained, #= DiffObjective, AbstractDifferentiableTune, ℓObjectiveResult, diff --git a/src/sampling/chain.jl b/src/sampling/chain.jl index a36cbe7..7cd93c6 100644 --- a/src/sampling/chain.jl +++ b/src/sampling/chain.jl @@ -141,7 +141,7 @@ function chainsummary( @unpack Ndigits, quantiles = printdefault @unpack progress = trace.summary @unpack tagged, paramnames = transform - Nparams = length(tagged) + Nparams = length_constrained(tagged) Nchains = length(transform.chains) ## Flatten parameter to 3D array computingtime = progress.enabled ? (progress.tlast - progress.tinit) : NaN diff --git a/src/sampling/inference.jl b/src/sampling/inference.jl index dce8ab5..27826ce 100644 --- a/src/sampling/inference.jl +++ b/src/sampling/inference.jl @@ -140,7 +140,7 @@ function trace_to_3DArray( ## Get trace information @unpack tagged, chains, effective_iterations = transform ## Preallocate array - mcmcchain = zeros(length(effective_iterations), length(chains), length(tagged)) + mcmcchain = zeros(length(effective_iterations), length(chains), length_constrained(tagged)) ## Flatten corresponding parameter #!NOTE: Commented-out loop below is threadsave, but chain is not flattened in correct order, which might be troublesome for MCMC chain analysis. # Threads.@threads for (idx, chain) in collect(enumerate(chains)) @@ -171,7 +171,7 @@ function trace_to_3DArrayᵤ( ## Get trace information @unpack tagged, chains, effective_iterations = transform ## Preallocate array - mcmcchain = zeros(length(effective_iterations), length(chains), length(tagged)) + mcmcchain = zeros(length(effective_iterations), length(chains), length_unconstrained(tagged)) ## Flatten corresponding parameter #!NOTE: This is threadsave, but chain is not flattened in correct order, which might be troublesome for MCMC chain analysis. # Threads.@threads for (idx, chain) in collect(enumerate(chains)) @@ -179,9 +179,8 @@ function trace_to_3DArrayᵤ( for (iter0, iterburnin) in enumerate(effective_iterations) # mcmcchain[iter0, :, idx] .= mcmcchain[iter0, idx, :] .= - flatten(tagged.info.reconstruct, - unconstrain(tagged.info.transform, subset(trace.val[chain][iterburnin], tagged.parameter) ) - ) +# flatten(tagged.info.reconstruct, unconstrain(tagged.info.transform, subset(trace.val[chain][iterburnin], tagged.parameter) ) ) + ModelWrappers.unconstrain_flatten(tagged.info, subset(trace.val[chain][iterburnin], tagged.parameter)) end end ## Return MCMCChain @@ -205,7 +204,7 @@ function trace_to_2DArray( ## Get trace information @unpack tagged, chains, effective_iterations = transform ## Preallocate array - mcmcchain = zeros(length(effective_iterations) * length(chains), length(tagged)) + mcmcchain = zeros(length(effective_iterations) * length(chains), length_constrained(tagged)) ## Flatten corresponding parameter #!NOTE: This is threadsave, but chain is not flattened in correct order, which might be troublesome for MCMC chain analysis. # Threads.@threads for (idx, chain) in collect(enumerate(chains)) @@ -235,7 +234,7 @@ function trace_to_2DArrayᵤ( ## Get trace information @unpack tagged, chains, effective_iterations = transform ## Preallocate array - mcmcchain = zeros(length(effective_iterations) * length(chains), length(tagged)) + mcmcchain = zeros(length(effective_iterations) * length(chains), length_unconstrained(tagged)) ## Flatten corresponding parameter #!NOTE: This is threadsave, but chain is not flattened in correct ordered, which might be troublesome for MCMC chain analysis. # Threads.@threads for (idx, chain) in collect(enumerate(chains)) @@ -243,9 +242,9 @@ function trace_to_2DArrayᵤ( for (idx, chain) in collect(enumerate(chains)) for (iter0, iterburnin) in enumerate(effective_iterations) iter += 1 - mcmcchain[iter, :] .= flatten(tagged.info.reconstruct, - unconstrain(tagged.info.transform, subset(trace.val[chain][iterburnin], tagged.parameter) ) - ) + mcmcchain[iter, :] .= +# flatten(tagged.info.reconstruct, unconstrain(tagged.info.transform, subset(trace.val[chain][iterburnin], tagged.parameter) ) ) + ModelWrappers.unconstrain_flatten(tagged.info, subset(trace.val[chain][iterburnin], tagged.parameter)) end end ## Return MCMCChain @@ -329,7 +328,7 @@ function flatten_chainvals( ## Get trace information @unpack tagged, chains, effective_iterations = transform ## Preallocate array - mcmcchain = [ [ zeros(tagged.info.reconstruct.default.output, length(tagged)) for _ in eachindex(effective_iterations) ] for _ in eachindex(chains) ] + mcmcchain = [ [ zeros(tagged.info.reconstruct.default.output, length_constrained(tagged)) for _ in eachindex(effective_iterations) ] for _ in eachindex(chains) ] ## Flatten corresponding parameter #!NOTE: This is threadsave, but chain is not flattened in correct order, which might be troublesome for MCMC chain analysis, hence we opt out of it. # Threads.@threads for (idx, chain) in collect(enumerate(chains))