Skip to content

Commit

Permalink
rename spectral_states to single_states, replica_states to spectral_s…
Browse files Browse the repository at this point in the history
…tates
  • Loading branch information
joachimbrand committed May 8, 2024
1 parent 2dd3e40 commit 25c57ee
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 77 deletions.
64 changes: 31 additions & 33 deletions src/QMCSimulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ See also [`state_vectors`](@ref), [`single_states`](@ref),
[`ProjectorMonteCarloProblem`](@ref), [`init`](@ref), [`solve!`](@ref).
"""
mutable struct QMCSimulation
qmc_problem::ProjectorMonteCarloProblem
problem::ProjectorMonteCarloProblem
algorithm # currently only FCIQMC() is implemented
qmc_state::ReplicaState
state::ReplicaState
report::Report
modified::Bool
aborted::Bool
Expand Down Expand Up @@ -76,9 +76,9 @@ function QMCSimulation(problem::ProjectorMonteCarloProblem; copy_vectors=true)
end
wm = working_memory(first(vectors))

# set up the replica_states
# set up the spectral_states
if n_replicas == 1
replica_states = (SpectralState(
spectral_states = (SpectralState(
(SingleState(
hamiltonian,
only(vectors),
Expand All @@ -92,7 +92,7 @@ function QMCSimulation(problem::ProjectorMonteCarloProblem; copy_vectors=true)
spectral_strategy
),)
else
replica_states = ntuple(n_replicas) do i
spectral_states = ntuple(n_replicas) do i
v, sp = vectors[i], shift_parameters[i]
rwm = (typeof(v) == typeof(first(vectors))) ? wm : working_memory(v)
SpectralState(
Expand All @@ -110,12 +110,12 @@ function QMCSimulation(problem::ProjectorMonteCarloProblem; copy_vectors=true)
)
end
end
@assert replica_states isa NTuple{n_replicas, <:SpectralState}
@assert spectral_states isa NTuple{n_replicas, <:SpectralState}

# set up the initial state
qmc_state = ReplicaState(
state = ReplicaState(
hamiltonian, # hamiltonian
replica_states, # replica_states
spectral_states, # spectral_states
Ref(maxlength), # maxlength
Ref(simulation_plan.starting_step), # step
simulation_plan, # simulation_plan
Expand All @@ -125,19 +125,19 @@ function QMCSimulation(problem::ProjectorMonteCarloProblem; copy_vectors=true)
)
report = Report()
report_metadata!(report, "algorithm", algorithm)
report_default_metadata!(report, qmc_state)
report_default_metadata!(report, state)
report_metadata!(report, metadata) # add user metadata
# Sanity checks.
check_transform(qmc_state.replica_strategy, qmc_state.hamiltonian)
check_transform(state.replica_strategy, state.hamiltonian)

return QMCSimulation(
problem, algorithm, qmc_state, report, false, false, false, "", 0.0
problem, algorithm, state, report, false, false, false, "", 0.0
)
end

function Base.show(io::IO, sm::QMCSimulation)
print(io, "QMCSimulation")
st = sm.qmc_state
st = sm.state
print(io, " with ", num_replicas(st), " replica(s) and ")
print(io, num_spectral_states(st), " spectral state(s).")
print(io, "\n Algorithm: ", sm.algorithm)
Expand All @@ -147,8 +147,8 @@ function Base.show(io::IO, sm::QMCSimulation)
sm.message == "" || print(io, "\n message: ", sm.message)
end

num_spectral_states(sm::QMCSimulation) = num_spectral_states(sm.qmc_state)
num_replicas(sm::QMCSimulation) = num_replicas(sm.qmc_state)
num_spectral_states(sm::QMCSimulation) = num_spectral_states(sm.state)
num_replicas(sm::QMCSimulation) = num_replicas(sm.state)

function report_simulation_status_metadata!(report::Report, sm::QMCSimulation)
@unpack modified, aborted, success, message, elapsed_time = sm
Expand All @@ -167,7 +167,7 @@ function Base.iterate(sm::QMCSimulation, state=1)
if state == 1
return DataFrame(sm), 2
elseif state == 2
return sm.qmc_state, 3
return sm.state, 3
else
return nothing
end
Expand All @@ -177,8 +177,6 @@ end
function Base.getproperty(sm::QMCSimulation, key::Symbol)
if key == :df
return DataFrame(sm)
elseif key == :state
return sm.qmc_state
else
return getfield(sm, key)
end
Expand All @@ -192,8 +190,8 @@ Tables.columnaccess(::Type{<:QMCSimulation}) = true
Tables.columns(sm::QMCSimulation) = Tables.columns(sm.report.data)
Tables.schema(sm::QMCSimulation) = Tables.schema(sm.report.data)

state_vectors(sim::QMCSimulation) = state_vectors(sim.qmc_state)
single_states(sim::QMCSimulation) = single_states(sim.qmc_state)
state_vectors(sim::QMCSimulation) = state_vectors(sim.state)
single_states(sim::QMCSimulation) = single_states(sim.state)

# TODO: interface for reading results

Expand Down Expand Up @@ -224,9 +222,9 @@ See also [`ProjectorMonteCarloProblem`](@ref), [`init`](@ref), [`solve!`](@ref),
[`Rimu.QMCSimulation`](@ref).
"""
function CommonSolve.step!(sm::QMCSimulation)
@unpack qmc_state, report, algorithm = sm
@unpack replica_states, simulation_plan, step, reporting_strategy,
replica_strategy = qmc_state
@unpack state, report, algorithm = sm
@unpack spectral_states, simulation_plan, step, reporting_strategy,
replica_strategy = state

if sm.aborted || sm.success
@warn "Simulation is already aborted or finished."
Expand All @@ -245,17 +243,17 @@ function CommonSolve.step!(sm::QMCSimulation)
end

proceed = true
# advance all replica_states
for replica in replica_states
proceed &= advance!(algorithm, report, qmc_state, replica)
# advance all spectral_states
for replica in spectral_states
proceed &= advance!(algorithm, report, state, replica)
end
sm.modified = true

# report replica stats
if step[] % reporting_interval(reporting_strategy) == 0
replica_names, replica_values = replica_stats(replica_strategy, replica_states)
replica_names, replica_values = replica_stats(replica_strategy, spectral_states)
report!(reporting_strategy, step[], report, replica_names, replica_values)
report_after_step!(reporting_strategy, step[], report, qmc_state)
report_after_step!(reporting_strategy, step[], report, state)
ensure_correct_lengths(report)
end

Expand Down Expand Up @@ -298,21 +296,21 @@ function CommonSolve.solve!(sm::QMCSimulation;
walltime = nothing,
reset_time = false,
)
@unpack qmc_state = sm
@unpack state = sm

reset_flags = reset_time # reset flags if resetting time
if !isnothing(last_step)
sm.qmc_state = @set qmc_state.simulation_plan.last_step = last_step
sm.state = @set state.simulation_plan.last_step = last_step
report_metadata!(sm.report, "laststep", last_step)
reset_flags = true
end
if !isnothing(walltime)
sm.qmc_state = @set qmc_state.simulation_plan.walltime = walltime
sm.state = @set state.simulation_plan.walltime = walltime
reset_flags = true
end

@unpack report = sm
@unpack simulation_plan, step, reporting_strategy = sm.qmc_state
@unpack simulation_plan, step, reporting_strategy = sm.state

last_step = simulation_plan.last_step
initial_step = step[]
Expand Down Expand Up @@ -365,9 +363,9 @@ function lomc!(state::ReplicaState, df=DataFrame(); laststep=0, name="lomc!", me
if !iszero(laststep)
state = @set state.simulation_plan.last_step = laststep
end
@unpack hamiltonian, replica_states, maxlength, step, simulation_plan,
@unpack hamiltonian, spectral_states, maxlength, step, simulation_plan,
reporting_strategy, post_step_strategy, replica_strategy = state
first_replica = only(first(replica_states).spectral_states) # SingleState
first_replica = only(first(spectral_states).single_states) # SingleState
@assert step[] simulation_plan.starting_step
problem = ProjectorMonteCarloProblem(hamiltonian;
start_at = first_replica.v,
Expand Down
6 changes: 3 additions & 3 deletions src/lomc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,15 @@ function advance!(::FCIQMC, report, state::ReplicaState, replica::SingleState)
end

if len == 0
if length(state.replica_states) > 1
if length(state.spectral_states) > 1
@error "population in replica $(replica.id) is dead. Aborting."
else
@error "population is dead. Aborting."
end
return false
end
if len > state.maxlength[]
if length(state.replica_states) > 1
if length(state.spectral_states) > 1
@error "`maxlength` reached in replica $(replica.id). Aborting."
else
@error "`maxlength` reached. Aborting."
Expand All @@ -249,6 +249,6 @@ function advance!(::FCIQMC, report, state::ReplicaState, replica::SingleState)
end

function advance!(algorithm, report, state::ReplicaState, replica::SpectralState{1})
return advance!(algorithm, report, state, only(replica.spectral_states))
return advance!(algorithm, report, state, only(replica.single_states))
end
# TODO: add advance! for SpectralState{N} where N > 1
30 changes: 15 additions & 15 deletions src/qmc_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct SpectralState{
NS<:NTuple{N,SingleState},
SS<:SpectralStrategy{N}
}
spectral_states::NS # Tuple of SingleState
single_states::NS # Tuple of SingleState
spectral_strategy::SS # SpectralStrategy
id::String # id is appended to column names
end
Expand All @@ -89,17 +89,17 @@ function Base.show(io::IO, s::SpectralState)
print(io, "SpectralState")
print(io, " with ", num_spectral_states(s), " spectral states")
print(io, "\n spectral_strategy: ", s.spectral_strategy)
for (i, r) in enumerate(s.spectral_states)
for (i, r) in enumerate(s.single_states)
print(io, "\n $i: ", r)
end
end

function state_vectors(state::SpectralState)
return mapreduce(state_vectors, vcat, state.spectral_states)
return mapreduce(state_vectors, vcat, state.single_states)
end

function single_states(state::SpectralState)
return mapreduce(single_states, vcat, state.spectral_states)
return mapreduce(single_states, vcat, state.single_states)
end

"""
Expand Down Expand Up @@ -133,7 +133,7 @@ struct ReplicaState{
PS<:NTuple{<:Any,PostStepStrategy},
}
hamiltonian::H
replica_states::R
spectral_states::R
maxlength::Ref{Int}
step::Ref{Int}
simulation_plan::SimulationPlan
Expand Down Expand Up @@ -223,14 +223,14 @@ function ReplicaState(
post_step_strategy = (post_step_strategy,)
end

# set up spectral_states
# set up single_states
n_spectral_states = num_spectral_states(spectral_strategy)
n_spectral_states == 1 || throw(ArgumentError("Only one spectral state is supported."))

# Set up replica_states
# Set up spectral_states
nreplicas = num_replicas(replica_strategy)
if nreplicas > 1
replica_states = ntuple(nreplicas) do i
spectral_states = ntuple(nreplicas) do i
SpectralState(
(SingleState(
hamiltonian,
Expand All @@ -246,14 +246,14 @@ function ReplicaState(
)
end
else
replica_states = (SpectralState(
spectral_states = (SpectralState(
(SingleState(hamiltonian, v, wm, s_strat, τ_strat, shift, dτ),),
spectral_strategy
),)
end

return ReplicaState(
hamiltonian, replica_states, Ref(Int(maxlength)),
hamiltonian, spectral_states, Ref(Int(maxlength)),
Ref(simulation_plan.starting_step), # step
simulation_plan,
# Ref(Int(laststep)),
Expand All @@ -271,7 +271,7 @@ function Base.show(io::IO, st::ReplicaState)
print(io, "\n H: ", st.hamiltonian)
print(io, "\n step: ", st.step[], " / ", st.simulation_plan.last_step)
print(io, "\n replicas: ")
for (i, r) in enumerate(st.replica_states)
for (i, r) in enumerate(st.spectral_states)
print(io, "\n $i: ", r)
end
end
Expand All @@ -288,9 +288,9 @@ See also [`single_states`](@ref), [`SingleState`](@ref), [`ReplicaState`](@ref),
@inline function state_vectors(state::ReplicaState{N,S}) where {N,S}
# Annoyingly this function is allocating if N > 1
return SMatrix{S,N}(
state.replica_states[fld1(i,S)].spectral_states[mod1(i,S)].v for i in 1:N*S
state.spectral_states[fld1(i,S)].single_states[mod1(i,S)].v for i in 1:N*S
)
# return SMatrix{S,N}(mapreduce(state_vectors, hcat, state.replica_states))
# return SMatrix{S,N}(mapreduce(state_vectors, hcat, state.spectral_states))
end

"""
Expand All @@ -303,13 +303,13 @@ See also [`state_vectors`](@ref), [`SingleState`](@ref), [`ReplicaState`](@ref),
[`SpectralState`](@ref), [`QMCSimulation`](@ref).
"""
function single_states(state::ReplicaState{N,S}) where {N,S}
return SMatrix{S,N}(mapreduce(single_states, hcat, state.replica_states))
return SMatrix{S,N}(mapreduce(single_states, hcat, state.spectral_states))
end

function report_default_metadata!(report::Report, state::ReplicaState)
report_metadata!(report, "Rimu.PACKAGE_VERSION", Rimu.PACKAGE_VERSION)
# add metadata from state
replica = state.replica_states[1].spectral_states[1]
replica = state.spectral_states[1].single_states[1]
shift_parameters = replica.shift_parameters
report_metadata!(report, "laststep", state.simulation_plan.last_step)
report_metadata!(report, "num_replicas", num_replicas(state))
Expand Down
8 changes: 4 additions & 4 deletions src/strategies_and_params/replicastrategy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Return the number of replicas used in the simulation.
num_replicas(::ReplicaStrategy{N}) where {N} = N

"""
replica_stats(RS::ReplicaStrategy{N}, replica_states::NTuple{N,SingleState}) -> (names, values)
replica_stats(RS::ReplicaStrategy{N}, spectral_states::NTuple{N,SingleState}) -> (names, values)
Return the names and values of statistics related to `N` replica states consistent with the
[`ReplicaStrategy`](@ref) `RS`. `names`
Expand Down Expand Up @@ -129,11 +129,11 @@ function AllOverlaps(num_replicas=2; operator=nothing, transform=nothing, vecnor
end
@deprecate AllOverlaps(num_replicas, operator) AllOverlaps(num_replicas; operator)

function replica_stats(rs::AllOverlaps{N,<:Any,<:Any,B}, replica_states::NTuple{N}) where {N,B}
function replica_stats(rs::AllOverlaps{N,<:Any,<:Any,B}, spectral_states::NTuple{N}) where {N,B}
# Not using broadcasting because it wasn't inferred properly.
# For now implement this assuming only a single spectral state; generalise later
vecs = ntuple(i -> only(replica_states[i].spectral_states).v, Val(N))
wms = ntuple(i -> only(replica_states[i].spectral_states).wm, Val(N))
vecs = ntuple(i -> only(spectral_states[i].single_states).v, Val(N))
wms = ntuple(i -> only(spectral_states[i].single_states).wm, Val(N))
return all_overlaps(rs.operators, vecs, wms, Val(B))
end

Expand Down
Loading

0 comments on commit 25c57ee

Please sign in to comment.