Skip to content

fix transformers 4.52 device_map ddp #4424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions swift/llm/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,10 @@ def patch_mp_ddp():
This should be called before any training starts.
"""
global _mp_ddp_patched
if is_mp_ddp() and not _mp_ddp_patched:
_mp_ddp_patched = True
if _mp_ddp_patched:
return
_mp_ddp_patched = True
if is_mp_ddp():
from accelerate.utils.modeling import get_balanced_memory, infer_auto_device_map

@wraps(infer_auto_device_map)
Expand All @@ -321,7 +323,7 @@ def _infer_auto_device_map_patch(model: nn.Module,
_old_ddp_init = DDP.__init__
accelerate.accelerator.torch.nn.parallel.DistributedDataParallel.__init__ = (
lambda self, model, device_ids, output_device, *args, **kwargs: _old_ddp_init(self, model, *args, **kwargs))
transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: None
transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: {}
transformers.modeling_utils.infer_auto_device_map = _infer_auto_device_map_patch

if is_mp_ddp() or use_torchacc():
Expand Down
5 changes: 2 additions & 3 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
logger.info_once(f'use_logits_to_keep: {use_logits_to_keep}')

if use_logits_to_keep:
if inputs['labels'].shape[0] == 1:
if inputs['labels'].shape[0] == 1 and not is_mp():
# device_map may encounter device mismatch issues.
loss_mask = (inputs['labels'] != -100)[0]
inputs['labels'] = inputs['labels'][:, loss_mask]
inputs['labels'] = nn.functional.pad(inputs['labels'], (1, 0), value=-100)
inputs['logits_to_keep'] = nn.functional.pad(loss_mask[1:], (0, 1), value=True)
if is_mp():
inputs['logits_to_keep'] = inputs['logits_to_keep'].cpu()
else:
inputs['logits_to_keep'] = (inputs['labels'].shape[-1] -
(torch.ne(inputs['labels'], -100).int().argmax(-1))).max().item() + 1
Expand Down