Skip to content

Commit 97c93c4

Browse files
authored
enable test_dispatch_model_tied_weights_memory_with_nested_offload_cpu on xpu (#3569)
* enable test_dispatch_model_tied_weights_memory_with_nested_offload_cpu case on XPU Signed-off-by: Matrix Yao <matrix.yao@intel.com> * replace hard-coded torch.cuda w/ device-dependent callings Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * use device agnostic clear_device_cache Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com>
1 parent cd37bbb commit 97c93c4

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

tests/test_big_modeling.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from accelerate.hooks import remove_hook_from_submodules
3737
from accelerate.test_utils import (
3838
require_bnb,
39-
require_cuda,
4039
require_cuda_or_xpu,
4140
require_multi_device,
4241
require_multi_gpu_or_xpu,
@@ -47,6 +46,7 @@
4746
torch_device,
4847
)
4948
from accelerate.utils import is_hpu_available, offload_state_dict
49+
from accelerate.utils.memory import clear_device_cache
5050
from accelerate.utils.versions import is_torch_version
5151

5252

@@ -379,7 +379,7 @@ def test_dispatch_model_tied_weights_memory(self):
379379

380380
torch_accelerator_module = getattr(torch, torch_device_type)
381381

382-
torch_accelerator_module.empty_cache() # Needed in case we run several tests in a row.
382+
clear_device_cache() # Needed in case we run several tests in a row.
383383

384384
model = nn.Sequential(
385385
OrderedDict(
@@ -443,7 +443,7 @@ def test_dispatch_model_tied_weights_memory_with_nested_offload_cpu(self):
443443
# Test that we do not duplicate tied weights at any point during dispatch_model call.
444444

445445
torch_accelerator_module = getattr(torch, torch_device_type)
446-
torch_accelerator_module.empty_cache() # Needed in case we run several tests in a row.
446+
clear_device_cache() # Needed in case we run several tests in a row.
447447

448448
class SubModule(torch.nn.Module):
449449
def __init__(self, ref_to_parameter):
@@ -521,7 +521,7 @@ def forward(self, x):
521521

522522
torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)
523523

524-
torch_accelerator_module.empty_cache()
524+
clear_device_cache()
525525

526526
free_memory_bytes_after_infer = torch_accelerator_module.mem_get_info(torch_device)[0]
527527

@@ -536,14 +536,16 @@ def forward(self, x):
536536

537537
# This test fails because sometimes data_ptr() of compute2.weight is the same as compute1.weight.
538538
# I checked that the values are not the same but it gives the same address. This does not happen on my local machine.
539-
@require_cuda
539+
@require_cuda_or_xpu
540540
@unittest.skip(
541541
"Flaky test, we should have enough coverage with test_dispatch_model_tied_weights_memory_with_nested_offload_cpu test"
542542
)
543543
def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(self):
544544
# Test that we do not duplicate tied weights at any point during dispatch_model call.
545545

546-
torch.cuda.empty_cache() # Needed in case we run several tests in a row.
546+
torch_accelerator_module = getattr(torch, torch_device_type)
547+
548+
clear_device_cache() # Needed in case we run several tests in a row.
547549

548550
class SubModule(torch.nn.Module):
549551
def __init__(self, ref_to_parameter):
@@ -589,37 +591,43 @@ def forward(self, x):
589591
expected = model(x)
590592

591593
# Just to initialize CUDA context.
592-
a = torch.rand(5).to("cuda:0") # noqa: F841
594+
device_0 = f"{torch_device_type}:0"
595+
a = torch.rand(5).to(device_0) # noqa: F841
593596

594-
free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0]
597+
free_memory_bytes = torch_accelerator_module.mem_get_info(device_0)[0]
595598
required_memory_bytes = 2 * 5000 * 5000 * (32 // 8) # 200 MB
596599

597600
# Leaving 150 MB of free memory for possible buffers, etc.
598601
n_vals = (free_memory_bytes - required_memory_bytes - int(200e6)) // (32 // 8)
599-
foo = torch.rand(n_vals, device="cuda:0") # noqa: F841
602+
foo = torch.rand(n_vals, device=device_0) # noqa: F841
600603

601-
free_memory_bytes_before_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
604+
free_memory_bytes_before_dispatch = torch_accelerator_module.mem_get_info(device_0)[0]
602605
with TemporaryDirectory() as tmp_dir:
603606
dispatch_model(model, device_map, offload_dir=tmp_dir)
604-
free_memory_bytes_after_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
607+
free_memory_bytes_after_dispatch = torch_accelerator_module.mem_get_info(device_0)[0]
605608

606609
assert (free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130
607610

611+
oom_error = (
612+
torch.OutOfMemoryError
613+
if hasattr(torch, "OutOfMemoryError")
614+
else torch_accelerator_module.OutOfMemoryError
615+
)
608616
with torch.no_grad():
609617
try:
610618
output = model(x)
611-
except torch.cuda.OutOfMemoryError as e:
612-
raise torch.cuda.OutOfMemoryError(
619+
except oom_error as e:
620+
raise oom_error(
613621
f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_disk. {e}"
614622
)
615623
except Exception as e:
616624
raise e
617625

618626
torch.testing.assert_close(expected, output.cpu(), atol=ATOL, rtol=RTOL)
619627

620-
torch.cuda.empty_cache()
628+
clear_device_cache()
621629

622-
free_memory_bytes_after_infer = torch.cuda.mem_get_info("cuda:0")[0]
630+
free_memory_bytes_after_infer = torch_accelerator_module.mem_get_info(device_0)[0]
623631

624632
# Check that we have no more references on GPU for the offloaded tied weight.
625633
n_non_empty = 0

tests/test_quantization.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import gc
1615
import tempfile
1716
import unittest
1817

@@ -543,8 +542,7 @@ def tearDown(self):
543542
del self.model_fp16
544543
del self.model_8bit
545544

546-
gc.collect()
547-
torch.cuda.empty_cache()
545+
clear_device_cache(garbage_collection=True)
548546

549547
def test_memory_footprint(self):
550548
r"""
@@ -663,8 +661,7 @@ def tearDown(self):
663661
del self.model_fp16
664662
del self.model_4bit
665663

666-
gc.collect()
667-
torch.cuda.empty_cache()
664+
clear_device_cache(garbage_collection=True)
668665

669666
def test_memory_footprint(self):
670667
r"""

0 commit comments

Comments
 (0)