Skip to content

Commit 265025b

Browse files
authored
Inform the user about a missing fabric.backward() call (#19447)
1 parent 6745994 commit 265025b

File tree

6 files changed

+100
-5
lines changed

6 files changed

+100
-5
lines changed

src/lightning/fabric/CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222

2323
- The `Fabric.rank_zero_first` context manager now uses a barrier without timeout to avoid long-running tasks to be interrupted ([#19448](https://github.com/Lightning-AI/lightning/pull/19448))
2424

25+
26+
- Fabric now raises an error if you forget to call `fabric.backward()` when it is needed by the strategy or precision selection ([#19447](https://github.com/Lightning-AI/lightning/pull/19447))
27+
28+
2529
-
2630

2731
### Deprecated

src/lightning/fabric/cli.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ def _legacy_main() -> None:
5353
Raises deprecation warning and runs through fabric cli if necessary, else runs the entrypoint directly
5454
5555
"""
56-
print("`lightning run model` is deprecated and will be removed in future versions."
57-
" Please call `fabric run model` instead.")
56+
print(
57+
"`lightning run model` is deprecated and will be removed in future versions."
58+
" Please call `fabric run model` instead."
59+
)
5860
args = sys.argv[1:]
5961
if args and args[0] == "run" and args[1] == "model":
6062
_main()

src/lightning/fabric/fabric.py

+24
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __init__(
142142
self._loggers = loggers if isinstance(loggers, list) else [loggers]
143143
self._models_setup: int = 0
144144
self._launched: bool = False
145+
self._backward_called: bool = False
145146

146147
self._prepare_run_method()
147148
if _is_using_cli():
@@ -253,6 +254,7 @@ def setup(
253254
if compile_kwargs is not None:
254255
module = _to_compiled(module, compile_kwargs)
255256
module = _FabricModule(module, self._precision, original_module=original_module)
257+
self._require_fabric_backward(module)
256258

257259
# Update the _DeviceDtypeModuleMixin's device parameter
258260
# NOTE: for sharded strategies or manual device placement, there's no single root device
@@ -317,6 +319,7 @@ def setup_module(
317319
if compile_kwargs is not None:
318320
module = _to_compiled(module, compile_kwargs)
319321
module = _FabricModule(module, self._precision, original_module=original_module)
322+
self._require_fabric_backward(module)
320323

321324
# Update the _DeviceDtypeModuleMixin's device parameter
322325
# NOTE: for sharded strategies or manual device placement, there's no single root device
@@ -445,7 +448,9 @@ def backward(self, tensor: Tensor, *args: Any, model: Optional[_FabricModule] =
445448
# requires to attach the current `DeepSpeedEngine` for the `_FabricOptimizer.step` call.
446449
self._strategy._deepspeed_engine = module
447450

451+
self._backward_called = True
448452
self._strategy.backward(tensor, module, *args, **kwargs)
453+
self._backward_called = False
449454

450455
def clip_gradients(
451456
self,
@@ -1090,6 +1095,25 @@ def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None
10901095
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
10911096
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
10921097

1098+
def _require_fabric_backward(self, module: _FabricModule) -> None:
1099+
strategy_requires = is_overridden("backward", self._strategy, parent=Strategy)
1100+
precision_requires = any(
1101+
is_overridden(method, self._precision, parent=Precision)
1102+
for method in ("pre_backward", "backward", "post_backward")
1103+
)
1104+
1105+
def _backward_hook(*_: Any, **__: Any) -> None:
1106+
if (strategy_requires or precision_requires) and not self._backward_called:
1107+
raise RuntimeError(
1108+
"The current strategy and precision selection requires you to call `fabric.backward(loss)`"
1109+
" instead of `loss.backward()`."
1110+
)
1111+
1112+
if _TORCH_GREATER_EQUAL_2_0:
1113+
module.register_full_backward_pre_hook(_backward_hook, prepend=True)
1114+
else:
1115+
module.register_full_backward_hook(_backward_hook)
1116+
10931117
@staticmethod
10941118
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
10951119
callbacks = callbacks if callbacks is not None else []

tests/tests_fabric/strategies/test_fsdp_integration.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,8 @@ def test_reapply_compile():
441441

442442
# Smoke-testing forward to ensure we don't get compilation errors
443443
for _ in range(3):
444-
fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward()
444+
loss = fabric_model(torch.randn(2, 32, device=fabric.device)).sum()
445+
fabric.backward(loss)
445446

446447

447448
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)

tests/tests_fabric/test_cli.py

+1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def test_cli_through_fabric_entry_point():
181181
message = "Usage: fabric run model [OPTIONS] SCRIPT [SCRIPT_ARGS]"
182182
assert message in result.stdout or message in result.stderr
183183

184+
184185
@pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package")
185186
def test_cli_through_lightning_entry_point():
186187
result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True)

tests/tests_fabric/test_fabric.py

+65-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from contextlib import nullcontext
1516
from re import escape
1617
from unittest import mock
1718
from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call
@@ -22,7 +23,6 @@
2223
import torch.distributed
2324
import torch.nn.functional
2425
from lightning.fabric.fabric import Fabric
25-
from lightning.fabric.plugins import Precision
2626
from lightning.fabric.strategies import (
2727
DataParallelStrategy,
2828
DDPStrategy,
@@ -34,6 +34,7 @@
3434
)
3535
from lightning.fabric.strategies.strategy import _Sharded
3636
from lightning.fabric.utilities.exceptions import MisconfigurationException
37+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
3738
from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
3839
from lightning.fabric.utilities.warnings import PossibleUserWarning
3940
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
@@ -611,12 +612,74 @@ def test_rank_properties():
611612
def test_backward():
612613
"""Test that backward() calls into the precision plugin."""
613614
fabric = Fabric()
614-
fabric._strategy = Mock(spec=Precision)
615+
fabric._strategy = Mock(spec=Strategy)
615616
loss = Mock()
616617
fabric.backward(loss, "arg", keyword="kwarg")
617618
fabric._strategy.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
618619

619620

621+
@pytest.mark.parametrize(("strategy", "precision", "error_expected"), [
622+
("auto", "32-true", False),
623+
("auto", "bf16-true", False),
624+
("auto", "bf16-mixed", True),
625+
pytest.param("fsdp", "32-true", True, marks=RunIf(min_cuda_gpus=1, min_torch="2.0.0")),
626+
])
627+
@pytest.mark.parametrize("setup_method", ["setup", "setup_module"])
628+
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
629+
def test_backward_required(_, strategy, precision, error_expected, setup_method):
630+
"""Test under which strategy and precision configurations the `fabric.backward()` call is required."""
631+
fabric = Fabric(
632+
accelerator=("cuda" if strategy == "fsdp" else "cpu"),
633+
strategy=strategy,
634+
precision=precision,
635+
devices=1
636+
)
637+
fabric._launched = True
638+
fabric.strategy.setup_module = lambda module: module
639+
640+
error_context = (
641+
pytest.raises(RuntimeError, match=escape("requires you to call `fabric.backward(loss)`")) if error_expected
642+
else nullcontext()
643+
)
644+
batch = torch.rand(2, 2)
645+
646+
# One model
647+
model1 = nn.Linear(2, 2)
648+
assert not (model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks)
649+
model1 = getattr(fabric, setup_method)(model1)
650+
assert model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks
651+
loss = model1(batch).sum()
652+
with error_context:
653+
loss.backward()
654+
loss = model1(batch).sum()
655+
fabric.backward(loss) # no error
656+
assert not fabric._backward_called
657+
658+
# Two models chained
659+
model2 = torch.nn.Linear(2, 2)
660+
model2 = getattr(fabric, setup_method)(model2)
661+
loss = model2(model1(batch)).sum()
662+
with error_context:
663+
loss.backward()
664+
loss = model2(model1(batch)).sum()
665+
fabric.backward(loss) # no error
666+
assert not fabric._backward_called
667+
668+
# Two independent models
669+
loss1 = model1(batch).sum()
670+
loss2 = model2(batch).sum()
671+
with error_context:
672+
loss1.backward()
673+
with error_context:
674+
loss2.backward()
675+
loss1 = model1(batch).sum()
676+
loss2 = model2(batch).sum()
677+
fabric.backward(loss1) # no error
678+
assert not fabric._backward_called
679+
fabric.backward(loss2) # no error
680+
assert not fabric._backward_called
681+
682+
620683
@RunIf(deepspeed=True, mps=False)
621684
def test_backward_model_input_required():
622685
"""Test that when using deepspeed and multiple models, backward() requires the model as input."""

0 commit comments

Comments
 (0)