@@ -31,6 +31,76 @@ def __init__(
31
31
] = None ,
32
32
do_sub_stacking : bool = True ,
33
33
):
34
+ """
35
+ Diverted chain (rollout) configuration with branch length fixed to one.
36
+
37
+ Essentially, this amounts to a one-step difference to a reference
38
+ (create on the fly by the differentiable `ref_stepper`). Falls back to
39
+ classical one-step supervised training for `num_rollout_steps=1`
40
+ (default).
41
+
42
+ Args:
43
+ data_trajectories (PyTree[Float[Array, "num_samples trj_len ..."]]):
44
+ The batch of trajectories to slice. This must be a PyTree of
45
+ Arrays who have at least two leading axes: a batch-axis and a
46
+ time axis. For example, the zeroth axis can be associated with
47
+ multiple initial conditions or constitutive parameters and the
48
+ first axis represents all temporal snapshots. A PyTree can also
49
+ just be an array. You can provide additional leafs in the
50
+ PyTree, e.g., for the corresponding constitutive parameters etc.
51
+ Make sure that the emulator has the corresponding signature.
52
+ ref_stepper (eqx.Module): The reference stepper to use for the
53
+ diverted chain. This is called on-the-fly. (keyword-only
54
+ argument)
55
+ residuum_fn (eqx.Module): For compatibility with other
56
+ configurations; not used. (keyword-only argument)
57
+ optimizer (optax.GradientTransformation): The optimizer to use for
58
+ training. For example, this can be `optax.adam(LEARNING_RATE)`.
59
+ Also use this to supply an optimizer with learning rate decay,
60
+ for example `optax.adam(optax.exponential_decay(...))`. If your
61
+ learning rate decay is designed for a certain number of update
62
+ steps, make sure that it aligns with `num_training_steps`.
63
+ (keyword-only argument)
64
+ callback_fn (BaseCallback, optional): A callback to use during
65
+ training. Defaults to None. (keyword-only argument)
66
+ num_training_steps (int): The number of training steps to perform.
67
+ (keyword-only argument)
68
+ batch_size (int): The batch size to use for training. Batches are
69
+ randomly sampled across both multiple trajectories, but also over
70
+ different windows within one trajectory. (keyword-only)
71
+ num_rollout_steps (int): The number of time steps to
72
+ autoregressively roll out the model. Defaults to 1. (keyword-only
73
+ argument)
74
+ time_level_loss (BaseLoss): The loss function to use at
75
+ each time step. Defaults to MSELoss(). (keyword-only argument)
76
+ cut_bptt (bool): Whether to cut the backpropagation through time
77
+ (BPTT), i.e., insert a `jax.lax.stop_gradient` into the
78
+ autoregressive network main chain. Defaults to False.
79
+ (keyword-only argument)
80
+ cut_bptt_every (int): The frequency at which to cut the BPTT.
81
+ Only relevant if `cut_bptt` is True. Defaults to 1 (meaning
82
+ after each step). (keyword-only argument)
83
+ cut_div_chain (bool): Whether to cut the diverted chain, i.e.,
84
+ insert a `jax.lax.stop_gradient` to not have cotangents flow
85
+ over the `ref_stepper`. In this case, the `ref_stepper` does not
86
+ have to be differentiable. Defaults to False. (keyword-only
87
+ argument)
88
+ time_level_weights (array[float], optional): An array of length
89
+ `num_rollout_steps` that contains the weights for each time
90
+ step. Defaults to None, which means that all time steps have the
91
+ same weight (=1.0). (keyword-only argument)
92
+
93
+
94
+ Info:
95
+ * The `ref_stepper` is called on-the-fly. If its forward (and vjp)
96
+ execution are expensive, this will dominate the computational
97
+ cost of this configuration.
98
+ * The usage of the `ref_stepper` includes the first branch starting
99
+ from the initial condition. Hence, no reference trajectory is
100
+ required.
101
+ * Under reverse-mode automatic differentiation memory usage grows
102
+ linearly with `num_rollout_steps`.
103
+ """
34
104
trajectory_sub_stacker = TrajectorySubStacker (
35
105
data_trajectories ,
36
106
sub_trajectory_len = num_rollout_steps + 1 , # +1 for the IC
0 commit comments