1
- from typing import Optional
1
+ from typing import Optional , Union
2
2
3
3
import equinox as eqx
4
4
import jax .numpy as jnp
5
5
import optax
6
- from jaxtyping import PRNGKeyArray , PyTree
6
+ from jaxtyping import Array , Float , PRNGKeyArray , PyTree
7
7
from tqdm .autonotebook import tqdm
8
8
9
9
from ._mixer import PermutationMixer , TrajectorySubStacker
@@ -34,31 +34,35 @@ def __init__(
34
34
callback_fn : Optional [BaseCallback ] = None ,
35
35
):
36
36
"""
37
- Abstract training for an autoregressive neural emulator on a collection of
38
- trajectories.
39
-
40
- The length of (sub-)trajectories returned by `trajectory_sub_stacker` must
41
- match the requires length of reference for the used `loss_configuration`.
42
-
43
- Args:
44
- trajectory_sub_stacker (TrajectorySubStacker): A callable that takes a
45
- list of indices and returns a collection of (sub-)trajectories.
46
- loss_configuration (BaseConfiguration): A configuration that defines the
47
- loss function to be minimized.
48
- ref_stepper (eqx.Module, optional): A reference stepper that is used to
49
- compute the residuum. Supply this if the loss configuration requires
50
- a reference stepper. Defaults to None.
51
- residuum_fn (eqx.Module, optional): A residuum function that computes the
52
- discrete residuum between two consecutive states. Supply this if the
53
- loss configuration requires a residuum function. Defaults to None.
54
- optimizer (optax.GradientTransformation): An optimizer that updates the
55
- parameters of the stepper given the gradient.
56
- num_minibatches (int): The number of minibatches to train on. This equals
57
- the total number of update steps performed. The number of epochs is
58
- determined based on this and the `batch_size`.
59
- batch_size (int): The size of each batch.
60
- callback_fn (BaseCallback, optional): A callback function that is called
61
- at the end of each minibatch. Defaults to None.
37
+ Abstract training for an autoregressive neural emulator on a collection
38
+ of trajectories.
39
+
40
+ !!! info
41
+ The length of (sub-)trajectories returned by
42
+ `trajectory_sub_stacker` must match the required length of reference
43
+ for the used `loss_configuration`.
44
+
45
+ **Arguments:**
46
+
47
+ - `trajectory_sub_stacker`: A callable that takes a
48
+ list of indices and returns a collection of (sub-)trajectories.
49
+ - `loss_configuration`: A configuration that defines the
50
+ loss function to be minimized.
51
+ - `ref_stepper`: A reference stepper that is used to
52
+ compute the residuum. Supply this if the loss configuration requires
53
+ a reference stepper.
54
+ - `residuum_fn`: A residuum function that computes the
55
+ discrete residuum between two consecutive states. Supply this if the
56
+ loss configuration requires a residuum function. Defaults to None.
57
+ - `optimizer`: An optimizer that updates the
58
+ parameters of the stepper given the gradient.
59
+ - `num_minibatches`: The number of minibatches to train on. This equals
60
+ the total number of update steps performed. The number of epochs is
61
+ automatically determined based on this and the `batch_size`.
62
+ - `batch_size`: The size of each minibatch, i.e., how many samples are
63
+ included within.
64
+ - `callback_fn`: A callback function that is called
65
+ at the end of each minibatch. Defaults to None.
62
66
"""
63
67
self .trajectory_sub_stacker = trajectory_sub_stacker
64
68
self .loss_configuration = loss_configuration
@@ -75,6 +79,17 @@ def full_loss(
75
79
) -> float :
76
80
"""
77
81
Compute the loss on the entire dataset.
82
+
83
+ !!! warning
84
+ This can lead to out of memory errors if the dataset is too large.
85
+
86
+ **Arguments:**
87
+
88
+ - `stepper`: The stepper to compute the loss with.
89
+
90
+ **Returns:**
91
+
92
+ - The loss value.
78
93
"""
79
94
return self .loss_configuration (
80
95
stepper ,
@@ -87,19 +102,22 @@ def step_fn(
87
102
self ,
88
103
stepper : eqx .Module ,
89
104
opt_state : optax .OptState ,
90
- data : PyTree ,
105
+ data : PyTree [ float [ Array , "batch_size sub_trj_len ..." ]] ,
91
106
) -> tuple [eqx .Module , optax .OptState , float ]:
92
107
"""
93
108
Perform a single update step to the `stepper`'s parameters.
94
109
95
- Args:
96
- stepper (eqx.Module): The stepper to be updated.
97
- opt_state (optax.OptState): The optimizer state.
98
- data (PyTree): The data for the current minibatch.
110
+ **Arguments:**
111
+
112
+ - `stepper`: The equinox module to be updated.
113
+ - `opt_state`: The current optimizer state.
114
+ - `data`: The data for the current minibatch.
115
+
116
+ **Returns:**
99
117
100
- Returns:
101
- tuple[eqx.Module, optax.OptState, float]: The updated stepper, the
102
- updated optimizer state, and the loss value.
118
+ - The updated equinox module
119
+ - The updated optimizer state
120
+ - The loss value
103
121
"""
104
122
loss , grad = eqx .filter_value_and_grad (
105
123
lambda m : self .loss_configuration (
@@ -114,16 +132,20 @@ def __call__(
114
132
self ,
115
133
stepper : eqx .Module ,
116
134
key : PRNGKeyArray ,
135
+ opt_state : Optional [optax .OptState ] = None ,
117
136
* ,
118
137
return_loss_history : bool = True ,
119
138
record_loss_every : int = 1 ,
120
- ):
139
+ spawn_tqdm : bool = True ,
140
+ ) -> Union [
141
+ tuple [eqx .Module , Float [Array , "num_minibatches" ]],
142
+ eqx .Module ,
143
+ tuple [eqx .Module , Float [Array , "num_minibatches" ], list ],
144
+ tuple [eqx .Module , list ],
145
+ ]:
121
146
"""
122
- Perform the entire training of an autoregressive neural emulator
123
- `stepper`.
124
-
125
- This method spawns a `tqdm` progress meter showing the current update
126
- step and displaying the epoch with its respetive minibatch counter.
147
+ Perform the entire training of an autoregressive neural emulator given
148
+ in an initial state as `stepper`.
127
149
128
150
This method's return signature depends on the presence of a callback
129
151
function. If a callback function is provided, this function has at max
@@ -133,25 +155,32 @@ def __call__(
133
155
values of the callback function at each minibatch. If no callback
134
156
function is provided, this function has at max two return values. The
135
157
first return value is the trained stepper, and the second return value
136
- is the loss history.
158
+ is the loss history. If `return_loss_history` is set to `False`, the
159
+ loss history will not be returned.
160
+
161
+ **Arguments:**
162
+
163
+ - `stepper`: The equinox Module to be trained.
164
+ - `key`: The random key to be used for shuffling the minibatches.
165
+ - `opt_state`: The initial optimizer state. Defaults to None, meaning
166
+ the optimizer will be reinitialized.
167
+ - `return_loss_history`: Whether to return the loss history.
168
+ - `record_loss_every`: Record the loss every `record_loss_every`
169
+ minibatches. Defaults to 1, i.e., record every minibatch.
170
+ - `spawn_tqdm`: Whether to spawn the tqdm progress meter showing the
171
+ current update step and displaying the epoch with its respetive
172
+ minibatch counter.
137
173
138
- Args:
139
- stepper (eqx.Module): The stepper to be trained. key (PRNGKeyArray):
140
- The random key to be used for shuffling the
141
- minibatches.
142
- return_loss_history (bool, optional): Whether to return the loss
143
- history. Defaults to True.
144
- record_loss_every (int, optional): Record the loss every
145
- `record_loss_every` minibatches. Defaults to 1.
174
+ **Returns:**
146
175
147
- Returns:
148
- Varying, see above .
176
+ - Varying, see above. It will always return the trained stepper as the
177
+ first return value .
149
178
150
- Tipp:
179
+ !!! tip
151
180
You can use `equinox.filter_vmap` to train mulitple networks (of the
152
- same architecture) at the same time. For example, if your GPU
153
- is not fully utilized yet, this will give you a init-seed
154
- statistic basically for free.
181
+ same architecture) at the same time. For example, if your GPU is not
182
+ fully utilized yet, this will give you a init-seed statistic
183
+ basically for free.
155
184
"""
156
185
loss_history = []
157
186
if self .callback_fn is not None :
@@ -164,15 +193,17 @@ def __call__(
164
193
shuffle_key = key ,
165
194
)
166
195
167
- p_meter = tqdm (
168
- total = self .num_minibatches ,
169
- desc = f"E: { 0 :05d} , B: { 0 :05d} " ,
170
- )
196
+ if spawn_tqdm :
197
+ p_meter = tqdm (
198
+ total = self .num_minibatches ,
199
+ desc = f"E: { 0 :05d} , B: { 0 :05d} " ,
200
+ )
171
201
172
202
update_fn = eqx .filter_jit (self .step_fn )
173
203
174
204
trained_stepper = stepper
175
- opt_state = self .optimizer .init (eqx .filter (trained_stepper , eqx .is_array ))
205
+ if opt_state is None :
206
+ opt_state = self .optimizer .init (eqx .filter (trained_stepper , eqx .is_array ))
176
207
177
208
for update_i in range (self .num_minibatches ):
178
209
batch_indices , (expoch_id , batch_id ) = mixer (update_i , return_info = True )
@@ -185,13 +216,15 @@ def __call__(
185
216
)
186
217
if update_i % record_loss_every == 0 :
187
218
loss_history .append (loss )
188
- p_meter .update (1 )
219
+ if spawn_tqdm :
220
+ p_meter .update (1 )
189
221
190
- p_meter .set_description (
191
- f"E: { expoch_id :05d} , B: { batch_id :05d} " ,
192
- )
222
+ p_meter .set_description (
223
+ f"E: { expoch_id :05d} , B: { batch_id :05d} " ,
224
+ )
193
225
194
- p_meter .close ()
226
+ if spawn_tqdm :
227
+ p_meter .close ()
195
228
196
229
loss_history = jnp .array (loss_history )
197
230
0 commit comments