Skip to content

Commit 1c4612e

Browse files
chualanagitAlan Chupre-commit-ci[bot]lantigaAlan Chu
authored
Add doc for TBPTT (#20422)
* Add doc for TBPTT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove url to prevent linting error * attempt to fix linter * add tbptt.rst file * adjust doc: * nit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make example easily copy and runnable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address comments * fix doc test warning * Update docs/source-pytorch/common/tbptt.rst --------- Co-authored-by: Alan Chu <alanchu@Alans-Air.lan> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com> Co-authored-by: Alan Chu <alanchu@Alans-Air.Home> Co-authored-by: Luca Antiga <luca@lightning.ai>
1 parent ca59e4e commit 1c4612e

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

docs/source-pytorch/common/index.rst

+7
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ How-to Guides
202202
:col_css: col-md-4
203203
:height: 180
204204

205+
.. displayitem::
206+
:header: Truncated Back-Propagation Through Time
207+
:description: Efficiently step through time when training recurrent models
208+
:button_link: ../common/tbtt.html
209+
:col_css: col-md-4
210+
:height: 180
211+
205212
.. raw:: html
206213

207214
</div>

docs/source-pytorch/common/tbptt.rst

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
##############################################
2+
Truncated Backpropagation Through Time (TBPTT)
3+
##############################################
4+
5+
Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of
6+
a much longer sequence. This is made possible by passing training batches
7+
split along the time-dimensions into splits of size k to the
8+
``training_step``. In order to keep the same forward propagation behavior, all
9+
hidden states should be kept in-between each time-dimension split.
10+
11+
12+
.. code-block:: python
13+
14+
import torch
15+
import torch.optim as optim
16+
import pytorch_lightning as pl
17+
from pytorch_lightning import LightningModule
18+
19+
class LitModel(LightningModule):
20+
21+
def __init__(self):
22+
super().__init__()
23+
24+
# 1. Switch to manual optimization
25+
self.automatic_optimization = False
26+
27+
self.truncated_bptt_steps = 10
28+
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
29+
30+
# 2. Remove the `hiddens` argument
31+
def training_step(self, batch, batch_idx):
32+
33+
# 3. Split the batch in chunks along the time dimension
34+
split_batches = split_batch(batch, self.truncated_bptt_steps)
35+
36+
batch_size = 10
37+
hidden_dim = 20
38+
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
39+
for split_batch in range(split_batches):
40+
# 4. Perform the optimization in a loop
41+
loss, hiddens = self.my_rnn(split_batch, hiddens)
42+
self.backward(loss)
43+
self.optimizer.step()
44+
self.optimizer.zero_grad()
45+
46+
# 5. "Truncate"
47+
hiddens = hiddens.detach()
48+
49+
# 6. Remove the return of `hiddens`
50+
# Returning loss in manual optimization is not needed
51+
return None
52+
53+
def configure_optimizers(self):
54+
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
55+
56+
if __name__ == "__main__":
57+
model = LitModel()
58+
trainer = pl.Trainer(max_epochs=5)
59+
trainer.fit(model, train_dataloader) # Define your own dataloader

0 commit comments

Comments
 (0)