Skip to content

Commit c12b4d3

Browse files
authored
Markdown rework (#2)
* Add markdown based documentation * Show all of base loss * Add constructor doc * Update base config to markdown * Format line length * Rework supervised configuration doc * Translate residuum config to markdown based syntax * translate leftover configuration to new docstrings * Translate supervised trainer * Update the other two convinience trainers * Update documentation to mixer components * Update docs of general trainer * Display additional methods in documentation * Allow supplying optstate * Allow option to deactivate the tqdm progress meter * Add documentation to all callbacks
1 parent 9c9c363 commit c12b4d3

22 files changed

+736
-539
lines changed

docs/api/general_trainer.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
members:
66
- __init__
77
- __call__
8+
- full_loss
9+
- step_fn

docs/api/loss.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,4 @@
3232

3333
---
3434

35-
::: trainax.loss.BaseLoss
36-
options:
37-
members:
38-
- __init__
39-
- __call__
35+
::: trainax.loss.BaseLoss

trainax/_general_trainer.py

Lines changed: 99 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22

33
import equinox as eqx
44
import jax.numpy as jnp
55
import optax
6-
from jaxtyping import PRNGKeyArray, PyTree
6+
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
77
from tqdm.autonotebook import tqdm
88

99
from ._mixer import PermutationMixer, TrajectorySubStacker
@@ -34,31 +34,35 @@ def __init__(
3434
callback_fn: Optional[BaseCallback] = None,
3535
):
3636
"""
37-
Abstract training for an autoregressive neural emulator on a collection of
38-
trajectories.
39-
40-
The length of (sub-)trajectories returned by `trajectory_sub_stacker` must
41-
match the requires length of reference for the used `loss_configuration`.
42-
43-
Args:
44-
trajectory_sub_stacker (TrajectorySubStacker): A callable that takes a
45-
list of indices and returns a collection of (sub-)trajectories.
46-
loss_configuration (BaseConfiguration): A configuration that defines the
47-
loss function to be minimized.
48-
ref_stepper (eqx.Module, optional): A reference stepper that is used to
49-
compute the residuum. Supply this if the loss configuration requires
50-
a reference stepper. Defaults to None.
51-
residuum_fn (eqx.Module, optional): A residuum function that computes the
52-
discrete residuum between two consecutive states. Supply this if the
53-
loss configuration requires a residuum function. Defaults to None.
54-
optimizer (optax.GradientTransformation): An optimizer that updates the
55-
parameters of the stepper given the gradient.
56-
num_minibatches (int): The number of minibatches to train on. This equals
57-
the total number of update steps performed. The number of epochs is
58-
determined based on this and the `batch_size`.
59-
batch_size (int): The size of each batch.
60-
callback_fn (BaseCallback, optional): A callback function that is called
61-
at the end of each minibatch. Defaults to None.
37+
Abstract training for an autoregressive neural emulator on a collection
38+
of trajectories.
39+
40+
!!! info
41+
The length of (sub-)trajectories returned by
42+
`trajectory_sub_stacker` must match the required length of reference
43+
for the used `loss_configuration`.
44+
45+
**Arguments:**
46+
47+
- `trajectory_sub_stacker`: A callable that takes a
48+
list of indices and returns a collection of (sub-)trajectories.
49+
- `loss_configuration`: A configuration that defines the
50+
loss function to be minimized.
51+
- `ref_stepper`: A reference stepper that is used to
52+
compute the residuum. Supply this if the loss configuration requires
53+
a reference stepper.
54+
- `residuum_fn`: A residuum function that computes the
55+
discrete residuum between two consecutive states. Supply this if the
56+
loss configuration requires a residuum function. Defaults to None.
57+
- `optimizer`: An optimizer that updates the
58+
parameters of the stepper given the gradient.
59+
- `num_minibatches`: The number of minibatches to train on. This equals
60+
the total number of update steps performed. The number of epochs is
61+
automatically determined based on this and the `batch_size`.
62+
- `batch_size`: The size of each minibatch, i.e., how many samples are
63+
included within.
64+
- `callback_fn`: A callback function that is called
65+
at the end of each minibatch. Defaults to None.
6266
"""
6367
self.trajectory_sub_stacker = trajectory_sub_stacker
6468
self.loss_configuration = loss_configuration
@@ -75,6 +79,17 @@ def full_loss(
7579
) -> float:
7680
"""
7781
Compute the loss on the entire dataset.
82+
83+
!!! warning
84+
This can lead to out of memory errors if the dataset is too large.
85+
86+
**Arguments:**
87+
88+
- `stepper`: The stepper to compute the loss with.
89+
90+
**Returns:**
91+
92+
- The loss value.
7893
"""
7994
return self.loss_configuration(
8095
stepper,
@@ -87,19 +102,22 @@ def step_fn(
87102
self,
88103
stepper: eqx.Module,
89104
opt_state: optax.OptState,
90-
data: PyTree,
105+
data: PyTree[float[Array, "batch_size sub_trj_len ..."]],
91106
) -> tuple[eqx.Module, optax.OptState, float]:
92107
"""
93108
Perform a single update step to the `stepper`'s parameters.
94109
95-
Args:
96-
stepper (eqx.Module): The stepper to be updated.
97-
opt_state (optax.OptState): The optimizer state.
98-
data (PyTree): The data for the current minibatch.
110+
**Arguments:**
111+
112+
- `stepper`: The equinox module to be updated.
113+
- `opt_state`: The current optimizer state.
114+
- `data`: The data for the current minibatch.
115+
116+
**Returns:**
99117
100-
Returns:
101-
tuple[eqx.Module, optax.OptState, float]: The updated stepper, the
102-
updated optimizer state, and the loss value.
118+
- The updated equinox module
119+
- The updated optimizer state
120+
- The loss value
103121
"""
104122
loss, grad = eqx.filter_value_and_grad(
105123
lambda m: self.loss_configuration(
@@ -114,16 +132,20 @@ def __call__(
114132
self,
115133
stepper: eqx.Module,
116134
key: PRNGKeyArray,
135+
opt_state: Optional[optax.OptState] = None,
117136
*,
118137
return_loss_history: bool = True,
119138
record_loss_every: int = 1,
120-
):
139+
spawn_tqdm: bool = True,
140+
) -> Union[
141+
tuple[eqx.Module, Float[Array, "num_minibatches"]],
142+
eqx.Module,
143+
tuple[eqx.Module, Float[Array, "num_minibatches"], list],
144+
tuple[eqx.Module, list],
145+
]:
121146
"""
122-
Perform the entire training of an autoregressive neural emulator
123-
`stepper`.
124-
125-
This method spawns a `tqdm` progress meter showing the current update
126-
step and displaying the epoch with its respetive minibatch counter.
147+
Perform the entire training of an autoregressive neural emulator given
148+
in an initial state as `stepper`.
127149
128150
This method's return signature depends on the presence of a callback
129151
function. If a callback function is provided, this function has at max
@@ -133,25 +155,32 @@ def __call__(
133155
values of the callback function at each minibatch. If no callback
134156
function is provided, this function has at max two return values. The
135157
first return value is the trained stepper, and the second return value
136-
is the loss history.
158+
is the loss history. If `return_loss_history` is set to `False`, the
159+
loss history will not be returned.
160+
161+
**Arguments:**
162+
163+
- `stepper`: The equinox Module to be trained.
164+
- `key`: The random key to be used for shuffling the minibatches.
165+
- `opt_state`: The initial optimizer state. Defaults to None, meaning
166+
the optimizer will be reinitialized.
167+
- `return_loss_history`: Whether to return the loss history.
168+
- `record_loss_every`: Record the loss every `record_loss_every`
169+
minibatches. Defaults to 1, i.e., record every minibatch.
170+
- `spawn_tqdm`: Whether to spawn the tqdm progress meter showing the
171+
current update step and displaying the epoch with its respetive
172+
minibatch counter.
137173
138-
Args:
139-
stepper (eqx.Module): The stepper to be trained. key (PRNGKeyArray):
140-
The random key to be used for shuffling the
141-
minibatches.
142-
return_loss_history (bool, optional): Whether to return the loss
143-
history. Defaults to True.
144-
record_loss_every (int, optional): Record the loss every
145-
`record_loss_every` minibatches. Defaults to 1.
174+
**Returns:**
146175
147-
Returns:
148-
Varying, see above.
176+
- Varying, see above. It will always return the trained stepper as the
177+
first return value.
149178
150-
Tipp:
179+
!!! tip
151180
You can use `equinox.filter_vmap` to train mulitple networks (of the
152-
same architecture) at the same time. For example, if your GPU
153-
is not fully utilized yet, this will give you a init-seed
154-
statistic basically for free.
181+
same architecture) at the same time. For example, if your GPU is not
182+
fully utilized yet, this will give you a init-seed statistic
183+
basically for free.
155184
"""
156185
loss_history = []
157186
if self.callback_fn is not None:
@@ -164,15 +193,17 @@ def __call__(
164193
shuffle_key=key,
165194
)
166195

167-
p_meter = tqdm(
168-
total=self.num_minibatches,
169-
desc=f"E: {0:05d}, B: {0:05d}",
170-
)
196+
if spawn_tqdm:
197+
p_meter = tqdm(
198+
total=self.num_minibatches,
199+
desc=f"E: {0:05d}, B: {0:05d}",
200+
)
171201

172202
update_fn = eqx.filter_jit(self.step_fn)
173203

174204
trained_stepper = stepper
175-
opt_state = self.optimizer.init(eqx.filter(trained_stepper, eqx.is_array))
205+
if opt_state is None:
206+
opt_state = self.optimizer.init(eqx.filter(trained_stepper, eqx.is_array))
176207

177208
for update_i in range(self.num_minibatches):
178209
batch_indices, (expoch_id, batch_id) = mixer(update_i, return_info=True)
@@ -185,13 +216,15 @@ def __call__(
185216
)
186217
if update_i % record_loss_every == 0:
187218
loss_history.append(loss)
188-
p_meter.update(1)
219+
if spawn_tqdm:
220+
p_meter.update(1)
189221

190-
p_meter.set_description(
191-
f"E: {expoch_id:05d}, B: {batch_id:05d}",
192-
)
222+
p_meter.set_description(
223+
f"E: {expoch_id:05d}, B: {batch_id:05d}",
224+
)
193225

194-
p_meter.close()
226+
if spawn_tqdm:
227+
p_meter.close()
195228

196229
loss_history = jnp.array(loss_history)
197230

0 commit comments

Comments
 (0)