Skip to content

Commit 7880c11

Browse files
authored
Alternative mechanism to detect missing Fabric.backward() call (#19493)
1 parent ea89133 commit 7880c11

File tree

7 files changed

+92
-56
lines changed

7 files changed

+92
-56
lines changed

src/lightning/fabric/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2323
- The `Fabric.rank_zero_first` context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
2424

2525

26-
- Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447))
26+
- Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447), [#19493](https://github.com/Lightning-AI/lightning/pull/19493))
2727

2828

2929
-

src/lightning/fabric/fabric.py

+9-31
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from torch.optim import Optimizer
4141
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler
4242

43+
import lightning.fabric
4344
from lightning.fabric.accelerators.accelerator import Accelerator
4445
from lightning.fabric.connector import _PLUGIN_INPUT, _PRECISION_INPUT, _Connector, _is_using_cli
4546
from lightning.fabric.loggers import Logger
@@ -142,7 +143,6 @@ def __init__(
142143
self._loggers = loggers if isinstance(loggers, list) else [loggers]
143144
self._models_setup: int = 0
144145
self._launched: bool = False
145-
self._backward_called: bool = False
146146

147147
self._prepare_run_method()
148148
if _is_using_cli():
@@ -253,19 +253,15 @@ def setup(
253253

254254
if compile_kwargs is not None:
255255
module = _to_compiled(module, compile_kwargs)
256-
module = _FabricModule(module, self._precision, original_module=original_module)
257-
self._require_fabric_backward(module)
256+
module = _FabricModule(module, self._strategy, original_module=original_module)
258257

259258
# Update the _DeviceDtypeModuleMixin's device parameter
260259
# NOTE: for sharded strategies or manual device placement, there's no single root device
261260
_update_properties(
262261
module, device=self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device
263262
)
264263

265-
optimizers = [
266-
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
267-
for optimizer in optimizers
268-
]
264+
optimizers = [_FabricOptimizer(optimizer, self._strategy, self._callbacks) for optimizer in optimizers]
269265

270266
self._models_setup += 1
271267

@@ -318,8 +314,7 @@ def setup_module(
318314

319315
if compile_kwargs is not None:
320316
module = _to_compiled(module, compile_kwargs)
321-
module = _FabricModule(module, self._precision, original_module=original_module)
322-
self._require_fabric_backward(module)
317+
module = _FabricModule(module, self._strategy, original_module=original_module)
323318

324319
# Update the _DeviceDtypeModuleMixin's device parameter
325320
# NOTE: for sharded strategies or manual device placement, there's no single root device
@@ -448,9 +443,11 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] =
448443
# requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call.
449444
self._strategy._deepspeed_engine = module
450445

451-
self._backward_called = True
452-
self._strategy.backward(tensor, module, *args, **kwargs)
453-
self._backward_called = False
446+
lightning.fabric.wrappers._in_fabric_backward = True
447+
try:
448+
self._strategy.backward(tensor, module, *args, **kwargs)
449+
finally:
450+
lightning.fabric.wrappers._in_fabric_backward = False
454451

455452
def clip_gradients(
456453
self,
@@ -1092,25 +1089,6 @@ def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None
10921089
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
10931090
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
10941091

1095-
def _require_fabric_backward(self, module: _FabricModule) -> None:
1096-
strategy_requires = is_overridden("backward", self._strategy, parent=Strategy)
1097-
precision_requires = any(
1098-
is_overridden(method, self._precision, parent=Precision)
1099-
for method in ("pre_backward", "backward", "post_backward")
1100-
)
1101-
1102-
def _backward_hook(*_: Any, **__: Any) -> None:
1103-
if (strategy_requires or precision_requires) and not self._backward_called:
1104-
raise RuntimeError(
1105-
"The current strategy and precision selection requires you to call `fabric.backward(loss)`"
1106-
" instead of `loss.backward()`."
1107-
)
1108-
1109-
if _TORCH_GREATER_EQUAL_2_0:
1110-
module.register_full_backward_pre_hook(_backward_hook, prepend=True)
1111-
else:
1112-
module.register_full_backward_hook(_backward_hook)
1113-
11141092
@staticmethod
11151093
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
11161094
callbacks = callbacks if callbacks is not None else []

src/lightning/fabric/wrappers.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import inspect
1515
from copy import deepcopy
16-
from functools import wraps
16+
from functools import partial, wraps
1717
from typing import (
1818
TYPE_CHECKING,
1919
Any,
@@ -31,6 +31,7 @@
3131
)
3232

3333
import torch
34+
from lightning_utilities import is_overridden
3435
from lightning_utilities.core.apply_func import apply_to_collection
3536
from torch import Tensor
3637
from torch import nn as nn
@@ -53,6 +54,8 @@
5354
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
5455
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")
5556

57+
_in_fabric_backward: bool = False
58+
5659

5760
class _FabricOptimizer:
5861
def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional[List[Callable]] = None) -> None:
@@ -105,7 +108,7 @@ def __getattr__(self, item: Any) -> Any:
105108

106109
class _FabricModule(_DeviceDtypeModuleMixin):
107110
def __init__(
108-
self, forward_module: nn.Module, precision: Precision, original_module: Optional[nn.Module] = None
111+
self, forward_module: nn.Module, strategy: Strategy, original_module: Optional[nn.Module] = None
109112
) -> None:
110113
"""The FabricModule is a thin wrapper around the :class:`torch.nn.Module` and handles precision / autocast
111114
automatically for the forward pass.
@@ -114,7 +117,7 @@ def __init__(
114117
115118
Args:
116119
forward_module: The module to wrap the ``forward`` method on.
117-
precision: Reference to the precision plugin for handling precision context
120+
strategy: Reference to the strategy for handling precision etc.
118121
original_module: The original, unmodified module as passed into the
119122
:meth:`lightning.fabric.fabric.Fabric.setup` method. This is needed when attribute lookup
120123
on this wrapper should pass through to the original module.
@@ -123,7 +126,7 @@ def __init__(
123126
super().__init__()
124127
self._forward_module = forward_module
125128
self._original_module = original_module or forward_module
126-
self._precision = precision
129+
self._strategy = strategy
127130
self._fabric_module_initialized = True
128131

129132
@property
@@ -133,12 +136,15 @@ def module(self) -> nn.Module:
133136
@override
134137
def forward(self, *args: Any, **kwargs: Any) -> Any:
135138
"""Casts all inputs to the right precision and handles autocast for operations in the module forward method."""
136-
args, kwargs = self._precision.convert_input((args, kwargs))
139+
precision = self._strategy.precision
140+
args, kwargs = precision.convert_input((args, kwargs))
137141

138-
with self._precision.forward_context():
142+
with precision.forward_context():
139143
output = self._forward_module(*args, **kwargs)
140144

141-
output = self._precision.convert_output(output)
145+
output = precision.convert_output(output)
146+
147+
apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook)
142148
return output
143149

144150
@overload
@@ -214,6 +220,19 @@ def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
214220

215221
return _wrapped_method
216222

223+
def _register_backward_hook(self, tensor: Tensor) -> Tensor:
224+
if not tensor.requires_grad:
225+
return tensor
226+
227+
strategy_requires = is_overridden("backward", self._strategy, parent=Strategy)
228+
precision_requires = any(
229+
is_overridden(method, self._strategy.precision, parent=Precision)
230+
for method in ("pre_backward", "backward", "post_backward")
231+
)
232+
hook = partial(_backward_hook, (strategy_requires or precision_requires))
233+
tensor.register_hook(hook)
234+
return tensor
235+
217236
@override
218237
def __getattr__(self, item: Any) -> Any:
219238
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
@@ -347,6 +366,14 @@ def _to_compiled(module: nn.Module, compile_kwargs: Dict[str, Any]) -> "Optimize
347366
return torch.compile(module, **compile_kwargs) # type: ignore[return-value]
348367

349368

369+
def _backward_hook(requires_backward: bool, *_: Any) -> None:
370+
if requires_backward and not _in_fabric_backward:
371+
raise RuntimeError(
372+
"The current strategy and precision selection requires you to call `fabric.backward(loss)`"
373+
" instead of `loss.backward()`."
374+
)
375+
376+
350377
def is_wrapped(obj: object) -> bool:
351378
"""Checks if an object was set up by Fabric.
352379

tests/tests_fabric/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,16 @@ def thread_police_duuu_daaa_duuu_daaa():
114114
raise AssertionError(f"Test left zombie thread: {thread}")
115115

116116

117+
@pytest.fixture(autouse=True)
118+
def reset_in_fabric_backward():
119+
"""Ensures that the wrappers.in_fabric_backward global variable gets reset after each test."""
120+
import lightning.fabric.wrappers as wrappers
121+
122+
assert hasattr(wrappers, "_in_fabric_backward")
123+
yield
124+
wrappers._in_fabric_backward = False
125+
126+
117127
@pytest.fixture()
118128
def reset_deterministic_algorithm():
119129
"""Ensures that torch determinism settings are reset before the next test runs."""

tests/tests_fabric/loggers/test_tensorboard.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_tensorboard_log_graph(tmp_path, example_input_array):
164164
logger._experiment.reset_mock()
165165

166166
# model wrapped in `FabricModule`
167-
wrapped = _FabricModule(model, precision=Mock())
167+
wrapped = _FabricModule(model, strategy=Mock())
168168
logger.log_graph(wrapped, example_input_array)
169169
if example_input_array is not None:
170170
logger.experiment.add_graph.assert_called_with(model, example_input_array)

tests/tests_fabric/test_fabric.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
)
3535
from lightning.fabric.strategies.strategy import _Sharded
3636
from lightning.fabric.utilities.exceptions import MisconfigurationException
37-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
3837
from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
3938
from lightning.fabric.utilities.warnings import PossibleUserWarning
4039
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
@@ -646,25 +645,27 @@ def test_backward_required(_, strategy, precision, error_expected, setup_method)
646645

647646
# One model
648647
model1 = nn.Linear(2, 2)
649-
assert not (model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks)
650648
model1 = getattr(fabric, setup_method)(model1)
651-
assert model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks
652-
loss = model1(batch).sum()
649+
output = model1(batch)
650+
assert output._backward_hooks is not None
651+
loss = output.sum()
653652
with error_context:
654653
loss.backward()
655654
loss = model1(batch).sum()
655+
assert not lightning.fabric.wrappers._in_fabric_backward
656656
fabric.backward(loss) # no error
657-
assert not fabric._backward_called
657+
assert not lightning.fabric.wrappers._in_fabric_backward
658658

659659
# Two models chained
660660
model2 = torch.nn.Linear(2, 2)
661661
model2 = getattr(fabric, setup_method)(model2)
662-
loss = model2(model1(batch)).sum()
662+
output = model2(model1(batch))
663+
assert output._backward_hooks is not None
664+
loss = output.sum()
663665
with error_context:
664666
loss.backward()
665667
loss = model2(model1(batch)).sum()
666668
fabric.backward(loss) # no error
667-
assert not fabric._backward_called
668669

669670
# Two independent models
670671
loss1 = model1(batch).sum()
@@ -676,9 +677,28 @@ def test_backward_required(_, strategy, precision, error_expected, setup_method)
676677
loss1 = model1(batch).sum()
677678
loss2 = model2(batch).sum()
678679
fabric.backward(loss1) # no error
679-
assert not fabric._backward_called
680680
fabric.backward(loss2) # no error
681-
assert not fabric._backward_called
681+
682+
# Model that returns a datastructure of tensors
683+
class DictReturnModel(nn.Linear):
684+
def forward(self, x):
685+
return {
686+
"loss": super().forward(x).sum(),
687+
"other": torch.rand(2, 2), # does not require grad
688+
}
689+
690+
model3 = DictReturnModel(2, 2)
691+
model3 = getattr(fabric, setup_method)(model3)
692+
output = model3(batch)
693+
loss = output["loss"]
694+
other = output["other"]
695+
assert loss._backward_hooks is not None
696+
assert other._backward_hooks is None
697+
698+
with error_context:
699+
(loss * 2).backward()
700+
loss = model3(batch)["loss"]
701+
fabric.backward(loss * 2) # no error
682702

683703

684704
@RunIf(deepspeed=True, mps=False)

tests/tests_fabric/test_wrappers.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ def __init__(self, module):
103103

104104
# Regular case: forward_module == original_module -> no warnings
105105
original_module = OriginalModule()
106-
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
106+
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
107107
assert fabric_module.method_without_module_invocation() == 100
108108

109109
# Special case: original module wrapped by forward module: -> warn if method accepts args
110110
original_module = OriginalModule()
111111
wrapped_module = ModuleWrapper(original_module)
112-
fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module)
112+
fabric_module = _FabricModule(forward_module=wrapped_module, strategy=Mock(), original_module=original_module)
113113
assert fabric_module.method_without_module_invocation() == 100
114114
with pytest.raises(
115115
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
@@ -254,7 +254,7 @@ def check_autocast(forward_input):
254254
return forward_input
255255

256256
module = Mock(wraps=torch.nn.Identity(), side_effect=check_autocast)
257-
fabric_module = _FabricModule(module, fabric._precision).to(device)
257+
fabric_module = _FabricModule(module, fabric._strategy).to(device)
258258
out = fabric_module(torch.tensor([1, 2, 3], dtype=input_type, device=device))
259259
assert module.call_args[0][0].dtype == expected_type
260260
assert out.dtype == input_type or out.dtype == torch.get_default_dtype()
@@ -560,10 +560,11 @@ def validation_step(self, arg, kwarg=None):
560560
def normal_method(self):
561561
pass
562562

563-
precision = Mock(wraps=Precision())
563+
strategy = Mock()
564+
strategy.precision = Mock(wraps=Precision())
564565
original_module = LightningModule()
565566
forward_module = DDP(original_module)
566-
fabric_module = _FabricModule(forward_module=forward_module, precision=precision, original_module=original_module)
567+
fabric_module = _FabricModule(forward_module=forward_module, strategy=strategy, original_module=original_module)
567568

568569
# Regular methods on the original_module are visible and identical on the fabric_module ...
569570
assert fabric_module.normal_method.__wrapped__ == original_module.normal_method
@@ -585,13 +586,13 @@ def normal_method(self):
585586
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return"
586587
assert fabric_module.training_step("train_arg", kwarg="train_kwarg") == "training_step_return" # call 2nd time
587588
assert fabric_module.validation_step("val_arg", kwarg="val_kwarg") == "validation_step_return"
588-
precision.forward_context.assert_called()
589+
strategy.precision.forward_context.assert_called()
589590

590591
# The forward method remains untouched/unpatched after the special methods have been called
591592
assert original_module.forward.__name__ == "forward"
592593

593594
# Special case: forward_module == original_module -> no special treatment applied
594-
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
595+
fabric_module = _FabricModule(forward_module=original_module, strategy=Mock(), original_module=original_module)
595596
assert fabric_module.training_step == original_module.training_step
596597
assert fabric_module.validation_step == original_module.validation_step
597598

0 commit comments

Comments
 (0)