|
15 | 15 | import warnings
|
16 | 16 | from functools import partial
|
17 | 17 | from io import BytesIO
|
18 |
| -from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Union |
| 18 | +from pathlib import Path |
| 19 | +from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, OrderedDict, Sequence, Set, Tuple, Union |
19 | 20 |
|
20 | 21 | import torch
|
21 | 22 | from lightning_utilities.core.apply_func import apply_to_collection
|
|
24 | 25 | from torch.nn import Parameter
|
25 | 26 | from typing_extensions import override
|
26 | 27 |
|
27 |
| -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 |
| 28 | +from lightning.fabric.utilities.imports import ( |
| 29 | + _TORCH_GREATER_EQUAL_2_0, |
| 30 | + _TORCH_GREATER_EQUAL_2_1, |
| 31 | + _TORCH_GREATER_EQUAL_2_2, |
| 32 | +) |
28 | 33 | from lightning.fabric.utilities.types import _PATH, _Stateful
|
29 | 34 |
|
| 35 | +_METADATA_FILENAME = "meta.pt" |
| 36 | + |
| 37 | + |
30 | 38 | if TYPE_CHECKING:
|
31 | 39 | from torch.storage import TypedStorage
|
32 | 40 |
|
@@ -227,3 +235,76 @@ def _move_state_into(
|
227 | 235 | destination[key].load_state_dict(state)
|
228 | 236 | else:
|
229 | 237 | destination[key] = state
|
| 238 | + |
| 239 | + |
| 240 | +def _load_distributed_checkpoint(checkpoint_folder: Path) -> Dict[str, Any]: |
| 241 | + """Loads a sharded checkpoint saved with the `torch.distributed.checkpoint` into a full state dict. |
| 242 | +
|
| 243 | + The current implementation assumes that the entire checkpoint fits in CPU memory. |
| 244 | +
|
| 245 | + """ |
| 246 | + if not _TORCH_GREATER_EQUAL_2_1: |
| 247 | + raise ImportError("Processing distributed checkpoints requires PyTorch >= 2.1.") |
| 248 | + |
| 249 | + from torch.distributed.checkpoint import FileSystemReader |
| 250 | + from torch.distributed.checkpoint.metadata import BytesStorageMetadata, TensorStorageMetadata |
| 251 | + |
| 252 | + if _TORCH_GREATER_EQUAL_2_2: |
| 253 | + from torch.distributed.checkpoint import load |
| 254 | + else: |
| 255 | + from torch.distributed.checkpoint import load_state_dict as load # deprecated |
| 256 | + |
| 257 | + reader = FileSystemReader(checkpoint_folder) |
| 258 | + metadata = reader.read_metadata() |
| 259 | + |
| 260 | + # TODO: Add sequential save to avoid storing the entire checkpoint in memory |
| 261 | + checkpoint: Dict[str, Any] = {} |
| 262 | + for tensor_name, sd_metadata in metadata.state_dict_metadata.items(): |
| 263 | + if isinstance(sd_metadata, BytesStorageMetadata): |
| 264 | + checkpoint[tensor_name] = "<bytes_io>" |
| 265 | + elif isinstance(sd_metadata, TensorStorageMetadata): |
| 266 | + checkpoint[tensor_name] = torch.empty( |
| 267 | + size=sd_metadata.size, |
| 268 | + dtype=sd_metadata.properties.dtype, |
| 269 | + device=torch.device("cpu"), |
| 270 | + memory_format=sd_metadata.properties.memory_format, |
| 271 | + layout=sd_metadata.properties.layout, |
| 272 | + requires_grad=sd_metadata.properties.requires_grad, |
| 273 | + pin_memory=sd_metadata.properties.pin_memory, |
| 274 | + ) |
| 275 | + |
| 276 | + load(state_dict=checkpoint, storage_reader=reader, no_dist=True) |
| 277 | + checkpoint = _unflatten_dict(checkpoint, key_map=metadata.planner_data) |
| 278 | + |
| 279 | + # This is the extra file saved by Fabric, with user data separate from weights and optimizer states |
| 280 | + extra_file = checkpoint_folder / _METADATA_FILENAME |
| 281 | + extra = torch.load(extra_file, map_location="cpu") if extra_file.is_file() else {} |
| 282 | + checkpoint.update(extra) |
| 283 | + |
| 284 | + return checkpoint |
| 285 | + |
| 286 | + |
| 287 | +def _unflatten_dict(checkpoint: Dict[str, Any], key_map: Dict[str, Tuple[str, ...]]) -> Dict[str, Any]: |
| 288 | + """Converts the flat dictionary with keys 'x.y.z...' to a nested dictionary using the provided key map. |
| 289 | +
|
| 290 | + Args: |
| 291 | + checkpoint: The flat checkpoint dictionary. |
| 292 | + key_map: A dictionary that maps the keys in flattened format 'x.y.z...' to a tuple representing |
| 293 | + the index path into the nested dictonary that this function should construct. |
| 294 | +
|
| 295 | + """ |
| 296 | + assert checkpoint.keys() == key_map.keys() |
| 297 | + converted: Dict[str, Any] = {} |
| 298 | + for flat_key in checkpoint: |
| 299 | + key_path = key_map[flat_key] |
| 300 | + _set_nested_dict_value(converted, key_path, checkpoint[flat_key]) |
| 301 | + return converted |
| 302 | + |
| 303 | + |
| 304 | +def _set_nested_dict_value(nested_dict: Dict[str, Any], key_path: Tuple[str, ...], value: Any) -> None: |
| 305 | + result = nested_dict |
| 306 | + for key in key_path[:-1]: |
| 307 | + if key not in result: |
| 308 | + result[key] = {} |
| 309 | + result = result[key] |
| 310 | + result[key_path[-1]] = value |
0 commit comments