Skip to content

Commit

Permalink
Add functions for retrieving the number of replicas, spectral states,…
Browse files Browse the repository at this point in the history
… and overlaps from DataFrames
  • Loading branch information
jamie-tay committed Feb 14, 2025
1 parent c6bd84b commit 0fa0aa3
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 20 deletions.
10 changes: 10 additions & 0 deletions src/Interfaces/Interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,18 @@ Follow the links for the definitions of the interfaces!
* [`apply_column!`](@ref)
* [`apply_operator!`](@ref)
* [`step_stats`](@ref)
## Functions for retrieving information from DataFrames:
* [`num_replicas`](@ref)
* [`num_spectral_states`](@ref)
* [`num_overlaps`](@ref)
"""
module Interfaces

using LinearAlgebra: LinearAlgebra, diag
using VectorInterface: VectorInterface, add, add!, zerovector!, scalartype
using DataFrames: DataFrame, metadata

import OrderedCollections: freeze

Expand All @@ -54,9 +61,12 @@ export
random_offdiagonal, starting_address, allows_address_type,
LOStructure, IsDiagonal, IsHermitian, AdjointKnown, AdjointUnknown, has_adjoint,
AbstractOperator, AbstractObservable
export
num_replicas, num_spectral_states, num_overlaps

include("stochasticstyles.jl")
include("dictvectors.jl")
include("hamiltonians.jl")
include("dataframes.jl")

end
46 changes: 46 additions & 0 deletions src/Interfaces/dataframes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
num_replicas(df::DataFrame)
Return the number of replicas used in the simulation that produced `df`.
"""
function num_replicas(df::DataFrame)
if haskey(metadata(df), "num_replicas")
return parse(Int, metadata(df, "num_replicas"))
else
num = length(filter(startswith("norm"), names(df)))
num > 0 || throw(ArgumentError("No replicas found in DataFrame"))
return num
end
end

"""
num_spectral_states(df::DataFrame)
Return the number of spectral states used in the simulation that produced `df`.
"""
function num_spectral_states(df::DataFrame)
if haskey(metadata(df), "num_spectral_states")
return parse(Int, metadata(df, "num_spectral_states"))
else
if length(filter(startswith("norm"), names(df))) == 0
throw(ArgumentError("No spectral states found in DataFrame"))
end
return 1
end
end

"""
num_overlaps(df::DataFrame)
Return the number of overlaps between replicas.
"""
function num_overlaps(df::DataFrame)
if haskey(metadata(df), "num_overlaps")
return parse(Int, metadata(df, "num_overlaps"))
else
if length(filter(startswith("norm"), names(df))) == 0
throw(ArgumentError("No replicas found in DataFrame"))
end
return length(filter(startswith(r"c[0-9]+_dot"), names(df)))
end
end
4 changes: 2 additions & 2 deletions src/Rimu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ include("mpi_helpers.jl")

include("Interfaces/Interfaces.jl")
@reexport using .Interfaces
using .Interfaces: dot_from_right
import .Interfaces: dot_from_right, num_replicas, num_overlaps, num_spectral_states
include("BitStringAddresses/BitStringAddresses.jl")
@reexport using .BitStringAddresses
include("Hamiltonians/Hamiltonians.jl")
Expand Down Expand Up @@ -78,7 +78,7 @@ export TimeStepStrategy, ConstantTimeStep, OvershootControl
export localpart, walkernumber
export smart_logger, default_logger
export ProjectorMonteCarloProblem, SimulationPlan, state_vectors
export FCIQMC, num_replicas, num_spectral_states, GramSchmidt
export FCIQMC, num_replicas, num_spectral_states, num_overlaps, GramSchmidt

function __init__()
# Turn on smart logging once at runtime. Turn off with `default_logger()`.
Expand Down
1 change: 1 addition & 0 deletions src/StatsTools/StatsTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ using SpecialFunctions: SpecialFunctions, erf
using Statistics: Statistics
using StrFormat: StrFormat, @f_str
using StrLiterals: StrLiterals
import ..Interfaces: num_replicas, num_spectral_states, num_overlaps

import ProgressLogging, Folds
import MacroTools
Expand Down
8 changes: 3 additions & 5 deletions src/StatsTools/reweighting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ Compute the [`growth_estimator`](@ref) on a `DataFrame` `df` or
depths.
Returns a `NamedTuple` with the fields
* `df_ge`: `DataFrame` with reweighting depth and `growth_estiamator` data. See example
* `df_ge`: `DataFrame` with reweighting depth and `growth_estimator` data. See example
below.
* `correlation_estimate`: estimated correlation time from blocking analysis
* `se, se_l, se_u`: [`shift_estimator`](@ref) and error
Expand Down Expand Up @@ -252,7 +252,6 @@ function growth_estimator_analysis(
df = DataFrame(sim)
shift_v = Vector(getproperty(df, Symbol(shift_name))) # casting to `Vector` to make SIMD loops efficient
norm_v = Vector(getproperty(df, Symbol(norm_name)))
num_reps = length(filter(startswith("norm"), names(df)))
time_step = isnothing(time_step) ? determine_constant_time_step(df) : time_step
se = blocking_analysis(shift_v; skip)
E_r = se.mean
Expand Down Expand Up @@ -423,7 +422,6 @@ function mixed_estimator_analysis(
shift_v = Vector(getproperty(df, Symbol(shift_name))) # casting to `Vector` to make SIMD loops efficient
hproj_v = Vector(getproperty(df, Symbol(hproj_name)))
vproj_v = Vector(getproperty(df, Symbol(vproj_name)))
num_reps = length(filter(startswith("norm"), names(df)))

time_step = isnothing(time_step) ? determine_constant_time_step(df) : time_step
se = blocking_analysis(shift_v; skip)
Expand Down Expand Up @@ -568,7 +566,7 @@ function rayleigh_replica_estimator(
kwargs...
)
df = DataFrame(sim)
num_reps = parse(Int, metadata(df, "num_replicas"))
num_reps = num_replicas(df)
time_step = isnothing(time_step) ? determine_constant_time_step(df) : time_step
T = eltype(df[!, Symbol(shift_name, "_r1s1")])
shift_v = Vector{T}[]
Expand Down Expand Up @@ -641,7 +639,7 @@ function rayleigh_replica_estimator_analysis(
kwargs...
)
df = DataFrame(sim)
num_reps = parse(Int, metadata(df, "num_replicas"))
num_reps = num_replicas(df)
time_step = isnothing(time_step) ? determine_constant_time_step(df) : time_step
# estimate the correlation time by blocking the shift data
T = eltype(df[!, Symbol(shift_name, "_r1s1")])
Expand Down
26 changes: 16 additions & 10 deletions src/StatsTools/variational_energy_estimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,38 @@ end

function variational_energy_estimator(sim; max_replicas=:all, spectral_state=1, kwargs...)
df = DataFrame(sim)
num_replicas = parse(Int, metadata(df, "num_replicas"))
if num_replicas == 1
num_reps = num_replicas(df)
if num_reps == 1
throw(ArgumentError(
"No replicas found. Use keyword \
`replica_strategy=AllOverlaps(n)` with n≥2 in `ProjectorMonteCarloProblem` to set up replicas!"
))
end
@assert num_replicas 2 "At least two replicas are needed, found $num_replicas"
@assert num_reps 2 "At least two replicas are needed, found $num_replicas"

num_overlaps = length(filter(startswith(Regex("r[0-9]+s$(spectral_state)_dot_r[0-9]+s$(spectral_state)")), names(df)))
@assert num_overlaps == binomial(num_replicas, 2) "Unexpected number of overlaps."
num_olaps = num_overlaps(df)
if num_olaps == 0
throw(ArgumentError(
"No overlaps found. Use keyword \
`replica_strategy=AllOverlaps(n)` with n≥2 in `ProjectorMonteCarloProblem` to set up replicas!"
))
end
@assert num_olaps == binomial(num_reps, 2) "Unexpected number of overlaps."

# process at most `max_replicas` but at least 2 replicas
if max_replicas isa Integer
num_replicas = max(2, min(max_replicas, num_replicas))
num_reps = max(2, min(max_replicas, num_reps))
end

shiftnames = [Symbol("shift_r$(i)s$(spectral_state)") for i in 1:num_replicas]
shiftnames = [Symbol("shift_r$(i)s$(spectral_state)") for i in 1:num_reps]
shifts = map(name -> getproperty(df, name), shiftnames)
@assert length(shifts) == num_replicas
@assert length(shifts) == num_reps

overlap_names = [
Symbol("r$(i)s$(spectral_state)_dot_r$(j)s$(spectral_state)") for i in 1:num_replicas for j in i+1:num_replicas
Symbol("r$(i)s$(spectral_state)_dot_r$(j)s$(spectral_state)") for i in 1:num_reps for j in i+1:num_reps
]
overlaps = map(name -> getproperty(df, name), overlap_names)
@assert length(overlaps) num_overlaps
@assert length(overlaps) num_olaps

return variational_energy_estimator(shifts, overlaps; kwargs...)
end
1 change: 1 addition & 0 deletions src/pmc_simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ end

num_spectral_states(sm::PMCSimulation) = num_spectral_states(sm.state)
num_replicas(sm::PMCSimulation) = num_replicas(sm.state)
num_overlaps(sm::PMCSimulation) = num_overlaps(sm.state)

function report_simulation_status_metadata!(report::Report, sm::PMCSimulation)
@unpack modified, aborted, success, message, elapsed_time = sm
Expand Down
3 changes: 3 additions & 0 deletions src/qmc_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ end

num_replicas(::ReplicaState{N}) where {N} = N
num_spectral_states(::ReplicaState{<:Any, S}) where {S} = S
num_overlaps(::ReplicaState{<:Any,<:Any,<:Any,<:Any,<:NoStats}) = 0
num_overlaps(::ReplicaState{N,<:Any,<:Any,<:Any,<:AllOverlaps{N,<:Any,<:Any,B}}) where {N,B} = B*N*(N-1)÷2

Base.show(io::IO, r::ReplicaState) = show(io, MIME("text/plain"), r)
function Base.show(io::IO, ::MIME"text/plain", st::ReplicaState)
Expand Down Expand Up @@ -187,6 +189,7 @@ function report_default_metadata!(report::Report, state::ReplicaState)
report_metadata!(report, "laststep", state.simulation_plan.last_step)
report_metadata!(report, "num_replicas", num_replicas(state))
report_metadata!(report, "num_spectral_states", num_spectral_states(state))
report_metadata!(report, "num_overlaps", num_overlaps(state))
report_metadata!(report, "hamiltonian", s_state.hamiltonian)
report_metadata!(report, "reporting_strategy", state.reporting_strategy)
report_metadata!(report, "shift_strategy", algorithm.shift_strategy)
Expand Down
7 changes: 7 additions & 0 deletions test/Interfaces.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using LinearAlgebra
using Rimu
using Test
using DataFrames

@testset "Interface basics" begin
@test eltype(StyleUnknown{String}()) == String
Expand Down Expand Up @@ -36,6 +37,12 @@ using Test
@test_throws ArgumentError Interfaces.dot_from_right(1, 2, 3)
end

@testset "DataFrame interfaces" begin
@test_throws ArgumentError num_replicas(DataFrame())
@test_throws ArgumentError num_spectral_states(DataFrame())
@test_throws ArgumentError num_overlaps(DataFrame())
end

# using lomc! with a matrix was removed in Rimu.jl v0.12.0
@testset "lomc! with matrix" begin
ham = [1 1 2 3 2;
Expand Down
12 changes: 9 additions & 3 deletions test/excited_states_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Rimu
using Test


@testset "excited state energies" begin
@testset "excited states" begin
ham = HubbardReal1D(BoseFS(1,1,1,1,1))
pr = ExactDiagonalizationProblem(ham)
result = solve(pr)
Expand All @@ -22,6 +22,7 @@ using Test
@test energy1.mean vals[1]
@test energy2.mean vals[2]
@test energy3.mean vals[3]
@test num_spectral_states(df) == 3

n_replicas = 2
p = ProjectorMonteCarloProblem(ham; spectral_strategy, last_step, style, n_replicas)
Expand All @@ -40,13 +41,18 @@ using Test
@test energy5.mean vals[2]
@test energy6.mean vals[3]

@test num_overlaps(df) == 0
@test_throws ArgumentError variational_energy_estimator(df)

replica_strategy = AllOverlaps(n_replicas; operator=G2RealCorrelator(0), mixed_spectral_overlaps=true)
p = ProjectorMonteCarloProblem(ham; spectral_strategy, last_step, style, replica_strategy)
df = DataFrame(solve(p))
@test num_replicas(df) == 2
@test num_overlaps(df) == 1
for state in 1:3
r = rayleigh_replica_estimator(df; spectral_state=state)
@test r.f g2s[state] atol=0.01
end
num_overlaps = length(filter(startswith(r"r[0-9]+s[0-9]+_dot"), names(df)))
@test num_overlaps == 15
num_olaps = length(filter(startswith(r"r[0-9]+s[0-9]+_dot"), names(df)))
@test num_olaps == 15
end

0 comments on commit 0fa0aa3

Please sign in to comment.