Skip to content

Commit

Permalink
Update Metropolis-Hastings code
Browse files Browse the repository at this point in the history
  • Loading branch information
jchristopherson committed Feb 18, 2025
1 parent 4e0f4de commit 9f1c7bc
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 18 deletions.
54 changes: 48 additions & 6 deletions examples/mcmc_example.f90
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,67 @@ program example
type(metropolis_hastings) :: mcmc

! Plot Variables
type(plot_2d) :: plt
type(multiplot) :: mplt
type(plot_2d) :: plt, plt1, plt2, plt3
class(plot_axis), pointer :: xAxis, yAxis
type(plot_data_2d) :: pd
class(legend), pointer :: lgnd
type(plot_data_histogram) :: pdh
class(terminal), pointer :: term

! Create an initial estimate
call random_number(xi)
! Create an initial estimate - intentionally starting outside of the
! target distribution for illustration purposes only.
xi = [1.0d1, 1.0d1]

! Sample a multivariate normal distribution using MH
call mcmc%sample(xi)

! Get the chains
! Get the chains - keep burn-in points for illustration purposes
chains = mcmc%get_chain()

! Plot the results
call plt%initialize()
! ------------------------------------------------------------------------------
! Plot histograms of the chains
call mplt%initialize(1, 3)
term => mplt%get_terminal()
call plt1%initialize()
call plt2%initialize()
call plt3%initialize()
xAxis => plt3%get_x_axis()
yAxis => plt3%get_y_axis()
call term%set_window_height(500)
call term%set_window_width(1500)
call plt1%set_title("x_{1}")
call plt2%set_title("x_{2}")
call plt3%set_title("x_{1} vs. x_{2}")
call xAxis%set_title("x_{1}")
call yAxis%set_title("x_{2}")
call pdh%define_data(chains(:,1))
call pdh%set_transparency(0.5)
call plt1%push(pdh)
call pdh%define_data(chains(:,2))
call plt2%push(pdh)
call pd%define_data(chains(:,1), chains(:,2))
call pd%set_draw_line(.false.)
call pd%set_draw_markers(.true.)
call pd%set_marker_style(MARKER_FILLED_CIRCLE)
call pd%set_marker_scaling(0.5)
call plt3%push(pd)
call mplt%set(1, 1, plt1)
call mplt%set(1, 2, plt2)
call mplt%set(1, 3, plt3)
call mplt%draw()

! Plot the chain
call plt%initialize()
lgnd => plt%get_legend()
call lgnd%set_is_visible(.true.)
call pd%define_data(chains(:,1))
call pd%set_draw_line(.true.)
call pd%set_draw_markers(.false.)
call pd%set_name("x_{1}")
call plt%push(pd)
call pd%define_data(chains(:,2))
call pd%set_name("x_{2}")
call plt%push(pd)
call plt%draw()
end program
114 changes: 102 additions & 12 deletions src/fstats_mcmc.f90
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@ module fstats_mcmc

type metropolis_hastings
!! An implementation of the Metropolis-Hastings algorithm for the
!! generation of a Markov chain.
!! generation of a Markov chain. This is a default implementation
!! that allows sampling of normally distributed posterior distributions
!! centered on zero with unit standard deviations. Proposals are
!! generated from a multivariate normal distribution with an identity
!! covariance matrix and centered on zero. To alter these sampling
!! and target distributions simply create a new class inheriting from
!! this class and override the appropriate routines.
integer(int32), private :: initial_iteration_estimate = 10000
!! An initial estimate at the number of allowed iterations.
integer(int32), private :: m_bufferSize = 0
Expand All @@ -34,6 +40,9 @@ module fstats_mcmc
procedure, public :: compute_hastings_ratio => mh_hastings_ratio
procedure, public :: target_distribution => mh_target
procedure, public :: sample => mh_sample
procedure, public :: reset => mh_clear_chain
procedure, public :: on_acceptance => mh_on_success
procedure, public :: on_rejection => mh_on_rejection
procedure, private :: resize_buffer => mh_resize_buffer
procedure, private :: get_buffer_length => mh_get_buffer_length
end type
Expand Down Expand Up @@ -178,17 +187,22 @@ subroutine mh_push(this, x, err)
end subroutine

! ------------------------------------------------------------------------------
function mh_get_chain(this, err) result(rst)
function mh_get_chain(this, bin, err) result(rst)
!! Gets a copy of the stored Markov chain.
class(metropolis_hastings), intent(in) :: this
!! The metropolis_hastings object.
real(real64), intent(in), optional :: bin
!! An optional input allowing for a burn-in region. The parameter
!! represents the amount (percentage-based) of the overall chain to
!! disregard as "burn-in" values. The value shoud exist on [0, 1).
!! The default value is 0 such that no values are disregarded.
class(errors), intent(inout), optional, target :: err
!! The error handling object.
real(real64), allocatable, dimension(:,:) :: rst
!! The resulting chain with each parameter represented by a column.

! Local Variables
integer(int32) :: npts, nvar, flag
integer(int32) :: npts, nvar, flag, nstart
class(errors), pointer :: errmgr
type(errors), target :: deferr

Expand All @@ -200,10 +214,16 @@ function mh_get_chain(this, err) result(rst)
end if
npts = this%get_chain_length()
nvar = this%get_state_variable_count()
if (present(bin)) then
nstart = floor(bin * npts)
npts = npts - nstart
else
nstart = 1
end if

! Process
allocate(rst(npts, nvar), stat = flag, &
source = this%m_buffer(1:npts,1:nvar))
source = this%m_buffer(nstart:,1:nvar))
if (flag /= 0) then
call report_memory_error(errmgr, "mh_get_chain", flag)
return
Expand Down Expand Up @@ -268,7 +288,7 @@ function mh_target(this, x) result(rst)
!! define the desired distribution. The default behavior of this
!! routine is to sammple a multivariate normal distribution with a mean
!! of zero and a variance of one (identity covariance matrix).
class(metropolis_hastings), intent(in) :: this
class(metropolis_hastings), intent(inout) :: this
!! The metropolis_hastings object.
real(real64), intent(in), dimension(:) :: x
!! The state vector.
Expand Down Expand Up @@ -299,17 +319,19 @@ subroutine mh_sample(this, xi, niter, err)
class(metropolis_hastings), intent(inout) :: this
!! The metropolis_hastings object.
real(real64), intent(in), dimension(:) :: xi
!! An initial estimate of the state variables.
!! An N-element array containing initial starting values of the state
!! variables.
integer(int32), intent(in), optional :: niter
!! An optional input defining the number of iterations to take. The
!! default is 10,000.
class(errors), intent(inout), optional, target :: err
!! The error handling object.

! Local Variables
integer(int32) :: i, npts
integer(int32) :: i, n, npts, flag
real(real64) :: r, pp, pc, a, a1, a2, alpha
real(real64), allocatable, dimension(:) :: xc, xp
real(real64), allocatable, dimension(:) :: xc, xp, means
real(real64), allocatable, dimension(:,:) :: sigma
class(errors), pointer :: errmgr
type(errors), target :: deferr

Expand All @@ -324,11 +346,20 @@ subroutine mh_sample(this, xi, niter, err)
else
npts = this%initial_iteration_estimate
end if
n = size(xi)

! TO DO: Reset the buffer state

! Initialize the stored distribution
! TO DO: Figure out a means for generating a meaningful covariance matrix.
! Initialize the proposal distribution. Use an identity matrix for the
! covariance matrix and assume a zero mean.
allocate(sigma(n, n), means(n), source = 0.0d0, stat = flag)
if (flag /= 0) then
call report_memory_error(errmgr, "mh_sample", flag)
return
end if
do i = 1, n
sigma(i,i) = 1.0d0
end do
call this%m_propDist%initialize(means, sigma, err = errmgr)
if (errmgr%has_error_occurred()) return

! Store the initial value
call this%push_new_state(xi, err = errmgr)
Expand Down Expand Up @@ -357,13 +388,72 @@ subroutine mh_sample(this, xi, niter, err)
! Update the values
xc = xp
pc = pp

! Take additional actions on success???
call this%on_acceptance(i, alpha, xc, xp, err = errmgr)
if (errmgr%has_error_occurred()) return
else
! Keep our current estimate
call this%push_new_state(xc, err = errmgr)
if (errmgr%has_error_occurred()) return

! Take additional actions on failure???
call this%on_rejection(i, alpha, xc, xp, err = errmgr)
if (errmgr%has_error_occurred()) return
end if
end do
end subroutine

! ------------------------------------------------------------------------------
subroutine mh_clear_chain(this)
!! Resets the object and clears out the buffer storing the chain values.
class(metropolis_hastings), intent(inout) :: this
!! The metropolis_hastings object.

! Clear the buffer
this%m_bufferSize = 0
this%m_numVars = 0
end subroutine

! ------------------------------------------------------------------------------
subroutine mh_on_success(this, iter, alpha, xc, xp, err)
!! Currently, this routine does nothing and is a placeholder for the user
!! that inherits this class to provide functionallity upon acceptance of
!! a proposed value.
class(metropolis_hastings), intent(inout) :: this
!! The metropolis_hastings object.
integer(int32), intent(in) :: iter
!! The current iteration number.
real(real64), intent(in) :: alpha
!! The proposal probabilty term used for acceptance criteria.
real(real64), intent(in), dimension(:) :: xc
!! An N-element array containing the current state variables.
real(real64), intent(in), dimension(size(xc)) :: xp
!! An N-element array containing the proposed state variables that
!! were just accepted.
class(errors), intent(inout), optional, target :: err
!! An error handling object.
end subroutine

! ------------------------------------------------------------------------------
subroutine mh_on_rejection(this, iter, alpha, xc, xp, err)
!! Currently, this routine does nothing and is a placeholder for the user
!! that inherits this class to provide functionallity upon rejection of
!! a proposed value.
class(metropolis_hastings), intent(inout) :: this
!! The metropolis_hastings object.
integer(int32), intent(in) :: iter
!! The current iteration number.
real(real64), intent(in) :: alpha
!! The proposal probabilty term used for acceptance criteria.
real(real64), intent(in), dimension(:) :: xc
!! An N-element array containing the current state variables.
real(real64), intent(in), dimension(size(xc)) :: xp
!! An N-element array containing the proposed state variables that
!! were just rejected.
class(errors), intent(inout), optional, target :: err
!! An error handling object.
end subroutine

! ------------------------------------------------------------------------------
end module

0 comments on commit 9f1c7bc

Please sign in to comment.