Skip to content

Commit 17e2dff

Browse files
committed
support XCCL on deepspeed side
Signed-off-by: yisheng <yi.sheng@intel.com>
1 parent c2c8199 commit 17e2dff

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

accelerator/real_accelerator.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,14 @@ def get_accelerator():
131131
if accelerator_name is None:
132132
try:
133133
import intel_extension_for_pytorch as ipex
134-
135134
if ipex._C._has_xpu():
136135
accelerator_name = "xpu"
137136
except ImportError as e:
138-
pass
137+
import torch
138+
if torch.xpu.is_available():
139+
accelerator_name = "xpu"
140+
else:
141+
pass
139142
if accelerator_name is None:
140143
try:
141144
import torch_npu # noqa: F401,F811 # type: ignore

accelerator/xpu_accelerator.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,32 @@
55

66
import torch
77
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
8-
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
9-
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
108
import functools
11-
129
import importlib
1310
import inspect
1411

12+
try:
13+
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
14+
oneccl_imported_p = True
15+
except ImportError as e:
16+
oneccl_imported_p = False
17+
18+
try:
19+
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
20+
ipex_imported_p = True
21+
except ImportError as e:
22+
ipex_imported_p = False
1523

1624
class XPU_Accelerator(DeepSpeedAccelerator):
1725

1826
def __init__(self):
1927
self._name = 'xpu'
2028
self._communication_backend_name = 'ccl'
29+
if oneccl_imported_p:
30+
self._communication_backend_name = 'ccl'
31+
else:
32+
# changed to xccl if not using torch-CCL on XPU device
33+
self._communication_backend_name = 'xccl'
2134
self._compile_backend = "inductor"
2235
self.aligned_tensors = []
2336
self.class_dict = None
@@ -26,11 +39,14 @@ def is_synchronized_device(self):
2639
return False
2740

2841
def use_host_timers(self):
29-
# WA XPU event will be consolidated in 2.6
30-
if ipex.__version__ < '2.6':
31-
return True
32-
else:
42+
if not ipex_imported_p:
3343
return self.is_synchronized_device()
44+
else:
45+
# WA XPU event will be consolidated in 2.6
46+
if ipex.__version__ < '2.6':
47+
return True
48+
else:
49+
return self.is_synchronized_device()
3450

3551
def resolves_data_dependency(self):
3652
return self.is_synchronized_device()
@@ -290,10 +306,13 @@ def get_op_builder(self, class_name):
290306
return self.class_dict['NotImplementedBuilder']
291307

292308
def build_extension(self):
293-
try:
294-
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
295-
except ImportError:
296-
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
309+
if ipex_imported_p:
310+
try:
311+
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
312+
except ImportError:
313+
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
314+
else:
315+
from torch.utils.cpp_extension import DpcppBuildExtension
297316
return DpcppBuildExtension
298317

299318
def export_envs(self):

0 commit comments

Comments
 (0)