From 9f1c7bc7eff6023a202f5d71f9d2cd19a99ae3f2 Mon Sep 17 00:00:00 2001 From: Jason Christopherson Date: Tue, 18 Feb 2025 07:46:46 -0600 Subject: [PATCH] Update Metropolis-Hastings code --- examples/mcmc_example.f90 | 54 ++++++++++++++++-- src/fstats_mcmc.f90 | 114 ++++++++++++++++++++++++++++++++++---- 2 files changed, 150 insertions(+), 18 deletions(-) diff --git a/examples/mcmc_example.f90 b/examples/mcmc_example.f90 index 59fd86e..5043c3c 100644 --- a/examples/mcmc_example.f90 +++ b/examples/mcmc_example.f90 @@ -10,25 +10,67 @@ program example type(metropolis_hastings) :: mcmc ! Plot Variables - type(plot_2d) :: plt + type(multiplot) :: mplt + type(plot_2d) :: plt, plt1, plt2, plt3 + class(plot_axis), pointer :: xAxis, yAxis type(plot_data_2d) :: pd + class(legend), pointer :: lgnd + type(plot_data_histogram) :: pdh + class(terminal), pointer :: term - ! Create an initial estimate - call random_number(xi) + ! Create an initial estimate - intentionally starting outside of the + ! target distribution for illustration purposes only. + xi = [1.0d1, 1.0d1] ! Sample a multivariate normal distribution using MH call mcmc%sample(xi) - ! Get the chains + ! Get the chains - keep burn-in points for illustration purposes chains = mcmc%get_chain() - ! Plot the results - call plt%initialize() +! ------------------------------------------------------------------------------ + ! Plot histograms of the chains + call mplt%initialize(1, 3) + term => mplt%get_terminal() + call plt1%initialize() + call plt2%initialize() + call plt3%initialize() + xAxis => plt3%get_x_axis() + yAxis => plt3%get_y_axis() + call term%set_window_height(500) + call term%set_window_width(1500) + call plt1%set_title("x_{1}") + call plt2%set_title("x_{2}") + call plt3%set_title("x_{1} vs. x_{2}") + call xAxis%set_title("x_{1}") + call yAxis%set_title("x_{2}") + call pdh%define_data(chains(:,1)) + call pdh%set_transparency(0.5) + call plt1%push(pdh) + call pdh%define_data(chains(:,2)) + call plt2%push(pdh) call pd%define_data(chains(:,1), chains(:,2)) call pd%set_draw_line(.false.) call pd%set_draw_markers(.true.) call pd%set_marker_style(MARKER_FILLED_CIRCLE) call pd%set_marker_scaling(0.5) + call plt3%push(pd) + call mplt%set(1, 1, plt1) + call mplt%set(1, 2, plt2) + call mplt%set(1, 3, plt3) + call mplt%draw() + + ! Plot the chain + call plt%initialize() + lgnd => plt%get_legend() + call lgnd%set_is_visible(.true.) + call pd%define_data(chains(:,1)) + call pd%set_draw_line(.true.) + call pd%set_draw_markers(.false.) + call pd%set_name("x_{1}") + call plt%push(pd) + call pd%define_data(chains(:,2)) + call pd%set_name("x_{2}") call plt%push(pd) call plt%draw() end program \ No newline at end of file diff --git a/src/fstats_mcmc.f90 b/src/fstats_mcmc.f90 index 18754ad9..ae6d559 100644 --- a/src/fstats_mcmc.f90 +++ b/src/fstats_mcmc.f90 @@ -14,7 +14,13 @@ module fstats_mcmc type metropolis_hastings !! An implementation of the Metropolis-Hastings algorithm for the - !! generation of a Markov chain. + !! generation of a Markov chain. This is a default implementation + !! that allows sampling of normally distributed posterior distributions + !! centered on zero with unit standard deviations. Proposals are + !! generated from a multivariate normal distribution with an identity + !! covariance matrix and centered on zero. To alter these sampling + !! and target distributions simply create a new class inheriting from + !! this class and override the appropriate routines. integer(int32), private :: initial_iteration_estimate = 10000 !! An initial estimate at the number of allowed iterations. integer(int32), private :: m_bufferSize = 0 @@ -34,6 +40,9 @@ module fstats_mcmc procedure, public :: compute_hastings_ratio => mh_hastings_ratio procedure, public :: target_distribution => mh_target procedure, public :: sample => mh_sample + procedure, public :: reset => mh_clear_chain + procedure, public :: on_acceptance => mh_on_success + procedure, public :: on_rejection => mh_on_rejection procedure, private :: resize_buffer => mh_resize_buffer procedure, private :: get_buffer_length => mh_get_buffer_length end type @@ -178,17 +187,22 @@ subroutine mh_push(this, x, err) end subroutine ! ------------------------------------------------------------------------------ -function mh_get_chain(this, err) result(rst) +function mh_get_chain(this, bin, err) result(rst) !! Gets a copy of the stored Markov chain. class(metropolis_hastings), intent(in) :: this !! The metropolis_hastings object. + real(real64), intent(in), optional :: bin + !! An optional input allowing for a burn-in region. The parameter + !! represents the amount (percentage-based) of the overall chain to + !! disregard as "burn-in" values. The value shoud exist on [0, 1). + !! The default value is 0 such that no values are disregarded. class(errors), intent(inout), optional, target :: err !! The error handling object. real(real64), allocatable, dimension(:,:) :: rst !! The resulting chain with each parameter represented by a column. ! Local Variables - integer(int32) :: npts, nvar, flag + integer(int32) :: npts, nvar, flag, nstart class(errors), pointer :: errmgr type(errors), target :: deferr @@ -200,10 +214,16 @@ function mh_get_chain(this, err) result(rst) end if npts = this%get_chain_length() nvar = this%get_state_variable_count() + if (present(bin)) then + nstart = floor(bin * npts) + npts = npts - nstart + else + nstart = 1 + end if ! Process allocate(rst(npts, nvar), stat = flag, & - source = this%m_buffer(1:npts,1:nvar)) + source = this%m_buffer(nstart:,1:nvar)) if (flag /= 0) then call report_memory_error(errmgr, "mh_get_chain", flag) return @@ -268,7 +288,7 @@ function mh_target(this, x) result(rst) !! define the desired distribution. The default behavior of this !! routine is to sammple a multivariate normal distribution with a mean !! of zero and a variance of one (identity covariance matrix). - class(metropolis_hastings), intent(in) :: this + class(metropolis_hastings), intent(inout) :: this !! The metropolis_hastings object. real(real64), intent(in), dimension(:) :: x !! The state vector. @@ -299,7 +319,8 @@ subroutine mh_sample(this, xi, niter, err) class(metropolis_hastings), intent(inout) :: this !! The metropolis_hastings object. real(real64), intent(in), dimension(:) :: xi - !! An initial estimate of the state variables. + !! An N-element array containing initial starting values of the state + !! variables. integer(int32), intent(in), optional :: niter !! An optional input defining the number of iterations to take. The !! default is 10,000. @@ -307,9 +328,10 @@ subroutine mh_sample(this, xi, niter, err) !! The error handling object. ! Local Variables - integer(int32) :: i, npts + integer(int32) :: i, n, npts, flag real(real64) :: r, pp, pc, a, a1, a2, alpha - real(real64), allocatable, dimension(:) :: xc, xp + real(real64), allocatable, dimension(:) :: xc, xp, means + real(real64), allocatable, dimension(:,:) :: sigma class(errors), pointer :: errmgr type(errors), target :: deferr @@ -324,11 +346,20 @@ subroutine mh_sample(this, xi, niter, err) else npts = this%initial_iteration_estimate end if + n = size(xi) - ! TO DO: Reset the buffer state - - ! Initialize the stored distribution - ! TO DO: Figure out a means for generating a meaningful covariance matrix. + ! Initialize the proposal distribution. Use an identity matrix for the + ! covariance matrix and assume a zero mean. + allocate(sigma(n, n), means(n), source = 0.0d0, stat = flag) + if (flag /= 0) then + call report_memory_error(errmgr, "mh_sample", flag) + return + end if + do i = 1, n + sigma(i,i) = 1.0d0 + end do + call this%m_propDist%initialize(means, sigma, err = errmgr) + if (errmgr%has_error_occurred()) return ! Store the initial value call this%push_new_state(xi, err = errmgr) @@ -357,13 +388,72 @@ subroutine mh_sample(this, xi, niter, err) ! Update the values xc = xp pc = pp + + ! Take additional actions on success??? + call this%on_acceptance(i, alpha, xc, xp, err = errmgr) + if (errmgr%has_error_occurred()) return else ! Keep our current estimate call this%push_new_state(xc, err = errmgr) if (errmgr%has_error_occurred()) return + + ! Take additional actions on failure??? + call this%on_rejection(i, alpha, xc, xp, err = errmgr) + if (errmgr%has_error_occurred()) return end if end do end subroutine +! ------------------------------------------------------------------------------ +subroutine mh_clear_chain(this) + !! Resets the object and clears out the buffer storing the chain values. + class(metropolis_hastings), intent(inout) :: this + !! The metropolis_hastings object. + + ! Clear the buffer + this%m_bufferSize = 0 + this%m_numVars = 0 +end subroutine + +! ------------------------------------------------------------------------------ +subroutine mh_on_success(this, iter, alpha, xc, xp, err) + !! Currently, this routine does nothing and is a placeholder for the user + !! that inherits this class to provide functionallity upon acceptance of + !! a proposed value. + class(metropolis_hastings), intent(inout) :: this + !! The metropolis_hastings object. + integer(int32), intent(in) :: iter + !! The current iteration number. + real(real64), intent(in) :: alpha + !! The proposal probabilty term used for acceptance criteria. + real(real64), intent(in), dimension(:) :: xc + !! An N-element array containing the current state variables. + real(real64), intent(in), dimension(size(xc)) :: xp + !! An N-element array containing the proposed state variables that + !! were just accepted. + class(errors), intent(inout), optional, target :: err + !! An error handling object. +end subroutine + +! ------------------------------------------------------------------------------ +subroutine mh_on_rejection(this, iter, alpha, xc, xp, err) + !! Currently, this routine does nothing and is a placeholder for the user + !! that inherits this class to provide functionallity upon rejection of + !! a proposed value. + class(metropolis_hastings), intent(inout) :: this + !! The metropolis_hastings object. + integer(int32), intent(in) :: iter + !! The current iteration number. + real(real64), intent(in) :: alpha + !! The proposal probabilty term used for acceptance criteria. + real(real64), intent(in), dimension(:) :: xc + !! An N-element array containing the current state variables. + real(real64), intent(in), dimension(size(xc)) :: xp + !! An N-element array containing the proposed state variables that + !! were just rejected. + class(errors), intent(inout), optional, target :: err + !! An error handling object. +end subroutine + ! ------------------------------------------------------------------------------ end module \ No newline at end of file