Skip to content

Commit 6745994

Browse files
authored
Avoid FSDP deprecations during save/load with newer torch versions (#19463)
* Avoid FSDP deprecations during save/load with newer torch versions * Refactor * Tests
1 parent 59e45d6 commit 6745994

File tree

5 files changed

+92
-59
lines changed

5 files changed

+92
-59
lines changed

src/lightning/fabric/strategies/fsdp.py

+53-25
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
_TORCH_GREATER_EQUAL_2_0,
6767
_TORCH_GREATER_EQUAL_2_1,
6868
_TORCH_GREATER_EQUAL_2_2,
69+
_TORCH_GREATER_EQUAL_2_3,
6970
)
7071
from lightning.fabric.utilities.init import _EmptyInit
7172
from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into
@@ -448,7 +449,6 @@ def save_checkpoint(
448449
if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path):
449450
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
450451

451-
from torch.distributed.checkpoint import FileSystemWriter, save_state_dict
452452
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
453453

454454
modules = [module for module in state.values() if _has_fsdp_modules(module)]
@@ -491,9 +491,7 @@ def save_checkpoint(
491491
target_dict = metadata
492492
_apply_filter(key, filter or {}, converted, target_dict)
493493

494-
# FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks
495-
writer = FileSystemWriter(path=path, single_file_per_rank=True)
496-
save_state_dict(converted_state, writer)
494+
_distributed_checkpoint_save(converted_state, path)
497495

498496
if self.global_rank == 0:
499497
torch.save(metadata, path / _METADATA_FILENAME)
@@ -555,16 +553,10 @@ def load_checkpoint(
555553
"Loading a single optimizer object from a checkpoint is not supported yet with the FSDP strategy."
556554
)
557555

558-
from torch.distributed.checkpoint import FileSystemReader
559556
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
560557
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
561558
from torch.distributed.fsdp import OptimStateKeyType
562559

563-
if _TORCH_GREATER_EQUAL_2_2:
564-
from torch.distributed.checkpoint import load
565-
else:
566-
from torch.distributed.checkpoint import load_state_dict as load # deprecated
567-
568560
modules = {key: module for key, module in state.items() if _has_fsdp_modules(module)}
569561
if len(modules) == 0:
570562
raise ValueError(
@@ -583,26 +575,30 @@ def load_checkpoint(
583575

584576
if _is_sharded_checkpoint(path):
585577
state_dict_ctx = _get_sharded_state_dict_context(module)
586-
reader = FileSystemReader(path=path)
587578

588579
with state_dict_ctx:
589580
module_state = {module_key: module.state_dict()}
590-
load(module_state, reader)
581+
_distributed_checkpoint_load(module_state, path)
591582
module.load_state_dict(module_state[module_key], strict=strict)
592583

593-
# the optimizer states must be loaded separately
594-
for optim_key, optim in optimizers.items():
595-
optim_state = load_sharded_optimizer_state_dict(
596-
model_state_dict=module_state[module_key],
597-
optimizer_key=optim_key,
598-
storage_reader=reader,
599-
)
600-
flattened_osd = FSDP.optim_state_dict_to_load(
601-
optim_state_dict=optim_state[optim_key],
602-
model=module,
603-
optim=optim,
604-
)
605-
optim.load_state_dict(flattened_osd)
584+
if optimizers:
585+
from torch.distributed.checkpoint import FileSystemReader
586+
# TODO: replace with newer APIs
587+
# https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271
588+
reader = FileSystemReader(path=path)
589+
# the optimizer states must be loaded separately
590+
for optim_key, optim in optimizers.items():
591+
optim_state = load_sharded_optimizer_state_dict(
592+
model_state_dict=module_state[module_key],
593+
optimizer_key=optim_key,
594+
storage_reader=reader,
595+
)
596+
flattened_osd = FSDP.optim_state_dict_to_load(
597+
optim_state_dict=optim_state[optim_key],
598+
model=module,
599+
optim=optim,
600+
)
601+
optim.load_state_dict(flattened_osd)
606602

607603
# Load metadata (anything not a module or optimizer)
608604
metadata = torch.load(path / _METADATA_FILENAME)
@@ -920,3 +916,35 @@ def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device)
920916

921917
for metric in (m for m in module.modules() if isinstance(m, Metric)):
922918
metric.to(device) # `.to()` is in-place
919+
920+
921+
def _distributed_checkpoint_save(converted_state: Dict[str, Any], path: Path) -> None:
922+
if _TORCH_GREATER_EQUAL_2_3:
923+
from torch.distributed.checkpoint import save
924+
# let torch automatically infer the writer to use. This might also support fsspec paths in the future
925+
# https://github.com/pytorch/pytorch/issues/118036
926+
save(converted_state, checkpoint_id=path) # type: ignore[call-arg]
927+
else: # deprecated
928+
from torch.distributed.checkpoint import FileSystemWriter
929+
if _TORCH_GREATER_EQUAL_2_2:
930+
from torch.distributed.checkpoint import save
931+
else:
932+
from torch.distributed.checkpoint import save_state_dict as save
933+
# FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks
934+
writer = FileSystemWriter(path=path, single_file_per_rank=True)
935+
save(converted_state, writer)
936+
937+
def _distributed_checkpoint_load(module_state: Dict[str, Any], path: Path) -> None:
938+
if _TORCH_GREATER_EQUAL_2_3:
939+
from torch.distributed.checkpoint import load
940+
# let torch automatically infer the reader to use. This might also support fsspec paths in the future
941+
# https://github.com/pytorch/pytorch/issues/118036
942+
load(module_state, checkpoint_id=path) # type: ignore[call-arg]
943+
else: # deprecated
944+
from torch.distributed.checkpoint import FileSystemReader
945+
if _TORCH_GREATER_EQUAL_2_2:
946+
from torch.distributed.checkpoint import load
947+
else:
948+
from torch.distributed.checkpoint import load_state_dict as load
949+
reader = FileSystemReader(path=path)
950+
load(module_state, reader)

src/lightning/fabric/utilities/imports.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
2727

2828
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
29-
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
30-
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0", use_base_version=True)
29+
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0")
30+
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
31+
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0", use_base_version=True)
3132
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1
3233

3334
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)

src/lightning/pytorch/strategies/fsdp.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
_METADATA_FILENAME,
3434
_activation_checkpointing_kwargs,
3535
_auto_wrap_policy_kwargs,
36+
_distributed_checkpoint_load,
37+
_distributed_checkpoint_save,
3638
_get_full_state_dict_context,
3739
_get_sharded_state_dict_context,
3840
_has_meta_device_parameters,
@@ -55,7 +57,6 @@
5557
from lightning.fabric.utilities.imports import (
5658
_TORCH_GREATER_EQUAL_2_0,
5759
_TORCH_GREATER_EQUAL_2_1,
58-
_TORCH_GREATER_EQUAL_2_2,
5960
)
6061
from lightning.fabric.utilities.init import _EmptyInit
6162
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
@@ -561,8 +562,6 @@ def save_checkpoint(
561562
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")
562563

563564
if self._state_dict_type == "sharded":
564-
from torch.distributed.checkpoint import FileSystemWriter, save_state_dict
565-
566565
if path.is_file():
567566
path.unlink()
568567
path.mkdir(parents=True, exist_ok=True)
@@ -572,9 +571,7 @@ def save_checkpoint(
572571
{f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))}
573572
)
574573

575-
# FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks
576-
writer = FileSystemWriter(path=path, single_file_per_rank=True)
577-
save_state_dict(converted_state, writer)
574+
_distributed_checkpoint_save(converted_state, path)
578575

579576
if self.global_rank == 0:
580577
torch.save(checkpoint, path / _METADATA_FILENAME)
@@ -596,23 +593,20 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
596593
assert self.lightning_module is not None
597594

598595
if _is_sharded_checkpoint(path):
599-
from torch.distributed.checkpoint import FileSystemReader
600596
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
601597

602-
if _TORCH_GREATER_EQUAL_2_2:
603-
from torch.distributed.checkpoint import load
604-
else:
605-
from torch.distributed.checkpoint import load_state_dict as load # deprecated
606-
607598
state_dict_ctx = _get_sharded_state_dict_context(self.model)
608-
reader = FileSystemReader(path=path)
609599

610600
with state_dict_ctx:
611601
module_state = {"model": self.model.state_dict()}
612-
load(module_state, reader)
602+
_distributed_checkpoint_load(module_state, path)
613603
self.model.load_state_dict(module_state["model"], strict=self.lightning_module.strict_loading)
614604

615-
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
605+
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING and self.optimizers:
606+
from torch.distributed.checkpoint import FileSystemReader
607+
# TODO: replace with newer APIs
608+
# https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271
609+
reader = FileSystemReader(path=path)
616610
# the optimizer states must be loaded separately
617611
for idx, optim in enumerate(self.optimizers):
618612
optim_key = f"optimizer_{idx}"

tests/tests_fabric/strategies/test_fsdp.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
_has_meta_device_parameters,
3030
_is_sharded_checkpoint,
3131
)
32-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
32+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2
3333
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
3434
from torch.optim import Adam
3535

@@ -241,13 +241,12 @@ def test_fsdp_save_checkpoint_storage_options(tmp_path):
241241

242242

243243
@RunIf(min_torch="2.0.0")
244-
@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock())
245244
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
246-
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock())
247-
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock())
248-
@mock.patch("lightning.fabric.strategies.fsdp.torch.save", return_value=Mock())
249-
@mock.patch("lightning.fabric.strategies.fsdp.shutil", return_value=MagicMock())
250-
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
245+
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context")
246+
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context")
247+
@mock.patch("lightning.fabric.strategies.fsdp.torch.save")
248+
@mock.patch("lightning.fabric.strategies.fsdp.shutil")
249+
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
251250
strategy = FSDPStrategy(state_dict_type="full")
252251

253252
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
@@ -278,22 +277,27 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
278277
torch_save_mock.assert_called_once()
279278

280279
strategy = FSDPStrategy(state_dict_type="sharded")
280+
save_mock = mock.patch(
281+
"torch.distributed.checkpoint.save"
282+
if _TORCH_GREATER_EQUAL_2_2 else "torch.distributed.checkpoint.save_state_dict")
281283

282284
# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
283285
path = tmp_path / "not-empty-2"
284286
path.mkdir()
285287
(path / "file").touch()
286288
model = Mock(spec=FullyShardedDataParallel)
287289
model.modules.return_value = [model]
288-
strategy.save_checkpoint(path=path, state={"model": model})
290+
with save_mock:
291+
strategy.save_checkpoint(path=path, state={"model": model})
289292
assert (path / "file").exists()
290293

291294
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
292295
path = tmp_path / "file-2.pt"
293296
path.touch()
294297
model = Mock(spec=FullyShardedDataParallel)
295298
model.modules.return_value = [model]
296-
strategy.save_checkpoint(path=path, state={"model": model})
299+
with save_mock:
300+
strategy.save_checkpoint(path=path, state={"model": model})
297301
assert path.is_dir()
298302

299303

tests/tests_pytorch/strategies/test_fsdp.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning.fabric.utilities.imports import (
1818
_TORCH_GREATER_EQUAL_2_0,
1919
_TORCH_GREATER_EQUAL_2_1,
20+
_TORCH_GREATER_EQUAL_2_2,
2021
)
2122
from lightning.fabric.utilities.load import _load_distributed_checkpoint
2223
from lightning.pytorch import Trainer
@@ -801,13 +802,12 @@ def test_save_checkpoint_storage_options(tmp_path):
801802

802803

803804
@RunIf(min_torch="2.0.0")
804-
@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock())
805805
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
806-
@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock())
807-
@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock())
808-
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save", return_value=Mock())
809-
@mock.patch("lightning.pytorch.strategies.fsdp.shutil", return_value=MagicMock())
810-
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
806+
@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context")
807+
@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context")
808+
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save")
809+
@mock.patch("lightning.pytorch.strategies.fsdp.shutil")
810+
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
811811
strategy = FSDPStrategy(state_dict_type="full")
812812

813813
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
@@ -839,21 +839,27 @@ def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___,
839839

840840
strategy = FSDPStrategy(state_dict_type="sharded")
841841

842+
save_mock = mock.patch(
843+
"torch.distributed.checkpoint.save"
844+
if _TORCH_GREATER_EQUAL_2_2 else "torch.distributed.checkpoint.save_state_dict")
845+
842846
# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
843847
path = tmp_path / "not-empty-2"
844848
path.mkdir()
845849
(path / "file").touch()
846850
model = Mock(spec=FullyShardedDataParallel)
847851
model.modules.return_value = [model]
848-
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
852+
with save_mock:
853+
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
849854
assert (path / "file").exists()
850855

851856
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
852857
path = tmp_path / "file-2.pt"
853858
path.touch()
854859
model = Mock(spec=FullyShardedDataParallel)
855860
model.modules.return_value = [model]
856-
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
861+
with save_mock:
862+
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
857863
assert path.is_dir()
858864

859865

0 commit comments

Comments
 (0)