|
| 1 | +import argparse |
| 2 | +from dataclasses import dataclass |
| 3 | +import logging |
| 4 | + |
| 5 | +import torch.distributed as dist |
1 | 6 | import lightning as L
|
2 | 7 | import torch
|
3 | 8 | import torch.nn as nn
|
4 | 9 | import torch.nn.functional as F
|
5 | 10 | from lightning.pytorch.demos import Transformer, WikiText2
|
6 |
| -from lightning.pytorch.strategies import ModelParallelStrategy |
7 |
| -from torch.distributed._composable.fsdp.fully_shard import fully_shard |
| 11 | +from lightning.pytorch.strategies import FSDPStrategy, ModelParallelStrategy |
| 12 | +from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision |
8 | 13 | from torch.utils.data import DataLoader
|
9 |
| -from torchao.float8 import Float8LinearConfig, convert_to_float8_training |
| 14 | + |
| 15 | +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| 16 | +log = logging.getLogger(__name__) |
| 17 | + |
| 18 | +@dataclass |
| 19 | +class Args: |
| 20 | + vocab_size: int = 32000 |
| 21 | + enable_fp8: bool = False |
| 22 | + enable_torch_compile: bool = False |
| 23 | + enable_cpu_offload: bool = False |
| 24 | + enable_gradient_checkpointing: bool = False |
| 25 | + enable_fsdp2: bool = False |
| 26 | + |
| 27 | +class SimpleLayer(nn.Module): |
| 28 | + def __init__(self, hidden_size): |
| 29 | + super(SimpleLayer, self).__init__() |
| 30 | + self.linear = nn.Linear(hidden_size, hidden_size) |
| 31 | + self.activation = nn.ReLU() |
| 32 | + |
| 33 | + def forward(self, x): |
| 34 | + print(f"Input shape before Linear: {x.shape}") |
| 35 | + x = self.linear(x) |
| 36 | + print(f"Output shape after Linear: {x.shape}") |
| 37 | + x = self.activation(x) |
| 38 | + return x |
| 39 | + |
| 40 | +class InnerModel(nn.Module): |
| 41 | + def __init__(self, num_layers, hidden_size, vocab_size=32000): |
| 42 | + super(InnerModel, self).__init__() |
| 43 | + # Embedding layer |
| 44 | + self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size) |
| 45 | + # Initialize a ModuleList to store the intermediate layers |
| 46 | + self.layers = nn.ModuleList([SimpleLayer(hidden_size) for _ in range(num_layers)]) |
| 47 | + self.lm_head = nn.Linear(hidden_size, vocab_size) |
| 48 | + |
| 49 | + |
| 50 | + def forward(self, x): |
| 51 | + x = self.embedding(x) |
| 52 | + # Pass the input through each layer sequentially |
| 53 | + for layer in self.layers: |
| 54 | + x = layer(x) |
| 55 | + x = self.lm_head(x) |
| 56 | + return x |
| 57 | + |
| 58 | + |
| 59 | +class ModelWrapper(nn.Module): |
| 60 | + def __init__(self, model): |
| 61 | + super(ModelWrapper, self).__init__() |
| 62 | + self.model = model # The wrapped Transformer model |
| 63 | + |
| 64 | + def forward(self, *args, **kwargs): |
| 65 | + return self.model(*args, **kwargs) |
10 | 66 |
|
11 | 67 |
|
12 | 68 | class LanguageModel(L.LightningModule):
|
13 |
| - def __init__(self, vocab_size): |
| 69 | + def __init__(self, |
| 70 | + vocab_size=32000, |
| 71 | + enable_fp8 = False, |
| 72 | + enable_fsdp2 = False, |
| 73 | + enable_torch_compile = False, |
| 74 | + enable_gradient_checkpointing = False, |
| 75 | + enable_cpu_offload = False |
| 76 | + ): |
14 | 77 | super().__init__()
|
15 |
| - self.vocab_size = vocab_size |
16 | 78 | self.model = None
|
| 79 | + self.vocab_size = vocab_size |
| 80 | + self.enable_fp8 = enable_fp8 |
| 81 | + self.enable_fsdp2 = enable_fsdp2 |
| 82 | + self.enable_torch_compile = enable_torch_compile |
| 83 | + self.enable_gradient_checkpointing = enable_gradient_checkpointing |
| 84 | + self.enable_cpu_offload = enable_cpu_offload |
| 85 | + self.model_path = "dummy" # placeholder |
| 86 | + self.parallel_dims = { |
| 87 | + "dp_shard_enabled": True if torch.cuda.device_count() > 1 else False |
| 88 | + } # only used for FP8 training |
| 89 | + |
| 90 | + def log_model_stage(self, stage: str): |
| 91 | + """ |
| 92 | + Logs the current state of the model with a description of the stage. |
| 93 | + Args: |
| 94 | + stage (str): Description of the current model stage. |
| 95 | + """ |
| 96 | + log.warning(f"Model at stage: {stage}\n{self.model}") |
| 97 | + |
| 98 | + def configure_torch_compile(self): |
| 99 | + if self.enable_torch_compile: |
| 100 | + from handlers.torch_compile_handler import TorchCompileHandler |
| 101 | + |
| 102 | + torch_compile_handler = TorchCompileHandler( |
| 103 | + enable_compile=self.enable_torch_compile, |
| 104 | + model_path=self.model_path, |
| 105 | + # Implicitly specify layers, default only support compile HuggingFace llama and mixtral model with llama MLP block and Mixtral MixtralBlockSparseTop2MLP block compiled |
| 106 | + compile_layers=["SimpleLayer"], |
| 107 | + compile_args=None, |
| 108 | + ) |
| 109 | + torch_compile_handler.compile_model(self.model) |
| 110 | + |
| 111 | + self.log_model_stage("Model after compile") |
| 112 | + |
| 113 | + def configure_fsdp2(self): |
| 114 | + if self.enable_fsdp2: |
| 115 | + self.all_gpus = dist.new_group(backend="nccl") |
| 116 | + dp_mesh = self.device_mesh["data_parallel"] |
| 117 | + assert dp_mesh.size() > 1 |
| 118 | + |
| 119 | + from handlers.fsdp2_handler import FSDP2Config, FSDP2Handler |
| 120 | + |
| 121 | + fsdp2_config = FSDP2Config( |
| 122 | + enable_cpu_offload=self.enable_cpu_offload, |
| 123 | + enable_gradient_checkpointing=self.enable_gradient_checkpointing, |
| 124 | + ) |
| 125 | + fsdp2_handler = FSDP2Handler(fsdp2_config, self.device_mesh) |
| 126 | + self.model = fsdp2_handler.wrap_model(self.model) |
| 127 | + |
| 128 | + self.log_model_stage("Model after FSDP wrapper") |
| 129 | + |
| 130 | + def configure_fp8(self): |
| 131 | + # Setup fp8 training, if enable_fp8 is false, it will create a fake handler |
| 132 | + from handlers.fp8_training_handler import FP8Config, Float8TrainingHandler |
| 133 | + |
| 134 | + fp8_config = FP8Config( |
| 135 | + enable_fp8=self.enable_fp8, |
| 136 | + enable_amax_init=False, |
| 137 | + scaling_type_input="delayed", |
| 138 | + scaling_type_weight="delayed", |
| 139 | + scaling_type_grad_output="delayed", |
| 140 | + enable_fsdp_float8_all_gather=False, |
| 141 | + precompute_float8_dynamic_scale_for_fsdp=False, |
| 142 | + pad_inner_dim=True, |
| 143 | + emulate_fp8=False, # Set to True for testing without FP8 hardware |
| 144 | + enable_torch_compile=self.enable_torch_compile, |
| 145 | + enable_pre_and_post_forward=False, |
| 146 | + ) |
| 147 | + self.fp8_handler = Float8TrainingHandler(fp8_config, self.model_path, self.parallel_dims) |
| 148 | + self.fp8_handler.convert_to_float8_training(self.model) |
| 149 | + self.log_model_stage("Model after FP8 wrapper") |
17 | 150 |
|
18 | 151 | def configure_model(self):
|
19 | 152 | if self.model is not None:
|
20 | 153 | return
|
21 | 154 |
|
22 | 155 | with torch.device("meta"):
|
23 |
| - model = Transformer( |
24 |
| - vocab_size=self.vocab_size, |
25 |
| - nlayers=16, |
26 |
| - nhid=4096, |
27 |
| - ninp=1024, |
28 |
| - nhead=32, |
| 156 | + self.model = ModelWrapper( |
| 157 | + InnerModel( |
| 158 | + num_layers=16, |
| 159 | + hidden_size=1024, |
| 160 | + vocab_size=self.vocab_size, |
| 161 | + ) |
29 | 162 | )
|
| 163 | + self.configure_fp8() |
| 164 | + self.configure_fsdp2() |
| 165 | + self.configure_torch_compile() |
| 166 | + self.model.train() |
30 | 167 |
|
31 |
| - float8_config = Float8LinearConfig( |
32 |
| - # pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly |
33 |
| - pad_inner_dim=True, |
34 |
| - ) |
35 |
| - |
36 |
| - def module_filter_fn(mod: torch.nn.Module, fqn: str): |
37 |
| - # we skip the decoder because it typically vocabulary size |
38 |
| - # is not divisible by 16 as required by float8 |
39 |
| - if fqn == "decoder": |
40 |
| - return False |
41 |
| - return True |
| 168 | + def on_train_batch_start(self, batch, batch_idx): |
| 169 | + super().on_train_batch_start(batch, batch_idx) |
| 170 | + self.hand_roll_base_zero_grad() |
42 | 171 |
|
43 |
| - convert_to_float8_training(model, config=float8_config, module_filter_fn=module_filter_fn) |
| 172 | + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): |
| 173 | + super().on_validation_batch_start(batch, batch_idx, dataloader_idx) |
| 174 | + self.hand_roll_base_zero_grad() |
44 | 175 |
|
45 |
| - for module in model.modules(): |
46 |
| - if isinstance(module, (nn.TransformerEncoderLayer, nn.TransformerDecoderLayer)): |
47 |
| - fully_shard(module, mesh=self.device_mesh) |
| 176 | + def hand_roll_base_zero_grad(self): |
| 177 | + # to resolve the torch compile + FSDP1 issue https://github.com/pytorch/pytorch/issues/139110 |
| 178 | + if self.enable_torch_compile and not self.enable_fsdp2: |
| 179 | + self.zero_grad(set_to_none=True) |
| 180 | + for p in self.parameters(): |
| 181 | + if p._base is not None and p._base.grad is not None: |
| 182 | + p._base._grad = None |
48 | 183 |
|
49 |
| - fully_shard(model, mesh=self.device_mesh) |
| 184 | + def on_before_optimizer_step(self, optimizer): |
| 185 | + self.fp8_handler.sync_float8_amax_and_scale_history(self.model) |
| 186 | + super().on_before_optimizer_step(optimizer) |
50 | 187 |
|
51 |
| - self.model = torch.compile(model) |
| 188 | + def on_train_batch_end(self, outputs, batch, batch_idx): |
| 189 | + self.fp8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model) |
| 190 | + super().on_train_batch_end(outputs, batch, batch_idx) |
52 | 191 |
|
53 | 192 | def training_step(self, batch):
|
54 | 193 | input, target = batch
|
55 |
| - output = self.model(input, target) |
56 |
| - loss = F.nll_loss(output, target.view(-1)) |
| 194 | + output = self.model(input) |
| 195 | + log_softmax = nn.LogSoftmax(dim=1) |
| 196 | + loss = F.nll_loss(log_softmax(output).view(-1, self.vocab_size), target.view(-1)) |
57 | 197 | self.log("train_loss", loss, prog_bar=True)
|
58 | 198 | return loss
|
59 | 199 |
|
60 | 200 | def configure_optimizers(self):
|
61 | 201 | return torch.optim.Adam(self.parameters(), lr=1e-4)
|
62 | 202 |
|
63 | 203 |
|
64 |
| -def train(): |
| 204 | +def train(args): |
65 | 205 | L.seed_everything(42)
|
66 | 206 |
|
67 | 207 | dataset = WikiText2()
|
68 | 208 | train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
|
69 | 209 |
|
70 |
| - model = LanguageModel(vocab_size=dataset.vocab_size) |
| 210 | + model = LanguageModel(vocab_size=args.vocab_size, |
| 211 | + enable_fp8 = args.enable_fp8, |
| 212 | + enable_fsdp2 = args.enable_fsdp2, |
| 213 | + enable_torch_compile = args.enable_torch_compile, |
| 214 | + enable_gradient_checkpointing = args.enable_gradient_checkpointing, |
| 215 | + enable_cpu_offload = args.enable_cpu_offload, |
| 216 | + ) |
71 | 217 |
|
72 |
| - mp_strategy = ModelParallelStrategy( |
73 |
| - data_parallel_size=1, |
74 |
| - tensor_parallel_size=1, |
75 |
| - ) |
76 |
| - |
77 |
| - trainer = L.Trainer(strategy=mp_strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8) |
| 218 | + if args.enable_fsdp2: |
| 219 | + strategy = ModelParallelStrategy( |
| 220 | + data_parallel_size=1, |
| 221 | + tensor_parallel_size=1, |
| 222 | + ) |
| 223 | + else: |
| 224 | + layers = {SimpleLayer} |
| 225 | + strategy = FSDPStrategy( |
| 226 | + auto_wrap_policy=layers, |
| 227 | + sharding_strategy="FULL_SHARD", |
| 228 | + forward_prefetch=True, |
| 229 | + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, |
| 230 | + sync_module_states=True, |
| 231 | + activation_checkpointing_policy=layers if args.enable_gradient_checkpointing else None, |
| 232 | + # for FSDP, we set mixed precision here instead of passing precision to PL trainer. |
| 233 | + # precision="bf16-true" in PL trainer means pure half precision (including optimizer update etc.) |
| 234 | + # while precision="bf16-mixed" results in unshard allgather performed in fp32: |
| 235 | + # https://github.com/Lightning-AI/pytorch-lightning/blob/bf25167bbf64f50ba335aa759318946b21775cd2/src/lightning/fabric/plugins/precision/fsdp.py#L83 |
| 236 | + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32), |
| 237 | + cpu_offload=args.enable_cpu_offload, |
| 238 | + ) |
| 239 | + trainer = L.Trainer(strategy=strategy, max_steps=100, precision="bf16-true", accumulate_grad_batches=8) |
78 | 240 |
|
79 | 241 | trainer.fit(model, train_dataloader)
|
80 | 242 |
|
81 | 243 | trainer.print(torch.cuda.memory_summary())
|
82 | 244 |
|
83 | 245 |
|
| 246 | +def parse_args(): |
| 247 | + parser = argparse.ArgumentParser(description="Train a language model.") |
| 248 | + parser.add_argument("--vocab_size", type=int, default=32000, help="Vocabulary size. Default is 32000.") |
| 249 | + parser.add_argument("--enable_fp8", action="store_true", help="Enable FP8 precision.") |
| 250 | + parser.add_argument("--enable_torch_compile", action="store_true", help="Enable Torch Compile.") |
| 251 | + parser.add_argument("--enable_cpu_offload", action="store_true", help="Enable CPU offload.") |
| 252 | + parser.add_argument("--enable_gradient_checkpointing", action="store_true", help="Enable gradient checkpointing.") |
| 253 | + parser.add_argument("--enable_fsdp2", action="store_true", help="Enable FSDP2.") |
| 254 | + args = parser.parse_args() |
| 255 | + return Args(**vars(args)) |
| 256 | + |
| 257 | + |
84 | 258 | if __name__ == "__main__":
|
85 | 259 | torch.set_float32_matmul_precision("high")
|
86 |
| - |
87 |
| - train() |
| 260 | + args = parse_args() |
| 261 | + train(args) |
0 commit comments