Skip to content

Commit b1127e3

Browse files
awaelchlicarmocca
andauthored
Utility to consolidate sharded checkpoints (#19213)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
1 parent ed367ca commit b1127e3

File tree

11 files changed

+505
-7
lines changed

11 files changed

+505
-7
lines changed

docs/source-fabric/guide/checkpoint/distributed_checkpoint.rst

+33-1
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,36 @@ Note that you can load the distributed checkpoint even if the world size has cha
183183
Convert a distributed checkpoint
184184
********************************
185185

186-
Coming soon.
186+
It is possible to convert a distributed checkpoint to a regular, single-file checkpoint with this utility:
187+
188+
.. code-block:: bash
189+
190+
python -m lightning.fabric.utilities.consolidate_checkpoint path/to/my/checkpoint
191+
192+
You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc.
193+
194+
.. note::
195+
196+
All tensors in the checkpoint will be converted to CPU tensors, and no GPUs are required to run the conversion command.
197+
This function assumes you have enough free CPU memory to hold the entire checkpoint in memory.
198+
199+
.. collapse:: Full example
200+
201+
Assuming you have saved a checkpoint ``my-checkpoint.ckpt`` using the examples above, run the following command to convert it:
202+
203+
.. code-block:: bash
204+
205+
python -m lightning.fabric.utilities.consolidate_checkpoint my-checkpoint.ckpt
206+
207+
This saves a new file ``my-checkpoint.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch:
208+
209+
.. code-block:: python
210+
211+
import torch
212+
213+
checkpoint = torch.load("my-checkpoint.ckpt.consolidated")
214+
print(list(checkpoint.keys()))
215+
print(checkpoint["model"]["transformer.decoder.layers.31.norm1.weight"])
216+
217+
218+
|

docs/source-pytorch/common/checkpointing_expert.rst

+34-1
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,37 @@ Note that you can load the distributed checkpoint even if the world size has cha
136136
Convert a distributed checkpoint
137137
********************************
138138

139-
Coming soon.
139+
It is possible to convert a distributed checkpoint to a regular, single-file checkpoint with this utility:
140+
141+
.. code-block:: bash
142+
143+
python -m lightning.pytorch.utilities.consolidate_checkpoint path/to/my/checkpoint
144+
145+
You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc.
146+
147+
.. note::
148+
149+
All tensors in the checkpoint will be converted to CPU tensors, and no GPUs are required to run the conversion command.
150+
This function assumes you have enough free CPU memory to hold the entire checkpoint in memory.
151+
152+
.. collapse:: Full example
153+
154+
Assuming you have saved a checkpoint ``epoch=0-step=3.ckpt`` using the examples above, run the following command to convert it:
155+
156+
.. code-block:: bash
157+
158+
cd lightning_logs/version_0/checkpoints
159+
python -m lightning.pytorch.utilities.consolidate_checkpoint epoch=0-step=3.ckpt
160+
161+
This saves a new file ``epoch=0-step=3.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch:
162+
163+
.. code-block:: python
164+
165+
import torch
166+
167+
checkpoint = torch.load("epoch=0-step=3.ckpt.consolidated")
168+
print(list(checkpoint.keys()))
169+
print(checkpoint["state_dict"]["model.transformer.decoder.layers.31.norm1.weight"])
170+
171+
172+
|

src/lightning/fabric/strategies/fsdp.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
_TORCH_GREATER_EQUAL_2_2,
7070
)
7171
from lightning.fabric.utilities.init import _EmptyInit
72-
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _move_state_into
72+
from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into
7373
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
7474
from lightning.fabric.utilities.seed import reset_seed
7575
from lightning.fabric.utilities.types import _PATH, _Stateful
@@ -87,7 +87,6 @@
8787
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
8888

8989
_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
90-
_METADATA_FILENAME = "meta.pt"
9190

9291

9392
class FSDPStrategy(ParallelStrategy, _Sharded):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import logging
2+
from argparse import ArgumentParser, Namespace
3+
from pathlib import Path
4+
5+
import torch
6+
7+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
8+
from lightning.fabric.utilities.load import _METADATA_FILENAME, _load_distributed_checkpoint
9+
10+
_log = logging.getLogger(__name__)
11+
12+
13+
def _parse_cli_args() -> Namespace:
14+
parser = ArgumentParser(
15+
description=(
16+
"Converts a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`."
17+
" Only supports FSDP sharded checkpoints at the moment."
18+
),
19+
)
20+
parser.add_argument(
21+
"checkpoint_folder",
22+
type=str,
23+
help=(
24+
"Path to a checkpoint folder, containing the sharded checkpoint files saved using the"
25+
" `torch.distributed.checkpoint` API."
26+
),
27+
)
28+
parser.add_argument(
29+
"--output_file",
30+
type=str,
31+
help=(
32+
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
33+
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
34+
" and a '.consolidated' suffix."
35+
),
36+
)
37+
return parser.parse_args()
38+
39+
40+
def _process_cli_args(args: Namespace) -> Namespace:
41+
if not _TORCH_GREATER_EQUAL_2_1:
42+
_log.error("Processing distributed checkpoints requires PyTorch >= 2.1.")
43+
exit(1)
44+
45+
checkpoint_folder = Path(args.checkpoint_folder)
46+
if not checkpoint_folder.exists():
47+
_log.error(f"The provided checkpoint folder does not exist: {checkpoint_folder}")
48+
exit(1)
49+
if not checkpoint_folder.is_dir():
50+
_log.error(
51+
f"The provided checkpoint path must be a folder, containing the checkpoint shards: {checkpoint_folder}"
52+
)
53+
exit(1)
54+
if not (checkpoint_folder / _METADATA_FILENAME).is_file():
55+
_log.error(
56+
"Only FSDP-sharded checkpoints saved with Lightning are supported for consolidation. The provided folder"
57+
f" is not in that format: {checkpoint_folder}"
58+
)
59+
exit(1)
60+
61+
if args.output_file is None:
62+
output_file = checkpoint_folder.with_suffix(checkpoint_folder.suffix + ".consolidated")
63+
else:
64+
output_file = Path(args.output_file)
65+
if output_file.exists():
66+
_log.error(
67+
"The path for the converted checkpoint already exists. Choose a different path by providing"
68+
f" `--output_file` or move/delete the file first: {output_file}"
69+
)
70+
exit(1)
71+
72+
return Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
73+
74+
75+
if __name__ == "__main__":
76+
args = _parse_cli_args()
77+
config = _process_cli_args(args)
78+
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
79+
torch.save(checkpoint, config.output_file)

src/lightning/fabric/utilities/load.py

+83-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import warnings
1616
from functools import partial
1717
from io import BytesIO
18-
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union
18+
from pathlib import Path
19+
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Tuple, Union
1920

2021
import torch
2122
from lightning_utilities.core.apply_func import apply_to_collection
@@ -24,9 +25,16 @@
2425
from torch.nn import Parameter
2526
from typing_extensions import override
2627

27-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
28+
from lightning.fabric.utilities.imports import (
29+
_TORCH_GREATER_EQUAL_2_0,
30+
_TORCH_GREATER_EQUAL_2_1,
31+
_TORCH_GREATER_EQUAL_2_2,
32+
)
2833
from lightning.fabric.utilities.types import _PATH, _Stateful
2934

35+
_METADATA_FILENAME = "meta.pt"
36+
37+
3038
if TYPE_CHECKING:
3139
from torch.storage import TypedStorage
3240

@@ -227,3 +235,76 @@ def _move_state_into(
227235
destination[key].load_state_dict(state)
228236
else:
229237
destination[key] = state
238+
239+
240+
def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]:
241+
"""Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict.
242+
243+
The current implementation assumes that the entire checkpoint fits in CPU memory.
244+
245+
"""
246+
if not _TORCH_GREATER_EQUAL_2_1:
247+
raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.1.")
248+
249+
from torch.distributed.checkpoint import FileSystemReader
250+
from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata
251+
252+
if _TORCH_GREATER_EQUAL_2_2:
253+
from torch.distributed.checkpoint import load
254+
else:
255+
from torch.distributed.checkpoint import load_state_dict as load # deprecated
256+
257+
reader = FileSystemReader(checkpoint_folder)
258+
metadata = reader.read_metadata()
259+
260+
# TODO: Add sequential save to avoid storing the entire checkpoint in memory
261+
checkpoint: Dict[str, Any] = {}
262+
for tensor_name, sd_metadata in metadata.state_dict_metadata.items():
263+
if isinstance(sd_metadata, BytesStorageMetadata):
264+
checkpoint[tensor_name] = "<bytes_io>"
265+
elif isinstance(sd_metadata, TensorStorageMetadata):
266+
checkpoint[tensor_name] = torch.empty(
267+
size=sd_metadata.size,
268+
dtype=sd_metadata.properties.dtype,
269+
device=torch.device("cpu"),
270+
memory_format=sd_metadata.properties.memory_format,
271+
layout=sd_metadata.properties.layout,
272+
requires_grad=sd_metadata.properties.requires_grad,
273+
pin_memory=sd_metadata.properties.pin_memory,
274+
)
275+
276+
load(state_dict=checkpoint, storage_reader=reader, no_dist=True)
277+
checkpoint = _unflatten_dict(checkpoint, key_map=metadata.planner_data)
278+
279+
# This is the extra file saved by Fabric, with user data separate from weights and optimizer states
280+
extra_file = checkpoint_folder / _METADATA_FILENAME
281+
extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {}
282+
checkpoint.update(extra)
283+
284+
return checkpoint
285+
286+
287+
def _unflatten_dict(checkpoint: Dict[str, Any], key_map: Dict[str, Tuple[str, ...]]) -> Dict[str, Any]:
288+
"""Converts the flat dictionary with keys 'x.y.z...' to a nested dictionary using the provided key map.
289+
290+
Args:
291+
checkpoint: The flat checkpoint dictionary.
292+
key_map: A dictionary that maps the keys in flattened format 'x.y.z...' to a tuple representing
293+
the index path into the nested dictonary that this function should construct.
294+
295+
"""
296+
assert checkpoint.keys() == key_map.keys()
297+
converted: Dict[str, Any] = {}
298+
for flat_key in checkpoint:
299+
key_path = key_map[flat_key]
300+
_set_nested_dict_value(converted, key_path, checkpoint[flat_key])
301+
return converted
302+
303+
304+
def _set_nested_dict_value(nested_dict: Dict[str, Any], key_path: Tuple[str, ...], value: Any) -> None:
305+
result = nested_dict
306+
for key in key_path[:-1]:
307+
if key not in result:
308+
result[key] = {}
309+
result = result[key]
310+
result[key_path[-1]] = value
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import re
2+
from typing import Any, Dict
3+
4+
import torch
5+
6+
from lightning.fabric.utilities.consolidate_checkpoint import _parse_cli_args, _process_cli_args
7+
from lightning.fabric.utilities.load import _load_distributed_checkpoint
8+
9+
10+
def _format_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
11+
"""Converts the special FSDP checkpoint format to the standard format the Lightning Trainer can load."""
12+
# Rename the model key
13+
checkpoint["state_dict"] = checkpoint.pop("model")
14+
15+
optimizer_keys = [key for key in checkpoint if re.match("optimizer_[0-9]+", key)]
16+
if not optimizer_keys:
17+
return checkpoint
18+
19+
# Optimizers are saved in special keys named `optimizer_0`, `optimizer_1`, etc.
20+
# These need to be merged back into a Python list
21+
checkpoint["optimizer_states"] = [checkpoint.pop(f"optimizer_{opt_idx}") for opt_idx in range(len(optimizer_keys))]
22+
return checkpoint
23+
24+
25+
if __name__ == "__main__":
26+
args = _parse_cli_args()
27+
config = _process_cli_args(args)
28+
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
29+
checkpoint = _format_checkpoint(checkpoint)
30+
torch.save(checkpoint, config.output_file)

tests/tests_fabric/strategies/test_fsdp_integration.py

+49
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from lightning.fabric.plugins import FSDPPrecision
2323
from lightning.fabric.strategies import FSDPStrategy
2424
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1
25+
from lightning.fabric.utilities.load import _load_distributed_checkpoint
2526
from lightning.fabric.wrappers import _FabricOptimizer
2627
from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType
2728
from torch.distributed.fsdp.wrap import always_wrap_policy, wrap
@@ -549,3 +550,51 @@ def test_clip_gradients(clip_type, precision):
549550

550551
optimizer.step()
551552
optimizer.zero_grad()
553+
554+
555+
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0")
556+
def test_save_sharded_and_consolidate_and_load(tmp_path):
557+
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
558+
559+
fabric = Fabric(
560+
accelerator="cuda",
561+
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="sharded"),
562+
devices=2,
563+
)
564+
fabric.launch()
565+
566+
model = BoringModel()
567+
optimizer = torch.optim.Adam(model.parameters())
568+
model, optimizer = fabric.setup(model, optimizer)
569+
state = {"model": model, "optimizer": optimizer, "steps": 1}
570+
571+
# run one iteration to init the state of the optimizer
572+
model(torch.rand(1, 32, device=fabric.device)).sum().backward()
573+
optimizer.step()
574+
575+
checkpoint_path_sharded = fabric.broadcast(str(tmp_path / "checkpoint_sharded"))
576+
fabric.save(checkpoint_path_sharded, state)
577+
assert set(os.listdir(checkpoint_path_sharded)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
578+
579+
# consolidate the checkpoint to a single file
580+
checkpoint_path_full = fabric.broadcast(str(tmp_path / "checkpoint_full.pt"))
581+
if fabric.global_rank == 0:
582+
checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded))
583+
torch.save(checkpoint, checkpoint_path_full)
584+
fabric.barrier()
585+
586+
# re-init and load from full checkpoint
587+
fabric = Fabric(
588+
accelerator="cuda",
589+
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
590+
devices=2,
591+
)
592+
593+
# Hack: we already called launch() on another Fabric instance above
594+
fabric._launched = True
595+
596+
model = BoringModel()
597+
optimizer = torch.optim.Adam(model.parameters())
598+
model, optimizer = fabric.setup(model, optimizer)
599+
state = {"model": model, "optimizer": optimizer, "steps": 1}
600+
fabric.load(checkpoint_path_full, state)

0 commit comments

Comments
 (0)