Skip to content

Commit ec88f7c

Browse files
authored
Lorenz extended (#4)
* Start writing an example of using more sophisticated learning configurations for the autoregressive Lorenzemulator * Change default dt to the value of Lorenz 63 * Advanced training methodologies using unrolling * Add example to docs * Add documentation
1 parent 88b6380 commit ec88f7c

File tree

5 files changed

+1487
-40
lines changed

5 files changed

+1487
-40
lines changed

docs/examples/advanced_lorenz_emulation.ipynb

Lines changed: 1395 additions & 0 deletions
Large diffs are not rendered by default.
Loading

docs/examples/lorenz_emulator.ipynb

Lines changed: 55 additions & 37 deletions
Large diffs are not rendered by default.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ nav:
8787
- Examples:
8888
- Introductory:
8989
- Lorenz Emulator: 'examples/lorenz_emulator.ipynb'
90+
- Advanced Lorenz Emulation: 'examples/advanced_lorenz_emulation.ipynb'
9091
- FOU one step learning is convex: 'examples/FOU_one_step_learning_is_convex.ipynb'
9192
- Configuration showcase: 'examples/configuration_showcase.ipynb'
9293
- Advanced:

trainax/_sample_data.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,42 @@ def _lorenz_rhs(
133133
return jnp.array([x_dot, y_dot, z_dot])
134134

135135

136+
def make_lorenz_stepper_rk4(
137+
dt: float = 0.01,
138+
*,
139+
sigma: float = 10.0,
140+
rho: float = 28.0,
141+
beta: float = 8.0 / 3.0,
142+
) -> Callable[[Float[Array, "3"]], Float[Array, "3"]]:
143+
r"""
144+
Produces a timestepper for the Lorenz system using a fixed-size Runge-Kutta
145+
4th order scheme.
146+
147+
**Arguments**:
148+
149+
- `dt`: The timestep size. Depending on the values of `sigma`, `rho`, and
150+
`beta`, the system might be hard to integrate. Usually, a time step
151+
$\Delta t \in [0.01, 0.1]$ is a good choice. The default is `0.01` which
152+
matches https://doi.org/10.1175/1520-0469(1963)020%3C0130:DNF%3E2.0.CO;2
153+
- `sigma`: The $\sigma$ parameter of the Lorenz system. The default is `10.0`.
154+
- `rho`: The $\rho$ parameter of the Lorenz system. The default is `28.0`.
155+
- `beta`: The $\beta$ parameter of the Lorenz system. The default is `8.0/3.0`.
156+
157+
**Returns**:
158+
159+
- A function that takes a state vector of shape `(3,)` and returns the next
160+
state vector of shape `(3,)`.
161+
"""
162+
lorenz_rhs_params_fixed = lambda u: _lorenz_rhs(u, sigma=sigma, rho=rho, beta=beta)
163+
lorenz_stepper = lambda u: _step_rk4(lorenz_rhs_params_fixed, u, dt=dt)
164+
return lorenz_stepper
165+
166+
136167
def lorenz_rk4(
137168
num_samples: int = 20,
138169
*,
139170
temporal_horizon: int = 1000,
140-
dt: float = 0.05,
171+
dt: float = 0.01,
141172
num_warmup_steps: int = 500,
142173
sigma: float = 10.0,
143174
rho: float = 28.0,
@@ -184,8 +215,10 @@ def lorenz_rk4(
184215

185216
u_0_set = jax.random.normal(key, shape=(num_samples, 3)) * init_std
186217

187-
lorenz_rhs_params_fixed = lambda u: _lorenz_rhs(u, sigma=sigma, rho=rho, beta=beta)
188-
lorenz_stepper = lambda u: _step_rk4(lorenz_rhs_params_fixed, u, dt=dt)
218+
# lorenz_rhs_params_fixed = lambda u: _lorenz_rhs(u, sigma=sigma, rho=rho, beta=beta)
219+
# lorenz_stepper = lambda u: _step_rk4(lorenz_rhs_params_fixed, u, dt=dt)
220+
221+
lorenz_stepper = make_lorenz_stepper_rk4(dt=dt, sigma=sigma, rho=rho, beta=beta)
189222

190223
def scan_fn(u, _):
191224
u_next = lorenz_stepper(u)

0 commit comments

Comments
 (0)