5
5
6
6
import torch
7
7
from deepspeed .accelerator .abstract_accelerator import DeepSpeedAccelerator
8
- import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
9
- import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
10
8
import functools
11
-
12
9
import importlib
13
10
import inspect
14
11
12
+ try :
13
+ import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
14
+ oneccl_imported_p = True
15
+ except ImportError as e :
16
+ oneccl_imported_p = False
17
+
18
+ try :
19
+ import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
20
+ ipex_imported_p = True
21
+ except ImportError as e :
22
+ ipex_imported_p = False
15
23
16
24
class XPU_Accelerator (DeepSpeedAccelerator ):
17
25
18
26
def __init__ (self ):
19
27
self ._name = 'xpu'
20
28
self ._communication_backend_name = 'ccl'
29
+ if oneccl_imported_p :
30
+ self ._communication_backend_name = 'ccl'
31
+ else :
32
+ # changed to xccl if not using torch-CCL on XPU device
33
+ self ._communication_backend_name = 'xccl'
21
34
self ._compile_backend = "inductor"
22
35
self .aligned_tensors = []
23
36
self .class_dict = None
@@ -26,11 +39,14 @@ def is_synchronized_device(self):
26
39
return False
27
40
28
41
def use_host_timers (self ):
29
- # WA XPU event will be consolidated in 2.6
30
- if ipex .__version__ < '2.6' :
31
- return True
32
- else :
42
+ if not ipex_imported_p :
33
43
return self .is_synchronized_device ()
44
+ else :
45
+ # WA XPU event will be consolidated in 2.6
46
+ if ipex .__version__ < '2.6' :
47
+ return True
48
+ else :
49
+ return self .is_synchronized_device ()
34
50
35
51
def resolves_data_dependency (self ):
36
52
return self .is_synchronized_device ()
@@ -290,10 +306,13 @@ def get_op_builder(self, class_name):
290
306
return self .class_dict ['NotImplementedBuilder' ]
291
307
292
308
def build_extension (self ):
293
- try :
294
- from intel_extension_for_pytorch .xpu .cpp_extension import DpcppBuildExtension
295
- except ImportError :
296
- from intel_extension_for_pytorch .xpu .utils import DpcppBuildExtension
309
+ if ipex_imported_p :
310
+ try :
311
+ from intel_extension_for_pytorch .xpu .cpp_extension import DpcppBuildExtension
312
+ except ImportError :
313
+ from intel_extension_for_pytorch .xpu .utils import DpcppBuildExtension
314
+ else :
315
+ from torch .utils .cpp_extension import DpcppBuildExtension
297
316
return DpcppBuildExtension
298
317
299
318
def export_envs (self ):
0 commit comments