66
66
_TORCH_GREATER_EQUAL_2_0 ,
67
67
_TORCH_GREATER_EQUAL_2_1 ,
68
68
_TORCH_GREATER_EQUAL_2_2 ,
69
+ _TORCH_GREATER_EQUAL_2_3 ,
69
70
)
70
71
from lightning .fabric .utilities .init import _EmptyInit
71
72
from lightning .fabric .utilities .load import _METADATA_FILENAME , _lazy_load , _materialize_tensors , _move_state_into
@@ -448,7 +449,6 @@ def save_checkpoint(
448
449
if path .is_dir () and self ._state_dict_type == "full" and not _is_sharded_checkpoint (path ):
449
450
raise IsADirectoryError (f"The checkpoint path exists and is a directory: { path } " )
450
451
451
- from torch .distributed .checkpoint import FileSystemWriter , save_state_dict
452
452
from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
453
453
454
454
modules = [module for module in state .values () if _has_fsdp_modules (module )]
@@ -491,9 +491,7 @@ def save_checkpoint(
491
491
target_dict = metadata
492
492
_apply_filter (key , filter or {}, converted , target_dict )
493
493
494
- # FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks
495
- writer = FileSystemWriter (path = path , single_file_per_rank = True )
496
- save_state_dict (converted_state , writer )
494
+ _distributed_checkpoint_save (converted_state , path )
497
495
498
496
if self .global_rank == 0 :
499
497
torch .save (metadata , path / _METADATA_FILENAME )
@@ -555,16 +553,10 @@ def load_checkpoint(
555
553
"Loading a single optimizer object from a checkpoint is not supported yet with the FSDP strategy."
556
554
)
557
555
558
- from torch .distributed .checkpoint import FileSystemReader
559
556
from torch .distributed .checkpoint .optimizer import load_sharded_optimizer_state_dict
560
557
from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
561
558
from torch .distributed .fsdp import OptimStateKeyType
562
559
563
- if _TORCH_GREATER_EQUAL_2_2 :
564
- from torch .distributed .checkpoint import load
565
- else :
566
- from torch .distributed .checkpoint import load_state_dict as load # deprecated
567
-
568
560
modules = {key : module for key , module in state .items () if _has_fsdp_modules (module )}
569
561
if len (modules ) == 0 :
570
562
raise ValueError (
@@ -583,26 +575,30 @@ def load_checkpoint(
583
575
584
576
if _is_sharded_checkpoint (path ):
585
577
state_dict_ctx = _get_sharded_state_dict_context (module )
586
- reader = FileSystemReader (path = path )
587
578
588
579
with state_dict_ctx :
589
580
module_state = {module_key : module .state_dict ()}
590
- load (module_state , reader )
581
+ _distributed_checkpoint_load (module_state , path )
591
582
module .load_state_dict (module_state [module_key ], strict = strict )
592
583
593
- # the optimizer states must be loaded separately
594
- for optim_key , optim in optimizers .items ():
595
- optim_state = load_sharded_optimizer_state_dict (
596
- model_state_dict = module_state [module_key ],
597
- optimizer_key = optim_key ,
598
- storage_reader = reader ,
599
- )
600
- flattened_osd = FSDP .optim_state_dict_to_load (
601
- optim_state_dict = optim_state [optim_key ],
602
- model = module ,
603
- optim = optim ,
604
- )
605
- optim .load_state_dict (flattened_osd )
584
+ if optimizers :
585
+ from torch .distributed .checkpoint import FileSystemReader
586
+ # TODO: replace with newer APIs
587
+ # https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271
588
+ reader = FileSystemReader (path = path )
589
+ # the optimizer states must be loaded separately
590
+ for optim_key , optim in optimizers .items ():
591
+ optim_state = load_sharded_optimizer_state_dict (
592
+ model_state_dict = module_state [module_key ],
593
+ optimizer_key = optim_key ,
594
+ storage_reader = reader ,
595
+ )
596
+ flattened_osd = FSDP .optim_state_dict_to_load (
597
+ optim_state_dict = optim_state [optim_key ],
598
+ model = module ,
599
+ optim = optim ,
600
+ )
601
+ optim .load_state_dict (flattened_osd )
606
602
607
603
# Load metadata (anything not a module or optimizer)
608
604
metadata = torch .load (path / _METADATA_FILENAME )
@@ -920,3 +916,35 @@ def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device)
920
916
921
917
for metric in (m for m in module .modules () if isinstance (m , Metric )):
922
918
metric .to (device ) # `.to()` is in-place
919
+
920
+
921
+ def _distributed_checkpoint_save (converted_state : Dict [str , Any ], path : Path ) -> None :
922
+ if _TORCH_GREATER_EQUAL_2_3 :
923
+ from torch .distributed .checkpoint import save
924
+ # let torch automatically infer the writer to use. This might also support fsspec paths in the future
925
+ # https://github.com/pytorch/pytorch/issues/118036
926
+ save (converted_state , checkpoint_id = path ) # type: ignore[call-arg]
927
+ else : # deprecated
928
+ from torch .distributed .checkpoint import FileSystemWriter
929
+ if _TORCH_GREATER_EQUAL_2_2 :
930
+ from torch .distributed .checkpoint import save
931
+ else :
932
+ from torch .distributed .checkpoint import save_state_dict as save
933
+ # FSDP's FileSystemWriter streams the tensors to disk to minimize memory peaks
934
+ writer = FileSystemWriter (path = path , single_file_per_rank = True )
935
+ save (converted_state , writer )
936
+
937
+ def _distributed_checkpoint_load (module_state : Dict [str , Any ], path : Path ) -> None :
938
+ if _TORCH_GREATER_EQUAL_2_3 :
939
+ from torch .distributed .checkpoint import load
940
+ # let torch automatically infer the reader to use. This might also support fsspec paths in the future
941
+ # https://github.com/pytorch/pytorch/issues/118036
942
+ load (module_state , checkpoint_id = path ) # type: ignore[call-arg]
943
+ else : # deprecated
944
+ from torch .distributed .checkpoint import FileSystemReader
945
+ if _TORCH_GREATER_EQUAL_2_2 :
946
+ from torch .distributed .checkpoint import load
947
+ else :
948
+ from torch .distributed .checkpoint import load_state_dict as load
949
+ reader = FileSystemReader (path = path )
950
+ load (module_state , reader )
0 commit comments