@@ -133,11 +133,42 @@ def _lorenz_rhs(
133
133
return jnp .array ([x_dot , y_dot , z_dot ])
134
134
135
135
136
+ def make_lorenz_stepper_rk4 (
137
+ dt : float = 0.01 ,
138
+ * ,
139
+ sigma : float = 10.0 ,
140
+ rho : float = 28.0 ,
141
+ beta : float = 8.0 / 3.0 ,
142
+ ) -> Callable [[Float [Array , "3" ]], Float [Array , "3" ]]:
143
+ r"""
144
+ Produces a timestepper for the Lorenz system using a fixed-size Runge-Kutta
145
+ 4th order scheme.
146
+
147
+ **Arguments**:
148
+
149
+ - `dt`: The timestep size. Depending on the values of `sigma`, `rho`, and
150
+ `beta`, the system might be hard to integrate. Usually, a time step
151
+ $\Delta t \in [0.01, 0.1]$ is a good choice. The default is `0.01` which
152
+ matches https://doi.org/10.1175/1520-0469(1963)020%3C0130:DNF%3E2.0.CO;2
153
+ - `sigma`: The $\sigma$ parameter of the Lorenz system. The default is `10.0`.
154
+ - `rho`: The $\rho$ parameter of the Lorenz system. The default is `28.0`.
155
+ - `beta`: The $\beta$ parameter of the Lorenz system. The default is `8.0/3.0`.
156
+
157
+ **Returns**:
158
+
159
+ - A function that takes a state vector of shape `(3,)` and returns the next
160
+ state vector of shape `(3,)`.
161
+ """
162
+ lorenz_rhs_params_fixed = lambda u : _lorenz_rhs (u , sigma = sigma , rho = rho , beta = beta )
163
+ lorenz_stepper = lambda u : _step_rk4 (lorenz_rhs_params_fixed , u , dt = dt )
164
+ return lorenz_stepper
165
+
166
+
136
167
def lorenz_rk4 (
137
168
num_samples : int = 20 ,
138
169
* ,
139
170
temporal_horizon : int = 1000 ,
140
- dt : float = 0.05 ,
171
+ dt : float = 0.01 ,
141
172
num_warmup_steps : int = 500 ,
142
173
sigma : float = 10.0 ,
143
174
rho : float = 28.0 ,
@@ -184,8 +215,10 @@ def lorenz_rk4(
184
215
185
216
u_0_set = jax .random .normal (key , shape = (num_samples , 3 )) * init_std
186
217
187
- lorenz_rhs_params_fixed = lambda u : _lorenz_rhs (u , sigma = sigma , rho = rho , beta = beta )
188
- lorenz_stepper = lambda u : _step_rk4 (lorenz_rhs_params_fixed , u , dt = dt )
218
+ # lorenz_rhs_params_fixed = lambda u: _lorenz_rhs(u, sigma=sigma, rho=rho, beta=beta)
219
+ # lorenz_stepper = lambda u: _step_rk4(lorenz_rhs_params_fixed, u, dt=dt)
220
+
221
+ lorenz_stepper = make_lorenz_stepper_rk4 (dt = dt , sigma = sigma , rho = rho , beta = beta )
189
222
190
223
def scan_fn (u , _ ):
191
224
u_next = lorenz_stepper (u )
0 commit comments