Skip to content

Commit 76f0c54

Browse files
authored
Fix TBPTT example (#20528)
* Fix TBPTT example * Make example self-contained * Update imports * Add test
1 parent ee7fa43 commit 76f0c54

File tree

3 files changed

+123
-22
lines changed

3 files changed

+123
-22
lines changed

docs/source-pytorch/common/tbptt.rst

+64-21
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split.
1212
.. code-block:: python
1313
1414
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
1517
import torch.optim as optim
16-
import pytorch_lightning as pl
17-
from pytorch_lightning import LightningModule
18+
from torch.utils.data import Dataset, DataLoader
1819
19-
class LitModel(LightningModule):
20+
import lightning as L
21+
22+
23+
class AverageDataset(Dataset):
24+
def __init__(self, dataset_len=300, sequence_len=100):
25+
self.dataset_len = dataset_len
26+
self.sequence_len = sequence_len
27+
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
28+
top, bottom = self.input_seq.chunk(2, -1)
29+
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
30+
31+
def __len__(self):
32+
return self.dataset_len
33+
34+
def __getitem__(self, item):
35+
return self.input_seq[item], self.output_seq[item]
36+
37+
38+
class LitModel(L.LightningModule):
2039
2140
def __init__(self):
2241
super().__init__()
2342
43+
self.batch_size = 10
44+
self.in_features = 10
45+
self.out_features = 5
46+
self.hidden_dim = 20
47+
2448
# 1. Switch to manual optimization
2549
self.automatic_optimization = False
26-
2750
self.truncated_bptt_steps = 10
28-
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
51+
52+
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
53+
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
54+
55+
def forward(self, x, hs):
56+
seq, hs = self.rnn(x, hs)
57+
return self.linear_out(seq), hs
2958
3059
# 2. Remove the `hiddens` argument
3160
def training_step(self, batch, batch_idx):
32-
3361
# 3. Split the batch in chunks along the time dimension
34-
split_batches = split_batch(batch, self.truncated_bptt_steps)
35-
36-
batch_size = 10
37-
hidden_dim = 20
38-
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
39-
for split_batch in range(split_batches):
40-
# 4. Perform the optimization in a loop
41-
loss, hiddens = self.my_rnn(split_batch, hiddens)
42-
self.backward(loss)
43-
self.optimizer.step()
44-
self.optimizer.zero_grad()
62+
x, y = batch
63+
split_x, split_y = [
64+
x.tensor_split(self.truncated_bptt_steps, dim=1),
65+
y.tensor_split(self.truncated_bptt_steps, dim=1)
66+
]
67+
68+
hiddens = None
69+
optimizer = self.optimizers()
70+
losses = []
71+
72+
# 4. Perform the optimization in a loop
73+
for x, y in zip(split_x, split_y):
74+
y_pred, hiddens = self(x, hiddens)
75+
loss = F.mse_loss(y_pred, y)
76+
77+
optimizer.zero_grad()
78+
self.manual_backward(loss)
79+
optimizer.step()
4580
4681
# 5. "Truncate"
47-
hiddens = hiddens.detach()
82+
hiddens = [h.detach() for h in hiddens]
83+
losses.append(loss.detach())
84+
85+
avg_loss = sum(losses) / len(losses)
86+
self.log("train_loss", avg_loss, prog_bar=True)
4887
4988
# 6. Remove the return of `hiddens`
5089
# Returning loss in manual optimization is not needed
5190
return None
5291
5392
def configure_optimizers(self):
54-
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
93+
return optim.Adam(self.parameters(), lr=0.001)
94+
95+
def train_dataloader(self):
96+
return DataLoader(AverageDataset(), batch_size=self.batch_size)
97+
5598
5699
if __name__ == "__main__":
57100
model = LitModel()
58-
trainer = pl.Trainer(max_epochs=5)
59-
trainer.fit(model, train_dataloader) # Define your own dataloader
101+
trainer = L.Trainer(max_epochs=5)
102+
trainer.fit(model)

tests/tests_pytorch/helpers/advanced_models.py

+51
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,54 @@ def configure_optimizers(self):
219219

220220
def train_dataloader(self):
221221
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)
222+
223+
224+
class TBPTTModule(LightningModule):
225+
def __init__(self):
226+
super().__init__()
227+
228+
self.batch_size = 10
229+
self.in_features = 10
230+
self.out_features = 5
231+
self.hidden_dim = 20
232+
233+
self.automatic_optimization = False
234+
self.truncated_bptt_steps = 10
235+
236+
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
237+
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
238+
239+
def forward(self, x, hs):
240+
seq, hs = self.rnn(x, hs)
241+
return self.linear_out(seq), hs
242+
243+
def training_step(self, batch, batch_idx):
244+
x, y = batch
245+
split_x, split_y = [
246+
x.tensor_split(self.truncated_bptt_steps, dim=1),
247+
y.tensor_split(self.truncated_bptt_steps, dim=1),
248+
]
249+
250+
hiddens = None
251+
optimizer = self.optimizers()
252+
losses = []
253+
254+
for x, y in zip(split_x, split_y):
255+
y_pred, hiddens = self(x, hiddens)
256+
loss = F.mse_loss(y_pred, y)
257+
258+
optimizer.zero_grad()
259+
self.manual_backward(loss)
260+
optimizer.step()
261+
262+
# "Truncate"
263+
hiddens = [h.detach() for h in hiddens]
264+
losses.append(loss.detach())
265+
266+
return
267+
268+
def configure_optimizers(self):
269+
return torch.optim.Adam(self.parameters(), lr=0.001)
270+
271+
def train_dataloader(self):
272+
return DataLoader(AverageDataset(), batch_size=self.batch_size)

tests/tests_pytorch/helpers/test_models.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from lightning.pytorch import Trainer
1818
from lightning.pytorch.demos.boring_classes import BoringModel
1919

20-
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN
20+
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule
2121
from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule
2222
from tests_pytorch.helpers.runif import RunIf
2323
from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel
@@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class):
4949
model.to_torchscript()
5050
if data_class:
5151
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)
52+
53+
54+
def test_tbptt(tmp_path):
55+
model = TBPTTModule()
56+
57+
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
58+
trainer.fit(model)

0 commit comments

Comments
 (0)