Skip to content

Commit ab41c74

Browse files
authored
fix transformers 4.52 device_map ddp (#4424)
1 parent 0dc2045 commit ab41c74

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

swift/llm/model/patcher.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,10 @@ def patch_mp_ddp():
298298
This should be called before any training starts.
299299
"""
300300
global _mp_ddp_patched
301-
if is_mp_ddp() and not _mp_ddp_patched:
302-
_mp_ddp_patched = True
301+
if _mp_ddp_patched:
302+
return
303+
_mp_ddp_patched = True
304+
if is_mp_ddp():
303305
from accelerate.utils.modeling import get_balanced_memory, infer_auto_device_map
304306

305307
@wraps(infer_auto_device_map)
@@ -321,7 +323,7 @@ def _infer_auto_device_map_patch(model: nn.Module,
321323
_old_ddp_init = DDP.__init__
322324
accelerate.accelerator.torch.nn.parallel.DistributedDataParallel.__init__ = (
323325
lambda self, model, device_ids, output_device, *args, **kwargs: _old_ddp_init(self, model, *args, **kwargs))
324-
transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: None
326+
transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: {}
325327
transformers.modeling_utils.infer_auto_device_map = _infer_auto_device_map_patch
326328

327329
if is_mp_ddp() or use_torchacc():

swift/trainers/trainers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
176176
logger.info_once(f'use_logits_to_keep: {use_logits_to_keep}')
177177

178178
if use_logits_to_keep:
179-
if inputs['labels'].shape[0] == 1:
179+
if inputs['labels'].shape[0] == 1 and not is_mp():
180+
# device_map may encounter device mismatch issues.
180181
loss_mask = (inputs['labels'] != -100)[0]
181182
inputs['labels'] = inputs['labels'][:, loss_mask]
182183
inputs['labels'] = nn.functional.pad(inputs['labels'], (1, 0), value=-100)
183184
inputs['logits_to_keep'] = nn.functional.pad(loss_mask[1:], (0, 1), value=True)
184-
if is_mp():
185-
inputs['logits_to_keep'] = inputs['logits_to_keep'].cpu()
186185
else:
187186
inputs['logits_to_keep'] = (inputs['labels'].shape[-1] -
188187
(torch.ne(inputs['labels'], -100).int().argmax(-1))).max().item() + 1

0 commit comments

Comments
 (0)