1
1
import argparse
2
- from dataclasses import dataclass
3
2
import logging
3
+ from dataclasses import dataclass
4
4
5
- import torch .distributed as dist
6
5
import lightning as L
7
6
import torch
7
+ import torch .distributed as dist
8
8
import torch .nn as nn
9
9
import torch .nn .functional as F
10
- from lightning .pytorch .demos import Transformer , WikiText2
10
+ from lightning .pytorch .demos import WikiText2
11
11
from lightning .pytorch .strategies import FSDPStrategy , ModelParallelStrategy
12
12
from torch .distributed .fsdp import BackwardPrefetch , MixedPrecision
13
13
from torch .utils .data import DataLoader
14
14
15
15
logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
16
16
log = logging .getLogger (__name__ )
17
17
18
+
18
19
@dataclass
19
20
class Args :
20
21
vocab_size : int = 32000
@@ -24,6 +25,7 @@ class Args:
24
25
enable_gradient_checkpointing : bool = False
25
26
enable_fsdp2 : bool = False
26
27
28
+
27
29
class SimpleLayer (nn .Module ):
28
30
def __init__ (self , hidden_size ):
29
31
super (SimpleLayer , self ).__init__ ()
@@ -37,6 +39,7 @@ def forward(self, x):
37
39
x = self .activation (x )
38
40
return x
39
41
42
+
40
43
class InnerModel (nn .Module ):
41
44
def __init__ (self , num_layers , hidden_size , vocab_size = 32000 ):
42
45
super (InnerModel , self ).__init__ ()
@@ -46,7 +49,6 @@ def __init__(self, num_layers, hidden_size, vocab_size=32000):
46
49
self .layers = nn .ModuleList ([SimpleLayer (hidden_size ) for _ in range (num_layers )])
47
50
self .lm_head = nn .Linear (hidden_size , vocab_size )
48
51
49
-
50
52
def forward (self , x ):
51
53
x = self .embedding (x )
52
54
# Pass the input through each layer sequentially
@@ -66,14 +68,15 @@ def forward(self, *args, **kwargs):
66
68
67
69
68
70
class LanguageModel (L .LightningModule ):
69
- def __init__ (self ,
70
- vocab_size = 32000 ,
71
- enable_fp8 = False ,
72
- enable_fsdp2 = False ,
73
- enable_torch_compile = False ,
74
- enable_gradient_checkpointing = False ,
75
- enable_cpu_offload = False
76
- ):
71
+ def __init__ (
72
+ self ,
73
+ vocab_size = 32000 ,
74
+ enable_fp8 = False ,
75
+ enable_fsdp2 = False ,
76
+ enable_torch_compile = False ,
77
+ enable_gradient_checkpointing = False ,
78
+ enable_cpu_offload = False ,
79
+ ):
77
80
super ().__init__ ()
78
81
self .model = None
79
82
self .vocab_size = vocab_size
@@ -88,10 +91,11 @@ def __init__(self,
88
91
} # only used for FP8 training
89
92
90
93
def log_model_stage (self , stage : str ):
91
- """
92
- Logs the current state of the model with a description of the stage.
94
+ """Logs the current state of the model with a description of the stage.
95
+
93
96
Args:
94
97
stage (str): Description of the current model stage.
98
+
95
99
"""
96
100
log .warning (f"Model at stage: { stage } \n { self .model } " )
97
101
@@ -129,7 +133,7 @@ def configure_fsdp2(self):
129
133
130
134
def configure_fp8 (self ):
131
135
# Setup fp8 training, if enable_fp8 is false, it will create a fake handler
132
- from handlers .fp8_training_handler import FP8Config , Float8TrainingHandler
136
+ from handlers .fp8_training_handler import Float8TrainingHandler , FP8Config
133
137
134
138
fp8_config = FP8Config (
135
139
enable_fp8 = self .enable_fp8 ,
@@ -207,13 +211,14 @@ def train(args):
207
211
dataset = WikiText2 ()
208
212
train_dataloader = DataLoader (dataset , num_workers = 8 , batch_size = 1 )
209
213
210
- model = LanguageModel (vocab_size = args .vocab_size ,
211
- enable_fp8 = args .enable_fp8 ,
212
- enable_fsdp2 = args .enable_fsdp2 ,
213
- enable_torch_compile = args .enable_torch_compile ,
214
- enable_gradient_checkpointing = args .enable_gradient_checkpointing ,
215
- enable_cpu_offload = args .enable_cpu_offload ,
216
- )
214
+ model = LanguageModel (
215
+ vocab_size = args .vocab_size ,
216
+ enable_fp8 = args .enable_fp8 ,
217
+ enable_fsdp2 = args .enable_fsdp2 ,
218
+ enable_torch_compile = args .enable_torch_compile ,
219
+ enable_gradient_checkpointing = args .enable_gradient_checkpointing ,
220
+ enable_cpu_offload = args .enable_cpu_offload ,
221
+ )
217
222
218
223
if args .enable_fsdp2 :
219
224
strategy = ModelParallelStrategy (
0 commit comments