Skip to content

Commit 36fdb57

Browse files
authored
Merge branch 'master' into feature/9947_dataloader-string
2 parents f492167 + 601c060 commit 36fdb57

File tree

6 files changed

+73
-1
lines changed

6 files changed

+73
-1
lines changed

docs/source-pytorch/common/checkpointing_basic.rst

+23-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ PyTorch Lightning checkpoints are fully usable in plain PyTorch.
2020

2121
----
2222

23+
.. important::
24+
25+
**Important Update: Deprecated Method**
26+
27+
Starting from PyTorch Lightning v1.0.0, the `resume_from_checkpoint` argument has been deprecated. To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method.
28+
Please update your code accordingly to avoid potential compatibility issues.
29+
2330
************************
2431
Contents of a checkpoint
2532
************************
@@ -197,16 +204,31 @@ You can disable checkpointing by passing:
197204

198205
----
199206

207+
200208
*********************
201209
Resume training state
202210
*********************
203211

204212
If you don't just want to load weights, but instead restore the full training, do the following:
205213

214+
Correct usage:
215+
206216
.. code-block:: python
207217
208218
model = LitModel()
209219
trainer = Trainer()
210220
211221
# automatically restores model, epoch, step, LR schedulers, etc...
212-
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
222+
trainer.fit(model, ckpt_path="path/to/your/checkpoint.ckpt")
223+
224+
.. warning::
225+
226+
The argument `resume_from_checkpoint` has been deprecated in versions of PyTorch Lightning >= 1.0.0.
227+
To resume training from a checkpoint, use the `ckpt_path` argument in the `fit()` method instead.
228+
229+
Incorrect (deprecated) usage:
230+
231+
.. code-block:: python
232+
233+
trainer = Trainer(resume_from_checkpoint="path/to/your/checkpoint.ckpt")
234+
trainer.fit(model)

src/lightning/fabric/plugins/precision/fsdp.py

+6
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
7474
}
7575
self._desired_input_dtype = precision_to_type[self.precision]
7676

77+
@override
78+
def convert_module(self, module: Module) -> Module:
79+
if "true" in self.precision:
80+
return module.to(dtype=self._desired_input_dtype)
81+
return module
82+
7783
@property
7884
def mixed_precision_config(self) -> "TorchMixedPrecision":
7985
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

src/lightning/pytorch/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))
3636
- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019))
3737
- Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163))
38+
- Fixed PyTorch Lightning FSDP takes more memory than PyTorch FSDP ([#20323](https://github.com/Lightning-AI/pytorch-lightning/pull/20323))
3839

3940

4041
## [2.3.0] - 2024-06-13

src/lightning/pytorch/plugins/precision/fsdp.py

+7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from lightning_utilities import apply_to_collection
1919
from torch import Tensor
20+
from torch.nn import Module
2021
from typing_extensions import get_args, override
2122

2223
import lightning.pytorch as pl
@@ -73,6 +74,12 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
7374
}
7475
self._desired_input_dtype = precision_to_type[self.precision]
7576

77+
@override
78+
def convert_module(self, module: Module) -> Module:
79+
if "true" in self.precision:
80+
return module.to(dtype=self._desired_input_dtype)
81+
return module
82+
7683
@override
7784
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
7885
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_

tests/tests_fabric/plugins/precision/test_fsdp.py

+18
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,21 @@ def test_invalid_precision_with_fsdp_precision():
127127

128128
with pytest.raises(ValueError, match="is not supported in FSDP. `precision` must be one of"):
129129
FSDPPrecision(precision="64-true")
130+
131+
132+
@pytest.mark.parametrize(
133+
("precision", "expected_dtype"),
134+
[
135+
("32-true", torch.float32),
136+
("bf16-mixed", torch.float32),
137+
("16-mixed", torch.float32),
138+
("bf16-true", torch.bfloat16),
139+
("16-true", torch.float16),
140+
],
141+
)
142+
def test_convert_module(precision, expected_dtype):
143+
precision = FSDPPrecision(precision=precision)
144+
module = torch.nn.Linear(2, 2)
145+
assert module.weight.dtype == module.bias.dtype == torch.float32
146+
module = precision.convert_module(module)
147+
assert module.weight.dtype == module.bias.dtype == expected_dtype

tests/tests_pytorch/plugins/precision/test_fsdp.py

+18
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,24 @@ def test_fsdp_precision_config(precision, expected):
4040
assert config.reduce_dtype == expected[2]
4141

4242

43+
@pytest.mark.parametrize(
44+
("precision", "expected_dtype"),
45+
[
46+
("32-true", torch.float32),
47+
("bf16-mixed", torch.float32),
48+
("16-mixed", torch.float32),
49+
("bf16-true", torch.bfloat16),
50+
("16-true", torch.float16),
51+
],
52+
)
53+
def test_convert_module(precision, expected_dtype):
54+
precision = FSDPPrecision(precision=precision)
55+
module = torch.nn.Linear(2, 2)
56+
assert module.weight.dtype == module.bias.dtype == torch.float32
57+
module = precision.convert_module(module)
58+
assert module.weight.dtype == module.bias.dtype == expected_dtype
59+
60+
4361
def test_fsdp_precision_default_scaler():
4462
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
4563

0 commit comments

Comments
 (0)