Skip to content

Commit abbc029

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent ec6f5a8 commit abbc029

File tree

2 files changed

+28
-24
lines changed
  • examples/pytorch

2 files changed

+28
-24
lines changed

examples/pytorch/custom_handler_fp8_fsdp1n2_compile/train.py

+27-22
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
import argparse
2-
from dataclasses import dataclass
32
import logging
3+
from dataclasses import dataclass
44

5-
import torch.distributed as dist
65
import lightning as L
76
import torch
7+
import torch.distributed as dist
88
import torch.nn as nn
99
import torch.nn.functional as F
10-
from lightning.pytorch.demos import Transformer, WikiText2
10+
from lightning.pytorch.demos import WikiText2
1111
from lightning.pytorch.strategies import FSDPStrategy, ModelParallelStrategy
1212
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision
1313
from torch.utils.data import DataLoader
1414

1515
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
1616
log = logging.getLogger(__name__)
1717

18+
1819
@dataclass
1920
class Args:
2021
vocab_size: int = 32000
@@ -24,6 +25,7 @@ class Args:
2425
enable_gradient_checkpointing: bool = False
2526
enable_fsdp2: bool = False
2627

28+
2729
class SimpleLayer(nn.Module):
2830
def __init__(self, hidden_size):
2931
super(SimpleLayer, self).__init__()
@@ -37,6 +39,7 @@ def forward(self, x):
3739
x = self.activation(x)
3840
return x
3941

42+
4043
class InnerModel(nn.Module):
4144
def __init__(self, num_layers, hidden_size, vocab_size=32000):
4245
super(InnerModel, self).__init__()
@@ -46,7 +49,6 @@ def __init__(self, num_layers, hidden_size, vocab_size=32000):
4649
self.layers = nn.ModuleList([SimpleLayer(hidden_size) for _ in range(num_layers)])
4750
self.lm_head = nn.Linear(hidden_size, vocab_size)
4851

49-
5052
def forward(self, x):
5153
x = self.embedding(x)
5254
# Pass the input through each layer sequentially
@@ -66,14 +68,15 @@ def forward(self, *args, **kwargs):
6668

6769

6870
class LanguageModel(L.LightningModule):
69-
def __init__(self,
70-
vocab_size=32000,
71-
enable_fp8 = False,
72-
enable_fsdp2 = False,
73-
enable_torch_compile = False,
74-
enable_gradient_checkpointing = False,
75-
enable_cpu_offload = False
76-
):
71+
def __init__(
72+
self,
73+
vocab_size=32000,
74+
enable_fp8=False,
75+
enable_fsdp2=False,
76+
enable_torch_compile=False,
77+
enable_gradient_checkpointing=False,
78+
enable_cpu_offload=False,
79+
):
7780
super().__init__()
7881
self.model = None
7982
self.vocab_size = vocab_size
@@ -88,10 +91,11 @@ def __init__(self,
8891
} # only used for FP8 training
8992

9093
def log_model_stage(self, stage: str):
91-
"""
92-
Logs the current state of the model with a description of the stage.
94+
"""Logs the current state of the model with a description of the stage.
95+
9396
Args:
9497
stage (str): Description of the current model stage.
98+
9599
"""
96100
log.warning(f"Model at stage: {stage}\n{self.model}")
97101

@@ -129,7 +133,7 @@ def configure_fsdp2(self):
129133

130134
def configure_fp8(self):
131135
# Setup fp8 training, if enable_fp8 is false, it will create a fake handler
132-
from handlers.fp8_training_handler import FP8Config, Float8TrainingHandler
136+
from handlers.fp8_training_handler import Float8TrainingHandler, FP8Config
133137

134138
fp8_config = FP8Config(
135139
enable_fp8=self.enable_fp8,
@@ -207,13 +211,14 @@ def train(args):
207211
dataset = WikiText2()
208212
train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
209213

210-
model = LanguageModel(vocab_size=args.vocab_size,
211-
enable_fp8 = args.enable_fp8,
212-
enable_fsdp2 = args.enable_fsdp2,
213-
enable_torch_compile = args.enable_torch_compile,
214-
enable_gradient_checkpointing = args.enable_gradient_checkpointing,
215-
enable_cpu_offload = args.enable_cpu_offload,
216-
)
214+
model = LanguageModel(
215+
vocab_size=args.vocab_size,
216+
enable_fp8=args.enable_fp8,
217+
enable_fsdp2=args.enable_fsdp2,
218+
enable_torch_compile=args.enable_torch_compile,
219+
enable_gradient_checkpointing=args.enable_gradient_checkpointing,
220+
enable_cpu_offload=args.enable_cpu_offload,
221+
)
217222

218223
if args.enable_fsdp2:
219224
strategy = ModelParallelStrategy(

examples/pytorch/fp8_fsdp_compile/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def module_filter_fn(mod: torch.nn.Module, fqn: str):
5050

5151
self.model = torch.compile(model)
5252

53-
5453
def training_step(self, batch):
5554
input, target = batch
5655
output = self.model(input, target)
@@ -85,4 +84,4 @@ def train():
8584
if __name__ == "__main__":
8685
torch.set_float32_matmul_precision("high")
8786

88-
train()
87+
train()

0 commit comments

Comments
 (0)