Skip to content

Commit 7cc79fe

Browse files
authored
Reapply torch.compile in Fabric.setup() (#19280)
1 parent 1faddcb commit 7cc79fe

File tree

9 files changed

+204
-38
lines changed

9 files changed

+204
-38
lines changed

docs/source-fabric/api/fabric_methods.rst

-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ This is useful if your model experiences *exploding gradients* during training.
108108
fabric.clip_gradients(model, optimizer, max_norm=2.0, norm_type="inf")
109109
110110
The :meth:`~lightning.fabric.fabric.Fabric.clip_gradients` method is agnostic to the precision and strategy being used.
111-
Note: Gradient clipping with FSDP is not yet fully supported.
112111

113112

114113
to_device

src/lightning/fabric/CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Added support for clipping gradients by value with FSDP ([#19236](https://github.com/Lightning-AI/lightning/pull/19236))
2525

2626

27+
- Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file ([#19213](https://github.com/Lightning-AI/lightning/pull/19213))
28+
29+
30+
- (Experimental) Added support for re-compiling the model inside `Fabric.setup()` over the FSDP/DDP wrappers ([#19280](https://github.com/Lightning-AI/lightning/pull/19280))
31+
32+
2733
### Changed
2834

2935
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))

src/lightning/fabric/fabric.py

+29-7
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
_FabricDataLoader,
7777
_FabricModule,
7878
_FabricOptimizer,
79+
_to_compiled,
7980
_unwrap_compiled,
8081
_unwrap_objects,
8182
)
@@ -213,6 +214,7 @@ def setup(
213214
module: nn.Module,
214215
*optimizers: Optimizer,
215216
move_to_device: bool = True,
217+
_reapply_compile: Optional[bool] = None,
216218
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
217219
r"""Set up a model and its optimizers for accelerated training.
218220
@@ -221,12 +223,17 @@ def setup(
221223
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
222224
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
223225
and alternatively use :meth:`to_device` manually.
226+
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
227+
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
228+
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
229+
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
224230
225231
Returns:
226232
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
227233
228234
"""
229235
self._validate_setup(module, optimizers)
236+
module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None)
230237
original_module = module
231238

232239
module = self._precision.convert_module(module)
@@ -242,6 +249,8 @@ def setup(
242249
else:
243250
module = self._strategy.setup_module(module)
244251

252+
if compile_kwargs is not None:
253+
module = _to_compiled(module, compile_kwargs)
245254
module = _FabricModule(module, self._precision, original_module=original_module)
246255

247256
# Update the _DeviceDtypeModuleMixin's device parameter
@@ -258,8 +267,8 @@ def setup(
258267
self._models_setup += 1
259268

260269
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
261-
original_module._fabric = self # type: ignore[assignment]
262-
original_module._fabric_optimizers = optimizers # type: ignore[assignment]
270+
original_module._fabric = self
271+
original_module._fabric_optimizers = optimizers
263272
if original_module not in self._callbacks:
264273
self._callbacks.append(original_module)
265274

@@ -270,7 +279,9 @@ def setup(
270279
return (module, *optimizers)
271280
return module
272281

273-
def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _FabricModule:
282+
def setup_module(
283+
self, module: nn.Module, move_to_device: bool = True, _reapply_compile: Optional[bool] = None
284+
) -> _FabricModule:
274285
r"""Set up a model for accelerated training or inference.
275286
276287
This is the same as calling ``.setup(model)`` with no optimizers. It is useful for inference or for certain
@@ -281,12 +292,17 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
281292
module: A :class:`torch.nn.Module` to set up
282293
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
283294
and alternatively use :meth:`to_device` manually.
295+
_reapply_compile: (Experimental) If set to ``True``, and the model was ``torch.compile``d before, the
296+
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
297+
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
298+
FSDP etc.). Only supported on PyTorch >= 2.1. Defaults to ``False``, but it may change in the future.
284299
285300
Returns:
286301
The wrapped model.
287302
288303
"""
289304
self._validate_setup_module(module)
305+
module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None)
290306
original_module = module
291307

292308
module = self._precision.convert_module(module)
@@ -296,6 +312,9 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
296312

297313
# Let strategy wrap and connect the module alone
298314
module = self._strategy.setup_module(module)
315+
316+
if compile_kwargs is not None:
317+
module = _to_compiled(module, compile_kwargs)
299318
module = _FabricModule(module, self._precision, original_module=original_module)
300319

301320
# Update the _DeviceDtypeModuleMixin's device parameter
@@ -305,7 +324,7 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri
305324
)
306325

307326
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
308-
original_module._fabric = self # type: ignore[assignment]
327+
original_module._fabric = self
309328
if original_module not in self._callbacks:
310329
self._callbacks.append(original_module)
311330

@@ -410,6 +429,7 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] =
410429
411430
"""
412431
module = model._forward_module if model is not None else model
432+
module, _ = _unwrap_compiled(module)
413433
if isinstance(self._strategy, DeepSpeedStrategy):
414434
if model is None:
415435
if self._models_setup == 0:
@@ -641,7 +661,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
641661
skip.
642662
643663
"""
644-
module = _unwrap_compiled(module)
664+
module, _ = _unwrap_compiled(module)
645665
if not isinstance(module, _FabricModule):
646666
raise TypeError(
647667
"You need to set up the model first before you can call `fabric.no_backward_sync()`:"
@@ -656,7 +676,9 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Conte
656676
category=PossibleUserWarning,
657677
)
658678
return nullcontext()
659-
return self._strategy._backward_sync_control.no_backward_sync(module._forward_module)
679+
680+
forward_module, _ = _unwrap_compiled(module._forward_module)
681+
return self._strategy._backward_sync_control.no_backward_sync(forward_module)
660682

661683
def sharded_model(self) -> ContextManager:
662684
r"""Instantiate a model under this context manager to prepare it for model-parallel sharding.
@@ -772,7 +794,7 @@ def load(
772794
# We need to unwrap objects (see above) but this creates a new dictionary. In-place updates
773795
# (for user metadata) wouldn't show up in the original dict, so we need to copy the data back.
774796
for k in list(unwrapped_state.keys()):
775-
obj = _unwrap_compiled(state[k])
797+
obj, _ = _unwrap_compiled(state[k])
776798
if isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader)):
777799
continue
778800
state[k] = unwrapped_state[k]

src/lightning/fabric/wrappers.py

+65-8
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15+
from copy import deepcopy
1516
from functools import wraps
16-
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, TypeVar, Union, overload
17+
from typing import (
18+
TYPE_CHECKING,
19+
Any,
20+
Callable,
21+
Dict,
22+
Generator,
23+
Iterator,
24+
List,
25+
Mapping,
26+
Optional,
27+
Tuple,
28+
TypeVar,
29+
Union,
30+
overload,
31+
)
1732

1833
import torch
1934
from lightning_utilities.core.apply_func import apply_to_collection
@@ -32,6 +47,9 @@
3247
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
3348
from lightning.fabric.utilities.types import Optimizable
3449

50+
if TYPE_CHECKING:
51+
from torch._dynamo import OptimizedModule
52+
3553
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
3654
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")
3755

@@ -285,8 +303,8 @@ def _unwrap_objects(collection: Any) -> Any:
285303
def _unwrap(
286304
obj: Union[_FabricModule, _FabricOptimizer, _FabricDataLoader]
287305
) -> Union[nn.Module, Optimizer, DataLoader]:
288-
if isinstance(unwrapped := _unwrap_compiled(obj), _FabricModule):
289-
return unwrapped._forward_module
306+
if isinstance(unwrapped := _unwrap_compiled(obj)[0], _FabricModule):
307+
return _unwrap_compiled(unwrapped._forward_module)[0]
290308
if isinstance(obj, _FabricOptimizer):
291309
return obj.optimizer
292310
if isinstance(obj, _FabricDataLoader):
@@ -302,19 +320,33 @@ def _unwrap(
302320
return apply_to_collection(collection, dtype=tuple(types), function=_unwrap)
303321

304322

305-
def _unwrap_compiled(obj: Any) -> Any:
323+
def _unwrap_compiled(obj: Union[Any, "OptimizedModule"]) -> Tuple[Union[Any, nn.Module], Optional[Dict[str, Any]]]:
306324
"""Removes the :class:`torch._dynamo.OptimizedModule` around the object if it is wrapped.
307325
308326
Use this function before instance checks against e.g. :class:`_FabricModule`.
309327
310328
"""
311329
if not _TORCH_GREATER_EQUAL_2_0:
312-
return obj
330+
# obj can't be an `OptimizedModule` anyway
331+
return obj, None
332+
313333
from torch._dynamo import OptimizedModule
314334

315335
if isinstance(obj, OptimizedModule):
316-
return obj._orig_mod
317-
return obj
336+
if (compile_kwargs := getattr(obj, "_compile_kwargs", None)) is None:
337+
raise RuntimeError(
338+
"Failed to determine the arguments that were used to compile the module. Make sure to import"
339+
" lightning before `torch.compile` is used."
340+
)
341+
return obj._orig_mod, compile_kwargs
342+
return obj, None
343+
344+
345+
def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> "OptimizedModule":
346+
if not _TORCH_GREATER_EQUAL_2_0:
347+
raise RuntimeError("Converting to a compiled module is only supported in PyTorch >= 2.0.0")
348+
349+
return torch.compile(module, **compile_kwargs) # type: ignore[return-value]
318350

319351

320352
def is_wrapped(obj: object) -> bool:
@@ -328,5 +360,30 @@ def is_wrapped(obj: object) -> bool:
328360
obj: The object to test.
329361
330362
"""
331-
obj = _unwrap_compiled(obj)
363+
obj, _ = _unwrap_compiled(obj)
332364
return isinstance(obj, (_FabricModule, _FabricOptimizer, _FabricDataLoader))
365+
366+
367+
def _capture_compile_kwargs(compile_fn: Callable) -> Callable:
368+
"""Wraps the ``torch.compile`` function and captures the compile arguments.
369+
370+
We extract the compile arguments so that we can reapply ``torch.compile`` in ``Fabric.setup()`` with the
371+
same arguments as the user passed to the original call. The arguments get stored in a dictionary
372+
``_compile_kwargs`` on the returned compiled module.
373+
374+
"""
375+
# Limitation: Currently, the global compile config does not get captured on a per-model basis.
376+
# PyTorch will resolve this in the future: https://github.com/pytorch/pytorch/issues/116575
377+
378+
@wraps(compile_fn)
379+
def _capture(model: Any, **kwargs: Any) -> Any:
380+
compiled_model = compile_fn(model, **kwargs)
381+
if isinstance(model, nn.Module):
382+
compiled_model._compile_kwargs = deepcopy(kwargs)
383+
return compiled_model
384+
385+
return _capture
386+
387+
388+
if _TORCH_GREATER_EQUAL_2_0:
389+
torch.compile = _capture_compile_kwargs(torch.compile)

src/lightning/pytorch/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Added the option `ModelCheckpoint(save_last='link')` to create a symbolic link for the 'last.ckpt' file ([#19191](https://github.com/Lightning-AI/lightning/pull/19191))
2525

2626

27+
- Added a utility function and CLI to consolidate FSDP sharded checkpoints into a single file ([#19213](https://github.com/Lightning-AI/lightning/pull/19213))
28+
29+
2730
### Changed
2831

2932
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))

tests/tests_fabric/strategies/test_ddp_integration.py

+40
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
1415
from copy import deepcopy
16+
from unittest import mock
17+
from unittest.mock import Mock
1518

1619
import pytest
1720
import torch
1821
from lightning.fabric import Fabric
22+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
23+
from torch.nn.parallel.distributed import DistributedDataParallel
1924

2025
from tests_fabric.helpers.runif import RunIf
2126
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
27+
from tests_fabric.test_fabric import BoringModel
2228

2329

2430
@pytest.mark.parametrize(
@@ -64,6 +70,40 @@ def assert_params_equal(params0, params1):
6470
assert_params_equal(params_before, wrapped_model.parameters())
6571

6672

73+
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True)
74+
@mock.patch(
75+
"lightning.fabric.wrappers.torch.compile",
76+
Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
77+
)
78+
@mock.patch.dict(os.environ, {})
79+
def test_reapply_compile():
80+
"""Test that Fabric can rewrap a compiled module such that compilation happens over the DDP-wrapper."""
81+
from torch._dynamo import OptimizedModule
82+
83+
fabric = Fabric(accelerator="cuda", devices=2, strategy="ddp")
84+
fabric.launch()
85+
86+
model = BoringModel()
87+
compile_kwargs = {"mode": "reduce-overhead"}
88+
compiled_model = torch.compile(model, **compile_kwargs)
89+
torch.compile.reset_mock()
90+
91+
fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
92+
93+
assert isinstance(fabric_model._forward_module, OptimizedModule)
94+
assert isinstance(fabric_model._forward_module._orig_mod, DistributedDataParallel)
95+
# Assert we called compile again with the same arguments, but on the DDP-wrapped module
96+
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
97+
98+
assert fabric_model._original_module == model
99+
assert fabric_model._forward_module._orig_mod.module == model
100+
assert fabric_model.device == fabric.device
101+
102+
# Smoke-testing forward to ensure we don't get compilation errors
103+
for _ in range(3):
104+
fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward()
105+
106+
67107
@pytest.mark.parametrize(
68108
("clip_type", "accelerator", "precision"),
69109
[

0 commit comments

Comments
 (0)