|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import os
|
| 15 | +from contextlib import nullcontext |
15 | 16 | from re import escape
|
16 | 17 | from unittest import mock
|
17 | 18 | from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call
|
|
22 | 23 | import torch.distributed
|
23 | 24 | import torch.nn.functional
|
24 | 25 | from lightning.fabric.fabric import Fabric
|
25 |
| -from lightning.fabric.plugins import Precision |
26 | 26 | from lightning.fabric.strategies import (
|
27 | 27 | DataParallelStrategy,
|
28 | 28 | DDPStrategy,
|
|
34 | 34 | )
|
35 | 35 | from lightning.fabric.strategies.strategy import _Sharded
|
36 | 36 | from lightning.fabric.utilities.exceptions import MisconfigurationException
|
| 37 | +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 |
37 | 38 | from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
|
38 | 39 | from lightning.fabric.utilities.warnings import PossibleUserWarning
|
39 | 40 | from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
|
@@ -611,12 +612,74 @@ def test_rank_properties():
|
611 | 612 | def test_backward():
|
612 | 613 | """Test that backward() calls into the precision plugin."""
|
613 | 614 | fabric = Fabric()
|
614 |
| - fabric._strategy = Mock(spec=Precision) |
| 615 | + fabric._strategy = Mock(spec=Strategy) |
615 | 616 | loss = Mock()
|
616 | 617 | fabric.backward(loss, "arg", keyword="kwarg")
|
617 | 618 | fabric._strategy.backward.assert_called_with(loss, None, "arg", keyword="kwarg")
|
618 | 619 |
|
619 | 620 |
|
| 621 | +@pytest.mark.parametrize(("strategy", "precision", "error_expected"), [ |
| 622 | + ("auto", "32-true", False), |
| 623 | + ("auto", "bf16-true", False), |
| 624 | + ("auto", "bf16-mixed", True), |
| 625 | + pytest.param("fsdp", "32-true", True, marks=RunIf(min_cuda_gpus=1, min_torch="2.0.0")), |
| 626 | +]) |
| 627 | +@pytest.mark.parametrize("setup_method", ["setup", "setup_module"]) |
| 628 | +@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False) |
| 629 | +def test_backward_required(_, strategy, precision, error_expected, setup_method): |
| 630 | + """Test under which strategy and precision configurations the `fabric.backward()` call is required.""" |
| 631 | + fabric = Fabric( |
| 632 | + accelerator=("cuda" if strategy == "fsdp" else "cpu"), |
| 633 | + strategy=strategy, |
| 634 | + precision=precision, |
| 635 | + devices=1 |
| 636 | + ) |
| 637 | + fabric._launched = True |
| 638 | + fabric.strategy.setup_module = lambda module: module |
| 639 | + |
| 640 | + error_context = ( |
| 641 | + pytest.raises(RuntimeError, match=escape("requires you to call `fabric.backward(loss)`")) if error_expected |
| 642 | + else nullcontext() |
| 643 | + ) |
| 644 | + batch = torch.rand(2, 2) |
| 645 | + |
| 646 | + # One model |
| 647 | + model1 = nn.Linear(2, 2) |
| 648 | + assert not (model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks) |
| 649 | + model1 = getattr(fabric, setup_method)(model1) |
| 650 | + assert model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks |
| 651 | + loss = model1(batch).sum() |
| 652 | + with error_context: |
| 653 | + loss.backward() |
| 654 | + loss = model1(batch).sum() |
| 655 | + fabric.backward(loss) # no error |
| 656 | + assert not fabric._backward_called |
| 657 | + |
| 658 | + # Two models chained |
| 659 | + model2 = torch.nn.Linear(2, 2) |
| 660 | + model2 = getattr(fabric, setup_method)(model2) |
| 661 | + loss = model2(model1(batch)).sum() |
| 662 | + with error_context: |
| 663 | + loss.backward() |
| 664 | + loss = model2(model1(batch)).sum() |
| 665 | + fabric.backward(loss) # no error |
| 666 | + assert not fabric._backward_called |
| 667 | + |
| 668 | + # Two independent models |
| 669 | + loss1 = model1(batch).sum() |
| 670 | + loss2 = model2(batch).sum() |
| 671 | + with error_context: |
| 672 | + loss1.backward() |
| 673 | + with error_context: |
| 674 | + loss2.backward() |
| 675 | + loss1 = model1(batch).sum() |
| 676 | + loss2 = model2(batch).sum() |
| 677 | + fabric.backward(loss1) # no error |
| 678 | + assert not fabric._backward_called |
| 679 | + fabric.backward(loss2) # no error |
| 680 | + assert not fabric._backward_called |
| 681 | + |
| 682 | + |
620 | 683 | @RunIf(deepspeed=True, mps=False)
|
621 | 684 | def test_backward_model_input_required():
|
622 | 685 | """Test that when using deepspeed and multiple models, backward() requires the model as input."""
|
|
0 commit comments