Skip to content

Commit

Permalink
add BaytesOptim
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Aug 19, 2023
1 parent 78d59a2 commit faa0af4
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr <p.aschermayr@gmail.com>"]
version = "0.3.10"
version = "0.3.11"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
7 changes: 6 additions & 1 deletion src/Baytes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ using ModelWrappers:
UnflattenStrict,
UnflattenFlexible

using BaytesMCMC, BaytesFilters, BaytesPMCMC, BaytesSMC
using BaytesMCMC, BaytesFilters, BaytesPMCMC, BaytesSMC, BaytesOptim

#Utility tools
import Base: Base, summary
Expand Down Expand Up @@ -195,6 +195,11 @@ export
ParticleMetropolis,
ParticleGibbs,

## Optimizer
Optimizer,
OptimConstructor,
OptimDefault,

## BaytesSMC
SMC,
SMCDefault,
Expand Down
10 changes: 5 additions & 5 deletions src/sampling/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ Check if we can capture results from last proposal step. Typically only possible
"""
function update(datatune::DataTune, temperingadaption::B, smc::SMCConstructor) where {B<:BaytesCore.UpdateBool}
#!NOTE: If a single SMC constructor is used, we can capture previous results and do not need to update sampler before new iteration.
#!NOTE: SMC only updates parameter if UpdateTrue(). Log-target/gradients will always be updated if jitter step applied.
#!NOTE: SMC only updates parameter if UpdateTrue(). Log-target/gradients will always be updated if jitter step applied in either case
return BaytesCore.UpdateFalse()
end
function update(datatune::DataTune{<:B}, temperingadaption::BaytesCore.UpdateFalse, mcmc::MCMCConstructor) where {B<:Batch}
#!NOTE: If a MCMC constructor is used and no tempering is applied, we can capture previous results and do not need to update sampler before new iteration.
#!NOTE: This only holds for Batch data in MCMC case
#!NOTE: MCMC updates log target evaluation and eventual gradients if UpdateTrue()
function update(datatune::DataTune{<:B}, temperingadaption::BaytesCore.UpdateFalse, algorithm::C) where {B<:Batch, C<:Union{MCMCConstructor, OptimConstructor}}
#!NOTE: If a MCMC/Optimizer constructor is used and no tempering is applied, we can capture previous results and do not need to update sampler before new iteration.
#!NOTE: This only holds for Batch data in MCMC/Optimizer case
#!NOTE: MCMC/Optimizer updates log target evaluation and gradients if UpdateTrue()
return BaytesCore.UpdateFalse()
end
function update(datatune::DataTune, temperingadaption, args...)
Expand Down
33 changes: 9 additions & 24 deletions test/test-construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ tempermethods = [
iter = 2
tempermethod = tempermethods[iter]
=#

############################################################################################
@testset "Sampling, type conversion" begin
for tempermethod in tempermethods
Expand Down Expand Up @@ -323,7 +324,8 @@ using Optim, NLSolversBase
)
Optimizer(OptimLBFG, )
trace, algorithms = sample(_rng, _obj.model, _obj.data, _oc; default = deepcopy(sampledefault))

## If single optimizer kernel assigned, can capture previous results
@test isa(trace.summary.info.captured, typeof(temperupdate))
## Inference Section
transform = Baytes.TraceTransform(trace, _obj.model)
postmean = trace_to_posteriormean(trace, transform)
Expand All @@ -344,17 +346,19 @@ using Optim, NLSolversBase
#Check printing commands
printchainsummary(trace, transform, Val(:text))
printchainsummary(_obj.model, trace, transform, Val(:text))
#=
## SMC

## SMC via IBIS
ibis = SMCConstructor(_oc, SMCDefault(jitterthreshold=0.99, resamplingthreshold=1.0))
trace, algorithms = sample(_rng, _obj.model, _obj.data, ibis; default = deepcopy(sampledefault))
## If single mcmc kernel assigned, can capture previous results
## Always update Gradient Result if new data is added
#!NOTE: But after first iteration, can capture results
@test isa(trace.summary.info.captured, UpdateFalse)
## Continue sampling
newdat = randn(_rng, length(_obj.data)+100)
trace2, algorithms2 = sample!(100, _rng, _obj.model, newdat, trace, algorithms)
#!NOTE: But after first iteration, can capture results
@test isa(trace2.summary.info.captured, UpdateFalse)
=#

## Combinations
trace, algorithms = sample(_rng, _obj.model, _obj.data, mcmc, _oc; default = deepcopy(sampledefault))
transform = Baytes.TraceTransform(trace, _obj.model)
Expand All @@ -367,25 +371,6 @@ using Optim, NLSolversBase
end
end




















############################################################################################
#Utility
@testset "Utility, maxiterations" begin
Expand Down

2 comments on commit faa0af4

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/89926

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.11 -m "<description of version>" faa0af4369c1720295e3ecc2c1238a9d844b5afc
git push origin v0.3.11

Please sign in to comment.