29
29
from lightning .fabric import Fabric
30
30
from lightning .fabric .plugins import FSDPPrecision
31
31
from lightning .fabric .strategies import FSDPStrategy
32
+ from lightning .fabric .utilities .imports import _TORCH_LESS_EQUAL_2_6
32
33
from lightning .fabric .utilities .load import _load_distributed_checkpoint
33
34
from lightning .fabric .wrappers import _FabricOptimizer
34
35
from tests_fabric .helpers .datasets import RandomDataset
@@ -411,8 +412,10 @@ def test_reapply_compile():
411
412
fabric .launch ()
412
413
413
414
model = BoringModel ()
414
- # compile_kwargs = {"mode": "reduce-overhead"}
415
- compiled_model = torch .compile (model ) # , **compile_kwargs
415
+ # currently (PyTorch 2.6) using ruduce-overhead here casues a RuntimeError:
416
+ # Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.
417
+ compile_kwargs = {"mode" : "reduce-overhead" } if _TORCH_LESS_EQUAL_2_6 else {}
418
+ compiled_model = torch .compile (model , ** compile_kwargs )
416
419
torch .compile .reset_mock ()
417
420
418
421
fabric_model = fabric .setup (compiled_model , _reapply_compile = True )
@@ -421,7 +424,7 @@ def test_reapply_compile():
421
424
assert isinstance (fabric_model ._forward_module ._orig_mod , FullyShardedDataParallel )
422
425
423
426
# Assert we called compile again with the same arguments, but on the FSDP-wrapped module
424
- torch .compile .assert_called_with (fabric_model ._forward_module ._orig_mod ) # , **compile_kwargs
427
+ torch .compile .assert_called_with (fabric_model ._forward_module ._orig_mod , ** compile_kwargs )
425
428
426
429
assert fabric_model ._original_module == model
427
430
assert fabric_model ._forward_module ._orig_mod .module == model
0 commit comments