diff --git a/swift/llm/model/patcher.py b/swift/llm/model/patcher.py index 2302c44cc..226894ff7 100644 --- a/swift/llm/model/patcher.py +++ b/swift/llm/model/patcher.py @@ -298,8 +298,10 @@ def patch_mp_ddp(): This should be called before any training starts. """ global _mp_ddp_patched - if is_mp_ddp() and not _mp_ddp_patched: - _mp_ddp_patched = True + if _mp_ddp_patched: + return + _mp_ddp_patched = True + if is_mp_ddp(): from accelerate.utils.modeling import get_balanced_memory, infer_auto_device_map @wraps(infer_auto_device_map) @@ -321,7 +323,7 @@ def _infer_auto_device_map_patch(model: nn.Module, _old_ddp_init = DDP.__init__ accelerate.accelerator.torch.nn.parallel.DistributedDataParallel.__init__ = ( lambda self, model, device_ids, output_device, *args, **kwargs: _old_ddp_init(self, model, *args, **kwargs)) - transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: None + transformers.modeling_utils.get_balanced_memory = lambda *args, **kwargs: {} transformers.modeling_utils.infer_auto_device_map = _infer_auto_device_map_patch if is_mp_ddp() or use_torchacc(): diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 1a91f5b7e..5689d57bb 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -176,13 +176,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N logger.info_once(f'use_logits_to_keep: {use_logits_to_keep}') if use_logits_to_keep: - if inputs['labels'].shape[0] == 1: + if inputs['labels'].shape[0] == 1 and not is_mp(): + # device_map may encounter device mismatch issues. loss_mask = (inputs['labels'] != -100)[0] inputs['labels'] = inputs['labels'][:, loss_mask] inputs['labels'] = nn.functional.pad(inputs['labels'], (1, 0), value=-100) inputs['logits_to_keep'] = nn.functional.pad(loss_mask[1:], (0, 1), value=True) - if is_mp(): - inputs['logits_to_keep'] = inputs['logits_to_keep'].cpu() else: inputs['logits_to_keep'] = (inputs['labels'].shape[-1] - (torch.ne(inputs['labels'], -100).int().argmax(-1))).max().item() + 1