Skip to content

feat: Introducing vr_aggregator with VRDirectCB and VRFRMODE #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
33a255d
trying to change extend_problem for odeprobelems
sivasathyaseeelan Feb 15, 2025
049a54a
trying to change extend_problem for odeprobelems
sivasathyaseeelan Feb 15, 2025
12db876
added DiffEqCallbacks
sivasathyaseeelan Feb 15, 2025
f060dbf
added for sde
sivasathyaseeelan Feb 16, 2025
3d30414
added for sde
sivasathyaseeelan Feb 16, 2025
0b13cd1
done mostly
sivasathyaseeelan Feb 16, 2025
c21ffd6
removed extend_problem
sivasathyaseeelan Feb 19, 2025
01efc5f
removed extend_problem
sivasathyaseeelan Feb 19, 2025
a7d99dd
added callback for variableratejump
sivasathyaseeelan Feb 19, 2025
5ec5e50
added callback for variableratejump
sivasathyaseeelan Feb 19, 2025
d10bfca
resetted
sivasathyaseeelan Feb 20, 2025
a6bcf94
resetted
sivasathyaseeelan Feb 20, 2025
36ec0c1
added callable cache
sivasathyaseeelan Feb 21, 2025
9f3b3b6
added callable cache
sivasathyaseeelan Feb 21, 2025
8de786b
refactored
sivasathyaseeelan Mar 10, 2025
53ca71f
removed tests of previous implementation
sivasathyaseeelan Mar 12, 2025
6abee57
Project.toml fixed
sivasathyaseeelan Mar 15, 2025
37bdb97
added test_broken
sivasathyaseeelan Mar 15, 2025
955b2a8
added test_broken
sivasathyaseeelan Mar 15, 2025
0ec4886
broken tests are seperated
sivasathyaseeelan Mar 15, 2025
58e5b6a
added variablerate_aggregator
sivasathyaseeelan Mar 15, 2025
77ac240
added variablerate_aggregator
sivasathyaseeelan Mar 15, 2025
3400dbe
added variablerate_aggregator
sivasathyaseeelan Mar 15, 2025
90bddf7
added variablerate_aggregator
sivasathyaseeelan Mar 15, 2025
39b98be
using a callable type
sivasathyaseeelan Mar 16, 2025
c49496d
added performance test
sivasathyaseeelan Mar 16, 2025
290d1c2
added benchmark
sivasathyaseeelan Mar 18, 2025
9925057
added benchmark
sivasathyaseeelan Mar 18, 2025
d190e69
added performance tests
sivasathyaseeelan Mar 19, 2025
a7cf216
some changes
sivasathyaseeelan Mar 19, 2025
28d5d08
added a test
sivasathyaseeelan Mar 19, 2025
f206d58
added a test
sivasathyaseeelan Mar 19, 2025
d185b0f
some changes
sivasathyaseeelan Mar 19, 2025
98ce7be
refactor phase 1 as per review
sivasathyaseeelan Mar 22, 2025
cbc3129
added DiffEqCallbacks in compat entry
sivasathyaseeelan Mar 22, 2025
b5198cc
typo fix
sivasathyaseeelan Mar 22, 2025
9218cc2
refactor phase 2
sivasathyaseeelan Mar 24, 2025
92ee7c9
refactor resolve reviews
sivasathyaseeelan Mar 27, 2025
1fbfc7c
some changes
sivasathyaseeelan Mar 27, 2025
6eb1aca
some changes
sivasathyaseeelan Mar 27, 2025
954c728
some changes
sivasathyaseeelan Mar 27, 2025
c74d853
some changes
sivasathyaseeelan Mar 27, 2025
40957d2
some changes
sivasathyaseeelan Mar 27, 2025
4b3ff00
some changes
sivasathyaseeelan Mar 27, 2025
f1c3fac
added functor for VRDirectCBEventCache
sivasathyaseeelan Mar 27, 2025
4927545
n_sims set to 1000
sivasathyaseeelan Mar 27, 2025
a3c7161
Project.toml
sivasathyaseeelan Mar 27, 2025
23ed262
some test changes
sivasathyaseeelan Mar 27, 2025
d0abaa1
benchmark updated
sivasathyaseeelan Mar 27, 2025
7fa65a4
integcallback fix
sivasathyaseeelan Mar 30, 2025
8b0a6aa
Update src/variable_rate.jl
ChrisRackauckas Apr 21, 2025
b6b8bee
Update src/variable_rate.jl
sivasathyaseeelan Apr 30, 2025
da5e1a8
Update src/variable_rate.jl
sivasathyaseeelan Apr 30, 2025
372a18f
bug fixed
sivasathyaseeelan May 2, 2025
f32b867
Update src/variable_rate.jl
sivasathyaseeelan May 3, 2025
5463eac
jump count test added
sivasathyaseeelan May 15, 2025
45c5d4b
Merge branch 'SciML:master' into removing-exrendedjumparray
sivasathyaseeelan May 15, 2025
0045e2a
thread safety test added
sivasathyaseeelan May 16, 2025
d8a6116
added hawkes test
sivasathyaseeelan May 16, 2025
3cb9cc3
docstring added
sivasathyaseeelan May 16, 2025
e532dc7
Update src/variable_rate.jl
sivasathyaseeelan May 16, 2025
8c886c9
Update src/variable_rate.jl
sivasathyaseeelan May 16, 2025
801c658
cache refactor
sivasathyaseeelan May 16, 2025
2025e76
made all concrete
sivasathyaseeelan May 16, 2025
a905426
Delete benchmarks/variable_rate.jl
ChrisRackauckas May 20, 2025
8a8d5e8
Update src/problem.jl
ChrisRackauckas May 20, 2025
68b77f7
updates
isaacsas May 22, 2025
15e6d21
fix docstrings
isaacsas May 22, 2025
9ecef11
refactor variable_rate.jl
isaacsas May 22, 2025
b321102
more updates
isaacsas May 22, 2025
07d8114
refactor
isaacsas May 22, 2025
7e84f93
rename methods
isaacsas May 22, 2025
adefba6
Update src/variable_rate.jl
sivasathyaseeelan May 22, 2025
3cfea37
initialization and single callback
sivasathyaseeelan May 22, 2025
3bc7a34
cleaned some tests
sivasathyaseeelan May 22, 2025
7192aa7
Update test/hawkes_test.jl
sivasathyaseeelan May 22, 2025
3236a26
cleaned a test
sivasathyaseeelan May 22, 2025
6126983
cleaned a test
sivasathyaseeelan May 22, 2025
86d394f
cleaned a test:
sivasathyaseeelan May 22, 2025
458c8eb
Project.toml fix
sivasathyaseeelan May 22, 2025
2924d5d
test clean
sivasathyaseeelan May 22, 2025
e527a41
cleaned
sivasathyaseeelan May 22, 2025
91bde9f
affect cleaned
sivasathyaseeelan May 22, 2025
1663e88
test fix
sivasathyaseeelan May 23, 2025
bfa450d
bug fix
sivasathyaseeelan May 23, 2025
aefda35
update variable rate implementation
isaacsas May 23, 2025
7e30e01
updates
isaacsas May 23, 2025
2c4417a
update monte carlo test
isaacsas May 23, 2025
66d72c5
don't assume float64
isaacsas May 23, 2025
87e23a6
use uniform solution indexing
isaacsas May 23, 2025
2a9dd27
don't save in initialization
isaacsas May 23, 2025
b0d990b
fix save_positions test
isaacsas May 23, 2025
a4e84e5
fix
isaacsas May 23, 2025
f0f97a4
remove redundant test
isaacsas May 23, 2025
280d4bb
fix thread_safety test
isaacsas May 23, 2025
3116de2
tweak name
isaacsas May 23, 2025
3fd9d43
fix variable rate tests
isaacsas May 23, 2025
c574873
fixes
isaacsas May 23, 2025
7177bbf
more fixes
isaacsas May 23, 2025
2c7bb20
more fixes
isaacsas May 23, 2025
6df9c3a
variable rate updates
isaacsas May 23, 2025
6f09943
update
isaacsas May 23, 2025
9bc7de1
update Hawkes test
isaacsas May 23, 2025
a41a475
fix gene expr test indexing
isaacsas May 23, 2025
1bd8a9a
Update variable_rate.jl
ChrisRackauckas May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "9.14.3"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand All @@ -30,6 +31,7 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ArrayInterface = "7.9"
DataStructures = "0.18"
DiffEqBase = "6.154"
DiffEqCallbacks = "4.3.0"
DocStringExtensions = "0.9"
FastBroadcast = "0.3"
FunctionWrappers = "1.1"
Expand All @@ -46,7 +48,6 @@ UnPack = "1.0.2"
julia = "1.10"

[extras]
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Expand All @@ -58,5 +59,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DiffEqCallbacks", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq",
"SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"]
test = ["LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test", "FastBroadcast"]
7 changes: 7 additions & 0 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using SciMLBase: SciMLBase, isdenseplot
using Base.FastMath: add_fast
using Setfield: @set, @set!

import DiffEqCallbacks: gauss_points, gauss_weights
import DiffEqBase: DiscreteCallback, init, solve, solve!, plot_indices, initialize!,
get_tstops, get_tstops_array, get_tstops_max
import Base: size, getindex, setindex!, length, similar, show, merge!, merge
Expand All @@ -21,6 +22,8 @@ import RecursiveArrayTools: recursivecopy!
using StaticArrays, Base.Threads
import SymbolicIndexingInterface as SII

import Random: AbstractRNG

abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
abstract type AbstractAggregatorAlgorithm end
Expand Down Expand Up @@ -70,6 +73,7 @@ include("spatial/directcrdirect.jl")
include("aggregators/aggregated_api.jl")

include("extended_jump_array.jl")
include("variable_rate.jl")
include("problem.jl")
include("solve.jl")
include("coupled_array.jl")
Expand Down Expand Up @@ -98,6 +102,9 @@ export reset_aggregated_jumps!

export ExtendedJumpArray

# Export VariableRateAggregator types
export VariableRateAggregator, VR_FRM, VR_Direct

# spatial structs and functions
export CartesianGrid, CartesianGridRej
export SpatialMassActionJump
Expand Down
202 changes: 19 additions & 183 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ $(TYPEDEF)

Defines a collection of jump processes to associate with another problem type.
- [Documentation Page](https://docs.sciml.ai/JumpProcesses/stable/jump_types/)
- [Tutorial Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/)
- [FAQ Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/#FAQ)
- [Tutorial
Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/)
- [FAQ
Page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/#FAQ)

### Constructors

Expand Down Expand Up @@ -44,20 +46,21 @@ then be passed within a single [`JumpSet`](@ref) or as subsequent sequential arg
$(FIELDS)

## Keyword Arguments
- `rng`, the random number generator to use. Defaults to Julia's built-in
generator.
- `save_positions=(true,true)`, specifies whether to save the system's state (before, after)
the jump occurs.
- `rng`, the random number generator to use. Defaults to Julia's built-in generator.
- `save_positions=(true,true)` when including variable rates and `(false,true)` for constant
rates, specifies whether to save the system's state (before, after) the jump occurs.
- `spatial_system`, for spatial problems the underlying spatial structure.
- `hopping_constants`, for spatial problems the spatial transition rate coefficients.
- `use_vrj_bounds = true`, set to false to disable handling bounded `VariableRateJump`s
with a supporting aggregator (such as `Coevolve`). They will then be handled via the
continuous integration interface, and treated like general `VariableRateJump`s.
- `use_vrj_bounds = true`, set to false to disable handling bounded `VariableRateJump`s with
a supporting aggregator (such as `Coevolve`). They will then be handled via the continuous
integration interface, and treated like general `VariableRateJump`s.
- `vr_aggregator`, indicates the aggregator to use for sampling variable rate jumps. Current
default is `VR_FRM`.

Please see the [tutorial
page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in the
DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage examples and
commonly asked questions.
page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in
the DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage
examples and commonly asked questions.
"""
mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J2,
J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J}
Expand Down Expand Up @@ -213,6 +216,7 @@ end
make_kwarg(; kwargs...) = kwargs

function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpSet;
vr_aggregator::VariableRateAggregator = VR_FRM(),
save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ?
(false, true) : (true, true),
rng = DEFAULT_RNG, scale_rates = true, useiszero = true,
Expand Down Expand Up @@ -270,9 +274,9 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS

# handle any remaining vrjs
if length(cvrjs) > 0
new_prob = extend_problem(prob, cvrjs; rng)
variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng)
cont_agg = cvrjs
# Handle variable rate jumps based on vr_aggregator
new_prob, variable_jump_callback, cont_agg = configure_jump_problem(prob,
vr_aggregator, jumps, cvrjs; rng)
else
new_prob = prob
variable_jump_callback = CallbackSet()
Expand All @@ -293,163 +297,6 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS
solkwargs)
end

# extends prob.u0 to an ExtendedJumpArray with Njumps integrated intensity values,
# of type prob.tspan
function extend_u0(prob, Njumps, rng)
ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:Njumps])
return u0
end

function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps; rng = DEFAULT_RNG)
error("General `VariableRateJump`s require a continuous problem, like an ODE/SDE/DDE/DAE problem. To use a `DiscreteProblem` bounded `VariableRateJump`s must be used. See the JumpProcesses docs.")
end

function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAULT_RNG)
_f = SciMLBase.unwrapped_f(prob.f)

if isinplace(prob)
jump_f = let _f = _f
function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t)
_f(du.u, u.u, p, t)
update_jumps!(du, u, p, t, length(u.u), jumps...)
end
end
else
jump_f = let _f = _f
function (u::ExtendedJumpArray, p, t)
du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u)
update_jumps!(du, u, p, t, length(u.u), jumps...)
return du
end
end
end

u0 = extend_u0(prob, length(jumps), rng)
f = ODEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys,
observed = prob.f.observed)
remake(prob; f, u0)
end

function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAULT_RNG)
_f = SciMLBase.unwrapped_f(prob.f)

if isinplace(prob)
jump_f = let _f = _f
function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t)
_f(du.u, u.u, p, t)
update_jumps!(du, u, p, t, length(u.u), jumps...)
end
end
else
jump_f = let _f = _f
function (u::ExtendedJumpArray, p, t)
du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u)
update_jumps!(du, u, p, t, length(u.u), jumps...)
return du
end
end
end

if prob.noise_rate_prototype === nothing
jump_g = function (du, u, p, t)
prob.g(du.u, u.u, p, t)
end
else
jump_g = function (du, u, p, t)
prob.g(du, u.u, p, t)
end
end

u0 = extend_u0(prob, length(jumps), rng)
f = SDEFunction{isinplace(prob)}(jump_f, jump_g; sys = prob.f.sys,
observed = prob.f.observed)
remake(prob; f, g = jump_g, u0)
end

function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAULT_RNG)
_f = SciMLBase.unwrapped_f(prob.f)

if isinplace(prob)
jump_f = let _f = _f
function (du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t)
_f(du.u, u.u, h, p, t)
update_jumps!(du, u, p, t, length(u.u), jumps...)
end
end
else
jump_f = let _f = _f
function (u::ExtendedJumpArray, h, p, t)
du = ExtendedJumpArray(_f(u.u, h, p, t), u.jump_u)
update_jumps!(du, u, p, t, length(u.u), jumps...)
return du
end
end
end

u0 = extend_u0(prob, length(jumps), rng)
f = DDEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys,
observed = prob.f.observed)
remake(prob; f, u0)
end

# Not sure if the DAE one is correct: Should be a residual of sorts
function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAULT_RNG)
_f = SciMLBase.unwrapped_f(prob.f)

if isinplace(prob)
jump_f = let _f = _f
function (out, du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t)
_f(out, du.u, u.u, h, p, t)
update_jumps!(out, u, p, t, length(u.u), jumps...)
end
end
else
jump_f = let _f = _f
function (du, u::ExtendedJumpArray, h, p, t)
out = ExtendedJumpArray(_f(du.u, u.u, h, p, t), u.jump_u)
update_jumps!(du, u, p, t, length(u.u), jumps...)
return du
end
end
end

u0 = extend_u0(prob, length(jumps), rng)
f = DAEFunction{isinplace(prob)}(jump_f, sys = prob.f.sys,
observed = prob.f.observed)
remake(prob; f, u0)
end

function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG)
condition = function(u, t, integrator)
u.jump_u[idx]
end
affect! = function(integrator)
jump.affect!(integrator)
integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t))
nothing
end
new_cb = ContinuousCallback(condition, affect!;
idxs = jump.idxs,
rootfind = jump.rootfind,
interp_points = jump.interp_points,
save_positions = jump.save_positions,
abstol = jump.abstol,
reltol = jump.reltol)
return new_cb
end

function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG)
idx += 1
new_cb = wrap_jump_in_callback(idx, jump; rng)
build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG)
end

function build_variable_callback(cb, idx, jump; rng = DEFAULT_RNG)
idx += 1
CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng))
end

aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A

@inline function extend_tstops!(tstops,
Expand All @@ -458,17 +305,6 @@ aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A
push!(tstops, jp.jump_callback.discrete_callbacks[1].condition.next_jump_time)
end

@inline function update_jumps!(du, u, p, t, idx, jump)
idx += 1
du[idx] = jump.rate(u.u, p, t)
end

@inline function update_jumps!(du, u, p, t, idx, jump, jumps...)
idx += 1
du[idx] = jump.rate(u.u, p, t)
update_jumps!(du, u, p, t, idx, jumps...)
end

### Displays
num_constant_rate_jumps(aggregator::AbstractSSAJumpAggregator) = length(aggregator.rates)

Expand Down
8 changes: 3 additions & 5 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ function resetted_jump_problem(_jump_prob, seed)
end
end

if !isempty(jump_prob.variable_jumps)
@assert jump_prob.prob.u0 isa ExtendedJumpArray
if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray
randexp!(_jump_prob.rng, jump_prob.prob.u0.jump_u)
jump_prob.prob.u0.jump_u .*= -1
end
Expand All @@ -69,9 +68,8 @@ function reset_jump_problem!(jump_prob, seed)
Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, seed)
end

if !isempty(jump_prob.variable_jumps)
@assert jump_prob.prob.u0 isa ExtendedJumpArray
if !isempty(jump_prob.variable_jumps) && jump_prob.prob.u0 isa ExtendedJumpArray
randexp!(jump_prob.rng, jump_prob.prob.u0.jump_u)
jump_prob.prob.u0.jump_u .*= -1
end
end
end
Loading
Loading