From faa0af4369c1720295e3ecc2c1238a9d844b5afc Mon Sep 17 00:00:00 2001
From: Patrick Aschermayr
Date: Sat, 19 Aug 2023 15:21:47 +0200
Subject: [PATCH] add BaytesOptim
---
Project.toml | 2 +-
src/Baytes.jl | 7 ++++++-
src/sampling/utility.jl | 10 +++++-----
test/test-construction.jl | 33 +++++++++------------------------
4 files changed, 21 insertions(+), 31 deletions(-)
diff --git a/Project.toml b/Project.toml
index bdeaf82..63797e4 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr "]
-version = "0.3.10"
+version = "0.3.11"
[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
diff --git a/src/Baytes.jl b/src/Baytes.jl
index 2a08fb8..1e43c8b 100644
--- a/src/Baytes.jl
+++ b/src/Baytes.jl
@@ -67,7 +67,7 @@ using ModelWrappers:
UnflattenStrict,
UnflattenFlexible
-using BaytesMCMC, BaytesFilters, BaytesPMCMC, BaytesSMC
+using BaytesMCMC, BaytesFilters, BaytesPMCMC, BaytesSMC, BaytesOptim
#Utility tools
import Base: Base, summary
@@ -195,6 +195,11 @@ export
ParticleMetropolis,
ParticleGibbs,
+ ## Optimizer
+ Optimizer,
+ OptimConstructor,
+ OptimDefault,
+
## BaytesSMC
SMC,
SMCDefault,
diff --git a/src/sampling/utility.jl b/src/sampling/utility.jl
index 376be2e..2bdf149 100644
--- a/src/sampling/utility.jl
+++ b/src/sampling/utility.jl
@@ -72,13 +72,13 @@ Check if we can capture results from last proposal step. Typically only possible
"""
function update(datatune::DataTune, temperingadaption::B, smc::SMCConstructor) where {B<:BaytesCore.UpdateBool}
#!NOTE: If a single SMC constructor is used, we can capture previous results and do not need to update sampler before new iteration.
- #!NOTE: SMC only updates parameter if UpdateTrue(). Log-target/gradients will always be updated if jitter step applied.
+ #!NOTE: SMC only updates parameter if UpdateTrue(). Log-target/gradients will always be updated if jitter step applied in either case
return BaytesCore.UpdateFalse()
end
-function update(datatune::DataTune{<:B}, temperingadaption::BaytesCore.UpdateFalse, mcmc::MCMCConstructor) where {B<:Batch}
- #!NOTE: If a MCMC constructor is used and no tempering is applied, we can capture previous results and do not need to update sampler before new iteration.
- #!NOTE: This only holds for Batch data in MCMC case
- #!NOTE: MCMC updates log target evaluation and eventual gradients if UpdateTrue()
+function update(datatune::DataTune{<:B}, temperingadaption::BaytesCore.UpdateFalse, algorithm::C) where {B<:Batch, C<:Union{MCMCConstructor, OptimConstructor}}
+ #!NOTE: If a MCMC/Optimizer constructor is used and no tempering is applied, we can capture previous results and do not need to update sampler before new iteration.
+ #!NOTE: This only holds for Batch data in MCMC/Optimizer case
+ #!NOTE: MCMC/Optimizer updates log target evaluation and gradients if UpdateTrue()
return BaytesCore.UpdateFalse()
end
function update(datatune::DataTune, temperingadaption, args...)
diff --git a/test/test-construction.jl b/test/test-construction.jl
index 1e97a4f..242872e 100644
--- a/test/test-construction.jl
+++ b/test/test-construction.jl
@@ -16,6 +16,7 @@ tempermethods = [
iter = 2
tempermethod = tempermethods[iter]
=#
+
############################################################################################
@testset "Sampling, type conversion" begin
for tempermethod in tempermethods
@@ -323,7 +324,8 @@ using Optim, NLSolversBase
)
Optimizer(OptimLBFG, :μ)
trace, algorithms = sample(_rng, _obj.model, _obj.data, _oc; default = deepcopy(sampledefault))
-
+ ## If single optimizer kernel assigned, can capture previous results
+ @test isa(trace.summary.info.captured, typeof(temperupdate))
## Inference Section
transform = Baytes.TraceTransform(trace, _obj.model)
postmean = trace_to_posteriormean(trace, transform)
@@ -344,17 +346,19 @@ using Optim, NLSolversBase
#Check printing commands
printchainsummary(trace, transform, Val(:text))
printchainsummary(_obj.model, trace, transform, Val(:text))
-#=
- ## SMC
+
+## SMC via IBIS
ibis = SMCConstructor(_oc, SMCDefault(jitterthreshold=0.99, resamplingthreshold=1.0))
trace, algorithms = sample(_rng, _obj.model, _obj.data, ibis; default = deepcopy(sampledefault))
- ## If single mcmc kernel assigned, can capture previous results
+ ## Always update Gradient Result if new data is added
+ #!NOTE: But after first iteration, can capture results
@test isa(trace.summary.info.captured, UpdateFalse)
## Continue sampling
newdat = randn(_rng, length(_obj.data)+100)
trace2, algorithms2 = sample!(100, _rng, _obj.model, newdat, trace, algorithms)
+ #!NOTE: But after first iteration, can capture results
@test isa(trace2.summary.info.captured, UpdateFalse)
-=#
+
## Combinations
trace, algorithms = sample(_rng, _obj.model, _obj.data, mcmc, _oc; default = deepcopy(sampledefault))
transform = Baytes.TraceTransform(trace, _obj.model)
@@ -367,25 +371,6 @@ using Optim, NLSolversBase
end
end
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
############################################################################################
#Utility
@testset "Utility, maxiterations" begin