Skip to content

Commit 3750a8a

Browse files
some changes
1 parent 69a7a5f commit 3750a8a

File tree

1 file changed

+41
-90
lines changed

1 file changed

+41
-90
lines changed

benchmarks/Jumps/VR_Aggregator_Benchmark.jmd

Lines changed: 41 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ The test cases are:
2929
7. **Matrix ODE with Variable Rate Jump**: Solved with `Tsit5`.
3030
8. **Complex ODE with Variable Rate Jump**: Solved with `Tsit5`.
3131

32-
For visualization, we solve one trajectory per test case with 2 jumps (2x2 matrix for Test 7). For benchmarking, we vary jumps from 1 to 20 (2x2 to 10x10 for Test 7), running 100 trajectories per configuration.
32+
For visualization, we solve one trajectory per test case with 2 jumps. For benchmarking, we vary jumps from 1 to 20, running 100 trajectories per configuration.
3333

3434
# Benchmark and Visualization Setup
3535

36-
We define factories for each test case to create problems with a variable number of jumps (or matrix size for Test 7).
36+
We define factories for each test case to create problems with a variable number of jumps.
3737

3838
```julia
3939
algorithms = Tuple{Any, Any, String, String}[
@@ -45,18 +45,14 @@ algorithms = Tuple{Any, Any, String, String}[
4545
(VR_FRM(), Rosenbrock23(), "VR_FRM", "Test 1 Rosenbrock23 (autodiff, VR_FRM)"),
4646
(VR_Direct(), SRIW1(), "VR_Direct", "Test 2 SRIW1 (VR_Direct)"),
4747
(VR_FRM(), SRIW1(), "VR_FRM", "Test 2 SRIW1 (VR_FRM)"),
48-
(VR_Direct(), SRA1(), "VR_Direct", "Test 3 SRA1 (VR_Direct)"),
49-
(VR_FRM(), SRA1(), "VR_FRM", "Test 3 SRA1 (VR_FRM)"),
50-
(VR_Direct(), Tsit5(), "VR_Direct", "Test 4 Tsit5 (VR_Direct, ConstantRateJump)"),
51-
(VR_FRM(), Tsit5(), "VR_FRM", "Test 4 Tsit5 (VR_FRM, ConstantRateJump)"),
52-
(VR_Direct(), Tsit5(), "VR_Direct", "Test 5 Tsit5 (VR_Direct)"),
53-
(VR_FRM(), Tsit5(), "VR_FRM", "Test 5 Tsit5 (VR_FRM)"),
54-
(VR_Direct(), SRIW1(), "VR_Direct", "Test 6 SRIW1 (VR_Direct)"),
55-
(VR_FRM(), SRIW1(), "VR_FRM", "Test 6 SRIW1 (VR_FRM)"),
56-
(VR_Direct(), Tsit5(), "VR_Direct", "Test 7 Tsit5 (VR_Direct)"),
57-
(VR_FRM(), Tsit5(), "VR_FRM", "Test 7 Tsit5 (VR_FRM)"),
58-
(VR_Direct(), Tsit5(), "VR_Direct", "Test 8 Tsit5 (VR_Direct)"),
59-
(VR_FRM(), Tsit5(), "VR_FRM", "Test 8 Tsit5 (VR_FRM)"),
48+
(VR_Direct(), Tsit5(), "VR_Direct", "Test 3 Tsit5 (VR_Direct, ConstantRateJump)"),
49+
(VR_FRM(), Tsit5(), "VR_FRM", "Test 3 Tsit5 (VR_FRM, ConstantRateJump)"),
50+
(VR_Direct(), Tsit5(), "VR_Direct", "Test 4 Tsit5 (VR_Direct)"),
51+
(VR_FRM(), Tsit5(), "VR_FRM", "Test 4 Tsit5 (VR_FRM)"),
52+
(VR_Direct(), SRIW1(), "VR_Direct", "Test 5 SRIW1 (VR_Direct)"),
53+
(VR_FRM(), SRIW1(), "VR_FRM", "Test 5 SRIW1 (VR_FRM)"),
54+
(VR_Direct(), Tsit5(), "VR_Direct", "Test 6 Tsit5 (VR_Direct)"),
55+
(VR_FRM(), Tsit5(), "VR_FRM", "Test 6 Tsit5 (VR_FRM)"),
6056
]
6157

6258
function create_test1_problem(num_jumps, vr_aggregator, solver)
@@ -79,19 +75,6 @@ function create_test2_problem(num_jumps, vr_aggregator, solver)
7975
end
8076

8177
function create_test3_problem(num_jumps, vr_aggregator, solver)
82-
ff = (du, u, p, t) -> (du .= p == 0 ? 1.01u : 2.01u)
83-
gg = (du, u, p, t) -> begin
84-
du[1, 1] = 0.3u[1]; du[1, 2] = 0.6u[1]
85-
du[2, 1] = 1.2u[1]; du[2, 2] = 0.2u[2]
86-
end
87-
prob = SDEProblem(ff, gg, ones(2), (0.0, 1.0), 0, noise_rate_prototype=zeros(2, 2))
88-
jumps = [VariableRateJump((u, p, t) -> u[1] * 1.0, (integrator) -> (integrator.p = 1)) for _ in 1:num_jumps]
89-
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
90-
ensemble_prob = EnsembleProblem(prob)
91-
return ensemble_prob, jump_prob
92-
end
93-
94-
function create_test4_problem(num_jumps, vr_aggregator, solver)
9578
f2 = (du, u, p, t) -> (du[1] = u[1])
9679
prob = ODEProblem(f2, [0.2], (0.0, 10.0))
9780
jumps = [ConstantRateJump((u, p, t) -> 2, (integrator) -> (integrator.u[1] = integrator.u[1] / 2)) for _ in 1:num_jumps]
@@ -100,7 +83,7 @@ function create_test4_problem(num_jumps, vr_aggregator, solver)
10083
return ensemble_prob, jump_prob
10184
end
10285

103-
function create_test5_problem(num_jumps, vr_aggregator, solver)
86+
function create_test4_problem(num_jumps, vr_aggregator, solver)
10487
f2 = (du, u, p, t) -> (du[1] = u[1])
10588
prob = ODEProblem(f2, [0.2], (0.0, 10.0))
10689
jumps = [VariableRateJump((u, p, t) -> u[1], (integrator) -> (integrator.u[1] = integrator.u[1] / 2); interp_points=100) for _ in 1:num_jumps]
@@ -109,7 +92,7 @@ function create_test5_problem(num_jumps, vr_aggregator, solver)
10992
return ensemble_prob, jump_prob
11093
end
11194

112-
function create_test6_problem(num_jumps, vr_aggregator, solver)
95+
function create_test5_problem(num_jumps, vr_aggregator, solver)
11396
f2 = (du, u, p, t) -> (du[1] = u[1])
11497
g2 = (du, u, p, t) -> (du[1] = u[1])
11598
prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0))
@@ -119,19 +102,7 @@ function create_test6_problem(num_jumps, vr_aggregator, solver)
119102
return ensemble_prob, jump_prob
120103
end
121104

122-
function create_test7_problem(num_jumps, vr_aggregator, solver, matrix_size=2)
123-
f3 = (du, u, p, t) -> (du .= u)
124-
u0 = ones(matrix_size, matrix_size)
125-
prob = ODEProblem(f3, u0, (0.0, 1.0))
126-
rate3 = (u, p, t) -> sum(u[1, :])
127-
affect3! = (integrator) -> (integrator.u .= range(0.25, 1.0, length=matrix_size^2))
128-
jumps = [VariableRateJump(rate3, affect3!) for _ in 1:num_jumps]
129-
jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator=vr_aggregator, rng=rng)
130-
ensemble_prob = EnsembleProblem(prob)
131-
return ensemble_prob, jump_prob
132-
end
133-
134-
function create_test8_problem(num_jumps, vr_aggregator, solver)
105+
function create_test6_problem(num_jumps, vr_aggregator, solver)
135106
f4 = (dx, x, p, t) -> (dx[1] = x[1])
136107
rate4 = (x, p, t) -> t
137108
affect4! = (integrator) -> (integrator.u[1] = integrator.u[1] * 0.5)
@@ -145,19 +116,15 @@ end
145116

146117
# Solution Visualization
147118

148-
We solve one trajectory for each test case with 2 jumps (2x2 matrix for Test 7) using `VR_Direct` and plot the state variables vs. time.
119+
We solve one trajectory for each test case with 2 jumps using `VR_Direct` and plot the state variables vs. time.
149120

150121
```julia
151122
let figs = []
152-
for test_num in 1:8
123+
for test_num in 1:6
153124
# Select a representative solver for each test
154-
algo, stepper = if test_num == 1
155-
VR_Direct(), Tsit5()
156-
elseif test_num == 2 || test_num == 6
125+
algo, stepper = if test_num == 2 || test_num == 5
157126
VR_Direct(), SRIW1()
158-
elseif test_num == 3
159-
VR_Direct(), SRA1()
160-
elseif test_num in [4, 5, 7, 8]
127+
else
161128
VR_Direct(), Tsit5()
162129
end
163130
label = "Test $test_num"
@@ -175,29 +142,16 @@ let figs = []
175142
create_test5_problem(2, algo, stepper)
176143
elseif test_num == 6
177144
create_test6_problem(2, algo, stepper)
178-
elseif test_num == 7
179-
create_test7_problem(2, algo, stepper, 2)
180-
elseif test_num == 8
181-
create_test8_problem(2, algo, stepper)
182145
end
183-
184-
# Solve one trajectory
185-
solver_kwargs = test_num == 3 ? (dt=1.0,) : ()
186-
try
187-
sol = solve(jump_prob, stepper; saveat=0.01, solver_kwargs...)
188146

147+
try
148+
sol = solve(jump_prob, stepper; saveat=0.01)
149+
189150
# Plot solution
190151
fig = plot(title="Test $test_num: Solution Trajectory", xlabel="Time", ylabel="State")
191-
if test_num == 7
192-
# For matrix ODE, plot sum of elements
193-
plot!(sol.t, [sum(sol.u[i]) for i in 1:length(sol.u)], label="Sum of Matrix Elements")
194-
elseif test_num == 8
152+
if test_num == 6
195153
# For complex ODE, plot real part
196154
plot!(sol.t, real.(sol[1,:]), label="Real Part")
197-
elseif test_num == 3
198-
# For 2D SDE, plot both components
199-
plot!(sol.t, sol[1,:], label="u[1]")
200-
plot!(sol.t, sol[2,:], label="u[2]")
201155
else
202156
# For scalar problems, plot state
203157
plot!(sol.t, sol[1,:], label="u[1]")
@@ -213,11 +167,10 @@ end
213167

214168
# Benchmark Execution
215169

216-
We benchmark each test case for 1 to 20 jumps (2x2 to 10x10 for Test 7), running 100 trajectories. Errors are logged to diagnose failures.
170+
We benchmark each test case for 1 to 20 jumps, running 100 trajectories. Errors are logged to diagnose failures.
217171

218172
```julia
219173
num_jumps_range = append!([1], 5:5:20)
220-
matrix_sizes = [2, 4, 6, 8, 10]
221174
bs = Vector{Vector{BenchmarkTools.Trial}}()
222175
errors = Dict{String, Vector{String}}()
223176

@@ -227,38 +180,36 @@ for (algo, stepper, agg_name, label) in algorithms
227180
errors[label] = String[]
228181
_bs = bs[end]
229182
test_num = parse(Int, match(r"Test (\d+)", label).captures[1])
230-
is_matrix_test = test_num == 7
231-
range_var = is_matrix_test ? matrix_sizes : num_jumps_range
183+
range_var = num_jumps_range
232184
for (i, var) in enumerate(range_var)
233185
if test_num == 1
234-
ensemble_prob, jump_prob = create_test1_problem(is_matrix_test ? 2 : var, algo, stepper)
186+
ensemble_prob, jump_prob = create_test1_problem(var, algo, stepper)
235187
elseif test_num == 2
236-
ensemble_prob, jump_prob = create_test2_problem(is_matrix_test ? 2 : var, algo, stepper)
188+
ensemble_prob, jump_prob = create_test2_problem(var, algo, stepper)
237189
elseif test_num == 3
238-
ensemble_prob, jump_prob = create_test3_problem(is_matrix_test ? 2 : var, algo, stepper)
190+
ensemble_prob, jump_prob = create_test3_problem(var, algo, stepper)
239191
elseif test_num == 4
240-
ensemble_prob, jump_prob = create_test4_problem(is_matrix_test ? 2 : var, algo, stepper)
192+
ensemble_prob, jump_prob = create_test4_problem(var, algo, stepper)
241193
elseif test_num == 5
242-
ensemble_prob, jump_prob = create_test5_problem(is_matrix_test ? 2 : var, algo, stepper)
194+
ensemble_prob, jump_prob = create_test5_problem(var, algo, stepper)
243195
elseif test_num == 6
244-
ensemble_prob, jump_prob = create_test6_problem(is_matrix_test ? 2 : var, algo, stepper)
245-
elseif test_num == 7
246-
ensemble_prob, jump_prob = create_test7_problem(2, algo, stepper, var)
247-
elseif test_num == 8
248-
ensemble_prob, jump_prob = create_test8_problem(is_matrix_test ? 2 : var, algo, stepper)
196+
ensemble_prob, jump_prob = create_test6_problem(var, algo, stepper)
249197
end
250-
solver_kwargs = test_num == 3 ? (dt=1.0,) : ""
251198
trial = try
252-
@benchmark solve($ensemble_prob, $stepper, EnsembleSerial(), trajectories=100, jump_prob=$jump_prob; $solver_kwargs...) samples=50 evals=1 seconds=10
199+
@benchmark(
200+
solve($jump_prob, $stepper),
201+
samples=50,
202+
evals=1,
203+
seconds=10
204+
)
253205
catch e
254-
push!(errors[label], "Error at $(is_matrix_test ? "Matrix Size" : "Num Jumps") = $var: $(sprint(showerror, e))")
206+
push!(errors[label], "Error at Num Jumps = $var: $(sprint(showerror, e))")
255207
BenchmarkTools.Trial(BenchmarkTools.Parameters(samples=50, evals=1, seconds=10))
256208
end
257209
push!(_bs, trial)
258-
if (var == 1 || var % (is_matrix_test ? 2 : 5) == 0)
259-
median_time = length(trial) > 0 ? "$(BenchmarkTools.prettytime(median(trial.times)))" : "nan"
260-
println("algo=$label, $(is_matrix_test ? "Matrix Size" : "Num Jumps") = $var, length = $(length(trial.times)), median time = $median_time")
261-
end
210+
211+
median_time = length(trial) > 0 ? "$(BenchmarkTools.prettytime(median(trial.times)))" : "nan"
212+
println("algo=$label, Num Jumps = $var, length = $(length(trial.times)), median time = $median_time")
262213
end
263214
end
264215

@@ -279,7 +230,7 @@ We plot the median execution times for each test case, comparing `VR_Direct` and
279230

280231
```julia
281232
let figs = []
282-
for test_num in 1:8
233+
for test_num in 1:6
283234
test_algorithms = filter(a -> parse(Int, match(r"Test (\d+)", a[4]).captures[1]) == test_num, algorithms)
284235
is_matrix_test = test_num == 7
285236
range_var = is_matrix_test ? matrix_sizes : num_jumps_range
@@ -307,6 +258,6 @@ let figs = []
307258
end
308259
push!(figs, fig)
309260
end
310-
plot(figs..., layout=(4, 2), format=fmt, size=(width_px, 4*height_px/2))
261+
plot(figs..., layout=(6, 1), format=fmt, size=(width_px, 8*height_px/2))
311262
end
312263
```

0 commit comments

Comments
 (0)