@@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split.
12
12
.. code-block :: python
13
13
14
14
import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
15
17
import torch.optim as optim
16
- import pytorch_lightning as pl
17
- from pytorch_lightning import LightningModule
18
+ from torch.utils.data import Dataset, DataLoader
18
19
19
- class LitModel (LightningModule ):
20
+ import lightning as L
21
+
22
+
23
+ class AverageDataset (Dataset ):
24
+ def __init__ (self , dataset_len = 300 , sequence_len = 100 ):
25
+ self .dataset_len = dataset_len
26
+ self .sequence_len = sequence_len
27
+ self .input_seq = torch.randn(dataset_len, sequence_len, 10 )
28
+ top, bottom = self .input_seq.chunk(2 , - 1 )
29
+ self .output_seq = top + bottom.roll(shifts = 1 , dims = - 1 )
30
+
31
+ def __len__ (self ):
32
+ return self .dataset_len
33
+
34
+ def __getitem__ (self , item ):
35
+ return self .input_seq[item], self .output_seq[item]
36
+
37
+
38
+ class LitModel (L .LightningModule ):
20
39
21
40
def __init__ (self ):
22
41
super ().__init__ ()
23
42
43
+ self .batch_size = 10
44
+ self .in_features = 10
45
+ self .out_features = 5
46
+ self .hidden_dim = 20
47
+
24
48
# 1. Switch to manual optimization
25
49
self .automatic_optimization = False
26
-
27
50
self .truncated_bptt_steps = 10
28
- self .my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
51
+
52
+ self .rnn = nn.LSTM(self .in_features, self .hidden_dim, batch_first = True )
53
+ self .linear_out = nn.Linear(in_features = self .hidden_dim, out_features = self .out_features)
54
+
55
+ def forward (self , x , hs ):
56
+ seq, hs = self .rnn(x, hs)
57
+ return self .linear_out(seq), hs
29
58
30
59
# 2. Remove the `hiddens` argument
31
60
def training_step (self , batch , batch_idx ):
32
-
33
61
# 3. Split the batch in chunks along the time dimension
34
- split_batches = split_batch(batch, self .truncated_bptt_steps)
35
-
36
- batch_size = 10
37
- hidden_dim = 20
38
- hiddens = torch.zeros(1 , batch_size, hidden_dim, device = self .device)
39
- for split_batch in range (split_batches):
40
- # 4. Perform the optimization in a loop
41
- loss, hiddens = self .my_rnn(split_batch, hiddens)
42
- self .backward(loss)
43
- self .optimizer.step()
44
- self .optimizer.zero_grad()
62
+ x, y = batch
63
+ split_x, split_y = [
64
+ x.tensor_split(self .truncated_bptt_steps, dim = 1 ),
65
+ y.tensor_split(self .truncated_bptt_steps, dim = 1 )
66
+ ]
67
+
68
+ hiddens = None
69
+ optimizer = self .optimizers()
70
+ losses = []
71
+
72
+ # 4. Perform the optimization in a loop
73
+ for x, y in zip (split_x, split_y):
74
+ y_pred, hiddens = self (x, hiddens)
75
+ loss = F.mse_loss(y_pred, y)
76
+
77
+ optimizer.zero_grad()
78
+ self .manual_backward(loss)
79
+ optimizer.step()
45
80
46
81
# 5. "Truncate"
47
- hiddens = hiddens.detach()
82
+ hiddens = [h.detach() for h in hiddens]
83
+ losses.append(loss.detach())
84
+
85
+ avg_loss = sum (losses) / len (losses)
86
+ self .log(" train_loss" , avg_loss, prog_bar = True )
48
87
49
88
# 6. Remove the return of `hiddens`
50
89
# Returning loss in manual optimization is not needed
51
90
return None
52
91
53
92
def configure_optimizers (self ):
54
- return optim.Adam(self .my_rnn.parameters(), lr = 0.001 )
93
+ return optim.Adam(self .parameters(), lr = 0.001 )
94
+
95
+ def train_dataloader (self ):
96
+ return DataLoader(AverageDataset(), batch_size = self .batch_size)
97
+
55
98
56
99
if __name__ == " __main__" :
57
100
model = LitModel()
58
- trainer = pl .Trainer(max_epochs = 5 )
59
- trainer.fit(model, train_dataloader) # Define your own dataloader
101
+ trainer = L .Trainer(max_epochs = 5 )
102
+ trainer.fit(model)
0 commit comments