Skip to content

Commit ec6f5a8

Browse files
committed
update example
1 parent 887199a commit ec6f5a8

File tree

3 files changed

+302
-41
lines changed

3 files changed

+302
-41
lines changed

_notebooks

Submodule _notebooks deleted from b83fde0
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,261 @@
1+
import argparse
2+
from dataclasses import dataclass
3+
import logging
4+
5+
import torch.distributed as dist
16
import lightning as L
27
import torch
38
import torch.nn as nn
49
import torch.nn.functional as F
510
from lightning.pytorch.demos import Transformer, WikiText2
6-
from lightning.pytorch.strategies import ModelParallelStrategy
7-
from torch.distributed._composable.fsdp.fully_shard import fully_shard
11+
from lightning.pytorch.strategies import FSDPStrategy, ModelParallelStrategy
12+
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision
813
from torch.utils.data import DataLoader
9-
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
14+
15+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
16+
log = logging.getLogger(__name__)
17+
18+
@dataclass
19+
class Args:
20+
vocab_size: int = 32000
21+
enable_fp8: bool = False
22+
enable_torch_compile: bool = False
23+
enable_cpu_offload: bool = False
24+
enable_gradient_checkpointing: bool = False
25+
enable_fsdp2: bool = False
26+
27+
class SimpleLayer(nn.Module):
28+
def __init__(self, hidden_size):
29+
super(SimpleLayer, self).__init__()
30+
self.linear = nn.Linear(hidden_size, hidden_size)
31+
self.activation = nn.ReLU()
32+
33+
def forward(self, x):
34+
print(f"Input shape before Linear: {x.shape}")
35+
x = self.linear(x)
36+
print(f"Output shape after Linear: {x.shape}")
37+
x = self.activation(x)
38+
return x
39+
40+
class InnerModel(nn.Module):
41+
def __init__(self, num_layers, hidden_size, vocab_size=32000):
42+
super(InnerModel, self).__init__()
43+
# Embedding layer
44+
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size)
45+
# Initialize a ModuleList to store the intermediate layers
46+
self.layers = nn.ModuleList([SimpleLayer(hidden_size) for _ in range(num_layers)])
47+
self.lm_head = nn.Linear(hidden_size, vocab_size)
48+
49+
50+
def forward(self, x):
51+
x = self.embedding(x)
52+
# Pass the input through each layer sequentially
53+
for layer in self.layers:
54+
x = layer(x)
55+
x = self.lm_head(x)
56+
return x
57+
58+
59+
class ModelWrapper(nn.Module):
60+
def __init__(self, model):
61+
super(ModelWrapper, self).__init__()
62+
self.model = model # The wrapped Transformer model
63+
64+
def forward(self, *args, **kwargs):
65+
return self.model(*args, **kwargs)
1066

1167

1268
class LanguageModel(L.LightningModule):
13-
def __init__(self, vocab_size):
69+
def __init__(self,
70+
vocab_size=32000,
71+
enable_fp8 = False,
72+
enable_fsdp2 = False,
73+
enable_torch_compile = False,
74+
enable_gradient_checkpointing = False,
75+
enable_cpu_offload = False
76+
):
1477
super().__init__()
15-
self.vocab_size = vocab_size
1678
self.model = None
79+
self.vocab_size = vocab_size
80+
self.enable_fp8 = enable_fp8
81+
self.enable_fsdp2 = enable_fsdp2
82+
self.enable_torch_compile = enable_torch_compile
83+
self.enable_gradient_checkpointing = enable_gradient_checkpointing
84+
self.enable_cpu_offload = enable_cpu_offload
85+
self.model_path = "dummy" # placeholder
86+
self.parallel_dims = {
87+
"dp_shard_enabled": True if torch.cuda.device_count() > 1 else False
88+
} # only used for FP8 training
89+
90+
def log_model_stage(self, stage: str):
91+
"""
92+
Logs the current state of the model with a description of the stage.
93+
Args:
94+
stage (str): Description of the current model stage.
95+
"""
96+
log.warning(f"Model at stage: {stage}\n{self.model}")
97+
98+
def configure_torch_compile(self):
99+
if self.enable_torch_compile:
100+
from handlers.torch_compile_handler import TorchCompileHandler
101+
102+
torch_compile_handler = TorchCompileHandler(
103+
enable_compile=self.enable_torch_compile,
104+
model_path=self.model_path,
105+
# Implicitly specify layers, default only support compile HuggingFace llama and mixtral model with llama MLP block and Mixtral MixtralBlockSparseTop2MLP block compiled
106+
compile_layers=["SimpleLayer"],
107+
compile_args=None,
108+
)
109+
torch_compile_handler.compile_model(self.model)
110+
111+
self.log_model_stage("Model after compile")
112+
113+
def configure_fsdp2(self):
114+
if self.enable_fsdp2:
115+
self.all_gpus = dist.new_group(backend="nccl")
116+
dp_mesh = self.device_mesh["data_parallel"]
117+
assert dp_mesh.size() > 1
118+
119+
from handlers.fsdp2_handler import FSDP2Config, FSDP2Handler
120+
121+
fsdp2_config = FSDP2Config(
122+
enable_cpu_offload=self.enable_cpu_offload,
123+
enable_gradient_checkpointing=self.enable_gradient_checkpointing,
124+
)
125+
fsdp2_handler = FSDP2Handler(fsdp2_config, self.device_mesh)
126+
self.model = fsdp2_handler.wrap_model(self.model)
127+
128+
self.log_model_stage("Model after FSDP wrapper")
129+
130+
def configure_fp8(self):
131+
# Setup fp8 training, if enable_fp8 is false, it will create a fake handler
132+
from handlers.fp8_training_handler import FP8Config, Float8TrainingHandler
133+
134+
fp8_config = FP8Config(
135+
enable_fp8=self.enable_fp8,
136+
enable_amax_init=False,
137+
scaling_type_input="delayed",
138+
scaling_type_weight="delayed",
139+
scaling_type_grad_output="delayed",
140+
enable_fsdp_float8_all_gather=False,
141+
precompute_float8_dynamic_scale_for_fsdp=False,
142+
pad_inner_dim=True,
143+
emulate_fp8=False, # Set to True for testing without FP8 hardware
144+
enable_torch_compile=self.enable_torch_compile,
145+
enable_pre_and_post_forward=False,
146+
)
147+
self.fp8_handler = Float8TrainingHandler(fp8_config, self.model_path, self.parallel_dims)
148+
self.fp8_handler.convert_to_float8_training(self.model)
149+
self.log_model_stage("Model after FP8 wrapper")
17150

18151
def configure_model(self):
19152
if self.model is not None:
20153
return
21154

22155
with torch.device("meta"):
23-
model = Transformer(
24-
vocab_size=self.vocab_size,
25-
nlayers=16,
26-
nhid=4096,
27-
ninp=1024,
28-
nhead=32,
156+
self.model = ModelWrapper(
157+
InnerModel(
158+
num_layers=16,
159+
hidden_size=1024,
160+
vocab_size=self.vocab_size,
161+
)
29162
)
163+
self.configure_fp8()
164+
self.configure_fsdp2()
165+
self.configure_torch_compile()
166+
self.model.train()
30167

31-
float8_config = Float8LinearConfig(
32-
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
33-
pad_inner_dim=True,
34-
)
35-
36-
def module_filter_fn(mod: torch.nn.Module, fqn: str):
37-
# we skip the decoder because it typically vocabulary size
38-
# is not divisible by 16 as required by float8
39-
if fqn == "decoder":
40-
return False
41-
return True
168+
def on_train_batch_start(self, batch, batch_idx):
169+
super().on_train_batch_start(batch, batch_idx)
170+
self.hand_roll_base_zero_grad()
42171

43-
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
172+
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0):
173+
super().on_validation_batch_start(batch, batch_idx, dataloader_idx)
174+
self.hand_roll_base_zero_grad()
44175

45-
for module in model.modules():
46-
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
47-
fully_shard(module, mesh=self.device_mesh)
176+
def hand_roll_base_zero_grad(self):
177+
# to resolve the torch compile + FSDP1 issue https://github.com/pytorch/pytorch/issues/139110
178+
if self.enable_torch_compile and not self.enable_fsdp2:
179+
self.zero_grad(set_to_none=True)
180+
for p in self.parameters():
181+
if p._base is not None and p._base.grad is not None:
182+
p._base._grad = None
48183

49-
fully_shard(model, mesh=self.device_mesh)
184+
def on_before_optimizer_step(self, optimizer):
185+
self.fp8_handler.sync_float8_amax_and_scale_history(self.model)
186+
super().on_before_optimizer_step(optimizer)
50187

51-
self.model = torch.compile(model)
188+
def on_train_batch_end(self, outputs, batch, batch_idx):
189+
self.fp8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model)
190+
super().on_train_batch_end(outputs, batch, batch_idx)
52191

53192
def training_step(self, batch):
54193
input, target = batch
55-
output = self.model(input, target)
56-
loss = F.nll_loss(output, target.view(-1))
194+
output = self.model(input)
195+
log_softmax = nn.LogSoftmax(dim=1)
196+
loss = F.nll_loss(log_softmax(output).view(-1, self.vocab_size), target.view(-1))
57197
self.log("train_loss", loss, prog_bar=True)
58198
return loss
59199

60200
def configure_optimizers(self):
61201
return torch.optim.Adam(self.parameters(), lr=1e-4)
62202

63203

64-
def train():
204+
def train(args):
65205
L.seed_everything(42)
66206

67207
dataset = WikiText2()
68208
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
69209

70-
model = LanguageModel(vocab_size=dataset.vocab_size)
210+
model = LanguageModel(vocab_size=args.vocab_size,
211+
enable_fp8 = args.enable_fp8,
212+
enable_fsdp2 = args.enable_fsdp2,
213+
enable_torch_compile = args.enable_torch_compile,
214+
enable_gradient_checkpointing = args.enable_gradient_checkpointing,
215+
enable_cpu_offload = args.enable_cpu_offload,
216+
)
71217

72-
mp_strategy = ModelParallelStrategy(
73-
data_parallel_size=1,
74-
tensor_parallel_size=1,
75-
)
76-
77-
trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8)
218+
if args.enable_fsdp2:
219+
strategy = ModelParallelStrategy(
220+
data_parallel_size=1,
221+
tensor_parallel_size=1,
222+
)
223+
else:
224+
layers = {SimpleLayer}
225+
strategy = FSDPStrategy(
226+
auto_wrap_policy=layers,
227+
sharding_strategy="FULL_SHARD",
228+
forward_prefetch=True,
229+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
230+
sync_module_states=True,
231+
activation_checkpointing_policy=layers if args.enable_gradient_checkpointing else None,
232+
# for FSDP, we set mixed precision here instead of passing precision to PL trainer.
233+
# precision="bf16-true" in PL trainer means pure half precision (including optimizer update etc.)
234+
# while precision="bf16-mixed" results in unshard allgather performed in fp32:
235+
# https://github.com/Lightning-AI/pytorch-lightning/blob/bf25167bbf64f50ba335aa759318946b21775cd2/src/lightning/fabric/plugins/precision/fsdp.py#L83
236+
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
237+
cpu_offload=args.enable_cpu_offload,
238+
)
239+
trainer = L.Trainer(strategy=strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8)
78240

79241
trainer.fit(model, train_dataloader)
80242

81243
trainer.print(torch.cuda.memory_summary())
82244

83245

246+
def parse_args():
247+
parser = argparse.ArgumentParser(description="Train a language model.")
248+
parser.add_argument("--vocab_size", type=int, default=32000, help="Vocabulary size. Default is 32000.")
249+
parser.add_argument("--enable_fp8", action="store_true", help="Enable FP8 precision.")
250+
parser.add_argument("--enable_torch_compile", action="store_true", help="Enable Torch Compile.")
251+
parser.add_argument("--enable_cpu_offload", action="store_true", help="Enable CPU offload.")
252+
parser.add_argument("--enable_gradient_checkpointing", action="store_true", help="Enable gradient checkpointing.")
253+
parser.add_argument("--enable_fsdp2", action="store_true", help="Enable FSDP2.")
254+
args = parser.parse_args()
255+
return Args(**vars(args))
256+
257+
84258
if __name__ == "__main__":
85259
torch.set_float32_matmul_precision("high")
86-
87-
train()
260+
args = parse_args()
261+
train(args)
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import lightning as L
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from lightning.pytorch.demos import Transformer, WikiText2
6+
from lightning.pytorch.strategies import ModelParallelStrategy
7+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
8+
from torch.utils.data import DataLoader
9+
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
10+
11+
12+
class LanguageModel(L.LightningModule):
13+
def __init__(self, vocab_size):
14+
super().__init__()
15+
self.vocab_size = vocab_size
16+
self.model = None
17+
18+
def configure_model(self):
19+
if self.model is not None:
20+
return
21+
22+
with torch.device("meta"):
23+
model = Transformer(
24+
vocab_size=self.vocab_size,
25+
nlayers=16,
26+
nhid=4096,
27+
ninp=1024,
28+
nhead=32,
29+
)
30+
31+
float8_config = Float8LinearConfig(
32+
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
33+
pad_inner_dim=True,
34+
)
35+
36+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
37+
# we skip the decoder because it typically vocabulary size
38+
# is not divisible by 16 as required by float8
39+
if fqn == "decoder":
40+
return False
41+
return True
42+
43+
convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn)
44+
45+
for module in model.modules():
46+
if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)):
47+
fully_shard(module, mesh=self.device_mesh)
48+
49+
fully_shard(model, mesh=self.device_mesh)
50+
51+
self.model = torch.compile(model)
52+
53+
54+
def training_step(self, batch):
55+
input, target = batch
56+
output = self.model(input, target)
57+
loss = F.nll_loss(output, target.view(-1))
58+
self.log("train_loss", loss, prog_bar=True)
59+
return loss
60+
61+
def configure_optimizers(self):
62+
return torch.optim.Adam(self.parameters(), lr=1e-4)
63+
64+
65+
def train():
66+
L.seed_everything(42)
67+
68+
dataset = WikiText2()
69+
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
70+
71+
model = LanguageModel(vocab_size=dataset.vocab_size)
72+
73+
mp_strategy = ModelParallelStrategy(
74+
data_parallel_size=1,
75+
tensor_parallel_size=1,
76+
)
77+
78+
trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8)
79+
80+
trainer.fit(model, train_dataloader)
81+
82+
trainer.print(torch.cuda.memory_summary())
83+
84+
85+
if __name__ == "__main__":
86+
torch.set_float32_matmul_precision("high")
87+
88+
train()

0 commit comments

Comments
 (0)