Skip to content

Commit fb4740d

Browse files
Merge pull request #477 from sivasathyaseeelan/removing-exrendedjumparray
feat: Introducing vr_aggregator with VRDirectCB and VRFRMODE
2 parents a404c82 + 1bd8a9a commit fb4740d

13 files changed

+696
-298
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "9.14.3"
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
99
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
10+
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
1011
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1112
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1213
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -30,6 +31,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
3031
ArrayInterface = "7.9"
3132
DataStructures = "0.18"
3233
DiffEqBase = "6.154"
34+
DiffEqCallbacks = "4.3.0"
3335
DocStringExtensions = "0.9"
3436
FastBroadcast = "0.3"
3537
FunctionWrappers = "1.1"
@@ -46,7 +48,6 @@ UnPack = "1.0.2"
4648
julia = "1.10"
4749

4850
[extras]
49-
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
5051
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
5152
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5253
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
@@ -58,5 +59,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
5859
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5960

6061
[targets]
61-
test = ["DiffEqCallbacks", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq",
62-
"SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"]
62+
test = ["LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"]

src/JumpProcesses.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using SciMLBase: SciMLBase, isdenseplot
1111
using Base.FastMath: add_fast
1212
using Setfield: @set, @set!
1313

14+
import DiffEqCallbacks: gauss_points, gauss_weights
1415
import DiffEqBase: DiscreteCallback, init, solve, solve!, plot_indices, initialize!,
1516
get_tstops, get_tstops_array, get_tstops_max
1617
import Base: size, getindex, setindex!, length, similar, show, merge!, merge
@@ -21,6 +22,8 @@ import RecursiveArrayTools: recursivecopy!
2122
using StaticArrays, Base.Threads
2223
import SymbolicIndexingInterface as SII
2324

25+
import Random: AbstractRNG
26+
2427
abstract type AbstractJump end
2528
abstract type AbstractMassActionJump <: AbstractJump end
2629
abstract type AbstractAggregatorAlgorithm end
@@ -70,6 +73,7 @@ include("spatial/directcrdirect.jl")
7073
include("aggregators/aggregated_api.jl")
7174

7275
include("extended_jump_array.jl")
76+
include("variable_rate.jl")
7377
include("problem.jl")
7478
include("solve.jl")
7579
include("coupled_array.jl")
@@ -98,6 +102,9 @@ export reset_aggregated_jumps!
98102

99103
export ExtendedJumpArray
100104

105+
# Export VariableRateAggregator types
106+
export VariableRateAggregator, VR_FRM, VR_Direct
107+
101108
# spatial structs and functions
102109
export CartesianGrid, CartesianGridRej
103110
export SpatialMassActionJump

src/problem.jl

Lines changed: 19 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ $(TYPEDEF)
1313
1414
Defines a collection of jump processes to associate with another problem type.
1515
- [Documentation Page](https://docs.sciml.ai/JumpProcesses/stable/jump_types/)
16-
- [Tutorial Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/)
17-
- [FAQ Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/#FAQ)
16+
- [Tutorial
17+
Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/)
18+
- [FAQ
19+
Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/#FAQ)
1820
1921
### Constructors
2022
@@ -44,20 +46,21 @@ then be passed within a single [`JumpSet`](@ref) or as subsequent sequential arg
4446
$(FIELDS)
4547
4648
## Keyword Arguments
47-
- `rng`, the random number generator to use. Defaults to Julia's built-in
48-
generator.
49-
- `save_positions=(true,true)`, specifies whether to save the system's state (before, after)
50-
the jump occurs.
49+
- `rng`, the random number generator to use. Defaults to Julia's built-in generator.
50+
- `save_positions=(true,true)` when including variable rates and `(false,true)` for constant
51+
rates, specifies whether to save the system's state (before, after) the jump occurs.
5152
- `spatial_system`, for spatial problems the underlying spatial structure.
5253
- `hopping_constants`, for spatial problems the spatial transition rate coefficients.
53-
- `use_vrj_bounds = true`, set to false to disable handling bounded `VariableRateJump`s
54-
with a supporting aggregator (such as `Coevolve`). They will then be handled via the
55-
continuous integration interface, and treated like general `VariableRateJump`s.
54+
- `use_vrj_bounds = true`, set to false to disable handling bounded `VariableRateJump`s with
55+
a supporting aggregator (such as `Coevolve`). They will then be handled via the continuous
56+
integration interface, and treated like general `VariableRateJump`s.
57+
- `vr_aggregator`, indicates the aggregator to use for sampling variable rate jumps. Current
58+
default is `VR_FRM`.
5659
5760
Please see the [tutorial
58-
page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in the
59-
DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage examples and
60-
commonly asked questions.
61+
page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in
62+
the DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage
63+
examples and commonly asked questions.
6164
"""
6265
mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J2,
6366
J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J}
@@ -213,6 +216,7 @@ end
213216
make_kwarg(; kwargs...) = kwargs
214217

215218
function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpSet;
219+
vr_aggregator::VariableRateAggregator = VR_FRM(),
216220
save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ?
217221
(false, true) : (true, true),
218222
rng = DEFAULT_RNG, scale_rates = true, useiszero = true,
@@ -270,9 +274,9 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
270274

271275
# handle any remaining vrjs
272276
if length(cvrjs) > 0
273-
new_prob = extend_problem(prob, cvrjs; rng)
274-
variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng)
275-
cont_agg = cvrjs
277+
# Handle variable rate jumps based on vr_aggregator
278+
new_prob, variable_jump_callback, cont_agg = configure_jump_problem(prob,
279+
vr_aggregator, jumps, cvrjs; rng)
276280
else
277281
new_prob = prob
278282
variable_jump_callback = CallbackSet()
@@ -293,163 +297,6 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
293297
solkwargs)
294298
end
295299

296-
# extends prob.u0 to an ExtendedJumpArray with Njumps integrated intensity values,
297-
# of type prob.tspan
298-
function extend_u0(prob, Njumps, rng)
299-
ttype = eltype(prob.tspan)
300-
u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:Njumps])
301-
return u0
302-
end
303-
304-
function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps; rng = DEFAULT_RNG)
305-
error("General `VariableRateJump`s require a continuous problem, like an ODE/SDE/DDE/DAE problem. To use a `DiscreteProblem` bounded `VariableRateJump`s must be used. See the JumpProcesses docs.")
306-
end
307-
308-
function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAULT_RNG)
309-
_f = SciMLBase.unwrapped_f(prob.f)
310-
311-
if isinplace(prob)
312-
jump_f = let _f = _f
313-
function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t)
314-
_f(du.u, u.u, p, t)
315-
update_jumps!(du, u, p, t, length(u.u), jumps...)
316-
end
317-
end
318-
else
319-
jump_f = let _f = _f
320-
function (u::ExtendedJumpArray, p, t)
321-
du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u)
322-
update_jumps!(du, u, p, t, length(u.u), jumps...)
323-
return du
324-
end
325-
end
326-
end
327-
328-
u0 = extend_u0(prob, length(jumps), rng)
329-
f = ODEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys,
330-
observed = prob.f.observed)
331-
remake(prob; f, u0)
332-
end
333-
334-
function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAULT_RNG)
335-
_f = SciMLBase.unwrapped_f(prob.f)
336-
337-
if isinplace(prob)
338-
jump_f = let _f = _f
339-
function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t)
340-
_f(du.u, u.u, p, t)
341-
update_jumps!(du, u, p, t, length(u.u), jumps...)
342-
end
343-
end
344-
else
345-
jump_f = let _f = _f
346-
function (u::ExtendedJumpArray, p, t)
347-
du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u)
348-
update_jumps!(du, u, p, t, length(u.u), jumps...)
349-
return du
350-
end
351-
end
352-
end
353-
354-
if prob.noise_rate_prototype === nothing
355-
jump_g = function (du, u, p, t)
356-
prob.g(du.u, u.u, p, t)
357-
end
358-
else
359-
jump_g = function (du, u, p, t)
360-
prob.g(du, u.u, p, t)
361-
end
362-
end
363-
364-
u0 = extend_u0(prob, length(jumps), rng)
365-
f = SDEFunction{isinplace(prob)}(jump_f, jump_g; sys = prob.f.sys,
366-
observed = prob.f.observed)
367-
remake(prob; f, g = jump_g, u0)
368-
end
369-
370-
function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAULT_RNG)
371-
_f = SciMLBase.unwrapped_f(prob.f)
372-
373-
if isinplace(prob)
374-
jump_f = let _f = _f
375-
function (du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t)
376-
_f(du.u, u.u, h, p, t)
377-
update_jumps!(du, u, p, t, length(u.u), jumps...)
378-
end
379-
end
380-
else
381-
jump_f = let _f = _f
382-
function (u::ExtendedJumpArray, h, p, t)
383-
du = ExtendedJumpArray(_f(u.u, h, p, t), u.jump_u)
384-
update_jumps!(du, u, p, t, length(u.u), jumps...)
385-
return du
386-
end
387-
end
388-
end
389-
390-
u0 = extend_u0(prob, length(jumps), rng)
391-
f = DDEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys,
392-
observed = prob.f.observed)
393-
remake(prob; f, u0)
394-
end
395-
396-
# Not sure if the DAE one is correct: Should be a residual of sorts
397-
function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAULT_RNG)
398-
_f = SciMLBase.unwrapped_f(prob.f)
399-
400-
if isinplace(prob)
401-
jump_f = let _f = _f
402-
function (out, du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t)
403-
_f(out, du.u, u.u, h, p, t)
404-
update_jumps!(out, u, p, t, length(u.u), jumps...)
405-
end
406-
end
407-
else
408-
jump_f = let _f = _f
409-
function (du, u::ExtendedJumpArray, h, p, t)
410-
out = ExtendedJumpArray(_f(du.u, u.u, h, p, t), u.jump_u)
411-
update_jumps!(du, u, p, t, length(u.u), jumps...)
412-
return du
413-
end
414-
end
415-
end
416-
417-
u0 = extend_u0(prob, length(jumps), rng)
418-
f = DAEFunction{isinplace(prob)}(jump_f, sys = prob.f.sys,
419-
observed = prob.f.observed)
420-
remake(prob; f, u0)
421-
end
422-
423-
function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG)
424-
condition = function(u, t, integrator)
425-
u.jump_u[idx]
426-
end
427-
affect! = function(integrator)
428-
jump.affect!(integrator)
429-
integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t))
430-
nothing
431-
end
432-
new_cb = ContinuousCallback(condition, affect!;
433-
idxs = jump.idxs,
434-
rootfind = jump.rootfind,
435-
interp_points = jump.interp_points,
436-
save_positions = jump.save_positions,
437-
abstol = jump.abstol,
438-
reltol = jump.reltol)
439-
return new_cb
440-
end
441-
442-
function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG)
443-
idx += 1
444-
new_cb = wrap_jump_in_callback(idx, jump; rng)
445-
build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG)
446-
end
447-
448-
function build_variable_callback(cb, idx, jump; rng = DEFAULT_RNG)
449-
idx += 1
450-
CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng))
451-
end
452-
453300
aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A
454301

455302
@inline function extend_tstops!(tstops,
@@ -458,17 +305,6 @@ aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A
458305
push!(tstops, jp.jump_callback.discrete_callbacks[1].condition.next_jump_time)
459306
end
460307

461-
@inline function update_jumps!(du, u, p, t, idx, jump)
462-
idx += 1
463-
du[idx] = jump.rate(u.u, p, t)
464-
end
465-
466-
@inline function update_jumps!(du, u, p, t, idx, jump, jumps...)
467-
idx += 1
468-
du[idx] = jump.rate(u.u, p, t)
469-
update_jumps!(du, u, p, t, idx, jumps...)
470-
end
471-
472308
### Displays
473309
num_constant_rate_jumps(aggregator::AbstractSSAJumpAggregator) = length(aggregator.rates)
474310

src/solve.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ function resetted_jump_problem(_jump_prob, seed)
5656
end
5757
end
5858

59-
if !isempty(jump_prob.variable_jumps)
60-
@assert jump_prob.prob.u0 isa ExtendedJumpArray
59+
if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray
6160
randexp!(_jump_prob.rng, jump_prob.prob.u0.jump_u)
6261
jump_prob.prob.u0.jump_u .*= -1
6362
end
@@ -69,9 +68,8 @@ function reset_jump_problem!(jump_prob, seed)
6968
Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, seed)
7069
end
7170

72-
if !isempty(jump_prob.variable_jumps)
73-
@assert jump_prob.prob.u0 isa ExtendedJumpArray
71+
if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray
7472
randexp!(jump_prob.rng, jump_prob.prob.u0.jump_u)
7573
jump_prob.prob.u0.jump_u .*= -1
7674
end
77-
end
75+
end

0 commit comments

Comments
 (0)