Skip to content

Commit

Permalink
Update naming convention for trace
Browse files Browse the repository at this point in the history
  • Loading branch information
paschermayr committed Oct 24, 2022
1 parent 130c762 commit 1e45efb
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Baytes"
uuid = "72ddfcfc-6e9d-43df-829b-7aed7c549d4f"
authors = ["Patrick Aschermayr <p.aschermayr@gmail.com>"]
version = "0.1.16"
version = "0.2.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
2 changes: 1 addition & 1 deletion src/sampling/chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ function chainsummary(
)
## Assign utility values
@unpack Ndigits, quantiles = printdefault
@unpack progress = trace.info
@unpack progress = trace.summary
@unpack tagged, paramnames = transform
Nparams = length(tagged)
Nchains = length(transform.chains)
Expand Down
2 changes: 1 addition & 1 deletion src/sampling/diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function printdiagnosticssummary(
## Assign utility variables
@unpack effective_iterations = transform
@unpack Ndigits, quantiles = printdefault
@unpack Nchains, Nalgorithms, burnin = trace.info.sampling
@unpack Nchains, Nalgorithms, burnin = trace.summary.info
## Print diagnostics for each sampler for each chain
println(
"#####################################################################################",
Expand Down
12 changes: 6 additions & 6 deletions src/sampling/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ end
function TraceTransform(
trace::Trace,
model::ModelWrapper,
tagged::Tagged = Tagged(model, trace.info.sampling.printedparam.printed),
tagged::Tagged = Tagged(model, trace.summary.info.printedparam.printed),
info::TransformInfo = TransformInfo(
collect(Base.OneTo(trace.info.sampling.Nchains)),
collect(Base.OneTo(trace.info.sampling.Nalgorithms)),
trace.info.sampling.burnin,
trace.info.sampling.thinning,
trace.info.sampling.iterations
collect(Base.OneTo(trace.summary.info.Nchains)),
collect(Base.OneTo(trace.summary.info.Nalgorithms)),
trace.summary.info.burnin,
trace.summary.info.thinning,
trace.summary.info.iterations
)
)
@unpack chains, algorithms, burnin, thinning, maxiterations = info
Expand Down
2 changes: 1 addition & 1 deletion src/sampling/logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Return `Progress` struct with arguments from `info` for sampling session.
```
"""
function progress(report::ProgressReport, info::SamplingInfo)
function progress(report::ProgressReport, info::SampleInfo)
return ProgressMeter.Progress(
info.iterations * info.Nalgorithms * info.Nchains;
enabled=report.bar,
Expand Down
16 changes: 8 additions & 8 deletions src/sampling/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ function sample(
ArgCheck.@argcheck iterations > burnin "Burnin set higher than number of iterations."
## Check if we can capture previous samples
updatesampler = update(datatune, tempering.adaption, args...)
## Construct SamplingInfo and ProgressLog
## Construct SampleInfo and ProgressLog
printedparameter = PrintedParameter(showparam(model, datatune, args...)...)
info = SamplingInfo(printedparameter, iterations, burnin, thinning, length(args), chains, updatesampler, tempering.adaption)
info = SampleInfo(printedparameter, iterations, burnin, thinning, length(args), chains, updatesampler, tempering.adaption)
progressmeter = progress(report, info)
## Initialize algorithms
println("Constructing new sampler...")
Expand All @@ -56,7 +56,7 @@ function sample(
## Initialize trace
trace = Trace(
_rng, algorithmᵛ, model, BaytesCore.adjust(datatune, data),
TraceInfo(tempertune, datatuneᵛ, default, info, progressmeter)
TraceSummary(tempertune, datatuneᵛ, default, info, progressmeter)
)
## Loop through iterations
println("Sampling starts...")
Expand Down Expand Up @@ -91,22 +91,22 @@ function sample!(iterations::Integer,
_rng::Random.AbstractRNG, model::M, data::D,
trace::Trace, algorithmᵛ
) where {M<:ModelWrapper,D}
@unpack tempertune, datatune, sampling, default = trace.info
@unpack Nalgorithms, Nchains, burnin, thinning, captured, tempered = sampling
@unpack tempertune, datatune, info, default = trace.summary
@unpack Nalgorithms, Nchains, burnin, thinning, captured, tempered = info
@unpack safeoutput, printoutput, printdefault, report = default
## Create new DataTune struct, taking into account current Index and data dimension
datatune_new = update(datatune, data)
## Check if iterations have to be adjusted if sequential data is used
iterations = maxiterations(datatune_new, iterations)
ArgCheck.@argcheck iterations > burnin "Burnin set higher than number of iterations."
info = SamplingInfo(sampling.printedparam, iterations, burnin, thinning, Nalgorithms, Nchains, captured, tempered)
progressmeter = progress(report, info)
info_new = SampleInfo(info.printedparam, iterations, burnin, thinning, Nalgorithms, Nchains, captured, tempered)
progressmeter = progress(report, info_new)
## Construct new models for algorithms
modelᵛ, datatuneᵛ = construct(model, datatune_new, Nchains, algorithmᵛ)
## Construct new trace to store new samples
trace_new = Trace(
_rng, algorithmᵛ, model, BaytesCore.adjust(datatune_new, data),
TraceInfo(tempertune, datatuneᵛ, default, info, progressmeter)
TraceSummary(tempertune, datatuneᵛ, default, info_new, progressmeter)
)
## Loop through iterations
println("Sampling starts...")
Expand Down
56 changes: 28 additions & 28 deletions src/sampling/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,36 @@ Contains useful information for post-sampling analysis. Also allows to continue
# Fields
$(TYPEDFIELDS)
"""
struct TraceInfo{
struct TraceSummary{
A<:TemperingMethod,
B<:Union{DataTune, Vector{<:DataTune}},
D<:BaytesCore.SampleDefault,
S<:SamplingInfo
S<:SampleInfo
}
"Tuning container for temperature tempering"
tempertune::A
"Tuning container for data tempering"
datatune::B
"Default information used for sample function"
"Default settings used for sample function"
default::D
"Information about trace used for postprocessing."
sampling::S
info::S
"Progress Log while sampling."
progress::ProgressMeter.Progress
function TraceInfo(
function TraceSummary(
tempertune::A,
datatune::B,
default::D,
sampling::S,
info::S,
progress::ProgressMeter.Progress,
) where {
A<:TemperingMethod,
B<:Union{DataTune, Vector{<:DataTune}},
D<:BaytesCore.SampleDefault,
S<:SamplingInfo
S<:SampleInfo
}
## Return info
return new{A,B,D,S}(tempertune, datatune, default, sampling, progress)
return new{A,B,D,S}(tempertune, datatune, default, info, progress)
end
end

Expand All @@ -49,20 +49,20 @@ Contains sampling chain and diagnostics for given algorithms.
# Fields
$(TYPEDFIELDS)
"""
struct Trace{C<:TraceInfo, A<:NamedTuple,B}
struct Trace{C<:TraceSummary, A<:NamedTuple,B}
"Model samples ~ out vector for corresponding chain, inner vector for iteration"
val::Vector{Vector{A}}
"Algorithm diagnostics ~ out vector for corresponding chain, inner vector for iteration"
diagnostics::Vector{B}
"Information about trace used for postprocessing."
info::C
summary::C
function Trace(
val::Vector{Vector{A}},
diagnostics::Vector{B},
info::C,
) where {A,B,C<:TraceInfo}
summary::C,
) where {A,B,C<:TraceSummary}
## Return trace
return new{C,A,B}(val, diagnostics, info)
return new{C,A,B}(val, diagnostics, summary)
end
end

Expand All @@ -71,16 +71,16 @@ function Trace(
algorithmᵛ::A,
model::ModelWrapper,
data::D,
info::TraceInfo,
summary::TraceSummary,
) where {A,D}
@unpack iterations, Nchains = info.sampling
@unpack iterations, Nchains = summary.info
## Create Model Parameter buffer
val = [Vector{typeof(model.val)}(undef, iterations) for _ in Base.OneTo(Nchains)]
## Create Diagnostics buffer for each algorithm used
diagtypes = infer(_rng, AbstractDiagnostics, algorithmᵛ, model, data)
diagnostics = diagnosticsbuffer(diagtypes, iterations, Nchains, algorithmᵛ)
## Return trace
return Trace(val, diagnostics, info)
return Trace(val, diagnostics, summary)
end

############################################################################################
Expand All @@ -98,13 +98,13 @@ Note that smc still works as intended if used alongside other mcmc sampler in `a
"""
function propose!(
_rng::Random.AbstractRNG,
trace::Trace{<:TraceInfo{<:BaytesCore.IterationTempering}},
trace::Trace{<:TraceSummary{<:BaytesCore.IterationTempering}},
algorithmᵛ::AbstractVector,
modelᵛ::Vector{M},
data::D,
) where {M<:ModelWrapper,D}
@unpack default, tempertune, datatune, sampling, progress = trace.info
@unpack iterations, Nchains, Nalgorithms, captured = sampling
@unpack default, tempertune, datatune, info, progress = trace.summary
@unpack iterations, Nchains, Nalgorithms, captured = info
@unpack log = default.report
## Propagate through data
Base.Threads.@threads for Nchain in Base.OneTo(Nchains)
Expand Down Expand Up @@ -136,13 +136,13 @@ end

function propose!(
_rng::Random.AbstractRNG,
trace::Trace{<:TraceInfo{<:BaytesCore.JointTempering}},
trace::Trace{<:TraceSummary{<:BaytesCore.JointTempering}},
algorithmᵛ::AbstractVector,
modelᵛ::Vector{M},
data::D,
) where {M<:ModelWrapper,D}
@unpack default, tempertune, datatune, sampling, progress = trace.info
@unpack iterations, Nchains, Nalgorithms, captured = sampling
@unpack default, tempertune, datatune, info, progress = trace.summary
@unpack iterations, Nchains, Nalgorithms, captured = info
@unpack log = default.report
## Compute initial temperature
temperature = BaytesCore.initial(tempertune)
Expand Down Expand Up @@ -179,8 +179,8 @@ function propose!(
modelᵛ::M,
data::D,
) where {T<:Trace,M<:ModelWrapper,D}
@unpack default, tempertune, datatune, sampling, progress = trace.info
@unpack iterations, Nchains, Nalgorithms, captured = sampling
@unpack default, tempertune, datatune, info, progress = trace.summary
@unpack iterations, Nchains, Nalgorithms, captured = info
@unpack log = default.report
## Compute initial temperature
temperature = BaytesCore.initial(tempertune)
Expand Down Expand Up @@ -257,11 +257,11 @@ function savetrace(trace::Trace, model::ModelWrapper, algorithm,
"M",
Dates.minute(Dates.now()),
"_Nchains",
trace.info.sampling.Nchains,
trace.summary.info.Nchains,
"_Iter",
trace.info.sampling.iterations,
trace.summary.info.iterations,
"_Burnin",
trace.info.sampling.burnin,
trace.summary.info.burnin,
))
)
JLD2.jldsave(
Expand All @@ -275,4 +275,4 @@ end

############################################################################################
#export
export TraceInfo, Trace, propose!, savetrace
export TraceSummary, Trace, propose!, savetrace
6 changes: 3 additions & 3 deletions src/sampling/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Contains several useful information for constructing sampler.
# Fields
$(TYPEDFIELDS)
"""
struct SamplingInfo{A<:PrintedParameter, U<:BaytesCore.UpdateBool, B<:BaytesCore.UpdateBool}
struct SampleInfo{A<:PrintedParameter, U<:BaytesCore.UpdateBool, B<:BaytesCore.UpdateBool}
"Parameter settings for printing."
printedparam::A
"Total number of sampling iterations."
Expand All @@ -46,7 +46,7 @@ struct SamplingInfo{A<:PrintedParameter, U<:BaytesCore.UpdateBool, B<:BaytesCore
captured::U
"Boolean if temperature is adapted for target function."
tempered::B
function SamplingInfo(
function SampleInfo(
printedparam::A,
iterations::Int64,
burnin::Int64,
Expand Down Expand Up @@ -214,4 +214,4 @@ end

############################################################################################
#export
export SamplingInfo, update, infer
export SampleInfo, update, infer
Loading

2 comments on commit 1e45efb

@paschermayr
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/70916

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 1e45efb3f78b7bcb887853063209e0c55d530393
git push origin v0.2.0

Please sign in to comment.