Skip to content

Commit 35b009f

Browse files
committed
Add thoughts on software design
1 parent a0bd404 commit 35b009f

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

README.md

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,39 @@ Additional axes are:
8484
* Cutting the backpropagation through time in the main chain (after each
8585
step, or sparse)
8686
* Cutting the diverted physics
87-
* Cutting the one or both levels of the inputs to a residuum function.
87+
* Cutting the one or both levels of the inputs to a residuum function.
88+
89+
### Implementation details
90+
91+
There are three levels of hierarchy:
92+
93+
1. The `loss` submodule defines time-level wise comparisons between two states.
94+
A state is either a tensor of shape `(num_channels, ...)` (with ellipsis
95+
indicating an arbitrary number of spatial dim,ensions) or a tensor of shape
96+
`(num_batches, num_channels, ...)`. The time level loss is implemented for
97+
the former but allows additional vectorized and (mean-)aggregated on the
98+
latter. (In the schematic above, the time-level loss is the green circle).
99+
2. The `configuration` submodule devises how neural time stepper $f_\theta$
100+
(denoted *NN* in the schematic) interplays with the numerical simulator
101+
$\mathcal{P}$. Similar to the time-level loss this is a callable PyTree which
102+
requires during calling the neural stepper and some data. What this data
103+
contains depends on the concrete configuration. For supervised rollout
104+
training it is the batch of (sub-) trajectories to be considered. Other
105+
configurations might also require the reference stepper or a two consecutive
106+
time level based residuum function. Each configuration is essentially an
107+
abstract implementation of the major methodologies (supervised,
108+
diverted-chain, mix-chain, residuum). The most general diverted chain
109+
implementation contains supervised and branch-one diverted chain as special
110+
cases. See the section "Relation between Diverted Chain and Residuum
111+
Training" for details how residuum training fits into the picture. All
112+
configurations allow setting additional constructor arguments to, e.g., cut
113+
the backpropagation through time (sparsely) or to supply time-level
114+
weightings (for example to exponentially discount contributions over long
115+
rollouts).
116+
3. The `training` submodule combines a configuration together with stochastic
117+
minibatching on a set of reference trajectories. For each configuration,
118+
there is a corresponding trainer that essentially is sugarcoating around
119+
combining the relevant configuration with the `GeneralTrainer` and a
120+
trajectory substacker.
121+
122+
### Relation between Diverted Chain and Residuum Training

0 commit comments

Comments
 (0)