@@ -29,11 +29,11 @@ The test cases are:
29
29
7. **Matrix ODE with Variable Rate Jump**: Solved with `Tsit5`.
30
30
8. **Complex ODE with Variable Rate Jump**: Solved with `Tsit5`.
31
31
32
- For visualization, we solve one trajectory per test case with 2 jumps (2x2 matrix for Test 7) . For benchmarking, we vary jumps from 1 to 20 (2x2 to 10x10 for Test 7) , running 100 trajectories per configuration.
32
+ For visualization, we solve one trajectory per test case with 2 jumps. For benchmarking, we vary jumps from 1 to 20, running 100 trajectories per configuration.
33
33
34
34
# Benchmark and Visualization Setup
35
35
36
- We define factories for each test case to create problems with a variable number of jumps (or matrix size for Test 7) .
36
+ We define factories for each test case to create problems with a variable number of jumps.
37
37
38
38
```julia
39
39
algorithms = Tuple{Any, Any, String, String}[
@@ -45,18 +45,14 @@ algorithms = Tuple{Any, Any, String, String}[
45
45
(VR_FRM(), Rosenbrock23(), "VR_FRM", "Test 1 Rosenbrock23 (autodiff, VR_FRM)"),
46
46
(VR_Direct(), SRIW1(), "VR_Direct", "Test 2 SRIW1 (VR_Direct)"),
47
47
(VR_FRM(), SRIW1(), "VR_FRM", "Test 2 SRIW1 (VR_FRM)"),
48
- (VR_Direct(), SRA1(), "VR_Direct", "Test 3 SRA1 (VR_Direct)"),
49
- (VR_FRM(), SRA1(), "VR_FRM", "Test 3 SRA1 (VR_FRM)"),
50
- (VR_Direct(), Tsit5(), "VR_Direct", "Test 4 Tsit5 (VR_Direct, ConstantRateJump)"),
51
- (VR_FRM(), Tsit5(), "VR_FRM", "Test 4 Tsit5 (VR_FRM, ConstantRateJump)"),
52
- (VR_Direct(), Tsit5(), "VR_Direct", "Test 5 Tsit5 (VR_Direct)"),
53
- (VR_FRM(), Tsit5(), "VR_FRM", "Test 5 Tsit5 (VR_FRM)"),
54
- (VR_Direct(), SRIW1(), "VR_Direct", "Test 6 SRIW1 (VR_Direct)"),
55
- (VR_FRM(), SRIW1(), "VR_FRM", "Test 6 SRIW1 (VR_FRM)"),
56
- (VR_Direct(), Tsit5(), "VR_Direct", "Test 7 Tsit5 (VR_Direct)"),
57
- (VR_FRM(), Tsit5(), "VR_FRM", "Test 7 Tsit5 (VR_FRM)"),
58
- (VR_Direct(), Tsit5(), "VR_Direct", "Test 8 Tsit5 (VR_Direct)"),
59
- (VR_FRM(), Tsit5(), "VR_FRM", "Test 8 Tsit5 (VR_FRM)"),
48
+ (VR_Direct(), Tsit5(), "VR_Direct", "Test 3 Tsit5 (VR_Direct, ConstantRateJump)"),
49
+ (VR_FRM(), Tsit5(), "VR_FRM", "Test 3 Tsit5 (VR_FRM, ConstantRateJump)"),
50
+ (VR_Direct(), Tsit5(), "VR_Direct", "Test 4 Tsit5 (VR_Direct)"),
51
+ (VR_FRM(), Tsit5(), "VR_FRM", "Test 4 Tsit5 (VR_FRM)"),
52
+ (VR_Direct(), SRIW1(), "VR_Direct", "Test 5 SRIW1 (VR_Direct)"),
53
+ (VR_FRM(), SRIW1(), "VR_FRM", "Test 5 SRIW1 (VR_FRM)"),
54
+ (VR_Direct(), Tsit5(), "VR_Direct", "Test 6 Tsit5 (VR_Direct)"),
55
+ (VR_FRM(), Tsit5(), "VR_FRM", "Test 6 Tsit5 (VR_FRM)"),
60
56
]
61
57
62
58
function create_test1_problem(num_jumps, vr_aggregator, solver)
@@ -79,19 +75,6 @@ function create_test2_problem(num_jumps, vr_aggregator, solver)
79
75
end
80
76
81
77
function create_test3_problem(num_jumps, vr_aggregator, solver)
82
- ff = (du, u, p, t) -> (du .= p == 0 ? 1.01u : 2.01u)
83
- gg = (du, u, p, t) -> begin
84
- du[1, 1] = 0.3u[1]; du[1, 2] = 0.6u[1]
85
- du[2, 1] = 1.2u[1]; du[2, 2] = 0.2u[2]
86
- end
87
- prob = SDEProblem(ff, gg, ones(2), (0.0, 1.0), 0, noise_rate_prototype=zeros(2, 2))
88
- jumps = [VariableRateJump((u, p, t) -> u[1] * 1.0, (integrator) -> (integrator.p = 1)) for _ in 1:num_jumps]
89
- jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
90
- ensemble_prob = EnsembleProblem(prob)
91
- return ensemble_prob, jump_prob
92
- end
93
-
94
- function create_test4_problem(num_jumps, vr_aggregator, solver)
95
78
f2 = (du, u, p, t) -> (du[1] = u[1])
96
79
prob = ODEProblem(f2, [0.2], (0.0, 10.0))
97
80
jumps = [ConstantRateJump((u, p, t) -> 2, (integrator) -> (integrator.u[1] = integrator.u[1] / 2)) for _ in 1:num_jumps]
@@ -100,7 +83,7 @@ function create_test4_problem(num_jumps, vr_aggregator, solver)
100
83
return ensemble_prob, jump_prob
101
84
end
102
85
103
- function create_test5_problem (num_jumps, vr_aggregator, solver)
86
+ function create_test4_problem (num_jumps, vr_aggregator, solver)
104
87
f2 = (du, u, p, t) -> (du[1] = u[1])
105
88
prob = ODEProblem(f2, [0.2], (0.0, 10.0))
106
89
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
@@ -109,7 +92,7 @@ function create_test5_problem(num_jumps, vr_aggregator, solver)
109
92
return ensemble_prob, jump_prob
110
93
end
111
94
112
- function create_test6_problem (num_jumps, vr_aggregator, solver)
95
+ function create_test5_problem (num_jumps, vr_aggregator, solver)
113
96
f2 = (du, u, p, t) -> (du[1] = u[1])
114
97
g2 = (du, u, p, t) -> (du[1] = u[1])
115
98
prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0))
@@ -119,19 +102,7 @@ function create_test6_problem(num_jumps, vr_aggregator, solver)
119
102
return ensemble_prob, jump_prob
120
103
end
121
104
122
- function create_test7_problem(num_jumps, vr_aggregator, solver, matrix_size=2)
123
- f3 = (du, u, p, t) -> (du .= u)
124
- u0 = ones(matrix_size, matrix_size)
125
- prob = ODEProblem(f3, u0, (0.0, 1.0))
126
- rate3 = (u, p, t) -> sum(u[1, :])
127
- affect3! = (integrator) -> (integrator.u .= range(0.25, 1.0, length=matrix_size^2))
128
- jumps = [VariableRateJump(rate3, affect3!) for _ in 1:num_jumps]
129
- jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
130
- ensemble_prob = EnsembleProblem(prob)
131
- return ensemble_prob, jump_prob
132
- end
133
-
134
- function create_test8_problem(num_jumps, vr_aggregator, solver)
105
+ function create_test6_problem(num_jumps, vr_aggregator, solver)
135
106
f4 = (dx, x, p, t) -> (dx[1] = x[1])
136
107
rate4 = (x, p, t) -> t
137
108
affect4! = (integrator) -> (integrator.u[1] = integrator.u[1] * 0.5)
@@ -145,19 +116,15 @@ end
145
116
146
117
# Solution Visualization
147
118
148
- We solve one trajectory for each test case with 2 jumps (2x2 matrix for Test 7) using `VR_Direct` and plot the state variables vs. time.
119
+ We solve one trajectory for each test case with 2 jumps using `VR_Direct` and plot the state variables vs. time.
149
120
150
121
```julia
151
122
let figs = []
152
- for test_num in 1:8
123
+ for test_num in 1:6
153
124
# Select a representative solver for each test
154
- algo, stepper = if test_num == 1
155
- VR_Direct(), Tsit5()
156
- elseif test_num == 2 || test_num == 6
125
+ algo, stepper = if test_num == 2 || test_num == 5
157
126
VR_Direct(), SRIW1()
158
- elseif test_num == 3
159
- VR_Direct(), SRA1()
160
- elseif test_num in [4, 5, 7, 8]
127
+ else
161
128
VR_Direct(), Tsit5()
162
129
end
163
130
label = "Test $test_num"
@@ -175,29 +142,16 @@ let figs = []
175
142
create_test5_problem(2, algo, stepper)
176
143
elseif test_num == 6
177
144
create_test6_problem(2, algo, stepper)
178
- elseif test_num == 7
179
- create_test7_problem(2, algo, stepper, 2)
180
- elseif test_num == 8
181
- create_test8_problem(2, algo, stepper)
182
145
end
183
-
184
- # Solve one trajectory
185
- solver_kwargs = test_num == 3 ? (dt=1.0,) : ()
186
- try
187
- sol = solve(jump_prob, stepper; saveat=0.01, solver_kwargs...)
188
146
147
+ try
148
+ sol = solve(jump_prob, stepper; saveat=0.01)
149
+
189
150
# Plot solution
190
151
fig = plot(title="Test $test_num: Solution Trajectory", xlabel="Time", ylabel="State")
191
- if test_num == 7
192
- # For matrix ODE, plot sum of elements
193
- plot!(sol.t, [sum(sol.u[i]) for i in 1:length(sol.u)], label="Sum of Matrix Elements")
194
- elseif test_num == 8
152
+ if test_num == 6
195
153
# For complex ODE, plot real part
196
154
plot!(sol.t, real.(sol[1,:]), label="Real Part")
197
- elseif test_num == 3
198
- # For 2D SDE, plot both components
199
- plot!(sol.t, sol[1,:], label="u[1]")
200
- plot!(sol.t, sol[2,:], label="u[2]")
201
155
else
202
156
# For scalar problems, plot state
203
157
plot!(sol.t, sol[1,:], label="u[1]")
@@ -213,11 +167,10 @@ end
213
167
214
168
# Benchmark Execution
215
169
216
- We benchmark each test case for 1 to 20 jumps (2x2 to 10x10 for Test 7) , running 100 trajectories. Errors are logged to diagnose failures.
170
+ We benchmark each test case for 1 to 20 jumps, running 100 trajectories. Errors are logged to diagnose failures.
217
171
218
172
```julia
219
173
num_jumps_range = append!([1], 5:5:20)
220
- matrix_sizes = [2, 4, 6, 8, 10]
221
174
bs = Vector{Vector{BenchmarkTools.Trial}}()
222
175
errors = Dict{String, Vector{String}}()
223
176
@@ -227,38 +180,36 @@ for (algo, stepper, agg_name, label) in algorithms
227
180
errors[label] = String[]
228
181
_bs = bs[end]
229
182
test_num = parse(Int, match(r"Test (\d+)", label).captures[1])
230
- is_matrix_test = test_num == 7
231
- range_var = is_matrix_test ? matrix_sizes : num_jumps_range
183
+ range_var = num_jumps_range
232
184
for (i, var) in enumerate(range_var)
233
185
if test_num == 1
234
- ensemble_prob, jump_prob = create_test1_problem(is_matrix_test ? 2 : var, algo, stepper)
186
+ ensemble_prob, jump_prob = create_test1_problem(var, algo, stepper)
235
187
elseif test_num == 2
236
- ensemble_prob, jump_prob = create_test2_problem(is_matrix_test ? 2 : var, algo, stepper)
188
+ ensemble_prob, jump_prob = create_test2_problem(var, algo, stepper)
237
189
elseif test_num == 3
238
- ensemble_prob, jump_prob = create_test3_problem(is_matrix_test ? 2 : var, algo, stepper)
190
+ ensemble_prob, jump_prob = create_test3_problem(var, algo, stepper)
239
191
elseif test_num == 4
240
- ensemble_prob, jump_prob = create_test4_problem(is_matrix_test ? 2 : var, algo, stepper)
192
+ ensemble_prob, jump_prob = create_test4_problem(var, algo, stepper)
241
193
elseif test_num == 5
242
- ensemble_prob, jump_prob = create_test5_problem(is_matrix_test ? 2 : var, algo, stepper)
194
+ ensemble_prob, jump_prob = create_test5_problem(var, algo, stepper)
243
195
elseif test_num == 6
244
- ensemble_prob, jump_prob = create_test6_problem(is_matrix_test ? 2 : var, algo, stepper)
245
- elseif test_num == 7
246
- ensemble_prob, jump_prob = create_test7_problem(2, algo, stepper, var)
247
- elseif test_num == 8
248
- ensemble_prob, jump_prob = create_test8_problem(is_matrix_test ? 2 : var, algo, stepper)
196
+ ensemble_prob, jump_prob = create_test6_problem(var, algo, stepper)
249
197
end
250
- solver_kwargs = test_num == 3 ? (dt=1.0,) : ""
251
198
trial = try
252
- @benchmark solve($ensemble_prob, $stepper, EnsembleSerial(), trajectories=100, jump_prob=$jump_prob; $solver_kwargs...) samples=50 evals=1 seconds=10
199
+ @benchmark(
200
+ solve($jump_prob, $stepper),
201
+ samples=50,
202
+ evals=1,
203
+ seconds=10
204
+ )
253
205
catch e
254
- push!(errors[label], "Error at $(is_matrix_test ? "Matrix Size" : " Num Jumps") = $var: $(sprint(showerror, e))")
206
+ push!(errors[label], "Error at Num Jumps = $var: $(sprint(showerror, e))")
255
207
BenchmarkTools.Trial(BenchmarkTools.Parameters(samples=50, evals=1, seconds=10))
256
208
end
257
209
push!(_bs, trial)
258
- if (var == 1 || var % (is_matrix_test ? 2 : 5) == 0)
259
- median_time = length(trial) > 0 ? "$(BenchmarkTools.prettytime(median(trial.times)))" : "nan"
260
- println("algo=$label, $(is_matrix_test ? "Matrix Size" : "Num Jumps") = $var, length = $(length(trial.times)), median time = $median_time")
261
- end
210
+
211
+ median_time = length(trial) > 0 ? "$(BenchmarkTools.prettytime(median(trial.times)))" : "nan"
212
+ println("algo=$label, Num Jumps = $var, length = $(length(trial.times)), median time = $median_time")
262
213
end
263
214
end
264
215
@@ -279,7 +230,7 @@ We plot the median execution times for each test case, comparing `VR_Direct` and
279
230
280
231
```julia
281
232
let figs = []
282
- for test_num in 1:8
233
+ for test_num in 1:6
283
234
test_algorithms = filter(a -> parse(Int, match(r"Test (\d+)", a[4]).captures[1]) == test_num, algorithms)
284
235
is_matrix_test = test_num == 7
285
236
range_var = is_matrix_test ? matrix_sizes : num_jumps_range
@@ -307,6 +258,6 @@ let figs = []
307
258
end
308
259
push!(figs, fig)
309
260
end
310
- plot(figs..., layout=(4, 2 ), format=fmt, size=(width_px, 4 *height_px/2))
261
+ plot(figs..., layout=(6, 1 ), format=fmt, size=(width_px, 8 *height_px/2))
311
262
end
312
263
```
0 commit comments