diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index 9c1713ca4c6c..9b64805db5c8 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -131,11 +131,14 @@ def get_accelerator(): if accelerator_name is None: try: import intel_extension_for_pytorch as ipex - if ipex._C._has_xpu(): accelerator_name = "xpu" except ImportError as e: - pass + import torch + if torch.xpu.is_available(): + accelerator_name = "xpu" + else: + pass if accelerator_name is None: try: import torch_npu # noqa: F401,F811 # type: ignore diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index ad8a10710bf2..fcce2502c428 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -5,19 +5,33 @@ import torch from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator -import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore -import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore import functools - import importlib import inspect +try: + import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore + oneccl_imported_p = True +except ImportError as e: + oneccl_imported_p = False + +try: + import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore + ipex_imported_p = True +except ImportError as e: + ipex_imported_p = False + class XPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'xpu' self._communication_backend_name = 'ccl' + if oneccl_imported_p: + self._communication_backend_name = 'ccl' + else: + # changed to xccl if not using torch-CCL on XPU device + self._communication_backend_name = 'xccl' self._compile_backend = "inductor" self.aligned_tensors = [] self.class_dict = None @@ -26,11 +40,14 @@ def is_synchronized_device(self): return False def use_host_timers(self): - # WA XPU event will be consolidated in 2.6 - if ipex.__version__ < '2.6': - return True - else: + if not ipex_imported_p: return self.is_synchronized_device() + else: + # WA XPU event will be consolidated in 2.6 + if ipex.__version__ < '2.6': + return True + else: + return self.is_synchronized_device() def resolves_data_dependency(self): return self.is_synchronized_device() @@ -290,10 +307,13 @@ def get_op_builder(self, class_name): return self.class_dict['NotImplementedBuilder'] def build_extension(self): - try: - from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension - except ImportError: - from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension + if ipex_imported_p: + try: + from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension + except ImportError: + from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension + else: + from torch.utils.cpp_extension import DpcppBuildExtension return DpcppBuildExtension def export_envs(self):