From faa0af4369c1720295e3ecc2c1238a9d844b5afc Mon Sep 17 00:00:00 2001 From: Patrick Aschermayr Date: Sat, 19 Aug 2023 15:21:47 +0200 Subject: [PATCH] add BaytesOptim --- Project.toml | 2 +- src/Baytes.jl | 7 ++++++- src/sampling/utility.jl | 10 +++++----- test/test-construction.jl | 33 +++++++++------------------------ 4 files changed, 21 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index bdeaf82..63797e4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Baytes" uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f" authors = ["Patrick Aschermayr "] -version = "0.3.10" +version = "0.3.11" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/Baytes.jl b/src/Baytes.jl index 2a08fb8..1e43c8b 100644 --- a/src/Baytes.jl +++ b/src/Baytes.jl @@ -67,7 +67,7 @@ using ModelWrappers: UnflattenStrict, UnflattenFlexible -using BaytesMCMC, BaytesFilters, BaytesPMCMC, BaytesSMC +using BaytesMCMC, BaytesFilters, BaytesPMCMC, BaytesSMC, BaytesOptim #Utility tools import Base: Base, summary @@ -195,6 +195,11 @@ export ParticleMetropolis, ParticleGibbs, + ## Optimizer + Optimizer, + OptimConstructor, + OptimDefault, + ## BaytesSMC SMC, SMCDefault, diff --git a/src/sampling/utility.jl b/src/sampling/utility.jl index 376be2e..2bdf149 100644 --- a/src/sampling/utility.jl +++ b/src/sampling/utility.jl @@ -72,13 +72,13 @@ Check if we can capture results from last proposal step. Typically only possible """ function update(datatune::DataTune, temperingadaption::B, smc::SMCConstructor) where {B<:BaytesCore.UpdateBool} #!NOTE: If a single SMC constructor is used, we can capture previous results and do not need to update sampler before new iteration. - #!NOTE: SMC only updates parameter if UpdateTrue(). Log-target/gradients will always be updated if jitter step applied. + #!NOTE: SMC only updates parameter if UpdateTrue(). Log-target/gradients will always be updated if jitter step applied in either case return BaytesCore.UpdateFalse() end -function update(datatune::DataTune{<:B}, temperingadaption::BaytesCore.UpdateFalse, mcmc::MCMCConstructor) where {B<:Batch} - #!NOTE: If a MCMC constructor is used and no tempering is applied, we can capture previous results and do not need to update sampler before new iteration. - #!NOTE: This only holds for Batch data in MCMC case - #!NOTE: MCMC updates log target evaluation and eventual gradients if UpdateTrue() +function update(datatune::DataTune{<:B}, temperingadaption::BaytesCore.UpdateFalse, algorithm::C) where {B<:Batch, C<:Union{MCMCConstructor, OptimConstructor}} + #!NOTE: If a MCMC/Optimizer constructor is used and no tempering is applied, we can capture previous results and do not need to update sampler before new iteration. + #!NOTE: This only holds for Batch data in MCMC/Optimizer case + #!NOTE: MCMC/Optimizer updates log target evaluation and gradients if UpdateTrue() return BaytesCore.UpdateFalse() end function update(datatune::DataTune, temperingadaption, args...) diff --git a/test/test-construction.jl b/test/test-construction.jl index 1e97a4f..242872e 100644 --- a/test/test-construction.jl +++ b/test/test-construction.jl @@ -16,6 +16,7 @@ tempermethods = [ iter = 2 tempermethod = tempermethods[iter] =# + ############################################################################################ @testset "Sampling, type conversion" begin for tempermethod in tempermethods @@ -323,7 +324,8 @@ using Optim, NLSolversBase ) Optimizer(OptimLBFG, :μ) trace, algorithms = sample(_rng, _obj.model, _obj.data, _oc; default = deepcopy(sampledefault)) - + ## If single optimizer kernel assigned, can capture previous results + @test isa(trace.summary.info.captured, typeof(temperupdate)) ## Inference Section transform = Baytes.TraceTransform(trace, _obj.model) postmean = trace_to_posteriormean(trace, transform) @@ -344,17 +346,19 @@ using Optim, NLSolversBase #Check printing commands printchainsummary(trace, transform, Val(:text)) printchainsummary(_obj.model, trace, transform, Val(:text)) -#= - ## SMC + +## SMC via IBIS ibis = SMCConstructor(_oc, SMCDefault(jitterthreshold=0.99, resamplingthreshold=1.0)) trace, algorithms = sample(_rng, _obj.model, _obj.data, ibis; default = deepcopy(sampledefault)) - ## If single mcmc kernel assigned, can capture previous results + ## Always update Gradient Result if new data is added + #!NOTE: But after first iteration, can capture results @test isa(trace.summary.info.captured, UpdateFalse) ## Continue sampling newdat = randn(_rng, length(_obj.data)+100) trace2, algorithms2 = sample!(100, _rng, _obj.model, newdat, trace, algorithms) + #!NOTE: But after first iteration, can capture results @test isa(trace2.summary.info.captured, UpdateFalse) -=# + ## Combinations trace, algorithms = sample(_rng, _obj.model, _obj.data, mcmc, _oc; default = deepcopy(sampledefault)) transform = Baytes.TraceTransform(trace, _obj.model) @@ -367,25 +371,6 @@ using Optim, NLSolversBase end end - - - - - - - - - - - - - - - - - - - ############################################################################################ #Utility @testset "Utility, maxiterations" begin