Skip to content

Commit c3c43b6

Browse files
robietaTaylor Robielantiga
authored
Small changes to reduce peak memory. (#389)
Co-authored-by: Taylor Robie <taylor.robie@lightning.ai> Co-authored-by: Luca Antiga <luca@lightning.ai>
1 parent 6d2c5ca commit c3c43b6

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

finetune/full.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def main(
5555
):
5656

5757
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
58-
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block)
58+
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)
5959

6060
fabric = L.Fabric(accelerator="cuda", devices=devices, precision="bf16-mixed", strategy=strategy)
6161
fabric.launch()
@@ -79,7 +79,7 @@ def main(
7979

8080
model = fabric.setup_module(model)
8181

82-
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
82+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False)
8383
optimizer = fabric.setup_optimizers(optimizer)
8484

8585
train(fabric, model, optimizer, train_data, val_data, out_dir)

pretrain/redpajama.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def main(
6969
transformer_auto_wrap_policy, transformer_layer_cls={Block}
7070
)
7171
strategy = FSDPStrategy(
72-
auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block
72+
auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True
7373
)
7474

7575
fabric = L.Fabric(
@@ -110,6 +110,7 @@ def main(
110110
lr=learning_rate,
111111
weight_decay=weight_decay,
112112
betas=(beta1, beta2),
113+
foreach=False,
113114
)
114115

115116
model, optimizer = fabric.setup(model, optimizer)

pretrain/shakespeare.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
def main() -> None:
4949
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
50-
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block)
50+
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)
5151

5252
fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy)
5353
fabric.launch()
@@ -70,7 +70,7 @@ def main() -> None:
7070

7171
model = fabric.setup_module(model)
7272

73-
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))
73+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False)
7474
optimizer = fabric.setup_optimizers(optimizer)
7575

7676
train(fabric, model, optimizer, train_data, val_data)

0 commit comments

Comments
 (0)