diff --git a/Project.toml b/Project.toml index 37ddc1c3..6774279c 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "9.14.3" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -30,6 +31,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" ArrayInterface = "7.9" DataStructures = "0.18" DiffEqBase = "6.154" +DiffEqCallbacks = "4.3.0" DocStringExtensions = "0.9" FastBroadcast = "0.3" FunctionWrappers = "1.1" @@ -46,7 +48,6 @@ UnPack = "1.0.2" julia = "1.10" [extras] -DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" @@ -58,5 +59,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["DiffEqCallbacks", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", - "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"] +test = ["LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"] diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index a382d208..d3d27967 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -11,6 +11,7 @@ using SciMLBase: SciMLBase, isdenseplot using Base.FastMath: add_fast using Setfield: @set, @set! +import DiffEqCallbacks: gauss_points, gauss_weights import DiffEqBase: DiscreteCallback, init, solve, solve!, plot_indices, initialize!, get_tstops, get_tstops_array, get_tstops_max import Base: size, getindex, setindex!, length, similar, show, merge!, merge @@ -21,6 +22,8 @@ import RecursiveArrayTools: recursivecopy! using StaticArrays, Base.Threads import SymbolicIndexingInterface as SII +import Random: AbstractRNG + abstract type AbstractJump end abstract type AbstractMassActionJump <: AbstractJump end abstract type AbstractAggregatorAlgorithm end @@ -70,6 +73,7 @@ include("spatial/directcrdirect.jl") include("aggregators/aggregated_api.jl") include("extended_jump_array.jl") +include("variable_rate.jl") include("problem.jl") include("solve.jl") include("coupled_array.jl") @@ -98,6 +102,9 @@ export reset_aggregated_jumps! export ExtendedJumpArray +# Export VariableRateAggregator types +export VariableRateAggregator, VR_FRM, VR_Direct + # spatial structs and functions export CartesianGrid, CartesianGridRej export SpatialMassActionJump diff --git a/src/problem.jl b/src/problem.jl index 6ec14162..885f3ef0 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -13,8 +13,10 @@ $(TYPEDEF) Defines a collection of jump processes to associate with another problem type. - [Documentation Page](https://docs.sciml.ai/JumpProcesses/stable/jump_types/) -- [Tutorial Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) -- [FAQ Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/#FAQ) +- [Tutorial + Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) +- [FAQ + Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/#FAQ) ### Constructors @@ -44,20 +46,21 @@ then be passed within a single [`JumpSet`](@ref) or as subsequent sequential arg $(FIELDS) ## Keyword Arguments -- `rng`, the random number generator to use. Defaults to Julia's built-in - generator. -- `save_positions=(true,true)`, specifies whether to save the system's state (before, after) - the jump occurs. +- `rng`, the random number generator to use. Defaults to Julia's built-in generator. +- `save_positions=(true,true)` when including variable rates and `(false,true)` for constant + rates, specifies whether to save the system's state (before, after) the jump occurs. - `spatial_system`, for spatial problems the underlying spatial structure. - `hopping_constants`, for spatial problems the spatial transition rate coefficients. -- `use_vrj_bounds = true`, set to false to disable handling bounded `VariableRateJump`s - with a supporting aggregator (such as `Coevolve`). They will then be handled via the - continuous integration interface, and treated like general `VariableRateJump`s. +- `use_vrj_bounds = true`, set to false to disable handling bounded `VariableRateJump`s with + a supporting aggregator (such as `Coevolve`). They will then be handled via the continuous + integration interface, and treated like general `VariableRateJump`s. +- `vr_aggregator`, indicates the aggregator to use for sampling variable rate jumps. Current + default is `VR_FRM`. Please see the [tutorial -page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in the -DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage examples and -commonly asked questions. +page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in +the DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage +examples and commonly asked questions. """ mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J2, J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J} @@ -213,6 +216,7 @@ end make_kwarg(; kwargs...) = kwargs function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpSet; + vr_aggregator::VariableRateAggregator = VR_FRM(), save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true), rng = DEFAULT_RNG, scale_rates = true, useiszero = true, @@ -270,9 +274,9 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS # handle any remaining vrjs if length(cvrjs) > 0 - new_prob = extend_problem(prob, cvrjs; rng) - variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng) - cont_agg = cvrjs + # Handle variable rate jumps based on vr_aggregator + new_prob, variable_jump_callback, cont_agg = configure_jump_problem(prob, + vr_aggregator, jumps, cvrjs; rng) else new_prob = prob variable_jump_callback = CallbackSet() @@ -293,163 +297,6 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS solkwargs) end -# extends prob.u0 to an ExtendedJumpArray with Njumps integrated intensity values, -# of type prob.tspan -function extend_u0(prob, Njumps, rng) - ttype = eltype(prob.tspan) - u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:Njumps]) - return u0 -end - -function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps; rng = DEFAULT_RNG) - error("General `VariableRateJump`s require a continuous problem, like an ODE/SDE/DDE/DAE problem. To use a `DiscreteProblem` bounded `VariableRateJump`s must be used. See the JumpProcesses docs.") -end - -function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAULT_RNG) - _f = SciMLBase.unwrapped_f(prob.f) - - if isinplace(prob) - jump_f = let _f = _f - function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t) - _f(du.u, u.u, p, t) - update_jumps!(du, u, p, t, length(u.u), jumps...) - end - end - else - jump_f = let _f = _f - function (u::ExtendedJumpArray, p, t) - du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u) - update_jumps!(du, u, p, t, length(u.u), jumps...) - return du - end - end - end - - u0 = extend_u0(prob, length(jumps), rng) - f = ODEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, - observed = prob.f.observed) - remake(prob; f, u0) -end - -function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAULT_RNG) - _f = SciMLBase.unwrapped_f(prob.f) - - if isinplace(prob) - jump_f = let _f = _f - function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t) - _f(du.u, u.u, p, t) - update_jumps!(du, u, p, t, length(u.u), jumps...) - end - end - else - jump_f = let _f = _f - function (u::ExtendedJumpArray, p, t) - du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u) - update_jumps!(du, u, p, t, length(u.u), jumps...) - return du - end - end - end - - if prob.noise_rate_prototype === nothing - jump_g = function (du, u, p, t) - prob.g(du.u, u.u, p, t) - end - else - jump_g = function (du, u, p, t) - prob.g(du, u.u, p, t) - end - end - - u0 = extend_u0(prob, length(jumps), rng) - f = SDEFunction{isinplace(prob)}(jump_f, jump_g; sys = prob.f.sys, - observed = prob.f.observed) - remake(prob; f, g = jump_g, u0) -end - -function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAULT_RNG) - _f = SciMLBase.unwrapped_f(prob.f) - - if isinplace(prob) - jump_f = let _f = _f - function (du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t) - _f(du.u, u.u, h, p, t) - update_jumps!(du, u, p, t, length(u.u), jumps...) - end - end - else - jump_f = let _f = _f - function (u::ExtendedJumpArray, h, p, t) - du = ExtendedJumpArray(_f(u.u, h, p, t), u.jump_u) - update_jumps!(du, u, p, t, length(u.u), jumps...) - return du - end - end - end - - u0 = extend_u0(prob, length(jumps), rng) - f = DDEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, - observed = prob.f.observed) - remake(prob; f, u0) -end - -# Not sure if the DAE one is correct: Should be a residual of sorts -function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAULT_RNG) - _f = SciMLBase.unwrapped_f(prob.f) - - if isinplace(prob) - jump_f = let _f = _f - function (out, du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t) - _f(out, du.u, u.u, h, p, t) - update_jumps!(out, u, p, t, length(u.u), jumps...) - end - end - else - jump_f = let _f = _f - function (du, u::ExtendedJumpArray, h, p, t) - out = ExtendedJumpArray(_f(du.u, u.u, h, p, t), u.jump_u) - update_jumps!(du, u, p, t, length(u.u), jumps...) - return du - end - end - end - - u0 = extend_u0(prob, length(jumps), rng) - f = DAEFunction{isinplace(prob)}(jump_f, sys = prob.f.sys, - observed = prob.f.observed) - remake(prob; f, u0) -end - -function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG) - condition = function(u, t, integrator) - u.jump_u[idx] - end - affect! = function(integrator) - jump.affect!(integrator) - integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t)) - nothing - end - new_cb = ContinuousCallback(condition, affect!; - idxs = jump.idxs, - rootfind = jump.rootfind, - interp_points = jump.interp_points, - save_positions = jump.save_positions, - abstol = jump.abstol, - reltol = jump.reltol) - return new_cb -end - -function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG) - idx += 1 - new_cb = wrap_jump_in_callback(idx, jump; rng) - build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG) -end - -function build_variable_callback(cb, idx, jump; rng = DEFAULT_RNG) - idx += 1 - CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng)) -end - aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A @inline function extend_tstops!(tstops, @@ -458,17 +305,6 @@ aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A push!(tstops, jp.jump_callback.discrete_callbacks[1].condition.next_jump_time) end -@inline function update_jumps!(du, u, p, t, idx, jump) - idx += 1 - du[idx] = jump.rate(u.u, p, t) -end - -@inline function update_jumps!(du, u, p, t, idx, jump, jumps...) - idx += 1 - du[idx] = jump.rate(u.u, p, t) - update_jumps!(du, u, p, t, idx, jumps...) -end - ### Displays num_constant_rate_jumps(aggregator::AbstractSSAJumpAggregator) = length(aggregator.rates) diff --git a/src/solve.jl b/src/solve.jl index dfd545e2..64b5c721 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -56,8 +56,7 @@ function resetted_jump_problem(_jump_prob, seed) end end - if !isempty(jump_prob.variable_jumps) - @assert jump_prob.prob.u0 isa ExtendedJumpArray + if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray randexp!(_jump_prob.rng, jump_prob.prob.u0.jump_u) jump_prob.prob.u0.jump_u .*= -1 end @@ -69,9 +68,8 @@ function reset_jump_problem!(jump_prob, seed) Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, seed) end - if !isempty(jump_prob.variable_jumps) - @assert jump_prob.prob.u0 isa ExtendedJumpArray + if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray randexp!(jump_prob.rng, jump_prob.prob.u0.jump_u) jump_prob.prob.u0.jump_u .*= -1 end -end +end \ No newline at end of file diff --git a/src/variable_rate.jl b/src/variable_rate.jl new file mode 100644 index 00000000..bb844e5d --- /dev/null +++ b/src/variable_rate.jl @@ -0,0 +1,418 @@ +""" +$(TYPEDEF) + +An abstract type for aggregators that manage the simulation of `VariableRateJump`s in jump processes. + +## Notes +- In hybrid ODE/SDE systems with general `VariableRateJump`s, `integrator.u` may be an + `ExtendedJumpArray` for some aggregators. +""" +abstract type VariableRateAggregator end + + +################################### VR_FRM #################################### + +""" +$(TYPEDEF) + +A concrete `VariableRateAggregator` implementing a first-reaction method variant for +simulating `VariableRateJump`s. `VR_FRM` (Variable Rate First Reaction Method with +Ordinary Differential Equation) uses a user-selected ODE solver to handle integrating each +jump's intensity / propensity. A callback is also used for each jump to determine when its +integrated intensity reaches a level corresponding to a firing time, and to then execute the +affect associated with the jump at that time. + +## Examples +Simulating a birth-death process with `VR_FRM`: +```julia +using JumpProcesses, OrdinaryDiffEq +u0 = [1.0] # Initial population +p = [10.0, 0.5] # [birth rate, death rate] +tspan = (0.0, 10.0) + +# Birth jump: ∅ → X +birth_rate(u, p, t) = p[1] +birth_affect!(integrator) = (integrator.u[1] += 1; nothing) +birth_jump = VariableRateJump(birth_rate, birth_affect!) + +# Death jump: X → ∅ +death_rate(u, p, t) = p[2] * u[1] +death_affect!(integrator) = (integrator.u[1] -= 1; nothing) +death_jump = VariableRateJump(death_rate, death_affect!) + +# Problem setup +oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p) +jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_FRM()) +sol = solve(jprob, Tsit5()) +``` + +## Notes +- Specify `VR_FRM` in a `JumpProblem` via the `vr_aggregator` keyword argument to select + its use for handling `VariableRateJump`s. +- While robust, it may be less performant than `VR_Direct` due to its integration of each + individual jump's intensity, and use of one continuous callback per jump to handle + detection of jump times and implementation of state changes from that jump. +""" +struct VR_FRM <: VariableRateAggregator end + +function configure_jump_problem(prob, vr_aggregator::VR_FRM, jumps, cvrjs; + rng = DEFAULT_RNG) + new_prob = extend_problem(prob, cvrjs; rng) + variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng) + cont_agg = cvrjs + return new_prob, variable_jump_callback, cont_agg +end + +# extends prob.u0 to an ExtendedJumpArray with Njumps integrated intensity values, +# of type prob.tspan +function extend_u0(prob, Njumps, rng) + ttype = eltype(prob.tspan) + u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:Njumps]) + return u0 +end + +function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps; rng = DEFAULT_RNG) + error("General `VariableRateJump`s require a continuous problem, like an ODE/SDE/DDE/DAE problem. To use a `DiscreteProblem` bounded `VariableRateJump`s must be used. See the JumpProcesses docs.") +end + +function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAULT_RNG) + _f = SciMLBase.unwrapped_f(prob.f) + + if isinplace(prob) + jump_f = let _f = _f + function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t) + _f(du.u, u.u, p, t) + update_jumps!(du, u, p, t, length(u.u), jumps...) + end + end + else + jump_f = let _f = _f + function (u::ExtendedJumpArray, p, t) + du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u) + update_jumps!(du, u, p, t, length(u.u), jumps...) + return du + end + end + end + + u0 = extend_u0(prob, length(jumps), rng) + f = ODEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, + observed = prob.f.observed) + remake(prob; f, u0) +end + +function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAULT_RNG) + _f = SciMLBase.unwrapped_f(prob.f) + + if isinplace(prob) + jump_f = let _f = _f + function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t) + _f(du.u, u.u, p, t) + update_jumps!(du, u, p, t, length(u.u), jumps...) + end + end + else + jump_f = let _f = _f + function (u::ExtendedJumpArray, p, t) + du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u) + update_jumps!(du, u, p, t, length(u.u), jumps...) + return du + end + end + end + + if prob.noise_rate_prototype === nothing + jump_g = function (du, u, p, t) + prob.g(du.u, u.u, p, t) + end + else + jump_g = function (du, u, p, t) + prob.g(du, u.u, p, t) + end + end + + u0 = extend_u0(prob, length(jumps), rng) + f = SDEFunction{isinplace(prob)}(jump_f, jump_g; sys = prob.f.sys, + observed = prob.f.observed) + remake(prob; f, g = jump_g, u0) +end + +function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAULT_RNG) + _f = SciMLBase.unwrapped_f(prob.f) + + if isinplace(prob) + jump_f = let _f = _f + function (du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t) + _f(du.u, u.u, h, p, t) + update_jumps!(du, u, p, t, length(u.u), jumps...) + end + end + else + jump_f = let _f = _f + function (u::ExtendedJumpArray, h, p, t) + du = ExtendedJumpArray(_f(u.u, h, p, t), u.jump_u) + update_jumps!(du, u, p, t, length(u.u), jumps...) + return du + end + end + end + + u0 = extend_u0(prob, length(jumps), rng) + f = DDEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, + observed = prob.f.observed) + remake(prob; f, u0) +end + +# Not sure if the DAE one is correct: Should be a residual of sorts +function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAULT_RNG) + _f = SciMLBase.unwrapped_f(prob.f) + + if isinplace(prob) + jump_f = let _f = _f + function (out, du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t) + _f(out, du.u, u.u, h, p, t) + update_jumps!(out, u, p, t, length(u.u), jumps...) + end + end + else + jump_f = let _f = _f + function (du, u::ExtendedJumpArray, h, p, t) + out = ExtendedJumpArray(_f(du.u, u.u, h, p, t), u.jump_u) + update_jumps!(du, u, p, t, length(u.u), jumps...) + return du + end + end + end + + u0 = extend_u0(prob, length(jumps), rng) + f = DAEFunction{isinplace(prob)}(jump_f, sys = prob.f.sys, + observed = prob.f.observed) + remake(prob; f, u0) +end + +function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG) + condition = function(u, t, integrator) + u.jump_u[idx] + end + affect! = function(integrator) + jump.affect!(integrator) + integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t)) + nothing + end + new_cb = ContinuousCallback(condition, affect!; + idxs = jump.idxs, + rootfind = jump.rootfind, + interp_points = jump.interp_points, + save_positions = jump.save_positions, + abstol = jump.abstol, + reltol = jump.reltol) + return new_cb +end + +function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG) + idx += 1 + new_cb = wrap_jump_in_callback(idx, jump; rng) + build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG) +end + +function build_variable_callback(cb, idx, jump; rng = DEFAULT_RNG) + idx += 1 + CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng)) +end + +@inline function update_jumps!(du, u, p, t, idx, jump) + idx += 1 + du[idx] = jump.rate(u.u, p, t) +end + +@inline function update_jumps!(du, u, p, t, idx, jump, jumps...) + idx += 1 + du[idx] = jump.rate(u.u, p, t) + update_jumps!(du, u, p, t, idx, jumps...) +end + +################################### VR_Direct #################################### + +""" +$(TYPEDEF) + +A concrete `VariableRateAggregator` implementing a direct method-based approach for +simulating `VariableRateJump`s. `VR_Direct` (Variable Rate Direct Callback) efficiently +samples jump times using one continuous callback to integrate the total intensity / +propensity for all `VariableRateJump`s, sample when the next jump occurs, and then sample +which jump occurs at this time. + +## Examples +Simulating a birth-death process with `VR_Direct` (default): +```julia +using JumpProcesses, OrdinaryDiffEq +u0 = [1.0] # Initial population +p = [10.0, 0.5] # [birth rate, death rate coefficient] +tspan = (0.0, 10.0) + +# Birth jump: ∅ → X +birth_rate(u, p, t) = p[1] +birth_affect!(integrator) = (integrator.u[1] += 1; nothing) +birth_jump = VariableRateJump(birth_rate, birth_affect!) + +# Death jump: X → ∅ +death_rate(u, p, t) = p[2] * u[1] +death_affect!(integrator) = (integrator.u[1] -= 1; nothing) +death_jump = VariableRateJump(death_rate, death_affect!) + +# Problem setup +oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p) +jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_Direct()) +sol = solve(jprob, Tsit5()) +``` + +## Notes +- `VR_Direct` is expected to generally be more performant than `VR_FRM`. +""" +struct VR_Direct <: VariableRateAggregator end + +mutable struct VR_DirectEventCache{T, RNG <: AbstractRNG} + prev_time::T + prev_threshold::T + current_time::T + current_threshold::T + total_rate_cache::T + rng::RNG + variable_jumps::Tuple{Vararg{VariableRateJump}} + cur_rates::Vector{T} + + function VR_DirectEventCache(jumps::JumpSet, ::Type{T}; rng = DEFAULT_RNG) where T + initial_threshold = randexp(rng, T) + vjumps = jumps.variable_jumps + cur_rates = Vector{T}(undef, length(vjumps)) + new{T, typeof(rng)}(zero(T), initial_threshold, zero(T), initial_threshold, + zero(T), rng, vjumps, cur_rates) + end +end + +# Initialization function for VR_DirectEventCache +function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrator) + cache.prev_time = zero(integrator.t) + cache.current_time = zero(integrator.t) + cache.prev_threshold = randexp(cache.rng, eltype(integrator.t)) + cache.current_threshold = cache.prev_threshold + cache.total_rate_cache = zero(integrator.t) + fill!(cache.cur_rates, zero(integrator.t)) + nothing +end + +# Wrapper for initialize to match ContinuousCallback signature +function initialize_vr_direct_wrapper(cb::ContinuousCallback, u, t, integrator) + initialize_vr_direct_cache!(cb.condition, u, t, integrator) + u_modified!(integrator, false) + nothing +end + + +# Merge callback parameters across all jumps for VR_Direct +function build_variable_integcallback(cache::VR_DirectEventCache, jumps::Tuple) + save_positions = (false, false) + abstol = jumps[1].abstol + reltol = jumps[1].reltol + + for jump in jumps + save_positions = save_positions .|| jump.save_positions + abstol = min(abstol, jump.abstol) + reltol = min(reltol, jump.reltol) + end + + return ContinuousCallback(cache, cache; initialize = initialize_vr_direct_wrapper, + save_positions, abstol, reltol) +end + +function configure_jump_problem(prob, vr_aggregator::VR_Direct, jumps, cvrjs; + rng = DEFAULT_RNG) + new_prob = prob + cache = VR_DirectEventCache(jumps, eltype(prob.tspan); rng) + variable_jump_callback = build_variable_integcallback(cache, cvrjs) + cont_agg = cvrjs + return new_prob, variable_jump_callback, cont_agg +end + +function total_variable_rate(vjumps, u, p, t, cur_rates, idx=1, prev_rate = zero(t)) + if idx > length(cur_rates) + return prev_rate + end + @inbounds begin + new_rate = vjumps[idx].rate(u, p, t) + sum_rate = add_fast(new_rate, prev_rate) + cur_rates[idx] = sum_rate + return total_variable_rate(vjumps, u, p, t, cur_rates, idx + 1, sum_rate) + end +end + +# how many quadrature points to use (i.e. determines the degree of the quadrature rule) +const NUM_GAUSS_QUAD_NODES = 4 + +# Condition functor defined directly on the cache +function (cache::VR_DirectEventCache)(u, t, integrator) + if integrator.t < cache.current_time + error("integrator.t < cache.current_time. $(integrator.t) < $(cache.current_time). This is not supported in the `VR_Direct` handling") + end + + if integrator.t != cache.current_time + cache.prev_time = cache.current_time + cache.prev_threshold = cache.current_threshold + cache.current_time = integrator.t + end + + dt = t - cache.prev_time + if dt == 0 + return cache.prev_threshold + end + + vjumps = cache.variable_jumps + cur_rates = cache.cur_rates + p = integrator.p + rate_increment = zero(t) + gps = gauss_points[NUM_GAUSS_QUAD_NODES] + weights = gauss_weights[NUM_GAUSS_QUAD_NODES] + tmid = (t + cache.prev_time) / 2 + halfdt = dt / 2 + for (i,τᵢ) in enumerate(gps) + τ = halfdt * τᵢ + tmid + u_τ = integrator(τ) + total_variable_rate_τ = total_variable_rate(vjumps, u_τ, p, τ, cur_rates) + rate_increment += weights[i] * total_variable_rate_τ + end + rate_increment *= halfdt + + cache.current_threshold = cache.prev_threshold - rate_increment + + return cache.current_threshold +end + +function execute_affect!(vjumps, integrator, idx) + @inbounds vjumps[idx].affect!(integrator) +end + +# Affect functor defined directly on the cache +function (cache::VR_DirectEventCache)(integrator) + t = integrator.t + u = integrator.u + p = integrator.p + rng = cache.rng + + cache.total_rate_cache = total_variable_rate(cache.variable_jumps, u, p, t, cache.cur_rates) + total_variable_rate_sum = cache.total_rate_cache + if total_variable_rate_sum <= 0 + return nothing + end + + r = rand(rng) * total_variable_rate_sum + vjumps = cache.variable_jumps + + @inbounds jump_idx = searchsortedfirst(cache.cur_rates, r) + execute_affect!(vjumps, integrator, jump_idx) + + cache.prev_time = t + cache.current_threshold = randexp(rng) + cache.prev_threshold = cache.current_threshold + cache.current_time = t + return nothing +end diff --git a/test/extended_jump_array.jl b/test/extended_jump_array.jl index 4bc52782..c90db6cd 100644 --- a/test/extended_jump_array.jl +++ b/test/extended_jump_array.jl @@ -72,7 +72,7 @@ oop_test_jump = VariableRateJump(oop_test_rate, oop_test_affect!) # Test in-place u₀ = [0.0] inplace_prob = ODEProblem((du, u, p, t) -> (du .= 0), u₀, (0.0, 2.0), nothing) -jump_prob = JumpProblem(inplace_prob, Direct(), oop_test_jump) +jump_prob = JumpProblem(inplace_prob, Direct(), oop_test_jump; vr_aggregator = VR_FRM()) sol = solve(jump_prob, Tsit5()) @test sol.retcode == ReturnCode.Success sol.u @@ -91,7 +91,7 @@ let rate(u, p, t) = u[1] affect!(integrator) = (integrator.u.u[1] = integrator.u.u[1] / 2; nothing) jump = VariableRateJump(rate, affect!) - jump_prob = JumpProblem(prob, Direct(), jump) + jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM()) sol = solve(jump_prob, Tsit5(); saveat = 0.5) times = range(0.0, 10.0; step = 0.5) @test issubset(times, sol.t) @@ -115,7 +115,7 @@ let end u₀ = [0, 0] oprob = ODEProblem(f!, u₀, (0.0, 10.0), p) - jprob = JumpProblem(oprob, Direct(), vrj, deathvrj) + jprob = JumpProblem(oprob, Direct(), vrj, deathvrj; vr_aggregator = VR_FRM()) sol = solve(jprob, Tsit5()) @test eltype(sol.u) <: ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}} @test SciMLBase.plottable_indices(sol.u[1]) == 1:length(u₀) diff --git a/test/geneexpr_test.jl b/test/geneexpr_test.jl index 1728a181..23e6ddf0 100644 --- a/test/geneexpr_test.jl +++ b/test/geneexpr_test.jl @@ -31,15 +31,16 @@ function runSSAs(jump_prob; use_stepper = true) mean(Psamp) end -function runSSAs_ode(jump_prob) - Psamp = zeros(Int, Nsims) +function runSSAs_ode(vrjprob) + Psamp = zeros(Float64, Nsims) for i in 1:Nsims - sol = solve(jump_prob, Tsit5(); saveat = jump_prob.prob.tspan[2]) + sol = solve(vrjprob, Tsit5(); saveat = vrjprob.prob.tspan[2]) Psamp[i] = sol[3, end] end - mean(Psamp) + return mean(Psamp) end + # MODEL SETUP # DNA repression model DiffEqBiological @@ -184,7 +185,10 @@ let crjmean = runSSAs(crjprob) f(du, u, p, t) = (du .= 0; nothing) oprob = ODEProblem(f, u0f, (0.0, tf / 5), rates) - vrjprob = JumpProblem(oprob, vrjs; save_positions = (false, false), rng) - vrjmean = runSSAs_ode(vrjprob) - @test abs(vrjmean - crjmean) < reltol * crjmean + + for vr_agg in (VR_FRM(), VR_Direct()) + vrjprob = JumpProblem(oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false), rng) + vrjmean = runSSAs_ode(vrjprob) + @test abs(vrjmean - crjmean) < reltol * crjmean + end end diff --git a/test/hawkes_test.jl b/test/hawkes_test.jl index 7e4623dd..35bc1791 100644 --- a/test/hawkes_test.jl +++ b/test/hawkes_test.jl @@ -60,10 +60,8 @@ function hawkes_jump(u, g, h; uselrate = true) return [hawkes_jump(i, g, h; uselrate) for i in 1:length(u)] end -function hawkes_problem(p, agg::Coevolve; u = [0.0], - tspan = (0.0, 50.0), - save_positions = (false, true), - g = [[1]], h = [[]], uselrate = true) +function hawkes_problem(p, agg::Coevolve; u = [0.0], tspan = (0.0, 50.0), + save_positions = (false, true), g = [[1]], h = [[]], uselrate = true, kwargs...) dprob = DiscreteProblem(u, tspan, p) jumps = hawkes_jump(u, g, h; uselrate) jprob = JumpProblem(dprob, agg, jumps...; dep_graph = g, save_positions, rng) @@ -76,11 +74,10 @@ function f!(du, u, p, t) end function hawkes_problem(p, agg; u = [0.0], tspan = (0.0, 50.0), - save_positions = (false, true), - g = [[1]], h = [[]], kwargs...) + save_positions = (false, true), g = [[1]], h = [[]], vr_aggregator = VR_FRM(), kwargs...) oprob = ODEProblem(f!, u, tspan, p) jumps = hawkes_jump(u, g, h) - jprob = JumpProblem(oprob, agg, jumps...; save_positions, rng) + jprob = JumpProblem(oprob, agg, jumps...; vr_aggregator, save_positions, rng) return jprob end @@ -118,15 +115,16 @@ for (i, alg) in enumerate(algs) else stepper = Tsit5() end - sols = Vector{ODESolution}(undef, Nsims) + sols = Vector{ODESolution}(undef, Nsims) for n in 1:Nsims reset_history!(h) sols[n] = solve(jump_prob, stepper) end + if alg isa Coevolve λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols)) - else - cols = length(sols[1].u[1].u) + else + cols = length(u0) λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols))[:, 1:cols] end @test isapprox(mean(λs), Eλ; atol = 0.01) @@ -135,34 +133,44 @@ end # test stepping Coevolve with continuous integrator and bounded jumps let alg = Coevolve() - oprob = ODEProblem(f!, u0, tspan, p) - jumps = hawkes_jump(u0, g, h) - jprob = JumpProblem(oprob, alg, jumps...; dep_graph = g, rng) - @test ((jprob.variable_jumps === nothing) || isempty(jprob.variable_jumps)) - sols = Vector{ODESolution}(undef, Nsims) - for n in 1:Nsims - reset_history!(h) - sols[n] = solve(jprob, Tsit5()) + for vr_aggregator in (VR_FRM(), VR_Direct()) + oprob = ODEProblem(f!, u0, tspan, p) + jumps = hawkes_jump(u0, g, h) + jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng) + @test ((jprob.variable_jumps === nothing) || isempty(jprob.variable_jumps)) + sols = Vector{ODESolution}(undef, Nsims) + for n in 1:Nsims + reset_history!(h) + sols[n] = solve(jprob, Tsit5()) + end + λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols)) + @test isapprox(mean(λs), Eλ; atol = 0.01) + @test isapprox(var(λs), Varλ; atol = 0.001) end - λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols)) - @test isapprox(mean(λs), Eλ; atol = 0.01) - @test isapprox(var(λs), Varλ; atol = 0.001) end # test disabling bounded jumps and using continuous integrator +Nsims = 500 let alg = Coevolve() - oprob = ODEProblem(f!, u0, tspan, p) - jumps = hawkes_jump(u0, g, h) - jprob = JumpProblem(oprob, alg, jumps...; dep_graph = g, rng, - use_vrj_bounds = false) - @test length(jprob.variable_jumps) == 1 - sols = Vector{ODESolution}(undef, Nsims) - for n in 1:Nsims - reset_history!(h) - sols[n] = solve(jprob, Tsit5()) + for vr_aggregator in (VR_FRM(), VR_Direct()) + oprob = ODEProblem(f!, u0, tspan, p) + jumps = hawkes_jump(u0, g, h) + jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng, + use_vrj_bounds = false) + @test length(jprob.variable_jumps) == 1 + sols = Vector{ODESolution}(undef, Nsims) + for n in 1:Nsims + reset_history!(h) + sols[n] = solve(jprob, Tsit5()) + end + + cols = length(u0) + if vr_aggregator isa VR_FRM + λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols))[:, 1:cols] + else + λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols)) + end + @test isapprox(mean(λs), Eλ; atol = 0.01) + @test isapprox(var(λs), Varλ; atol = 0.001) end - cols = length(sols[1].u[1].u) - λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols))[:, 1:cols] - @test isapprox(mean(λs), Eλ; atol = 0.01) - @test isapprox(var(λs), Varλ; atol = 0.001) -end +end \ No newline at end of file diff --git a/test/monte_carlo_test.jl b/test/monte_carlo_test.jl index 0563ea07..2235582a 100644 --- a/test/monte_carlo_test.jl +++ b/test/monte_carlo_test.jl @@ -8,15 +8,21 @@ prob = SDEProblem(f, g, [1.0], (0.0, 1.0)) rate = (u, p, t) -> 200.0 affect! = integrator -> (integrator.u[1] = integrator.u[1] / 2) jump = VariableRateJump(rate, affect!, save_positions = (false, true)) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM(), rng) monte_prob = EnsembleProblem(jump_prob) sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, save_everystep = false, dt = 0.001, adaptive = false) @test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2] +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_Direct(), rng) +monte_prob = EnsembleProblem(jump_prob) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false) +@test allunique(sol.u[1].t) + jump = ConstantRateJump(rate, affect!) -jump_prob = JumpProblem(prob, Direct(), jump, save_positions = (true, false), rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump; save_positions = (true, false), rng) monte_prob = EnsembleProblem(jump_prob) sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, save_everystep = false, dt = 0.001, adaptive = false) -@test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2] +@test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2] \ No newline at end of file diff --git a/test/remake_test.jl b/test/remake_test.jl index 676e615d..2d5512d7 100644 --- a/test/remake_test.jl +++ b/test/remake_test.jl @@ -75,7 +75,10 @@ let rrate(u, p, t) = u[1] aaffect!(integrator) = (integrator.u[1] += 1; nothing) vrj = VariableRateJump(rrate, aaffect!) - jprob = JumpProblem(prob, vrj; rng) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng) + sol = solve(jprob, Tsit5()) + @test all(==(0.0), sol[1, :]) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng) sol = solve(jprob, Tsit5()) @test all(==(0.0), sol[1, :]) u0 = [4.0] @@ -101,7 +104,10 @@ let rrate(u, p, t) = u[1] aaffect!(integrator) = (integrator.u[1] += 1; nothing) vrj = VariableRateJump(rrate, aaffect!) - jprob = JumpProblem(prob, vrj; rng) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng) + sol = solve(jprob, Tsit5()) + @test all(==(0.0), sol[1, :]) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng) sol = solve(jprob, Tsit5()) @test all(==(0.0), sol[1, :]) u0 = [4.0] diff --git a/test/save_positions.jl b/test/save_positions.jl index 1e5ddc40..e9194557 100644 --- a/test/save_positions.jl +++ b/test/save_positions.jl @@ -22,7 +22,12 @@ let oprob = ODEProblem((du, u, p, t) -> 0, u0, tspan) jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1; urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0) - jumpproblem = JumpProblem(oprob, alg, jump; dep_graph = [[1]], + jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = VR_Direct(), dep_graph = [[1]], + save_positions = (false, true), rng) + sol = solve(jumpproblem, Tsit5(); save_everystep = false) + @test sol.t == [0.0, 30.0] + + jumpproblem = JumpProblem(oprob, alg, jump; vr_aggregator = VR_FRM(), dep_graph = [[1]], save_positions = (false, true), rng) sol = solve(jumpproblem, Tsit5(); save_everystep = false) @test sol.t == [0.0, 30.0] diff --git a/test/thread_safety.jl b/test/thread_safety.jl index 4be4d739..2c886672 100644 --- a/test/thread_safety.jl +++ b/test/thread_safety.jl @@ -22,14 +22,15 @@ let end u_0 = [1.0] ode_prob = ODEProblem(f!, u_0, (0.0, 10)) - rate(u, p, t) = 1.0 - jump!(integrator) = nothing - jump_prob = JumpProblem(ode_prob, Direct(), VariableRateJump(rate, jump!)) - prob_func(prob, i, repeat) = deepcopy(prob) - prob = EnsembleProblem(jump_prob,prob_func = prob_func) - solve(prob, Tsit5(), EnsembleThreads(), trajectories=10) + vrj = VariableRateJump((u,p,t) -> 1.0, integrator -> nothing) - sol = solve(prob, Tsit5(), EnsembleThreads(), trajectories=400) - init_props = [sol[i].u[1][2] for i = 1:length(sol)] - @test allunique(init_props) + for agg in (VR_FRM(), VR_Direct()) + jump_prob = JumpProblem(ode_prob, Direct(), vrj; vr_aggregator = agg) + prob_func(prob, i, repeat) = deepcopy(prob) + prob = EnsembleProblem(jump_prob,prob_func = prob_func) + sol = solve(prob, Tsit5(), EnsembleThreads(), trajectories=400, + save_everystep = false) + firstrx_time = [sol.u[i].t[2] for i = 1:length(sol)] + @test allunique(firstrx_time) + end end \ No newline at end of file diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 39c8d55b..20b012c0 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -1,5 +1,5 @@ using DiffEqBase, JumpProcesses, OrdinaryDiffEq, StochasticDiffEq, Test -using Random, LinearSolve +using Random, LinearSolve, Statistics using StableRNGs rng = StableRNG(12345) @@ -30,29 +30,28 @@ f = function (du, u, p, t) end prob = ODEProblem(f, [0.2], (0.0, 10.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) - +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) integrator = init(jump_prob, Tsit5()) - sol = solve(jump_prob, Tsit5()) sol = solve(jump_prob, Rosenbrock23(autodiff = false)) sol = solve(jump_prob, Rosenbrock23()) -# @show sol[end] -# display(sol[end]) - +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng) +integrator = init(jump_prob_gill, Tsit5()) +sol_gill = solve(jump_prob_gill, Tsit5()) +sol_gill = solve(jump_prob, Rosenbrock23(autodiff = false)) +sol_gill = solve(jump_prob, Rosenbrock23()) @test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12 @test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1e-12 g = function (du, u, p, t) du[1] = u[1] end - prob = SDEProblem(f, g, [0.2], (0.0, 10.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) - -sol = solve(jump_prob, SRIW1()) - +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng = rng) +sol = solve(jump_prob, SRIW1()) +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng=rng) +sol_gill = solve(jump_prob_gill, SRIW1()) @test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12 @test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1e-12 @@ -63,38 +62,36 @@ function ff(du, u, p, t) du .= 2.01u end end - function gg(du, u, p, t) du[1, 1] = 0.3u[1] du[1, 2] = 0.6u[1] du[2, 1] = 1.2u[1] du[2, 2] = 0.2u[2] end - rate_switch(u, p, t) = u[1] * 1.0 - function affect_switch!(integrator) integrator.p = 1 end - jump_switch = VariableRateJump(rate_switch, affect_switch!) - prob = SDEProblem(ff, gg, ones(2), (0.0, 1.0), 0, noise_rate_prototype = zeros(2, 2)) -jump_prob = JumpProblem(prob, Direct(), jump_switch; rng = rng) -solve(jump_prob, SRA1(), dt = 1.0) +jump_prob = JumpProblem(prob, jump_switch; vr_aggregator = VR_FRM(), rng) +jump_prob_gill = JumpProblem(prob, jump_switch; vr_aggregator = VR_Direct(), rng) +sol = solve(jump_prob, SRA1(), dt = 1.0) +sol_gill = solve(jump_prob_gill, SRA1(), dt = 1.0) ## Some integration tests function f2(du, u, p, t) du[1] = u[1] end - prob = ODEProblem(f2, [0.2], (0.0, 10.0)) rate2(u, p, t) = 2 affect2!(integrator) = (integrator.u[1] = integrator.u[1] / 2) jump = ConstantRateJump(rate2, affect2!) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) +jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) +jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct(), rng) sol = solve(jump_prob, Tsit5()) +sol_gill = solve(jump_prob_gill, Tsit5()) sol(4.0) sol.u[4] @@ -102,25 +99,27 @@ rate2b(u, p, t) = u[1] affect2!(integrator) = (integrator.u[1] = integrator.u[1] / 2) jump = VariableRateJump(rate2b, affect2!) jump2 = deepcopy(jump) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng) sol = solve(jump_prob, Tsit5()) +sol_gill = solve(jump_prob_gill, Tsit5()) sol(4.0) sol.u[4] function g2(du, u, p, t) du[1] = u[1] end - prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng) sol = solve(jump_prob, SRIW1()) +sol_gill = solve(jump_prob_gill, SRIW1()) sol(4.0) sol.u[4] function f3(du, u, p, t) du .= u end - prob = ODEProblem(f3, [1.0 2.0; 3.0 4.0], (0.0, 1.0)) rate3(u, p, t) = u[1] + u[2] affect3!(integrator) = (integrator.u[1] = 0.25; @@ -128,8 +127,10 @@ integrator.u[2] = 0.5; integrator.u[3] = 0.75; integrator.u[4] = 1) jump = VariableRateJump(rate3, affect3!) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) +jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) +jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct(), rng) sol = solve(jump_prob, Tsit5()) +sol_gill = solve(jump_prob_gill, Tsit5()) # test for https://discourse.julialang.org/t/differentialequations-jl-package-variable-rate-jumps-with-complex-variables/80366/2 function f4(dx, x, p, t) @@ -143,26 +144,19 @@ jump = VariableRateJump(rate4, affect4!) x₀ = 1.0 + 0.0im Δt = (0.0, 6.0) prob = ODEProblem(f4, [x₀], Δt) -jumpProblem = JumpProblem(prob, Direct(), jump) -sol = solve(jumpProblem, Tsit5()) +jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) +jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct(), rng) +sol = solve(jump_prob, Tsit5()) +sol_gill = solve(jump_prob_gill, Tsit5()) # Out of place test - -function drift(x, p, t) - return p * x -end - -function rate2c(x, p, t) - return 3 * max(0.0, x[1]) -end - -function affect!2(integrator) - integrator.u ./= 2 -end +drift(x, p, t) = p * x +rate2c(x, p, t) = 3 * max(0.0, x[1]) +affect!2(integrator) = (integrator.u ./= 2; nothing) x0 = rand(2) prob = ODEProblem(drift, x0, (0.0, 10.0), 2.0) jump = VariableRateJump(rate2c, affect!2) -jump_prob = JumpProblem(prob, Direct(), jump) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM(), rng) # test to check lack of dependency graphs is caught in Coevolve for systems with non-maj # jumps @@ -260,7 +254,7 @@ let d_jump = VariableRateJump(d_rate, death!) ode_prob = ODEProblem(ode_fxn, u0, tspan, p) - sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; rng) + sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator = VR_FRM(), rng) @test allunique(sjm_prob.prob.u0.jump_u) u0old = copy(sjm_prob.prob.u0.jump_u) for i in 1:Nsims @@ -318,12 +312,127 @@ let d_jump = VariableRateJump(d_rate, death!) ode_prob = ODEProblem(ode_fxn, u0, tspan, p) - sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; rng) dt = 0.1 tsave = range(tspan[1], tspan[2]; step = dt) - for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization())) - umean = getmean(Nsims, sjm_prob, alg, dt, tsave, seed) - @test all(abs.(umean .- n.(tsave)) .< 0.05 * n.(tsave)) - seed += Nsims + for vr_aggregator in (VR_FRM(), VR_Direct()) + sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator, rng) + + for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization())) + umean = getmean(Nsims, sjm_prob, alg, dt, tsave, seed) + @test all(abs.(umean .- n.(tsave)) .< 0.05 * n.(tsave)) + seed += Nsims + end + end +end + +# Correctness test based on +# VR_Direct and VR_FRM +# Function to run ensemble and compute statistics +function run_ensemble(prob, alg, jumps...; vr_aggregator=VR_FRM(), Nsims=8000) + rng = StableRNG(12345) + jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator, rng) + ensemble = EnsembleProblem(jump_prob) + sol = solve(ensemble, alg, trajectories=Nsims, save_everystep=false) + return mean(sol.u[i][1,end] for i in 1:Nsims) +end + +# Test 1: Simple ODE with two variable rate jumps +let + rate = (u, p, t) -> u[1] + affect! = (integrator) -> (integrator.u[1] = integrator.u[1] / 2) + jump = VariableRateJump(rate, affect!, interp_points=1000) + jump2 = deepcopy(jump) + + f = (du, u, p, t) -> (du[1] = u[1]) + prob = ODEProblem(f, [0.2], (0.0, 10.0)) + + mean_vrfr = run_ensemble(prob, Tsit5(), jump, jump2) + mean_vrdcb = run_ensemble(prob, Tsit5(), jump, jump2; vr_aggregator=VR_Direct()) + + @test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05) +end + +# Test 2: SDE with two variable rate jumps +let + f = (du, u, p, t) -> (du[1] = -u[1] / 10.0) + g = (du, u, p, t) -> (du[1] = -u[1] / 10.0) + rate = (u, p, t) -> u[1] / 10.0 + affect! = (integrator) -> (integrator.u[1] = integrator.u[1] + 1) + jump = VariableRateJump(rate, affect!) + jump2 = deepcopy(jump) + + prob = SDEProblem(f, g, [10.0], (0.0, 10.0)) + + mean_vrfr = run_ensemble(prob, SRIW1(), jump, jump2) + mean_vrdcb = run_ensemble(prob, SRIW1(), jump, jump2; vr_aggregator=VR_Direct()) + + @test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05) +end + +# Test 3: ODE with analytical solution +let + λ = 2.0 + f = (du, u, p, t) -> (du[1] = -u[1]; nothing) + rate = (u, p, t) -> λ + affect! = (integrator) -> (integrator.u[1] += 1; nothing) + jump = VariableRateJump(rate, affect!) + + prob = ODEProblem(f, [0.2], (0.0, 10.0)) + + mean_vrfr = run_ensemble(prob, Tsit5(), jump) + mean_vrdcb = run_ensemble(prob, Tsit5(), jump; vr_aggregator = VR_Direct()) + + t = 10.0 + u0 = 0.2 + analytical_mean = u0 * exp(-t) + λ*(1 - exp(-t)) + + @test isapprox(mean_vrfr, analytical_mean, rtol=0.05) + @test isapprox(mean_vrfr, mean_vrdcb, rtol=0.05) +end + +# Test 4: No. of Jumps +let + f(du, u, p, t) = (du[1] = 0.0; nothing) + + # Define birth jump: ∅ → X + birth_rate(u, p, t) = 10.0 + function birth_affect!(integrator) + integrator.u[1] += 1 + integrator.p[3] += 1 + nothing end + birth_jump = VariableRateJump(birth_rate, birth_affect!) + + # Define death jump: X → ∅ + death_rate(u, p, t) = 0.5 * u[1] + function death_affect!(integrator) + integrator.u[1] -= 1 + integrator.p[3] += 1 + nothing + end + death_jump = VariableRateJump(death_rate, death_affect!) + + Nsims = 100 + results = Dict() + u0 = [1.0] + tspan = (0.0, 10.0) + for vr_aggregator in (VR_FRM(), VR_Direct()) + jump_counts = zeros(Int, Nsims) + p = [0.0, 0.0, 0] + prob = ODEProblem(f, u0, tspan, p) + jump_prob = JumpProblem(prob, Direct(), birth_jump, death_jump; vr_aggregator, rng) + + for i in 1:Nsims + sol = solve(jump_prob, Tsit5()) + jump_counts[i] = jump_prob.prob.p[3] + jump_prob.prob.p[3] = 0 + end + + results[vr_aggregator] = (mean_jumps=mean(jump_counts), jump_counts=jump_counts) + @test sum(jump_counts) > 1000 + end + + mean_jumps_vrfr = results[VR_FRM()].mean_jumps + mean_jumps_vrdcb = results[VR_Direct()].mean_jumps + @test isapprox(mean_jumps_vrfr, mean_jumps_vrdcb, rtol=0.1) end