Skip to content

Commit eb91bbe

Browse files
committed
Add docs
1 parent eadccdb commit eb91bbe

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed

trainax/trainer/_diverted_chain_branch_one.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,76 @@ def __init__(
3131
] = None,
3232
do_sub_stacking: bool = True,
3333
):
34+
"""
35+
Diverted chain (rollout) configuration with branch length fixed to one.
36+
37+
Essentially, this amounts to a one-step difference to a reference
38+
(create on the fly by the differentiable `ref_stepper`). Falls back to
39+
classical one-step supervised training for `num_rollout_steps=1`
40+
(default).
41+
42+
Args:
43+
data_trajectories (PyTree[Float[Array, "num_samples trj_len ..."]]):
44+
The batch of trajectories to slice. This must be a PyTree of
45+
Arrays who have at least two leading axes: a batch-axis and a
46+
time axis. For example, the zeroth axis can be associated with
47+
multiple initial conditions or constitutive parameters and the
48+
first axis represents all temporal snapshots. A PyTree can also
49+
just be an array. You can provide additional leafs in the
50+
PyTree, e.g., for the corresponding constitutive parameters etc.
51+
Make sure that the emulator has the corresponding signature.
52+
ref_stepper (eqx.Module): The reference stepper to use for the
53+
diverted chain. This is called on-the-fly. (keyword-only
54+
argument)
55+
residuum_fn (eqx.Module): For compatibility with other
56+
configurations; not used. (keyword-only argument)
57+
optimizer (optax.GradientTransformation): The optimizer to use for
58+
training. For example, this can be `optax.adam(LEARNING_RATE)`.
59+
Also use this to supply an optimizer with learning rate decay,
60+
for example `optax.adam(optax.exponential_decay(...))`. If your
61+
learning rate decay is designed for a certain number of update
62+
steps, make sure that it aligns with `num_training_steps`.
63+
(keyword-only argument)
64+
callback_fn (BaseCallback, optional): A callback to use during
65+
training. Defaults to None. (keyword-only argument)
66+
num_training_steps (int): The number of training steps to perform.
67+
(keyword-only argument)
68+
batch_size (int): The batch size to use for training. Batches are
69+
randomly sampled across both multiple trajectories, but also over
70+
different windows within one trajectory. (keyword-only)
71+
num_rollout_steps (int): The number of time steps to
72+
autoregressively roll out the model. Defaults to 1. (keyword-only
73+
argument)
74+
time_level_loss (BaseLoss): The loss function to use at
75+
each time step. Defaults to MSELoss(). (keyword-only argument)
76+
cut_bptt (bool): Whether to cut the backpropagation through time
77+
(BPTT), i.e., insert a `jax.lax.stop_gradient` into the
78+
autoregressive network main chain. Defaults to False.
79+
(keyword-only argument)
80+
cut_bptt_every (int): The frequency at which to cut the BPTT.
81+
Only relevant if `cut_bptt` is True. Defaults to 1 (meaning
82+
after each step). (keyword-only argument)
83+
cut_div_chain (bool): Whether to cut the diverted chain, i.e.,
84+
insert a `jax.lax.stop_gradient` to not have cotangents flow
85+
over the `ref_stepper`. In this case, the `ref_stepper` does not
86+
have to be differentiable. Defaults to False. (keyword-only
87+
argument)
88+
time_level_weights (array[float], optional): An array of length
89+
`num_rollout_steps` that contains the weights for each time
90+
step. Defaults to None, which means that all time steps have the
91+
same weight (=1.0). (keyword-only argument)
92+
93+
94+
Info:
95+
* The `ref_stepper` is called on-the-fly. If its forward (and vjp)
96+
execution are expensive, this will dominate the computational
97+
cost of this configuration.
98+
* The usage of the `ref_stepper` includes the first branch starting
99+
from the initial condition. Hence, no reference trajectory is
100+
required.
101+
* Under reverse-mode automatic differentiation memory usage grows
102+
linearly with `num_rollout_steps`.
103+
"""
34104
trajectory_sub_stacker = TrajectorySubStacker(
35105
data_trajectories,
36106
sub_trajectory_len=num_rollout_steps + 1, # +1 for the IC

0 commit comments

Comments
 (0)