Skip to content

Commit 4de45f0

Browse files
committed
handle pytorch versions
1 parent 0d4444e commit 4de45f0

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/tests_fabric/test_fabric.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from lightning.fabric.strategies.strategy import _Sharded
3636
from lightning.fabric.utilities.exceptions import MisconfigurationException
37+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
3738
from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
3839
from lightning.fabric.utilities.warnings import PossibleUserWarning
3940
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
@@ -644,9 +645,9 @@ def test_backward_required(_, strategy, precision, error_expected, setup_method)
644645

645646
# One model
646647
model1 = nn.Linear(2, 2)
647-
assert not model1._backward_pre_hooks
648+
assert not (model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks)
648649
model1 = getattr(fabric, setup_method)(model1)
649-
assert model1._backward_pre_hooks
650+
assert model1._backward_pre_hooks if _TORCH_GREATER_EQUAL_2_0 else model1._backward_hooks
650651
loss = model1(batch).sum()
651652
with error_context:
652653
loss.backward()

0 commit comments

Comments
 (0)