36
36
from accelerate .hooks import remove_hook_from_submodules
37
37
from accelerate .test_utils import (
38
38
require_bnb ,
39
- require_cuda ,
40
39
require_cuda_or_xpu ,
41
40
require_multi_device ,
42
41
require_multi_gpu_or_xpu ,
47
46
torch_device ,
48
47
)
49
48
from accelerate .utils import is_hpu_available , offload_state_dict
49
+ from accelerate .utils .memory import clear_device_cache
50
50
from accelerate .utils .versions import is_torch_version
51
51
52
52
@@ -379,7 +379,7 @@ def test_dispatch_model_tied_weights_memory(self):
379
379
380
380
torch_accelerator_module = getattr (torch , torch_device_type )
381
381
382
- torch_accelerator_module . empty_cache () # Needed in case we run several tests in a row.
382
+ clear_device_cache () # Needed in case we run several tests in a row.
383
383
384
384
model = nn .Sequential (
385
385
OrderedDict (
@@ -443,7 +443,7 @@ def test_dispatch_model_tied_weights_memory_with_nested_offload_cpu(self):
443
443
# Test that we do not duplicate tied weights at any point during dispatch_model call.
444
444
445
445
torch_accelerator_module = getattr (torch , torch_device_type )
446
- torch_accelerator_module . empty_cache () # Needed in case we run several tests in a row.
446
+ clear_device_cache () # Needed in case we run several tests in a row.
447
447
448
448
class SubModule (torch .nn .Module ):
449
449
def __init__ (self , ref_to_parameter ):
@@ -521,7 +521,7 @@ def forward(self, x):
521
521
522
522
torch .testing .assert_close (expected , output .cpu (), atol = ATOL , rtol = RTOL )
523
523
524
- torch_accelerator_module . empty_cache ()
524
+ clear_device_cache ()
525
525
526
526
free_memory_bytes_after_infer = torch_accelerator_module .mem_get_info (torch_device )[0 ]
527
527
@@ -536,14 +536,16 @@ def forward(self, x):
536
536
537
537
# This test fails because sometimes data_ptr() of compute2.weight is the same as compute1.weight.
538
538
# I checked that the values are not the same but it gives the same address. This does not happen on my local machine.
539
- @require_cuda
539
+ @require_cuda_or_xpu
540
540
@unittest .skip (
541
541
"Flaky test, we should have enough coverage with test_dispatch_model_tied_weights_memory_with_nested_offload_cpu test"
542
542
)
543
543
def test_dispatch_model_tied_weights_memory_with_nested_offload_disk (self ):
544
544
# Test that we do not duplicate tied weights at any point during dispatch_model call.
545
545
546
- torch .cuda .empty_cache () # Needed in case we run several tests in a row.
546
+ torch_accelerator_module = getattr (torch , torch_device_type )
547
+
548
+ clear_device_cache () # Needed in case we run several tests in a row.
547
549
548
550
class SubModule (torch .nn .Module ):
549
551
def __init__ (self , ref_to_parameter ):
@@ -589,37 +591,43 @@ def forward(self, x):
589
591
expected = model (x )
590
592
591
593
# Just to initialize CUDA context.
592
- a = torch .rand (5 ).to ("cuda:0" ) # noqa: F841
594
+ device_0 = f"{ torch_device_type } :0"
595
+ a = torch .rand (5 ).to (device_0 ) # noqa: F841
593
596
594
- free_memory_bytes = torch . cuda . mem_get_info ("cuda:0" )[0 ]
597
+ free_memory_bytes = torch_accelerator_module . mem_get_info (device_0 )[0 ]
595
598
required_memory_bytes = 2 * 5000 * 5000 * (32 // 8 ) # 200 MB
596
599
597
600
# Leaving 150 MB of free memory for possible buffers, etc.
598
601
n_vals = (free_memory_bytes - required_memory_bytes - int (200e6 )) // (32 // 8 )
599
- foo = torch .rand (n_vals , device = "cuda:0" ) # noqa: F841
602
+ foo = torch .rand (n_vals , device = device_0 ) # noqa: F841
600
603
601
- free_memory_bytes_before_dispatch = torch . cuda . mem_get_info ("cuda:0" )[0 ]
604
+ free_memory_bytes_before_dispatch = torch_accelerator_module . mem_get_info (device_0 )[0 ]
602
605
with TemporaryDirectory () as tmp_dir :
603
606
dispatch_model (model , device_map , offload_dir = tmp_dir )
604
- free_memory_bytes_after_dispatch = torch . cuda . mem_get_info ("cuda:0" )[0 ]
607
+ free_memory_bytes_after_dispatch = torch_accelerator_module . mem_get_info (device_0 )[0 ]
605
608
606
609
assert (free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch ) * 1e-6 < 130
607
610
611
+ oom_error = (
612
+ torch .OutOfMemoryError
613
+ if hasattr (torch , "OutOfMemoryError" )
614
+ else torch_accelerator_module .OutOfMemoryError
615
+ )
608
616
with torch .no_grad ():
609
617
try :
610
618
output = model (x )
611
- except torch . cuda . OutOfMemoryError as e :
612
- raise torch . cuda . OutOfMemoryError (
619
+ except oom_error as e :
620
+ raise oom_error (
613
621
f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_disk. { e } "
614
622
)
615
623
except Exception as e :
616
624
raise e
617
625
618
626
torch .testing .assert_close (expected , output .cpu (), atol = ATOL , rtol = RTOL )
619
627
620
- torch . cuda . empty_cache ()
628
+ clear_device_cache ()
621
629
622
- free_memory_bytes_after_infer = torch . cuda . mem_get_info ("cuda:0" )[0 ]
630
+ free_memory_bytes_after_infer = torch_accelerator_module . mem_get_info (device_0 )[0 ]
623
631
624
632
# Check that we have no more references on GPU for the offloaded tied weight.
625
633
n_non_empty = 0
0 commit comments