Skip to content

Commit 84cae47

Browse files
committed
_TORCH_LESS_EQUAL_2_6
1 parent 3a2499e commit 84cae47

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

src/lightning/fabric/utilities/imports.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
3535
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
3636
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
37+
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
3738

3839
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
3940

tests/tests_fabric/strategies/test_ddp_integration.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch.nn.parallel.distributed import DistributedDataParallel
2424

2525
from lightning.fabric import Fabric
26+
from lightning.fabric.utilities.imports import _TORCH_LESS_EQUAL_2_6
2627
from tests_fabric.helpers.runif import RunIf
2728
from tests_fabric.strategies.test_single_device import _run_test_clip_gradients
2829
from tests_fabric.test_fabric import BoringModel
@@ -84,16 +85,18 @@ def test_reapply_compile():
8485
fabric.launch()
8586

8687
model = BoringModel()
87-
# compile_kwargs = {"mode": "reduce-overhead"}
88-
compiled_model = torch.compile(model) # , **compile_kwargs
88+
# currently (PyTorch 2.6) using ruduce-overhead here casues a RuntimeError:
89+
# Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.
90+
compile_kwargs = {"mode": "reduce-overhead"} if _TORCH_LESS_EQUAL_2_6 else {}
91+
compiled_model = torch.compile(model, **compile_kwargs)
8992
torch.compile.reset_mock()
9093

9194
fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
9295

9396
assert isinstance(fabric_model._forward_module, OptimizedModule)
9497
assert isinstance(fabric_model._forward_module._orig_mod, DistributedDataParallel)
9598
# Assert we called compile again with the same arguments, but on the DDP-wrapped module
96-
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod) # , **compile_kwargs
99+
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
97100

98101
assert fabric_model._original_module == model
99102
assert fabric_model._forward_module._orig_mod.module == model

tests/tests_fabric/strategies/test_fsdp_integration.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from lightning.fabric import Fabric
3030
from lightning.fabric.plugins import FSDPPrecision
3131
from lightning.fabric.strategies import FSDPStrategy
32+
from lightning.fabric.utilities.imports import _TORCH_LESS_EQUAL_2_6
3233
from lightning.fabric.utilities.load import _load_distributed_checkpoint
3334
from lightning.fabric.wrappers import _FabricOptimizer
3435
from tests_fabric.helpers.datasets import RandomDataset
@@ -411,8 +412,10 @@ def test_reapply_compile():
411412
fabric.launch()
412413

413414
model = BoringModel()
414-
# compile_kwargs = {"mode": "reduce-overhead"}
415-
compiled_model = torch.compile(model) # , **compile_kwargs
415+
# currently (PyTorch 2.6) using ruduce-overhead here casues a RuntimeError:
416+
# Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.
417+
compile_kwargs = {"mode": "reduce-overhead"} if _TORCH_LESS_EQUAL_2_6 else {}
418+
compiled_model = torch.compile(model, **compile_kwargs)
416419
torch.compile.reset_mock()
417420

418421
fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
@@ -421,7 +424,7 @@ def test_reapply_compile():
421424
assert isinstance(fabric_model._forward_module._orig_mod, FullyShardedDataParallel)
422425

423426
# Assert we called compile again with the same arguments, but on the FSDP-wrapped module
424-
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod) # , **compile_kwargs
427+
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
425428

426429
assert fabric_model._original_module == model
427430
assert fabric_model._forward_module._orig_mod.module == model

0 commit comments

Comments
 (0)