From 2200de1cbd2f67d9c0aa901c678a126e1b497d86 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 13 May 2025 16:21:31 +0800 Subject: [PATCH 01/38] support megatron dpo --- swift/cli/_megatron/main.py | 3 ++- swift/cli/_megatron/rlhf.py | 5 +++++ swift/megatron/__init__.py | 2 +- swift/megatron/argument/__init__.py | 1 + swift/megatron/argument/megatron_args.py | 9 ++++++++- swift/megatron/argument/rlhf_args.py | 9 +++++++++ swift/megatron/argument/train_args.py | 2 -- swift/megatron/train/__init__.py | 2 ++ swift/megatron/train/rlhf.py | 24 ++++++++++++++++++++++++ swift/utils/tb_utils.py | 2 ++ 10 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 swift/cli/_megatron/rlhf.py create mode 100644 swift/megatron/argument/rlhf_args.py create mode 100644 swift/megatron/train/rlhf.py diff --git a/swift/cli/_megatron/main.py b/swift/cli/_megatron/main.py index 3a53c375a0..a83dcac8b1 100644 --- a/swift/cli/_megatron/main.py +++ b/swift/cli/_megatron/main.py @@ -7,8 +7,9 @@ logger = get_logger() ROUTE_MAPPING: Dict[str, str] = { - 'sft': 'swift.cli._megatron.sft', 'pt': 'swift.cli._megatron.pt', + 'sft': 'swift.cli._megatron.sft', + 'rlhf': 'swift.cli._megatron.rlhf', } diff --git a/swift/cli/_megatron/rlhf.py b/swift/cli/_megatron/rlhf.py new file mode 100644 index 0000000000..096e1e3808 --- /dev/null +++ b/swift/cli/_megatron/rlhf.py @@ -0,0 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from swift.megatron import megatron_rlhf_main + +if __name__ == '__main__': + megatron_rlhf_main() diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index 9e7c6b4060..c3263ed8ec 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -12,7 +12,7 @@ from swift.utils.import_utils import _LazyModule if TYPE_CHECKING: - from .train import megatron_sft_main, megatron_pt_main + from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main from .utils import convert_hf2mcore, convert_mcore2hf from .argument import MegatronTrainArguments from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/argument/__init__.py b/swift/megatron/argument/__init__.py index 032d3c471b..a2ad08daa3 100644 --- a/swift/megatron/argument/__init__.py +++ b/swift/megatron/argument/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .megatron_args import MegatronArguments +from .rlhf_args import MegatronRLHFArguments from .train_args import MegatronTrainArguments diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 90309ff114..dc9546a85d 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -11,7 +11,14 @@ @dataclass -class ExtraMegatronArguments: +class RLHFMegatronArgumentsMixin: + beta: float = 0.1 + rpo_alpha: float = 1. + ref_load: Optional[str] = None + + +@dataclass +class ExtraMegatronArguments(RLHFMegatronArgumentsMixin): padded_vocab_size: Optional[int] = None rope_scaling: Optional[Union[dict, str]] = None torch_dtype: Optional[torch.dtype] = None diff --git a/swift/megatron/argument/rlhf_args.py b/swift/megatron/argument/rlhf_args.py new file mode 100644 index 0000000000..a6104cacdf --- /dev/null +++ b/swift/megatron/argument/rlhf_args.py @@ -0,0 +1,9 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Literal + +from .train_args import MegatronTrainArguments + + +class MegatronRLHFArguments(MegatronTrainArguments): + rlhf_type: Literal['dpo'] = 'dpo' + loss_scale: str = 'last_round' diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index c43b5e8f76..8be90eb90f 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -2,8 +2,6 @@ import os from dataclasses import dataclass -import torch - from swift.llm import BaseArguments from swift.llm.argument.base_args import to_abspath from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master diff --git a/swift/megatron/train/__init__.py b/swift/megatron/train/__init__.py index 8f6a98be92..1b091bd4a3 100644 --- a/swift/megatron/train/__init__.py +++ b/swift/megatron/train/__init__.py @@ -1,2 +1,4 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. from .pt import megatron_pt_main +from .rlhf import megatron_rlhf_main from .sft import megatron_sft_main diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py new file mode 100644 index 0000000000..ab0a31fd01 --- /dev/null +++ b/swift/megatron/train/rlhf.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import List, Union + +from megatron.core.enums import ModelType +from megatron.training import pretrain + +from swift.utils import get_logger, is_master, plot_images +from ..argument import MegatronRLHFArguments +from ..utils import patch_megatron_tokenizer +from .patcher import patch_megatron_data_collator, patch_training_log +from .sft import MegatronSft +from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider + +logger = get_logger() + + +class MegatronRLHF(MegatronSft): + args_class = MegatronRLHFArguments + args: args_class + + +def megatron_rlhf_main(args: Union[List[str], MegatronRLHFArguments, None] = None): + return MegatronRLHF(args).main() diff --git a/swift/utils/tb_utils.py b/swift/utils/tb_utils.py index 050e84d1ae..387275da09 100644 --- a/swift/utils/tb_utils.py +++ b/swift/utils/tb_utils.py @@ -45,6 +45,8 @@ def plot_images(images_dir: str, figsize: Tuple[int, int] = (8, 5), dpi: int = 100) -> None: """Using tensorboard's data content to plot images""" + if not os.path.exists(tb_dir): + return smooth_key = smooth_key or [] os.makedirs(images_dir, exist_ok=True) fname = [fname for fname in os.listdir(tb_dir) if os.path.isfile(os.path.join(tb_dir, fname))][0] From 3a21dcaac8909573c397673381021c9dd066d984 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 13 May 2025 16:23:33 +0800 Subject: [PATCH 02/38] update --- swift/megatron/train/rlhf.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index ab0a31fd01..b8860cd1c2 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,16 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os from typing import List, Union - -from megatron.core.enums import ModelType -from megatron.training import pretrain - -from swift.utils import get_logger, is_master, plot_images +from swift.utils import get_logger from ..argument import MegatronRLHFArguments -from ..utils import patch_megatron_tokenizer -from .patcher import patch_megatron_data_collator, patch_training_log from .sft import MegatronSft -from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider logger = get_logger() From 21d044a8b9cc09267105a8ca1b66be752139d923 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 13 May 2025 21:26:18 +0800 Subject: [PATCH 03/38] update --- swift/megatron/init.py | 46 ------------------------------------ swift/megatron/train/rlhf.py | 22 +++++++++++++++++ swift/megatron/train/sft.py | 45 ++++++++++++++++++++++++++++++++++- tests/megatron/test_rlhf.py | 19 +++++++++++++++ 4 files changed, 85 insertions(+), 47 deletions(-) create mode 100644 tests/megatron/test_rlhf.py diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 72380c414a..7319b5970f 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -1,7 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os import sys -from contextlib import contextmanager from swift.llm import git_clone_github from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run @@ -21,53 +20,8 @@ def _patch_transformer_engine(): pass -def new_cyclic_iter(iter): - from megatron.training import get_args - args = get_args() - max_epochs = args.max_epochs - i = 0 - while True: - if getattr(args, 'is_training', False): - if max_epochs and i >= max_epochs: - logger.info(f'Training of {i} epochs has been completed, the training has finished.') - break - logger.info(f'The training of Epoch {i} starts...') - for x in iter: - yield x - i += 1 - - -@contextmanager -def _training_context(): - from megatron.training import get_args - args = get_args() - args.is_training = True - try: - yield - finally: - args.is_training = False - - -def _patch_max_epochs(): - # support max_epochs - from megatron.training import training - train_step_origin = training.train_step - - def train_step(*args, **kwargs): - with _training_context(): - try: - return train_step_origin(*args, **kwargs) - except StopIteration: - return {}, True, True, True, 0, None, None - - training.train_step = train_step - - training.cyclic_iter = new_cyclic_iter - - def _patch_megatron(): _patch_transformer_engine() - _patch_max_epochs() def init_megatron_env() -> None: diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index b8860cd1c2..a5cd71195e 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,5 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Union + +from megatron.training import get_args, get_model, training +from megatron.training.checkpointing import load_checkpoint + from swift.utils import get_logger from ..argument import MegatronRLHFArguments from .sft import MegatronSft @@ -11,6 +15,24 @@ class MegatronRLHF(MegatronSft): args_class = MegatronRLHFArguments args: args_class + def _patch_setup_model_and_optimizer(self): + origin_setup_model_and_optimizer = training.setup_model_and_optimizer + + def setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs): + args = get_args() + ref_model = get_model(model_provider_func, model_type) + if args.ref_load is not None: + args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( + ref_model, None, None, load_arg='ref_load') + args.ref_model = ref_model + return origin_setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) + + training.setup_model_and_optimizer = setup_model_and_optimizer + + def run(self): + self._patch_setup_model_and_optimizer() + super().run() + def megatron_rlhf_main(args: Union[List[str], MegatronRLHFArguments, None] = None): return MegatronRLHF(args).main() diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 4fa3e24f18..081db24a9f 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -1,9 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +from contextlib import contextmanager from typing import List, Union from megatron.core.enums import ModelType -from megatron.training import pretrain +from megatron.training import get_args, pretrain, training from swift.llm.train import SwiftSft from swift.utils import get_logger, is_master, plot_images @@ -30,8 +31,50 @@ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) self.template.use_megatron = True args.save_args(args.save) + @staticmethod + def new_cyclic_iter(iter): + args = get_args() + max_epochs = args.max_epochs + i = 0 + while True: + if getattr(args, 'is_training', False): + if max_epochs and i >= max_epochs: + logger.info(f'Training of {i} epochs has been completed, the training has finished.') + break + logger.info(f'The training of Epoch {i} starts...') + for x in iter: + yield x + i += 1 + + @staticmethod + @contextmanager + def _training_context(): + args = get_args() + args.is_training = True + try: + yield + finally: + args.is_training = False + + def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): + return self._train_step_origin(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config) + + def _patch_train_step(self): + # support max_epochs + def train_step(*args, **kwargs): + with self._training_context(): + try: + return self.train_step(*args, **kwargs) + except StopIteration: + return {}, True, True, True, 0, None, None + + self._train_step_origin = training.train_step + training.train_step = train_step + training.cyclic_iter = MegatronSft.new_cyclic_iter + def run(self): args = self.args + self._patch_train_step() train_dataset, val_dataset = self._get_dataset() train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) diff --git a/tests/megatron/test_rlhf.py b/tests/megatron/test_rlhf.py new file mode 100644 index 0000000000..3bdc857d19 --- /dev/null +++ b/tests/megatron/test_rlhf.py @@ -0,0 +1,19 @@ +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' + +def test_dpo(): + from swift.megatron import megatron_rlhf_main, MegatronRLHFArguments + megatron_rlhf_main( + MegatronRLHFArguments( + load='Qwen2.5-3B-Instruct-mcore', + dataset=[ + 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#1000' + ], + tensor_model_parallel_size=2, + train_iters=100, + eval_iters=5, + finetune=True)) + +if __name__ == '__main__': + test_dpo() From 29714f10570e46bab70edccd561557f435ddfd7e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 13 May 2025 21:33:30 +0800 Subject: [PATCH 04/38] update --- swift/megatron/__init__.py | 6 +++--- swift/megatron/train/rlhf.py | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/swift/megatron/__init__.py b/swift/megatron/__init__.py index c3263ed8ec..b3ba174b26 100644 --- a/swift/megatron/__init__.py +++ b/swift/megatron/__init__.py @@ -14,13 +14,13 @@ if TYPE_CHECKING: from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main from .utils import convert_hf2mcore, convert_mcore2hf - from .argument import MegatronTrainArguments + from .argument import MegatronTrainArguments, MegatronRLHFArguments from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model else: _import_structure = { - 'train': ['megatron_sft_main', 'megatron_pt_main'], + 'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'], 'utils': ['convert_hf2mcore', 'convert_mcore2hf'], - 'argument': ['MegatronTrainArguments'], + 'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments'], 'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'] } diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index a5cd71195e..1c4ae4142d 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -29,6 +29,9 @@ def setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) training.setup_model_and_optimizer = setup_model_and_optimizer + def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): + super().train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config) + def run(self): self._patch_setup_model_and_optimizer() super().run() From 48bdef97814ad9fb81fcb3abc3b603654c046cbe Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 13 May 2025 21:35:55 +0800 Subject: [PATCH 05/38] update --- swift/megatron/train/sft.py | 28 +++++++++++++++++++++++++--- swift/megatron/train/utils.py | 22 ---------------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 081db24a9f..b310e45907 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -1,17 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os from contextlib import contextmanager +from functools import partial from typing import List, Union from megatron.core.enums import ModelType -from megatron.training import get_args, pretrain, training +from megatron.training import get_args, get_timers, pretrain, training from swift.llm.train import SwiftSft from swift.utils import get_logger, is_master, plot_images from ..argument import MegatronTrainArguments from ..utils import patch_megatron_tokenizer from .patcher import patch_megatron_data_collator, patch_training_log -from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider +from .utils import build_streaming_dataloader, get_batch, get_swift_datasets_provider logger = get_logger() @@ -72,6 +73,27 @@ def train_step(*args, **kwargs): training.train_step = train_step training.cyclic_iter = MegatronSft.new_cyclic_iter + def forward_step(self, data_iterator, model): + from pretrain_gpt import loss_func + + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + data = get_batch(data_iterator) + if not data: + raise StopIteration + tokens, labels, attention_mask, position_ids, packed_seq_params = data + timers('batch-generator').stop() + + with stimer: + output_tensor = model( + tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params) + loss_mask = None if labels is None else (labels != -100).float() + return output_tensor, partial(loss_func, loss_mask) + def run(self): args = self.args self._patch_train_step() @@ -94,7 +116,7 @@ def run(self): datasets_provider, args.megatron_model_meta.model_provider, ModelType.encoder_or_decoder, - forward_step, + self.forward_step, args_defaults=args.extra_args) finally: # Visualization diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py index 69caa161d1..73e45c4b87 100644 --- a/swift/megatron/train/utils.py +++ b/swift/megatron/train/utils.py @@ -1,5 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from functools import partial from typing import Any, Dict, Optional import torch @@ -206,24 +205,3 @@ def get_batch(data_iterator): # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) return batch.values() - - -def forward_step(data_iterator, model): - from pretrain_gpt import loss_func - - timers = get_timers() - - # Get the batch. - timers('batch-generator', log_level=2).start() - global stimer - with stimer(bdata=True): - data = get_batch(data_iterator) - if not data: - raise StopIteration - tokens, labels, attention_mask, position_ids, packed_seq_params = data - timers('batch-generator').stop() - - with stimer: - output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params) - loss_mask = None if labels is None else (labels != -100).float() - return output_tensor, partial(loss_func, loss_mask) From 6806521531df3fc974aa06b8eab2f2d3752a38c9 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 13 May 2025 21:36:45 +0800 Subject: [PATCH 06/38] update --- tests/megatron/test_rlhf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/megatron/test_rlhf.py b/tests/megatron/test_rlhf.py index 3bdc857d19..5329934672 100644 --- a/tests/megatron/test_rlhf.py +++ b/tests/megatron/test_rlhf.py @@ -2,14 +2,13 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' + def test_dpo(): from swift.megatron import megatron_rlhf_main, MegatronRLHFArguments megatron_rlhf_main( MegatronRLHFArguments( load='Qwen2.5-3B-Instruct-mcore', - dataset=[ - 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#1000' - ], + dataset=['hjh0119/shareAI-Llama3-DPO-zh-en-emoji#1000'], tensor_model_parallel_size=2, train_iters=100, eval_iters=5, From ebfe9d1509e4f4195e870da01dc8aa73c457223e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 14 May 2025 00:19:01 +0800 Subject: [PATCH 07/38] update --- swift/llm/data_loader.py | 15 +++++++++++++-- swift/trainers/mixin.py | 4 ++-- swift/trainers/sequence_parallel/ulysses.py | 7 +++++-- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/swift/llm/data_loader.py b/swift/llm/data_loader.py index fe20854ca6..2415e754cc 100644 --- a/swift/llm/data_loader.py +++ b/swift/llm/data_loader.py @@ -4,6 +4,8 @@ import torch.distributed as dist from torch.utils.data import DataLoader +from swift.llm import to_device + class BatchSamplerShard: @@ -56,18 +58,25 @@ def __len__(self) -> int: class DataLoaderShard(DataLoader): - def __init__(self, dataset, batch_sampler: BatchSamplerShard, **dataloader_params): + def __init__(self, dataset, batch_sampler: BatchSamplerShard, device=None, **dataloader_params): self.batch_sampler = batch_sampler + self.device = device super().__init__(dataset, batch_sampler=self.batch_sampler, **dataloader_params) def set_epoch(self, epoch: int): self.batch_sampler.set_epoch(epoch) + def __iter__(self): + for item in super().__iter__(): + if self.device: + item = to_device(item, self.device) + yield item class DataLoaderDispatcher: - def __init__(self, base_dataloader): + def __init__(self, base_dataloader, device=None): self.base_dataloader = base_dataloader + self.device = device @property def rank(self): @@ -102,4 +111,6 @@ def __iter__(self): data = self._scatter_object_list(None) if data is None: break + if self.device: + data = to_device(data, self.device) yield data diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index fbd382d99f..6a402510f6 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -493,13 +493,13 @@ def get_train_dataloader(self): if hasattr(train_dataset, '__len__'): batch_sampler = BatchSamplerShard( len(train_dataset), batch_size=self._train_batch_size, **batch_sampler_params) - dataloader = DataLoaderShard(train_dataset, batch_sampler, **dataloader_params) + dataloader = DataLoaderShard(train_dataset, batch_sampler, self.accelerator.device, **dataloader_params) else: # IterableDataset if dist.is_initialized() and dataloader_params['prefetch_factor']: dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size() dataloader = DataLoader(train_dataset, batch_size=self._train_batch_size, **dataloader_params) - dataloader = DataLoaderDispatcher(dataloader) + dataloader = DataLoaderDispatcher(dataloader, self.accelerator.device) return dataloader diff --git a/swift/trainers/sequence_parallel/ulysses.py b/swift/trainers/sequence_parallel/ulysses.py index d9c415c15e..f317e156e6 100644 --- a/swift/trainers/sequence_parallel/ulysses.py +++ b/swift/trainers/sequence_parallel/ulysses.py @@ -144,9 +144,10 @@ def set_epoch(self, epoch: int) -> None: class UlyssesDispatcher(DataLoaderDispatcher): - def __init__(self, base_dataloader, ulysses): + def __init__(self, base_dataloader, ulysses, device=None): super().__init__(base_dataloader) self.ulysses = ulysses + self.device = device def __iter__(self): base_iter = iter(self.base_dataloader) @@ -161,6 +162,8 @@ def __iter__(self): pass if data is None: break + if self.device: + data = to_device(data, self.device) yield data @@ -546,7 +549,7 @@ def get_dataloader(self, trainer, dataset, batch_size): if dist.is_initialized() and dataloader_params['prefetch_factor']: dataloader_params['prefetch_factor'] = dataloader_params['prefetch_factor'] * dist.get_world_size() dataloader = DataLoader(dataset, batch_size=batch_size, **dataloader_params) - dataloader = UlyssesDispatcher(dataloader, self) + dataloader = UlyssesDispatcher(dataloader, self, trainer.accelerator.device) return dataloader def prepare_trainer(self, trainer): From 5802d8e1f98de9a9ecddd92c294686c92affc0a1 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 14 May 2025 00:19:34 +0800 Subject: [PATCH 08/38] update --- swift/llm/data_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/llm/data_loader.py b/swift/llm/data_loader.py index 2415e754cc..a928e33d66 100644 --- a/swift/llm/data_loader.py +++ b/swift/llm/data_loader.py @@ -72,6 +72,7 @@ def __iter__(self): item = to_device(item, self.device) yield item + class DataLoaderDispatcher: def __init__(self, base_dataloader, device=None): From fef4b9b2c85505d2ddb6cac655d6de6822a3cc97 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 14 May 2025 00:35:53 +0800 Subject: [PATCH 09/38] update --- tests/megatron/test_rlhf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/megatron/test_rlhf.py b/tests/megatron/test_rlhf.py index 5329934672..54b2ed9ad7 100644 --- a/tests/megatron/test_rlhf.py +++ b/tests/megatron/test_rlhf.py @@ -14,5 +14,6 @@ def test_dpo(): eval_iters=5, finetune=True)) + if __name__ == '__main__': test_dpo() From 7a595d309ecc839ecf3d65616d14f694103a87fc Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 14 May 2025 22:58:53 +0800 Subject: [PATCH 10/38] update --- swift/megatron/train/sft.py | 2 ++ swift/megatron/train/utils.py | 4 ---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index b310e45907..5d2bcc7916 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -6,6 +6,7 @@ from megatron.core.enums import ModelType from megatron.training import get_args, get_timers, pretrain, training +from megatron.core.utils import StragglerDetector from swift.llm.train import SwiftSft from swift.utils import get_logger, is_master, plot_images @@ -16,6 +17,7 @@ logger = get_logger() +stimer = StragglerDetector() class MegatronSft(SwiftSft): args_class = MegatronTrainArguments diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py index 73e45c4b87..b91278f90c 100644 --- a/swift/megatron/train/utils.py +++ b/swift/megatron/train/utils.py @@ -4,15 +4,11 @@ import torch from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.utils import StragglerDetector from megatron.training import get_args, get_timers from megatron.training.training import cyclic_iter from swift.llm import DataLoaderDispatcher -stimer = StragglerDetector() - - def get_swift_datasets_provider(train_dataset, val_dataset): def swift_datasets_provider(train_val_test_num_samples): From 9dde75af3edb4f777c5765b6f0cfda7c5d44ed24 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 15 May 2025 13:58:41 +0800 Subject: [PATCH 11/38] update --- swift/megatron/train/rlhf.py | 33 +++++++++++++++++++++++++++++++-- swift/megatron/train/sft.py | 8 ++++---- swift/megatron/train/utils.py | 23 ++++++++++++----------- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 1c4ae4142d..495f38230f 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,12 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import List, Union -from megatron.training import get_args, get_model, training +import torch +from megatron.core import mpu +from megatron.training import get_args, get_model, get_timers, training from megatron.training.checkpointing import load_checkpoint +from megatron.training.utils import unwrap_model from swift.utils import get_logger from ..argument import MegatronRLHFArguments from .sft import MegatronSft +from .utils import get_batch logger = get_logger() @@ -29,8 +33,33 @@ def setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) training.setup_model_and_optimizer = setup_model_and_optimizer + def ref_forward(self, data_iterator): + args = get_args() + ref_model = unwrap_model(args.ref_model[0]) + timers = get_timers() + timers('batch-ref-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + data = get_batch(data_iterator) + timers('batch-ref-generator').stop() + + ref_forward_step = InferenceForwardStep(ref_model) + ref_logits = ref_forward_step(**data) + + if mpu.is_pipeline_last_stage(): + data['ref_logits'] = ref_logits.detach() + else: + data['ref_logits'] = None + return data + def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - super().train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config) + args = get_args() + num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) + res = [] + for i in range(num_iters_per_step): + with torch.no_grad(): + res.append(self.ref_forward(data_iterator)) + super().train_step(self, forward_step_func, iter(res), model, optimizer, opt_param_scheduler, config) def run(self): self._patch_setup_model_and_optimizer() diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 5d2bcc7916..0facfc47c7 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -5,8 +5,8 @@ from typing import List, Union from megatron.core.enums import ModelType -from megatron.training import get_args, get_timers, pretrain, training from megatron.core.utils import StragglerDetector +from megatron.training import get_args, get_timers, pretrain, training from swift.llm.train import SwiftSft from swift.utils import get_logger, is_master, plot_images @@ -19,6 +19,7 @@ stimer = StragglerDetector() + class MegatronSft(SwiftSft): args_class = MegatronTrainArguments args: args_class @@ -87,12 +88,11 @@ def forward_step(self, data_iterator, model): data = get_batch(data_iterator) if not data: raise StopIteration - tokens, labels, attention_mask, position_ids, packed_seq_params = data timers('batch-generator').stop() with stimer: - output_tensor = model( - tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params) + output_tensor = model(**data) + labels = data.get('labels') loss_mask = None if labels is None else (labels != -100).float() return output_tensor, partial(loss_func, loss_mask) diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py index b91278f90c..d88cf83189 100644 --- a/swift/megatron/train/utils.py +++ b/swift/megatron/train/utils.py @@ -9,6 +9,7 @@ from swift.llm import DataLoaderDispatcher + def get_swift_datasets_provider(train_dataset, val_dataset): def swift_datasets_provider(train_val_test_num_samples): @@ -54,10 +55,10 @@ def _broadcast(item): except StopIteration: seq_length = -1 else: - tokens = data['input_ids'] - seq_length = tokens.shape[1] + input_ids = data['input_ids'] + seq_length = input_ids.shape[1] batch = { - 'tokens': tokens.cuda(non_blocking=True), + 'input_ids': input_ids.cuda(non_blocking=True), 'labels': data['labels'].cuda(non_blocking=True), 'attention_mask': None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True), @@ -68,13 +69,13 @@ def _broadcast(item): if seq_length.item() == -1: return {} if args.pipeline_model_parallel_size == 1: - _broadcast(batch['tokens']) + _broadcast(batch['input_ids']) _broadcast(batch['labels']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) elif mpu.is_pipeline_first_stage(): - _broadcast(batch['tokens']) + _broadcast(batch['input_ids']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) @@ -89,7 +90,7 @@ def _broadcast(item): if seq_length.item() == -1: return {} micro_batch_size = 1 # use qkv_format 'thd' - tokens = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device()) + input_ids = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device()) labels = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device()) if args.create_attention_mask_in_dataloader: attention_mask = torch.empty((micro_batch_size, 1, seq_length, seq_length), @@ -102,7 +103,7 @@ def _broadcast(item): device=torch.cuda.current_device()) if args.pipeline_model_parallel_size == 1: - _broadcast(tokens) + _broadcast(input_ids) _broadcast(labels) _broadcast(attention_mask) _broadcast(position_ids) @@ -110,18 +111,18 @@ def _broadcast(item): elif mpu.is_pipeline_first_stage(): labels = None - _broadcast(tokens) + _broadcast(input_ids) _broadcast(attention_mask) _broadcast(position_ids) elif mpu.is_pipeline_last_stage(): - tokens = None + input_ids = None _broadcast(labels) _broadcast(attention_mask) _broadcast(position_ids) # compat packing & cp - batch = {'tokens': tokens, 'labels': labels, 'attention_mask': attention_mask, 'position_ids': position_ids} + batch = {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask, 'position_ids': position_ids} return batch @@ -200,4 +201,4 @@ def get_batch(data_iterator): batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids']) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch) - return batch.values() + return batch From 007c9eddaadc48a6c1cd405fa0cba99a34e19ad8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 28 May 2025 16:29:27 +0800 Subject: [PATCH 12/38] update --- swift/megatron/init.py | 46 +---------------------------------- swift/megatron/train/utils.py | 7 +++++- 2 files changed, 7 insertions(+), 46 deletions(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 1c016d5271..0375fc1140 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os import sys +from contextlib import contextmanager from swift.llm import git_clone_github from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run @@ -29,50 +30,6 @@ def _patch_transformer_engine(): pass -def new_cyclic_iter(iter): - from megatron.training import get_args - args = get_args() - max_epochs = args.max_epochs - i = 0 - while True: - if getattr(args, 'is_training', False): - if max_epochs and i >= max_epochs: - logger.info(f'Training of {i} epochs has been completed, the training has finished.') - break - logger.info(f'The training of Epoch {i} starts...') - for x in iter: - yield x - i += 1 - - -@contextmanager -def _training_context(): - from megatron.training import get_args - args = get_args() - args.is_training = True - try: - yield - finally: - args.is_training = False - - -def _patch_max_epochs(): - # support max_epochs - from megatron.training import training - train_step_origin = training.train_step - - def train_step(*args, **kwargs): - with _training_context(): - try: - return train_step_origin(*args, **kwargs) - except StopIteration: - return {}, True, True, True, 0, None, None - - training.train_step = train_step - - training.cyclic_iter = new_cyclic_iter - - def _patch__batched_p2p_ops(): from megatron.core.pipeline_parallel import p2p_communication @@ -87,7 +44,6 @@ def _batched_p2p_ops(**kwargs): def _patch_megatron(): _patch_transformer_engine() - _patch_max_epochs() _patch__batched_p2p_ops() diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py index 33f70e86dd..b96e2db907 100644 --- a/swift/megatron/train/utils.py +++ b/swift/megatron/train/utils.py @@ -130,7 +130,12 @@ def _broadcast(item): _broadcast(attention_mask) _broadcast(position_ids) # compat packing & cp - batch = {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask, 'position_ids': position_ids} + batch = { + 'input_ids': input_ids, + 'labels': labels, + 'attention_mask': attention_mask, + 'position_ids': position_ids + } return batch From 1f0f411c48521ed9754c76423a20958f554cd376 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Jun 2025 14:55:46 +0800 Subject: [PATCH 13/38] update --- swift/megatron/train/rlhf.py | 3 +-- swift/megatron/train/sft.py | 8 +++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 495f38230f..c0f880116a 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -38,8 +38,7 @@ def ref_forward(self, data_iterator): ref_model = unwrap_model(args.ref_model[0]) timers = get_timers() timers('batch-ref-generator', log_level=2).start() - global stimer - with stimer(bdata=True): + with self.stimer(bdata=True): data = get_batch(data_iterator) timers('batch-ref-generator').stop() diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 5698a617dc..5661e7d617 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -18,8 +18,6 @@ logger = get_logger() -stimer = StragglerDetector() - class MegatronSft(SwiftSft): args_class = MegatronTrainArguments @@ -35,6 +33,7 @@ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) self._prepare_template() self.template.use_megatron = True args.save_args(args.save) + self.stimer = StragglerDetector() @contextmanager def _get_train_iters(self, train_dataset): @@ -106,14 +105,13 @@ def forward_step(self, data_iterator, model): # Get the batch. timers('batch-generator', log_level=2).start() - global stimer - with stimer(bdata=True): + with self.stimer(bdata=True): data = get_batch(data_iterator) if not data: raise StopIteration timers('batch-generator').stop() - with stimer: + with self.stimer: output_tensor = model(**data) labels = data.get('labels') loss_mask = None if labels is None else (labels != -100).float() From 615befc853c0b18af6dddf8c47d53549fcc288b8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Jun 2025 20:08:01 +0800 Subject: [PATCH 14/38] update --- swift/megatron/train/rlhf.py | 77 +++++++++++++++++---- swift/trainers/rlhf_trainer/dpo_trainer.py | 20 ++++-- swift/trainers/sequence_parallel/ulysses.py | 8 +-- 3 files changed, 80 insertions(+), 25 deletions(-) diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index c0f880116a..d7abe72352 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import List, Union +from functools import partial +from typing import List, Tuple, Union import torch from megatron.core import mpu @@ -19,36 +20,51 @@ class MegatronRLHF(MegatronSft): args_class = MegatronRLHFArguments args: args_class + def _prepare_template(self) -> None: + super()._prepare_template() + self.template.set_mode('rlhf') + def _patch_setup_model_and_optimizer(self): origin_setup_model_and_optimizer = training.setup_model_and_optimizer def setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs): args = get_args() ref_model = get_model(model_provider_func, model_type) - if args.ref_load is not None: - args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( - ref_model, None, None, load_arg='ref_load') - args.ref_model = ref_model + if args.ref_load is None: + args.ref_load = args.load + args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( + ref_model, None, None, load_arg='ref_load') + self.ref_model = ref_model[0] + self.ref_model.eval() return origin_setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) training.setup_model_and_optimizer = setup_model_and_optimizer def ref_forward(self, data_iterator): + from swift.trainers import DPOTrainer args = get_args() - ref_model = unwrap_model(args.ref_model[0]) + ref_model = unwrap_model(self.ref_model) timers = get_timers() timers('batch-ref-generator', log_level=2).start() with self.stimer(bdata=True): data = get_batch(data_iterator) + if not data: + raise StopIteration timers('batch-ref-generator').stop() - - ref_forward_step = InferenceForwardStep(ref_model) - ref_logits = ref_forward_step(**data) - - if mpu.is_pipeline_last_stage(): - data['ref_logits'] = ref_logits.detach() - else: - data['ref_logits'] = None + labels = data.pop('labels', None) + with torch.no_grad(): + ref_logits = ref_model(**data) + ref_logits = ref_logits.to(torch.float32) + per_token_logps, _, loss_mask = DPOTrainer.get_per_token_logps(ref_logits, labels) + cu_seqlens = data['packed_seq_params'].cu_seqlens_q[:args.micro_batch_size * 2 + 1] + all_logps = per_token_logps.new_zeros((args.micro_batch_size, )) + for i in range(args.micro_batch_size): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + all_logps[i] = per_token_logps[:, start:end].sum() + num_tokens = cu_seqlens[args.micro_batch_size] + # output['nll_loss'] = self.get_nll_loss(all_logits[:, :num_tokens], labels[:, :num_tokens]) + data['labels'] = labels + data['logps'] = all_logps return data def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): @@ -58,12 +74,43 @@ def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_par for i in range(num_iters_per_step): with torch.no_grad(): res.append(self.ref_forward(data_iterator)) - super().train_step(self, forward_step_func, iter(res), model, optimizer, opt_param_scheduler, config) + return super().train_step(forward_step_func, iter(res), model, optimizer, opt_param_scheduler, config) def run(self): self._patch_setup_model_and_optimizer() super().run() + def loss_func(self, loss_mask: torch.Tensor, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor): + from swift.trainers import DPOTrainer + nll_loss = torch.sum(output_tensor * loss_mask) / loss_mask.sum() + losses, chosen_rewards, rejected_rewards = DPOTrainer.dpo_loss(model_output['chosen_logps'], + model_output['rejected_logps'], ref_chosen_logps, + ref_rejected_logps) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + def forward_step(self, data_iterator, model): + timers = get_timers() + args = get_args() + + # Get the batch. + timers('batch-generator', log_level=2).start() + with self.stimer(bdata=True): + data = next(data_iterator) + timers('batch-generator').stop() + ref_logps = data.pop('logps') + labels = data.pop('labels') + with self.stimer: + logits = model(**data) + if labels is None: + loss_mask = None + else: + loss = model.module.module.compute_language_model_loss(labels, logits) + cu_seqlens = data['packed_seq_params'].cu_seqlens_q[:args.micro_batch_size * 2 + 1] + num_tokens = cu_seqlens[args.micro_batch_size] + loss_mask = (labels != -100).float() + loss_mask[:, num_tokens:] = 0 + return logits, partial(self.loss_func, loss_mask, ref_logps=ref_logps) + def megatron_rlhf_main(args: Union[List[str], MegatronRLHFArguments, None] = None): return MegatronRLHF(args).main() diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index 95c2e27305..8385ab3128 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -49,7 +49,10 @@ def get_nll_loss(self, logits, labels): return loss_fct(logits, labels) def concatenated_forward( - self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]], **kwargs + self, + model: nn.Module, + batch: Dict[str, Union[List, torch.LongTensor]], + is_ref_model=False ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: batch = batch.copy() labels = batch.pop('labels', None) @@ -81,7 +84,9 @@ def concatenated_forward( if not self.is_encoder_decoder and self.template.sequence_parallel_size == 1: # Shift so that tokens < n predict n labels = torch.roll(labels, shifts=-1, dims=1) - per_token_logps, mean_all_logits, loss_mask = self.get_per_token_logps(all_logits, labels) + per_token_logps, mean_all_logits, loss_mask = self.get_per_token_logps( + all_logits, labels, label_pad_token_id=self.label_pad_token_id) + origin_per_token_logps = per_token_logps if self.loss_type == 'ipo': size_completion = loss_mask.sum(dim=-1) per_token_logps = per_token_logps / size_completion @@ -95,7 +100,8 @@ def concatenated_forward( all_logps[i] = per_token_logps[:, start:end].sum() num_examples = all_logps.shape[0] // 2 num_tokens = cu_seqlens[num_examples] - output['nll_loss'] = self.get_nll_loss(all_logits[:, :num_tokens], labels[:, :num_tokens]) + if not is_ref_model: + output['nll_loss'] = -origin_per_token_logps[:, :num_tokens][loss_mask[:, :num_tokens]].mean() output['chosen_logps'] = all_logps[:num_examples] output['rejected_logps'] = all_logps[num_examples:] output['mean_chosen_logits'] = mean_all_logits[:, :num_tokens][loss_mask[:, :num_tokens]].mean() @@ -103,7 +109,8 @@ def concatenated_forward( else: all_logps = per_token_logps.sum(-1) num_examples = labels.shape[0] // 2 - output['nll_loss'] = self.get_nll_loss(all_logits[:num_examples], labels[:num_examples]) + if not is_ref_model: + output['nll_loss'] = -origin_per_token_logps[:num_examples][loss_mask[:num_examples]].mean() output['chosen_logps'] = all_logps[:num_examples] output['rejected_logps'] = all_logps[num_examples:] output['mean_chosen_logits'] = mean_all_logits[:num_examples][loss_mask[:num_examples]].mean() @@ -112,15 +119,16 @@ def concatenated_forward( output['aux_loss'] = outputs.aux_loss return output + @staticmethod def get_per_token_logps( - self, logits: torch.FloatTensor, labels: torch.LongTensor, + label_pad_token_id=-100, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if logits.shape[:-1] != labels.shape: raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}' 'and labels must have the same shape {labels.shape}') - loss_mask = labels != self.label_pad_token_id + loss_mask = labels != label_pad_token_id labels = labels.clone() labels[~loss_mask] = 0 # https://github.com/huggingface/trl/pull/2799 diff --git a/swift/trainers/sequence_parallel/ulysses.py b/swift/trainers/sequence_parallel/ulysses.py index da836bbeb7..2304c8dc4a 100644 --- a/swift/trainers/sequence_parallel/ulysses.py +++ b/swift/trainers/sequence_parallel/ulysses.py @@ -167,12 +167,12 @@ def old_policy(self): # For DPO -def get_per_token_logps(self, - logits: torch.FloatTensor, +def get_per_token_logps(logits: torch.FloatTensor, labels: torch.LongTensor, + label_pad_token_id=-100, ulysses=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: _, _, labels, _, _, _ = ulysses.pad_and_split_inputs(None, None, labels, None, None, None) - loss_mask = labels != self.label_pad_token_id + loss_mask = labels != label_pad_token_id labels = labels.clone() # No need to shift, pad and split has shifted the inputs. labels[~loss_mask] = 0 labels = labels.to(logits.device) @@ -827,7 +827,7 @@ def prepare_trainer(self, trainer): if trainer.__class__.__name__ in ('Seq2SeqTrainer', 'DPOTrainer'): trainer.compute_loss_func = partial(loss_scale_sp_func, ulysses=self) if trainer.__class__.__name__ == 'DPOTrainer': - trainer.get_per_token_logps = MethodType(partial(get_per_token_logps, ulysses=self), trainer) + trainer.get_per_token_logps = partial(get_per_token_logps, ulysses=self) def rlhf_loss_scale_sp_func(_, *args, **kwargs): return loss_scale_sp_func(*args, ulysses=self, **kwargs) From ab5bdfacb713b5d8d6ab55badb8122f078057ec8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Jun 2025 20:08:49 +0800 Subject: [PATCH 15/38] update --- tests/megatron/test_rlhf.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/tests/megatron/test_rlhf.py b/tests/megatron/test_rlhf.py index 54b2ed9ad7..b04597f650 100644 --- a/tests/megatron/test_rlhf.py +++ b/tests/megatron/test_rlhf.py @@ -1,6 +1,6 @@ import os -os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' +os.environ['CUDA_VISIBLE_DEVICES'] = '0' def test_dpo(): @@ -9,11 +9,32 @@ def test_dpo(): MegatronRLHFArguments( load='Qwen2.5-3B-Instruct-mcore', dataset=['hjh0119/shareAI-Llama3-DPO-zh-en-emoji#1000'], - tensor_model_parallel_size=2, + tensor_model_parallel_size=1, train_iters=100, eval_iters=5, finetune=True)) +# {'loss': 2.58519292, 'grad_norm': 17.12728882, 'learning_rate': 9.998e-05, 'memory(GiB)': 11.29, 'train_speed(iter/s)': 0.028085, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/chosen': -372.38751221, 'logps/rejected': -485.3223877, 'logits/chosen': -1.5612483, 'logits/rejected': -1.18372846, 'nll_loss': 1.89204574, 'epoch': 0.02, 'global_step/max_steps': '1/100', 'percentage': '1.00%', 'elapsed_time': '34s', 'remaining_time': '56m 45s'} +# {'loss': 2.03561759, 'grad_norm': 1.61355615, 'learning_rate': 9.938e-05, 'memory(GiB)': 11.29, 'train_speed(iter/s)': 0.04982, 'rewards/chosen': 3.43259692, 'rewards/rejected': 0.04113517, 'rewards/accuracies': 1.0, 'rewards/margins': 3.39146185, 'logps/chosen': -420.24295044, 'logps/rejected': -530.99975586, 'logits/chosen': -1.50232077, 'logits/rejected': -0.97954285, 'nll_loss': 1.89556026, 'epoch': 0.08, 'global_step/max_steps': '5/100', 'percentage': '5.00%', 'elapsed_time': '1m 39s', 'remaining_time': '31m 23s'} +# {'loss': 2.58330202, 'grad_norm': 17.195858, 'learning_rate': 9.998e-05, 'memory(GiB)': 10.31, 'train_speed(iter/s)': 0.046415, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/chosen': -372.38418579, 'logps/rejected': -485.07315063, 'logits/chosen': -1.56147575, 'logits/rejected': -1.18395948, 'nll_loss': 1.89015484, 'epoch': 0.02, 'global_step/max_steps': '1/100', 'percentage': '1.00%', 'elapsed_time': '20s', 'remaining_time': '33m 29s'} +# {'loss': 2.03590226, 'grad_norm': 1.62211466, 'learning_rate': 9.938e-05, 'memory(GiB)': 10.32, 'train_speed(iter/s)': 0.058378, 'rewards/chosen': 3.42427087, 'rewards/rejected': 0.01991111, 'rewards/accuracies': 1.0, 'rewards/margins': 3.40435982, 'logps/chosen': -420.37756348, 'logps/rejected': -531.13378906, 'logits/chosen': -1.50198746, 'logits/rejected': -0.97799051, 'nll_loss': 1.89604306, 'epoch': 0.08, 'global_step/max_steps': '5/100', 'percentage': '5.00%', 'elapsed_time': '1m 24s', 'remaining_time': '26m 43s'} + + +def test_hf(): + from swift.llm import rlhf_main, RLHFArguments + rlhf_main( + RLHFArguments( + model='Qwen/Qwen2.5-3B-Instruct', + dataset=['hjh0119/shareAI-Llama3-DPO-zh-en-emoji#1000'], + max_steps=100, + padding_free=True, + attn_impl='flash_attn', + train_dataloader_shuffle=False, + use_logits_to_keep=False, + )) + + if __name__ == '__main__': - test_dpo() + # test_dpo() + test_hf() From 515d476ba09021bc6941a0916fb70e566d29b3a5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 5 Jun 2025 23:31:32 +0800 Subject: [PATCH 16/38] update --- swift/megatron/argument/megatron_args.py | 7 +- swift/megatron/argument/rlhf_args.py | 4 + swift/megatron/train/rlhf.py | 99 ++++++++++++++++-------- 3 files changed, 78 insertions(+), 32 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 2cf80b72ac..16e25bfcca 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -15,9 +15,14 @@ @dataclass class RLHFMegatronArgumentsMixin: + ref_load: Optional[str] = None + beta: float = 0.1 rpo_alpha: float = 1. - ref_load: Optional[str] = None + reference_free: bool = False + label_smoothing: float = 0. + f_divergence_type: str = 'reverse_kl' + loss_type: str = 'sigmoid' @dataclass diff --git a/swift/megatron/argument/rlhf_args.py b/swift/megatron/argument/rlhf_args.py index a6104cacdf..2119c54dcd 100644 --- a/swift/megatron/argument/rlhf_args.py +++ b/swift/megatron/argument/rlhf_args.py @@ -1,9 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass from typing import Literal from .train_args import MegatronTrainArguments +@dataclass class MegatronRLHFArguments(MegatronTrainArguments): rlhf_type: Literal['dpo'] = 'dpo' loss_scale: str = 'last_round' + + calculate_per_token_loss: bool = False diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index d7abe72352..1fbd4390dc 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from collections import namedtuple from functools import partial from typing import List, Tuple, Union @@ -8,7 +9,8 @@ from megatron.training.checkpointing import load_checkpoint from megatron.training.utils import unwrap_model -from swift.utils import get_logger +from swift.trainers import DPOTrainer +from swift.utils import get_current_device, get_logger from ..argument import MegatronRLHFArguments from .sft import MegatronSft from .utils import get_batch @@ -16,10 +18,28 @@ logger = get_logger() +class DummyDPOTrainer(DPOTrainer): + # For reusing the dpo_loss function in TRL. + def __init__(self, args): + from trl.trainer import FDivergenceConstants + self.accelerator = namedtuple('Accelerator', ['device'])(device=get_current_device()) + self.f_alpha_divergence_coef = 1. + self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: self.f_alpha_divergence_coef} + self.reference_free = args.reference_free + self.label_smoothing = args.label_smoothing + self.f_divergence_type = args.f_divergence_type + self.loss_type = args.loss_type + self.beta = args.beta + + class MegatronRLHF(MegatronSft): args_class = MegatronRLHFArguments args: args_class + def __init__(self, args: Union[List[str], MegatronRLHFArguments, None] = None) -> None: + super().__init__(args) + self.dummy_dpo_trainer = DummyDPOTrainer(self.args) + def _prepare_template(self) -> None: super()._prepare_template() self.template.set_mode('rlhf') @@ -41,8 +61,6 @@ def setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) training.setup_model_and_optimizer = setup_model_and_optimizer def ref_forward(self, data_iterator): - from swift.trainers import DPOTrainer - args = get_args() ref_model = unwrap_model(self.ref_model) timers = get_timers() timers('batch-ref-generator', log_level=2).start() @@ -51,21 +69,24 @@ def ref_forward(self, data_iterator): if not data: raise StopIteration timers('batch-ref-generator').stop() - labels = data.pop('labels', None) + labels = data['labels'] with torch.no_grad(): - ref_logits = ref_model(**data) - ref_logits = ref_logits.to(torch.float32) - per_token_logps, _, loss_mask = DPOTrainer.get_per_token_logps(ref_logits, labels) - cu_seqlens = data['packed_seq_params'].cu_seqlens_q[:args.micro_batch_size * 2 + 1] - all_logps = per_token_logps.new_zeros((args.micro_batch_size, )) - for i in range(args.micro_batch_size): + output_tensor = ref_model(**data) + data['logps'] = self.get_logps(output_tensor, labels, data['packed_seq_params']) + return data + + @staticmethod + def get_logps(output_tensor, labels, packed_seq_params): + args = get_args() + per_token_logps = -output_tensor + loss_mask = labels != -100 + per_token_logps = per_token_logps * loss_mask + cu_seqlens = packed_seq_params.cu_seqlens_q[:args.micro_batch_size * 2 + 1] + all_logps = per_token_logps.new_zeros((args.micro_batch_size * 2, )) + for i in range(args.micro_batch_size * 2): start, end = cu_seqlens[i], cu_seqlens[i + 1] all_logps[i] = per_token_logps[:, start:end].sum() - num_tokens = cu_seqlens[args.micro_batch_size] - # output['nll_loss'] = self.get_nll_loss(all_logits[:, :num_tokens], labels[:, :num_tokens]) - data['labels'] = labels - data['logps'] = all_logps - return data + return all_logps def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): args = get_args() @@ -80,17 +101,41 @@ def run(self): self._patch_setup_model_and_optimizer() super().run() - def loss_func(self, loss_mask: torch.Tensor, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor): + def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, + packed_seq_params): from swift.trainers import DPOTrainer + args = get_args() + loss_mask = labels != -100 + num_tokens = packed_seq_params.cu_seqlens_q[args.micro_batch_size] + loss_mask[:, num_tokens:] = 0 nll_loss = torch.sum(output_tensor * loss_mask) / loss_mask.sum() - losses, chosen_rewards, rejected_rewards = DPOTrainer.dpo_loss(model_output['chosen_logps'], - model_output['rejected_logps'], ref_chosen_logps, - ref_rejected_logps) + logps = self.get_logps(output_tensor, labels, packed_seq_params) + loss, chosen_rewards, rejected_rewards = self.dummy_dpo_trainer.dpo_loss( + logps[:args.micro_batch_size], + logps[args.micro_batch_size:], + ref_logps[:args.micro_batch_size], + ref_logps[args.micro_batch_size:], + ) reward_accuracies = (chosen_rewards > rejected_rewards).float() + if args.rpo_alpha > 0: + loss = loss + args.rpo_alpha * nll_loss + loss = loss.mean() + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + return loss, { + 'loss': reporting_loss, + 'rewards/chosen': chosen_rewards, + 'rewards/rejected': rejected_rewards, + 'rewards/accuracies': reward_accuracies, + 'rewards/margins': chosen_rewards - rejected_rewards, + 'logps/chosen': logps[:args.micro_batch_size], + 'logps/rejected': logps[args.micro_batch_size:], + 'nll_loss': nll_loss + } def forward_step(self, data_iterator, model): timers = get_timers() - args = get_args() # Get the batch. timers('batch-generator', log_level=2).start() @@ -98,18 +143,10 @@ def forward_step(self, data_iterator, model): data = next(data_iterator) timers('batch-generator').stop() ref_logps = data.pop('logps') - labels = data.pop('labels') with self.stimer: - logits = model(**data) - if labels is None: - loss_mask = None - else: - loss = model.module.module.compute_language_model_loss(labels, logits) - cu_seqlens = data['packed_seq_params'].cu_seqlens_q[:args.micro_batch_size * 2 + 1] - num_tokens = cu_seqlens[args.micro_batch_size] - loss_mask = (labels != -100).float() - loss_mask[:, num_tokens:] = 0 - return logits, partial(self.loss_func, loss_mask, ref_logps=ref_logps) + output_tensor = model(**data) + return output_tensor, partial( + self.loss_func, ref_logps=ref_logps, labels=data['labels'], packed_seq_params=data['packed_seq_params']) def megatron_rlhf_main(args: Union[List[str], MegatronRLHFArguments, None] = None): From 15db9af2f1c9e24a3c380dc5235eaea4f3eb5df8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Jun 2025 10:09:47 +0800 Subject: [PATCH 17/38] update --- ...Megatron-SWIFT\350\256\255\347\273\203.md" | 2 +- .../Instruction/Megatron-SWIFT-Training.md | 2 +- examples/train/megatron/{ => moe}/moe.sh | 0 .../train/megatron/{ => moe}/qwen3_moe.sh | 0 examples/train/megatron/rlhf/dpo.sh | 34 +++++++++++++++++++ swift/megatron/argument/megatron_args.py | 2 +- swift/megatron/train/rlhf.py | 12 +++---- 7 files changed, 43 insertions(+), 9 deletions(-) rename examples/train/megatron/{ => moe}/moe.sh (100%) rename examples/train/megatron/{ => moe}/qwen3_moe.sh (100%) create mode 100644 examples/train/megatron/rlhf/dpo.sh diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index cf375aae4d..4bbf8eb3cb 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -221,7 +221,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I **日志参数**: - log_params_norm: 记录参数的norm。默认为False。 -- log_throughput: 记录每个GPU的吞吐量。默认为True。 +- log_throughput: 记录每个GPU的吞吐量。默认为False。 - 注意:在非packing情况下,log_throughput并不准确,因为`seq_length`并不等于真实序列长度。 - tensorboard_log_interval: 记录到tensorboard的间隔(steps),默认为1。 - tensorboard_queue_size: 队列长度(与磁盘IO相关),类似于写入的间隔。默认为50。 diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index 50998fcf3d..7bafeec671 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -229,7 +229,7 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the **Logging Parameters**: - log_params_norm: Logs the norm of parameters. Default is False. -- log_throughput: Logs throughput per GPU. Default is True. +- log_throughput: Logs throughput per GPU. Default is False. - Note: In non-packing scenarios, log_throughput is not accurate because `seq_length` does not equal the actual sequence length. - tensorboard_log_interval: Interval (steps) for logging to TensorBoard, default is 1. - tensorboard_queue_size: Queue length (related to disk I/O), similar to write intervals. Default is 50. diff --git a/examples/train/megatron/moe.sh b/examples/train/megatron/moe/moe.sh similarity index 100% rename from examples/train/megatron/moe.sh rename to examples/train/megatron/moe/moe.sh diff --git a/examples/train/megatron/qwen3_moe.sh b/examples/train/megatron/moe/qwen3_moe.sh similarity index 100% rename from examples/train/megatron/qwen3_moe.sh rename to examples/train/megatron/moe/qwen3_moe.sh diff --git a/examples/train/megatron/rlhf/dpo.sh b/examples/train/megatron/rlhf/dpo.sh new file mode 100644 index 0000000000..2e06c17214 --- /dev/null +++ b/examples/train/megatron/rlhf/dpo.sh @@ -0,0 +1,34 @@ +# 4 * 78GiB +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=4 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +megatron rlhf \ + --rlhf_type dpo \ + --load Qwen3-8B-Base-mcore \ + --dataset 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#20000' \ + --tensor_model_parallel_size 4 \ + --micro_batch_size 16 \ + --global_batch_size 16 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --max_epochs 1 \ + --eval_iters 20 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_iters 50 \ + --min_lr 1e-6 \ + --save megatron_output/Qwen3-32B \ + --eval_interval 200 \ + --save_interval 200 \ + --max_length 8192 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --attention_backend flash \ + --beta 0.1 \ + --rpo_alpha 1 \ + --loss_type sigmoid diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 16e25bfcca..3511d0c2ff 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -162,7 +162,7 @@ class MegatronArguments(ExtraMegatronArguments): # logging log_params_norm: bool = False - log_throughput: bool = True + log_throughput: bool = False tensorboard_log_interval: int = 1 tensorboard_queue_size: int = 50 log_timers_to_tensorboard: bool = True diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 1fbd4390dc..be34543d30 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -125,12 +125,12 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab return loss, { 'loss': reporting_loss, - 'rewards/chosen': chosen_rewards, - 'rewards/rejected': rejected_rewards, - 'rewards/accuracies': reward_accuracies, - 'rewards/margins': chosen_rewards - rejected_rewards, - 'logps/chosen': logps[:args.micro_batch_size], - 'logps/rejected': logps[args.micro_batch_size:], + 'rewards/chosen': chosen_rewards.mean(), + 'rewards/rejected': rejected_rewards.mean(), + 'rewards/accuracies': reward_accuracies.mean(), + 'rewards/margins': (chosen_rewards - rejected_rewards).mean(), + 'logps/chosen': logps[:args.micro_batch_size].mean(), + 'logps/rejected': logps[args.micro_batch_size:].mean(), 'nll_loss': nll_loss } From f8d29a75defd6bc36e8a4d1fa946c242685761c7 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Jun 2025 10:12:11 +0800 Subject: [PATCH 18/38] update shell --- examples/train/megatron/rlhf/dpo.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/train/megatron/rlhf/dpo.sh b/examples/train/megatron/rlhf/dpo.sh index 2e06c17214..a8ea3dd2f3 100644 --- a/examples/train/megatron/rlhf/dpo.sh +++ b/examples/train/megatron/rlhf/dpo.sh @@ -1,4 +1,4 @@ -# 4 * 78GiB +# 4 * 60GiB PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ From c96cef778e7ca35810e39c05f52f7fbf87d0f4db Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Jun 2025 10:19:59 +0800 Subject: [PATCH 19/38] update --- tests/megatron/test_rlhf.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/megatron/test_rlhf.py b/tests/megatron/test_rlhf.py index b04597f650..2cd8432c86 100644 --- a/tests/megatron/test_rlhf.py +++ b/tests/megatron/test_rlhf.py @@ -1,6 +1,6 @@ import os -os.environ['CUDA_VISIBLE_DEVICES'] = '0' +os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' def test_dpo(): @@ -9,18 +9,14 @@ def test_dpo(): MegatronRLHFArguments( load='Qwen2.5-3B-Instruct-mcore', dataset=['hjh0119/shareAI-Llama3-DPO-zh-en-emoji#1000'], - tensor_model_parallel_size=1, + micro_batch_size=16, + tensor_model_parallel_size=2, train_iters=100, eval_iters=5, + log_interval=1, finetune=True)) -# {'loss': 2.58519292, 'grad_norm': 17.12728882, 'learning_rate': 9.998e-05, 'memory(GiB)': 11.29, 'train_speed(iter/s)': 0.028085, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/chosen': -372.38751221, 'logps/rejected': -485.3223877, 'logits/chosen': -1.5612483, 'logits/rejected': -1.18372846, 'nll_loss': 1.89204574, 'epoch': 0.02, 'global_step/max_steps': '1/100', 'percentage': '1.00%', 'elapsed_time': '34s', 'remaining_time': '56m 45s'} -# {'loss': 2.03561759, 'grad_norm': 1.61355615, 'learning_rate': 9.938e-05, 'memory(GiB)': 11.29, 'train_speed(iter/s)': 0.04982, 'rewards/chosen': 3.43259692, 'rewards/rejected': 0.04113517, 'rewards/accuracies': 1.0, 'rewards/margins': 3.39146185, 'logps/chosen': -420.24295044, 'logps/rejected': -530.99975586, 'logits/chosen': -1.50232077, 'logits/rejected': -0.97954285, 'nll_loss': 1.89556026, 'epoch': 0.08, 'global_step/max_steps': '5/100', 'percentage': '5.00%', 'elapsed_time': '1m 39s', 'remaining_time': '31m 23s'} -# {'loss': 2.58330202, 'grad_norm': 17.195858, 'learning_rate': 9.998e-05, 'memory(GiB)': 10.31, 'train_speed(iter/s)': 0.046415, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/chosen': -372.38418579, 'logps/rejected': -485.07315063, 'logits/chosen': -1.56147575, 'logits/rejected': -1.18395948, 'nll_loss': 1.89015484, 'epoch': 0.02, 'global_step/max_steps': '1/100', 'percentage': '1.00%', 'elapsed_time': '20s', 'remaining_time': '33m 29s'} -# {'loss': 2.03590226, 'grad_norm': 1.62211466, 'learning_rate': 9.938e-05, 'memory(GiB)': 10.32, 'train_speed(iter/s)': 0.058378, 'rewards/chosen': 3.42427087, 'rewards/rejected': 0.01991111, 'rewards/accuracies': 1.0, 'rewards/margins': 3.40435982, 'logps/chosen': -420.37756348, 'logps/rejected': -531.13378906, 'logits/chosen': -1.50198746, 'logits/rejected': -0.97799051, 'nll_loss': 1.89604306, 'epoch': 0.08, 'global_step/max_steps': '5/100', 'percentage': '5.00%', 'elapsed_time': '1m 24s', 'remaining_time': '26m 43s'} - - def test_hf(): from swift.llm import rlhf_main, RLHFArguments rlhf_main( @@ -36,5 +32,5 @@ def test_hf(): if __name__ == '__main__': - # test_dpo() - test_hf() + test_dpo() + # test_hf() From b6fc6d92a81d060958b4097df57a8536eca7a90c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Jun 2025 13:29:46 +0800 Subject: [PATCH 20/38] update --- swift/megatron/init.py | 235 ++++++++++++++++++++++++++++++++++++ swift/megatron/train/sft.py | 26 ++-- 2 files changed, 248 insertions(+), 13 deletions(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index db7cb988db..958e94a36c 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -1,7 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import datetime import os import sys +import torch + from swift.llm import git_clone_github from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run @@ -41,9 +44,241 @@ def _batched_p2p_ops(**kwargs): p2p_communication._batched_p2p_ops = _batched_p2p_ops +def _patch_training_log(): + from megatron.core import mpu + from megatron.core.transformer.moe.moe_utils import track_moe_metrics + from megatron.training.theoretical_memory_usage import report_theoretical_memory + from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper + from megatron.training import (training, get_args, get_timers, get_tensorboard_writer, get_wandb_writer, + get_one_logger, one_logger_utils, is_last_rank, print_rank_last) + from megatron.training.training import num_floating_point_operations + from megatron.core.num_microbatches_calculator import get_num_microbatches + from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory + + def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, + report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad): + """Log training information such as losses, timing, ....""" + args = get_args() + timers = get_timers() + writer = get_tensorboard_writer() + wandb_writer = get_wandb_writer() + + # Advanced, skipped, and Nan iterations. + advanced_iters_key = 'advanced iterations' + skipped_iters_key = 'skipped iterations' + nan_iters_key = 'nan iterations' + # Advanced iterations. + if not skipped_iter: + total_loss_dict[advanced_iters_key] = total_loss_dict.get(advanced_iters_key, 0) + 1 + else: + if advanced_iters_key not in total_loss_dict: + total_loss_dict[advanced_iters_key] = 0 + # Skipped iterations. + total_loss_dict[skipped_iters_key] = total_loss_dict.get(skipped_iters_key, 0) + skipped_iter + # Update losses and set nan iterations + got_nan = False + for key in loss_dict: + if not skipped_iter: + total_loss_dict[key] = total_loss_dict.get(key, torch.tensor([0.0], dtype=torch.float, + device='cuda')) + loss_dict[key] + else: + value = loss_dict[key].float().sum().item() + is_nan = value == float('inf') or value == -float('inf') or value != value + got_nan = got_nan or is_nan + total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int(got_nan) + + # Logging. + timers_to_log = [ + 'forward-backward', 'forward-compute', 'backward-compute', 'batch-generator', 'forward-recv', + 'forward-send', 'backward-recv', 'backward-send', 'forward-send-forward-recv', 'forward-send-backward-recv', + 'backward-send-forward-recv', 'backward-send-backward-recv', 'forward-backward-send-forward-backward-recv', + 'layernorm-grads-all-reduce', 'embedding-grads-all-reduce', 'all-grads-sync', 'params-all-gather', + 'optimizer-copy-to-main-grad', 'optimizer-unscale-and-check-inf', 'optimizer-clip-main-grad', + 'optimizer-count-zeros', 'optimizer-inner-step', 'optimizer-copy-main-to-model-params', 'optimizer' + ] + + # Calculate batch size. + batch_size = args.micro_batch_size * args.data_parallel_size * get_num_microbatches() + + # Track app tag & app tag ID + one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length) + + total_iterations = total_loss_dict[advanced_iters_key] + total_loss_dict[skipped_iters_key] + + # learning rate will be None on ranks without trainable params, so we must gather across mp ranks + learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate) + # Tensorboard values. + # Timer requires all the ranks to call. + if args.log_timers_to_tensorboard and (iteration % args.tensorboard_log_interval == 0): + timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) + if writer and (iteration % args.tensorboard_log_interval == 0): + if wandb_writer: + wandb_writer.log({'samples vs steps': args.consumed_train_samples}, iteration) + writer.add_scalar('learning-rate', learning_rate, iteration) + writer.add_scalar('learning-rate vs samples', learning_rate, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'learning-rate': learning_rate}, iteration) + if args.decoupled_lr is not None: + writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration) + if args.skipped_train_samples > 0: + writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration) + if wandb_writer: + wandb_writer.log({'skipped-train-samples': args.skipped_train_samples}, iteration) + writer.add_scalar('batch-size', batch_size, iteration) + writer.add_scalar('batch-size vs samples', batch_size, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'batch-size': batch_size}, iteration) + for key in loss_dict: + writer.add_scalar(key, loss_dict[key], iteration) + writer.add_scalar(key + ' vs samples', loss_dict[key], args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({key: loss_dict[key]}, iteration) + if args.log_loss_scale_to_tensorboard: + writer.add_scalar('loss-scale', loss_scale, iteration) + writer.add_scalar('loss-scale vs samples', loss_scale, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'loss-scale': loss_scale}, iteration) + if args.log_world_size_to_tensorboard: + writer.add_scalar('world-size', args.world_size, iteration) + writer.add_scalar('world-size vs samples', args.world_size, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'world-size': args.world_size}, iteration) + if grad_norm is not None: + writer.add_scalar('grad-norm', grad_norm, iteration) + writer.add_scalar('grad-norm vs samples', grad_norm, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'grad-norm': grad_norm}, iteration) + if num_zeros_in_grad is not None: + writer.add_scalar('num-zeros', num_zeros_in_grad, iteration) + writer.add_scalar('num-zeros vs samples', num_zeros_in_grad, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration) + if params_norm is not None: + writer.add_scalar('params-norm', params_norm, iteration) + writer.add_scalar('params-norm vs samples', params_norm, args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'params-norm': params_norm}, iteration) + if args.log_memory_to_tensorboard: + mem_stats = torch.cuda.memory_stats() + writer.add_scalar( + 'mem-reserved-bytes', + mem_stats['reserved_bytes.all.current'], + iteration, + ) + writer.add_scalar( + 'mem-allocated-bytes', + mem_stats['allocated_bytes.all.current'], + iteration, + ) + writer.add_scalar( + 'mem-max-allocated-bytes', + mem_stats['allocated_bytes.all.peak'], + iteration, + ) + writer.add_scalar( + 'mem-allocated-count', + mem_stats['allocation.all.current'], + iteration, + ) + if args.num_experts is not None: + moe_loss_scale = 1 / get_num_microbatches() + track_names = [] + if args.moe_router_load_balancing_type in ['aux_loss', 'seq_aux_loss']: + track_names.append('load_balancing_loss') + if args.moe_z_loss_coeff is not None: + track_names.append('z_loss') + track_moe_metrics( + loss_scale=moe_loss_scale, + iteration=iteration, + writer=writer, + wandb_writer=wandb_writer, + total_loss_dict=total_loss_dict, + per_layer_logging=args.moe_per_layer_logging, + force_initialize=True, + track_names=track_names, + num_layers=args.num_layers, + moe_layer_freq=args.moe_layer_freq) + if args.mtp_num_layers is not None: + mtp_loss_scale = 1 / get_num_microbatches() + MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict) + if iteration % args.log_interval == 0: + if args.record_memory_history and is_last_rank(): + snapshot = torch.cuda.memory._snapshot() + from pickle import dump + with open(args.memory_snapshot_path, 'wb') as f: + dump(snapshot, f) + + elapsed_time = timers('interval-time').elapsed(barrier=True) + elapsed_time_per_iteration = elapsed_time / total_iterations + + throughput = num_floating_point_operations(args, batch_size) / ( + elapsed_time_per_iteration * 10**12 * args.world_size) + + one_logger_utils.track_e2e_metrics(args.log_throughput, throughput) + + if args.log_timers_to_tensorboard: + if writer: + writer.add_scalar('iteration-time', elapsed_time_per_iteration, iteration) + if wandb_writer: + wandb_writer.log({'iteration-time': elapsed_time_per_iteration}, iteration) + log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]" + log_string += ' iteration {:8d}/{:8d} |'.format(iteration, args.train_iters) + log_string += ' consumed samples: {:12d} |'.format(args.consumed_train_samples) + if args.skipped_train_samples > 0: + log_string += ' skipped samples: {:12d} |'.format(args.skipped_train_samples) + log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(elapsed_time_per_iteration * 1000.0) + if args.log_throughput: + log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |' + if args.log_timers_to_tensorboard: + if writer: + writer.add_scalar('throughput', throughput, iteration) + if wandb_writer: + wandb_writer.log({'throughput': throughput}, iteration) + # Decoupled_learning_rate should be not None only on first and last pipeline stage. + log_string += f' learning rate: {learning_rate:.6E} |' + if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) + or mpu.is_pipeline_last_stage(ignore_virtual=True)): + assert decoupled_learning_rate is not None + log_string += f' decoupled learning rate: {decoupled_learning_rate:.6E} |' + else: + assert decoupled_learning_rate is None + log_string += f' global batch size: {batch_size:5d} |' + for key in total_loss_dict: + if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]: + avg = total_loss_dict[key].item() / float(max(1, total_loss_dict[advanced_iters_key])) + log_string += ' {}: {:.6E} |'.format(key, avg) + total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda') + log_string += f' loss scale: {loss_scale:.1f} |' + if grad_norm is not None: + log_string += f' grad norm: {grad_norm:.3f} |' + if num_zeros_in_grad is not None: + log_string += f' num zeros: {num_zeros_in_grad} |' + if params_norm is not None: + log_string += f' params norm: {params_norm:.3f} |' + log_string += ' number of skipped iterations: {:3d} |'.format(total_loss_dict[skipped_iters_key]) + log_string += ' number of nan iterations: {:3d} |'.format(total_loss_dict[nan_iters_key]) + total_loss_dict[advanced_iters_key] = 0 + total_loss_dict[skipped_iters_key] = 0 + total_loss_dict[nan_iters_key] = 0 + print_rank_last(log_string) + if report_memory_flag: + # Report memory after optimizer state has been initialized. + if torch.distributed.get_rank() == 0: + num_microbatches = get_num_microbatches() + report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True) + report_memory(f'(after {iteration} iterations)') + report_memory_flag = False + timers.log(timers_to_log, normalizer=args.log_interval) + + return report_memory_flag + + training.training_log = training_log + + def _patch_megatron(): _patch_transformer_engine() _patch__batched_p2p_ops() + _patch_training_log() def init_megatron_env() -> None: diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 5661e7d617..cded8ad1a9 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -83,20 +83,20 @@ def _training_context(): args.is_training = False def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - return self._train_step_origin(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config) - - def _patch_train_step(self): + with self._training_context(): + try: + return self._origin_train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, + config) + except StopIteration: + return {}, True, True, True, 0, None, None + + def _patch_megatron(self): # support max_epochs - def train_step(*args, **kwargs): - with self._training_context(): - try: - return self.train_step(*args, **kwargs) - except StopIteration: - return {}, True, True, True, 0, None, None - - self._train_step_origin = training.train_step - training.train_step = train_step + self._origin_train_step = training.train_step + training.train_step = self.train_step training.cyclic_iter = MegatronSft.new_cyclic_iter + # patch training_log + self._origin_training_log = training.training_log def forward_step(self, data_iterator, model): from pretrain_gpt import loss_func @@ -119,7 +119,7 @@ def forward_step(self, data_iterator, model): def run(self): args = self.args - self._patch_train_step() + self._patch_megatron() train_dataset, val_dataset = self._get_dataset() train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) From 95f74fae5872bb71b45b01053e3355718ceb3708 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Jun 2025 13:30:14 +0800 Subject: [PATCH 21/38] update --- swift/megatron/init.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 958e94a36c..8b951c7529 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -45,6 +45,7 @@ def _batched_p2p_ops(**kwargs): def _patch_training_log(): + # TODO: support swanlab from megatron.core import mpu from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.training.theoretical_memory_usage import report_theoretical_memory From ac1c33bb8ee7257bb149dae670655659378888f6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Jun 2025 14:12:28 +0800 Subject: [PATCH 22/38] update --- examples/train/megatron/rlhf/dpo.sh | 2 +- swift/megatron/init.py | 8 ++++++-- swift/megatron/train/rlhf.py | 27 ++++----------------------- 3 files changed, 11 insertions(+), 26 deletions(-) diff --git a/examples/train/megatron/rlhf/dpo.sh b/examples/train/megatron/rlhf/dpo.sh index a8ea3dd2f3..a1d1ae499b 100644 --- a/examples/train/megatron/rlhf/dpo.sh +++ b/examples/train/megatron/rlhf/dpo.sh @@ -19,7 +19,7 @@ megatron rlhf \ --lr 1e-5 \ --lr_warmup_iters 50 \ --min_lr 1e-6 \ - --save megatron_output/Qwen3-32B \ + --save megatron_output/Qwen3-8B-Base \ --eval_interval 200 \ --save_interval 200 \ --max_length 8192 \ diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 8b951c7529..0f1a2ae4f7 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import datetime import os import sys +from datetime import datetime import torch @@ -279,7 +279,11 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r def _patch_megatron(): _patch_transformer_engine() _patch__batched_p2p_ops() - _patch_training_log() + try: + _patch_training_log() + logger.info('Patch training_log successfully applied.') + except Exception: + pass def init_megatron_env() -> None: diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index be34543d30..fef2e12b09 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -5,7 +5,7 @@ import torch from megatron.core import mpu -from megatron.training import get_args, get_model, get_timers, training +from megatron.training import get_args, get_model, training from megatron.training.checkpointing import load_checkpoint from megatron.training.utils import unwrap_model @@ -38,6 +38,7 @@ class MegatronRLHF(MegatronSft): def __init__(self, args: Union[List[str], MegatronRLHFArguments, None] = None) -> None: super().__init__(args) + self._patch_setup_model_and_optimizer() self.dummy_dpo_trainer = DummyDPOTrainer(self.args) def _prepare_template(self) -> None: @@ -62,13 +63,10 @@ def setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) def ref_forward(self, data_iterator): ref_model = unwrap_model(self.ref_model) - timers = get_timers() - timers('batch-ref-generator', log_level=2).start() with self.stimer(bdata=True): data = get_batch(data_iterator) if not data: raise StopIteration - timers('batch-ref-generator').stop() labels = data['labels'] with torch.no_grad(): output_tensor = ref_model(**data) @@ -88,19 +86,6 @@ def get_logps(output_tensor, labels, packed_seq_params): all_logps[i] = per_token_logps[:, start:end].sum() return all_logps - def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - args = get_args() - num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) - res = [] - for i in range(num_iters_per_step): - with torch.no_grad(): - res.append(self.ref_forward(data_iterator)) - return super().train_step(forward_step_func, iter(res), model, optimizer, opt_param_scheduler, config) - - def run(self): - self._patch_setup_model_and_optimizer() - super().run() - def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, packed_seq_params): from swift.trainers import DPOTrainer @@ -135,13 +120,9 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab } def forward_step(self, data_iterator, model): - timers = get_timers() + with torch.no_grad(): + data = self.ref_forward(data_iterator) - # Get the batch. - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = next(data_iterator) - timers('batch-generator').stop() ref_logps = data.pop('logps') with self.stimer: output_tensor = model(**data) From 7a75d938bf4723d018557ad4cf5a00e839e630ee Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Jun 2025 15:34:23 +0800 Subject: [PATCH 23/38] update --- swift/megatron/argument/megatron_args.py | 2 +- swift/megatron/init.py | 1 + swift/megatron/train/sft.py | 12 +++++++----- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 3511d0c2ff..ce91a3efb6 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -175,7 +175,7 @@ class MegatronArguments(ExtraMegatronArguments): wandb_save_dir: Optional[str] = None # evaluate - eval_iters: int = 100 + eval_iters: Optional[int] = None eval_interval: Optional[int] = None # other diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 0f1a2ae4f7..0d729c9d87 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -56,6 +56,7 @@ def _patch_training_log(): from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory + # Code borrowed from megatron-lm def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad): """Log training information such as losses, timing, ....""" diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index cded8ad1a9..71bbf75d53 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -36,19 +36,21 @@ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) self.stimer = StragglerDetector() @contextmanager - def _get_train_iters(self, train_dataset): + def _get_iters(self, train_dataset, val_dataset): from megatron.training import training origin_initialize_megatron = training.initialize_megatron def initialize_megatron(*_args, **kwargs): res = origin_initialize_megatron(*_args, **kwargs) args = get_args() + data_parallel_size = mpu.get_data_parallel_world_size() + step_batch_size = args.micro_batch_size * data_parallel_size if args.train_iters is None and hasattr(train_dataset, '__len__'): - data_parallel_size = mpu.get_data_parallel_world_size() - step_batch_size = \ - args.micro_batch_size * data_parallel_size dataset_sample = len(train_dataset) // step_batch_size * step_batch_size args.train_iters = (dataset_sample * args.max_epochs // args.global_batch_size) + 1 + if val_dataset is not None and args.eval_iters is None and hasattr(val_dataset, '__len__'): + dataset_sample = len(val_dataset) // step_batch_size * step_batch_size + args.eval_iters = max(dataset_sample // args.global_batch_size, 1) return res training.initialize_megatron = initialize_megatron @@ -134,7 +136,7 @@ def run(self): logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'The logging file will be saved in: {logging_path}') try: - with patch_megatron_data_collator(data_collator), self._get_train_iters(train_dataset): + with patch_megatron_data_collator(data_collator), self._get_iters(train_dataset, val_dataset): extra_args_provider = args.megatron_model_meta.extra_args_provider pretrain( datasets_provider, From d981bc5a093c678423c41c3d6372e8a929319b89 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 6 Jun 2025 15:34:55 +0800 Subject: [PATCH 24/38] update --- tests/megatron/test_rlhf.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/megatron/test_rlhf.py b/tests/megatron/test_rlhf.py index 2cd8432c86..7b8a8013cd 100644 --- a/tests/megatron/test_rlhf.py +++ b/tests/megatron/test_rlhf.py @@ -8,13 +8,14 @@ def test_dpo(): megatron_rlhf_main( MegatronRLHFArguments( load='Qwen2.5-3B-Instruct-mcore', - dataset=['hjh0119/shareAI-Llama3-DPO-zh-en-emoji#1000'], + dataset=['hjh0119/shareAI-Llama3-DPO-zh-en-emoji#10000'], micro_batch_size=16, tensor_model_parallel_size=2, - train_iters=100, - eval_iters=5, + eval_interval=5, log_interval=1, - finetune=True)) + finetune=True, + max_epochs=1, + )) def test_hf(): From 9b089b595e7b4b3d88a8537687ed2cb0198120ad Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 7 Jun 2025 17:54:35 +0800 Subject: [PATCH 25/38] fix dpo emoji dataset --- swift/llm/dataset/dataset/llm.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 96e646f39d..8b9f70714d 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -684,11 +684,22 @@ def repair_conversations(s: Union[str, Any]) -> Any: preprocess_func=MessagesPreprocessor(repair_messages=repair_conversations), tags=['chat', 'em'])) + +class EmojiPreprocessr(ResponsePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: + # Remove dirty characters + row['query'] = row['query'].replace('️', '') + row['response'] = row['response'].replace('️', '') + row['rejected_response'] = row['rejected_response'].replace('️', '') + return super().preprocess(row) + + register_dataset( DatasetMeta( ms_dataset_id='hjh0119/shareAI-Llama3-DPO-zh-en-emoji', hf_dataset_id='shareAI/DPO-zh-en-emoji', - preprocess_func=ResponsePreprocessor(columns={ + preprocess_func=EmojiPreprocessr(columns={ 'answer_zh': 'response', 'answer_en': 'rejected_response' }), From a6067bd990f47813be9c4a23c4f261e5c3cde581 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 9 Jun 2025 22:37:53 +0800 Subject: [PATCH 26/38] update --- .../Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" | 2 +- docs/source_en/Instruction/Megatron-SWIFT-Training.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index 4bbf8eb3cb..64d778c7d9 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -172,7 +172,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I - seq_length: 默认为None,即设置为`max_length`。对数据集长度进行限制请使用基本参数中的`--max_length`控制,无需设置此参数。 - use_cpu_initialization: 在cpu上初始化权重,默认为False。在进行HF和MCore权重转换时会被使用。 - no_create_attention_mask_in_dataloader: 在dataloader中不创建attention mask,默认为True。 -- extra_megatron_kwargs: Additional parameters passed to Megatron, provided as a JSON object. Defaults to None. +- extra_megatron_kwargs: 传入megatron的其他参数,使用json传递。默认为None。 **学习率参数**: - 🔥lr: 初始学习率,最终会根据学习率预热策略和衰减策略决定每个迭代的学习率,默认为1e-5。 diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index 7bafeec671..de1a6adc99 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -175,7 +175,7 @@ The speed comparison of full-parameter training for Dense/MoE models using `mega seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the dataset length, please use the `--max_length` parameter in the basic arguments; there is no need to set this parameter. - use_cpu_initialization: Initializes weights on the CPU, default is False. Used during HF and MCore weight conversion. - no_create_attention_mask_in_dataloader: Does not create an attention mask in the dataloader, default is True. -- extra_megatron_kwargs: 传入megatron的其他参数,使用json传递。默认为None。 +- extra_megatron_kwargs: Additional parameters passed to Megatron, provided as a JSON object. Defaults to None. **Learning Rate Parameters**: From bd46a59cd1be3fb4107926963ed7be8768a11e65 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 9 Jun 2025 23:28:51 +0800 Subject: [PATCH 27/38] update --- .../Megatron-SWIFT\350\256\255\347\273\203.md" | 12 +++++++++++- .../Instruction/Megatron-SWIFT-Training.md | 13 ++++++++++++- swift/megatron/argument/megatron_args.py | 2 +- swift/megatron/train/sft.py | 18 ++++++++++++------ 4 files changed, 36 insertions(+), 9 deletions(-) diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index 64d778c7d9..5eafb6674a 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -235,7 +235,8 @@ I am a language model developed by swift, you can call me swift-robot. How can I - wandb_save_dir: 本地保存 wandb 结果的路径。默认为''。 **评估参数**: -- 🔥eval_iters: 评估的迭代次数,默认为100。 +- 🔥eval_iters: 评估的迭代次数,默认为-1,根据验证数据集的数量设置合适的值。 + - 注意:若使用流式数据集,该值需要手动设置。 - 🔥eval_interval: 评估的间隔(steps),默认为None,即设置为save_interval。 **混合精度参数**: @@ -289,6 +290,15 @@ I am a language model developed by swift, you can call me swift-robot. How can I - moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。 - moe_shared_expert_overlap: 启用共享专家计算与调度器通信之间的重叠。如果不启用此选项,共享专家将在路由专家之后执行。仅在设置了`moe_shared_expert_intermediate_size`时有效。默认为False。 +**DPO参数**: +- ref_load: ref_model的加载路径。默认为None,即设置为`load`。 +- beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。 +- rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失)。`loss = dpo_loss + rpo_alpha * nll_loss`。默认为1。 +- reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。 +- label_smoothing: 默认为0.。 +- f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 +- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 + ### Megatron训练参数 diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index de1a6adc99..34dcfe58d8 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -244,7 +244,8 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the **Evaluation Parameters**: -- 🔥eval_iters: Number of evaluation iterations, default is 100. +- 🔥eval_iters: The number of iterations for evaluation. Defaults to -1, and a suitable value will be set based on the size of the validation dataset. + - Note: If using a streaming dataset, this value needs to be set manually. - 🔥eval_interval: Evaluation interval (steps), default is None, meaning it will be set to save_interval. **Mixed Precision Parameters**: @@ -301,6 +302,16 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the - moe_expert_capacity_factor: Capacity factor for each expert, None means no tokens will be dropped. Default is None. - moe_shared_expert_overlap: Enable overlapping of shared expert computation with scheduler communication. If this option is not enabled, shared experts will execute after the routing experts. Only effective when `moe_shared_expert_intermediate_size` is set. Default is False. +**DPO Parameters** +- ref_load: The path to load the reference model. Defaults to `None`, which means it will be set to `load`. +- beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1. +- rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) used to control the weight of the NLL term (i.e., SFT loss) in the loss function. The total loss is calculated as `loss = dpo_loss + rpo_alpha * nll_loss`. Default is 1. +- reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`. +- label_smoothing: Default is 0. +- f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. +- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. + + ### Megatron Training Parameters Megatron training parameters inherit from Megatron parameters and basic parameters. For information on basic parameters, see [here](./Command-line-parameters.md#base-arguments). Additionally, the following parameters are included: diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index ce91a3efb6..a00401a9f9 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -175,7 +175,7 @@ class MegatronArguments(ExtraMegatronArguments): wandb_save_dir: Optional[str] = None # evaluate - eval_iters: Optional[int] = None + eval_iters: int = -1 eval_interval: Optional[int] = None # other diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 71bbf75d53..7958641618 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -45,12 +45,18 @@ def initialize_megatron(*_args, **kwargs): args = get_args() data_parallel_size = mpu.get_data_parallel_world_size() step_batch_size = args.micro_batch_size * data_parallel_size - if args.train_iters is None and hasattr(train_dataset, '__len__'): - dataset_sample = len(train_dataset) // step_batch_size * step_batch_size - args.train_iters = (dataset_sample * args.max_epochs // args.global_batch_size) + 1 - if val_dataset is not None and args.eval_iters is None and hasattr(val_dataset, '__len__'): - dataset_sample = len(val_dataset) // step_batch_size * step_batch_size - args.eval_iters = max(dataset_sample // args.global_batch_size, 1) + if args.train_iters is None: + if hasattr(train_dataset, '__len__'): + dataset_sample = len(train_dataset) // step_batch_size * step_batch_size + args.train_iters = (dataset_sample * args.max_epochs // args.global_batch_size) + 1 + else: + raise ValueError('You are using a streaming training dataset. Please explicitly specify `--train_iters`.') + if val_dataset is not None and args.eval_iters < 0: + if hasattr(val_dataset, '__len__'): + dataset_sample = len(val_dataset) // step_batch_size * step_batch_size + args.eval_iters = max(dataset_sample // args.global_batch_size, 1) + else: + raise ValueError('You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') return res training.initialize_megatron = initialize_megatron From 568b3aa5633a56e70b5874b63d77f61d2ff3d459 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 10 Jun 2025 01:24:39 +0800 Subject: [PATCH 28/38] update --- examples/train/megatron/rlhf/moe.sh | 36 +++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 examples/train/megatron/rlhf/moe.sh diff --git a/examples/train/megatron/rlhf/moe.sh b/examples/train/megatron/rlhf/moe.sh new file mode 100644 index 0000000000..6026fc3843 --- /dev/null +++ b/examples/train/megatron/rlhf/moe.sh @@ -0,0 +1,36 @@ +# 8 * 64GiB +NPROC_PER_NODE=8 \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +megatron rlhf \ + --rlhf_type dpo \ + --load Qwen1.5-MoE-A2.7B-mcore \ + --dataset 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#20000' \ + --tensor_model_parallel_size 2 \ + --expert_model_parallel_size 4 \ + --moe_grouped_gemm true \ + --moe_shared_expert_overlap true \ + --moe_aux_loss_coeff 0.01 \ + --micro_batch_size 4 \ + --global_batch_size 16 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --max_epochs 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_iters 100 \ + --min_lr 1e-6 \ + --save megatron_output/Qwen1.5-MoE-A2.7B \ + --eval_interval 200 \ + --save_interval 200 \ + --max_length 8192 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --attention_backend flash \ + --beta 0.1 \ + --rpo_alpha 1 \ + --loss_type sigmoid From 0dff938ed90252dc6c1c9f9438021b4431ce0e82 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 10 Jun 2025 02:00:37 +0800 Subject: [PATCH 29/38] update --- .../Megatron-SWIFT\350\256\255\347\273\203.md" | 11 +++++++++-- .../Instruction/Megatron-SWIFT-Training.md | 13 +++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index 5eafb6674a..59048f925e 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -300,14 +300,21 @@ I am a language model developed by swift, you can call me swift-robot. How can I - loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 -### Megatron训练参数 +### 训练参数 Megatron训练参数继承自Megatron参数和基本参数。基本参数的内容可以参考[这里](./命令行参数.md#基本参数)。此外还包括以下参数: - add_version: 在`save`上额外增加目录`'<版本号>-<时间戳>'`防止权重覆盖,默认为True。 -- 🔥packing: 是否使用序列packing,默认为False。 +- 🔥packing: 是否使用序列packing,默认为False。当前支持`megatron pt/sft`。 - 🔥packing_cache: 指定 packing 缓存目录。默认值为`None`,表示缓存将存储在环境变量 `$MODELSCOPE_CACHE`所指定的路径下。在跨节点使用 packing 功能时,需确保所有节点的 packing 缓存路径共享且一致。你可以通过设置`MODELSCOPE_CACHE`环境变量,或在命令行中添加 `--packing_cache `参数来实现这一要求。 - 🔥streaming: 流式读取并处理数据集,默认False。通常在处理大型数据集时,设置为True。更多流式的参数查看命令行参数文档。 - lazy_tokenize: 默认为False。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(这可以避免在训练中出现报错);设置为True,则在训练中对数据集进行tokenize(这可以节约内存)。 - max_epochs: 训练到`max_epochs`时强制退出训练,并对权重进行验证和保存。该参数在使用流式数据集时很有用。默认为None。 - 注意:如果你使用非流式数据集,该参数会为你自动计算train_iters,你不需要手动传入`train_iters`。 + + +### RLHF参数 +除了继承训练参数外,还支持以下参数: +- rlhf_type: 默认为'dpo'。目前可选择为'dpo'。 +- loss_scale: 覆盖[基本参数](./命令行参数.md)中的loss_scale。默认为'last_round'。 +- calculate_per_token_loss: 覆盖Megatron参数,默认为False。 diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index 34dcfe58d8..2c3a0382ee 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -312,14 +312,23 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the - loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. -### Megatron Training Parameters +### Training Parameters Megatron training parameters inherit from Megatron parameters and basic parameters. For information on basic parameters, see [here](./Command-line-parameters.md#base-arguments). Additionally, the following parameters are included: - add_version: Adds a directory `-` to `save` to prevent overwriting weights, default is True. -- 🔥packing: Whether to use sequence packing, defaults to False. +- 🔥packing: Whether to use sequence packing, defaults to False. Currently supports `megatron pt/sft`. - 🔥packing_cache: Specifies the directory for packing cache. The default value is `None`, which means the cache will be stored in the path defined by the environment variable `$MODELSCOPE_CACHE`. When using the packing feature across multiple nodes, ensure that all nodes share the same packing cache directory. You can achieve this by setting the `MODELSCOPE_CACHE` environment variable or by adding the `--packing_cache ` argument in the command line. - 🔥streaming: Stream reading and processing of the dataset, default is False. It is typically set to True when handling large datasets. For more information on streaming parameters, refer to the command-line parameters documentation. - lazy_tokenize: Default is False. If this parameter is set to False, all dataset samples are tokenized before training (this avoids errors during training); if set to True, tokenization occurs during training (this saves memory). - max_epochs: Forces the training to exit after reaching `max_epochs`, and performs validation and saving of the model weights. This parameter is especially useful when using a streaming dataset. Default is None. - Note: If you use a non-streaming dataset, this parameter will automatically calculate train_iters for you, so there is no need to pass `train_iters` manually. + + +### RLHF Parameters + +In addition to inheriting the training parameters, the following parameters are also supported: + +- rlhf_type: Default is 'dpo'. Currently, only 'dpo' is available. +- loss_scale: Overrides the `loss_scale` in [basic parameters](https://idealab.alibaba-inc.com/command_line_arguments.md). Default is 'last_round'. +- calculate_per_token_loss: Overrides the Megatron parameter. Default is False. From c57a29e17f49531f609f8d92fd31aceb02924c5c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 11 Jun 2025 01:48:48 +0800 Subject: [PATCH 30/38] update --- examples/train/megatron/rlhf/dpo.sh | 1 - swift/megatron/train/rlhf.py | 55 ++++++++++++++++++++++++----- swift/megatron/train/sft.py | 6 ++-- swift/megatron/train/utils.py | 2 ++ 4 files changed, 53 insertions(+), 11 deletions(-) diff --git a/examples/train/megatron/rlhf/dpo.sh b/examples/train/megatron/rlhf/dpo.sh index a1d1ae499b..6f6eabded0 100644 --- a/examples/train/megatron/rlhf/dpo.sh +++ b/examples/train/megatron/rlhf/dpo.sh @@ -13,7 +13,6 @@ megatron rlhf \ --recompute_method uniform \ --recompute_num_layers 1 \ --max_epochs 1 \ - --eval_iters 20 \ --finetune true \ --cross_entropy_loss_fusion true \ --lr 1e-5 \ diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index fef2e12b09..5243c7c0cd 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,10 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from collections import namedtuple from functools import partial -from typing import List, Tuple, Union +from typing import List, Union import torch from megatron.core import mpu +from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank from megatron.training import get_args, get_model, training from megatron.training.checkpointing import load_checkpoint from megatron.training.utils import unwrap_model @@ -61,16 +62,45 @@ def setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) training.setup_model_and_optimizer = setup_model_and_optimizer + @staticmethod + def _forward_step_helper(model, inputs): + args = get_args() + if mpu.is_pipeline_first_stage(): + micro_batch_size = 1 # use qkv_format 'thd' + seq_length = inputs['input_ids'].shape[1] + if args.sequence_parallel: + seq_length //= mpu.get_tensor_model_parallel_world_size() + recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], + device=torch.cuda.current_device(), + dtype=torch.int64) + else: + recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) + recv_from_prev_pipeline_rank_(recv_shape_buffer) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(recv_shape_buffer) + shape = recv_shape_buffer.tolist() + + if not mpu.is_pipeline_first_stage(): + recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) + recv_from_prev_pipeline_rank_(recv_buffer) + model.set_input_tensor(recv_buffer) + output_tensor = model(**inputs) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + output_tensor = None + + return output_tensor + def ref_forward(self, data_iterator): ref_model = unwrap_model(self.ref_model) with self.stimer(bdata=True): data = get_batch(data_iterator) if not data: raise StopIteration - labels = data['labels'] + labels = data.get('labels') with torch.no_grad(): - output_tensor = ref_model(**data) - data['logps'] = self.get_logps(output_tensor, labels, data['packed_seq_params']) + output_tensor = self._forward_step_helper(ref_model, data) + data['logps'] = None if labels is None else self.get_logps(output_tensor, labels, data['packed_seq_params']) return data @staticmethod @@ -79,11 +109,13 @@ def get_logps(output_tensor, labels, packed_seq_params): per_token_logps = -output_tensor loss_mask = labels != -100 per_token_logps = per_token_logps * loss_mask - cu_seqlens = packed_seq_params.cu_seqlens_q[:args.micro_batch_size * 2 + 1] + cu_seqlens = packed_seq_params.cu_seqlens_q[:args.micro_batch_size * 2 + 1] // args.context_parallel_size all_logps = per_token_logps.new_zeros((args.micro_batch_size * 2, )) for i in range(args.micro_batch_size * 2): start, end = cu_seqlens[i], cu_seqlens[i + 1] all_logps[i] = per_token_logps[:, start:end].sum() + if args.context_parallel_size > 1: + torch.distributed.all_reduce(all_logps, group=mpu.get_context_parallel_group()) return all_logps def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, @@ -91,9 +123,13 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab from swift.trainers import DPOTrainer args = get_args() loss_mask = labels != -100 - num_tokens = packed_seq_params.cu_seqlens_q[args.micro_batch_size] + num_tokens = packed_seq_params.cu_seqlens_q[args.micro_batch_size] // args.context_parallel_size loss_mask[:, num_tokens:] = 0 - nll_loss = torch.sum(output_tensor * loss_mask) / loss_mask.sum() + nll_loss = torch.concat([torch.sum(output_tensor * loss_mask)[None], loss_mask.sum()[None]]) + if args.context_parallel_size > 1: + torch.distributed.all_reduce(nll_loss, group=mpu.get_context_parallel_group()) + nll_loss = nll_loss[0] / nll_loss[1] + logps = self.get_logps(output_tensor, labels, packed_seq_params) loss, chosen_rewards, rejected_rewards = self.dummy_dpo_trainer.dpo_loss( logps[:args.micro_batch_size], @@ -127,7 +163,10 @@ def forward_step(self, data_iterator, model): with self.stimer: output_tensor = model(**data) return output_tensor, partial( - self.loss_func, ref_logps=ref_logps, labels=data['labels'], packed_seq_params=data['packed_seq_params']) + self.loss_func, + ref_logps=ref_logps, + labels=data.get('labels'), + packed_seq_params=data.get('packed_seq_params')) def megatron_rlhf_main(args: Union[List[str], MegatronRLHFArguments, None] = None): diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 7958641618..60df64aa73 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -50,13 +50,15 @@ def initialize_megatron(*_args, **kwargs): dataset_sample = len(train_dataset) // step_batch_size * step_batch_size args.train_iters = (dataset_sample * args.max_epochs // args.global_batch_size) + 1 else: - raise ValueError('You are using a streaming training dataset. Please explicitly specify `--train_iters`.') + raise ValueError( + 'You are using a streaming training dataset. Please explicitly specify `--train_iters`.') if val_dataset is not None and args.eval_iters < 0: if hasattr(val_dataset, '__len__'): dataset_sample = len(val_dataset) // step_batch_size * step_batch_size args.eval_iters = max(dataset_sample // args.global_batch_size, 1) else: - raise ValueError('You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') + raise ValueError( + 'You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') return res training.initialize_megatron = initialize_megatron diff --git a/swift/megatron/train/utils.py b/swift/megatron/train/utils.py index 991c8a671c..92530dffa4 100644 --- a/swift/megatron/train/utils.py +++ b/swift/megatron/train/utils.py @@ -83,11 +83,13 @@ def _broadcast(item): _broadcast(batch['position_ids']) elif mpu.is_pipeline_first_stage(): + batch['labels'] = None _broadcast(batch['input_ids']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) elif mpu.is_pipeline_last_stage(): + batch['input_ids'] = None _broadcast(batch['labels']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) From 79eca7358f04293d658925956389451f8922650d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 11 Jun 2025 02:34:45 +0800 Subject: [PATCH 31/38] update --- swift/megatron/train/rlhf.py | 11 ++- swift/megatron/train/sft.py | 136 ++++++++++++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 2 deletions(-) diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 5243c7c0cd..f63a9f351d 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -155,9 +155,18 @@ def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, lab 'nll_loss': nll_loss } + def _replace_data_iterator(self, data_iterator): + args = get_args() + num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) + res = [] + for i in range(num_iters_per_step): + with torch.no_grad(): + res.append(self.ref_forward(data_iterator)) + return iter(res) + def forward_step(self, data_iterator, model): with torch.no_grad(): - data = self.ref_forward(data_iterator) + data = next(data_iterator) ref_logps = data.pop('logps') with self.stimer: diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 60df64aa73..3a822f098f 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -1,13 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import time from contextlib import contextmanager from functools import partial from typing import List, Union +import torch from megatron.core import mpu from megatron.core.enums import ModelType +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine from megatron.core.utils import StragglerDetector -from megatron.training import get_args, get_timers, pretrain, training +from megatron.training import ft_integration, get_args, get_timers, is_last_rank, pretrain, print_rank_0, training from swift.llm.train import SwiftSft from swift.utils import get_logger, is_master, plot_images @@ -92,14 +97,140 @@ def _training_context(): finally: args.is_training = False + def _replace_data_iterator(self, data_iterator): + return data_iterator + def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): with self._training_context(): try: + data_iterator = self._replace_data_iterator(data_iterator) return self._origin_train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config) except StopIteration: return {}, True, True, True, 0, None, None + def evaluate(self, + forward_step_func, + data_iterator, + model, + process_non_loss_data_func, + config, + verbose=False, + non_loss_data_func=None): + """Evaluation.""" + args = get_args() + timers = get_timers() + + timers('evaluate', log_level=0).start(barrier=True) + + if args.vision_pretraining and args.vision_pretraining_type == 'dino': + from megatron.legacy.model.vision.knn_monitor import compute_feature_bank + compute_feature_bank(model) + + # Turn on evaluation mode which disables dropout. + for model_module in model: + model_module.eval() + + # Disable result validation during evaluation + rerun_state_machine = get_rerun_state_machine() + rerun_mode = rerun_state_machine.get_mode() + rerun_state_machine.set_mode(RerunMode.DISABLED) + + total_loss_dict = {} + + # make validation batch size independent from training batch size + eval_batch_size = args.global_batch_size + eval_num_microbatches = eval_batch_size // \ + (args.micro_batch_size * args.data_parallel_size) + + with torch.no_grad(): + iteration = 0 + if verbose: + print_rank_0(f'Evaluating on {args.eval_iters * eval_batch_size} samples') + while iteration < args.eval_iters: + iteration += 1 + if verbose: + print_rank_0(f'Evaluating iter {iteration}/{args.eval_iters}') + + forward_backward_func = get_forward_backward_func() + # Don't care about timing during evaluation + config.timers = None + ft_integration.on_eval_step_start() + data_iterator = self._replace_data_iterator(data_iterator) + loss_dicts = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=eval_num_microbatches, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=True) + ft_integration.on_eval_step_end() + config.timers = get_timers() + + # Empty unused memory + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Reduce across processes. + for loss_dict in loss_dicts: + for key in loss_dict: + if key not in total_loss_dict: + total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda() + val = loss_dict[key] + if isinstance(val, tuple) or isinstance(val, list): + total_loss_dict[key][0] += val[0] + total_loss_dict[key][1] += val[1] + else: + total_loss_dict[key][0] += val + total_loss_dict[key][1] += 1 + + args.consumed_valid_samples += eval_batch_size + + if args.exit_duration_in_mins: + train_time = (time.time() - training._TRAIN_START_TIME) / 60.0 + done_cuda = torch.tensor([train_time > args.exit_duration_in_mins], dtype=torch.int, device='cuda') + torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) + done = done_cuda.item() + if done: + rerun_state_machine.set_mode(rerun_mode) + print_rank_0('Exiting during evaluation, timelimit reached') + return None, None, True + + collected_non_loss_data = None + if non_loss_data_func is not None: + collected_non_loss_data = non_loss_data_func(model) + elif process_non_loss_data_func is not None and is_last_rank(): + collected_non_loss_data = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=True, + collect_non_loss_data=True) + + # Move model back to the train mode. + for model_module in model: + model_module.train() + + for key in total_loss_dict: + numerator, denominator = total_loss_dict[key] + total_loss_dict[key] = numerator / denominator + + timers('evaluate').stop() + timers.log(['evaluate']) + + rerun_state_machine.set_mode(rerun_mode) + + rerun_state_machine.set_mode(rerun_mode) + + return total_loss_dict, collected_non_loss_data, False + def _patch_megatron(self): # support max_epochs self._origin_train_step = training.train_step @@ -107,6 +238,9 @@ def _patch_megatron(self): training.cyclic_iter = MegatronSft.new_cyclic_iter # patch training_log self._origin_training_log = training.training_log + # patch evaluate + self._origin_evaluate = training.evaluate + training.evaluate = self.evaluate def forward_step(self, data_iterator, model): from pretrain_gpt import loss_func From 2ad6fd2d9e4aa5e448da5f1e1e3df5f0910a082f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 11 Jun 2025 10:41:02 +0800 Subject: [PATCH 32/38] update --- swift/megatron/train/rlhf.py | 171 +----------- swift/megatron/train/sft.py | 252 +----------------- swift/megatron/train/trainers/__init__.py | 3 + swift/megatron/train/trainers/dpo_trainer.py | 171 ++++++++++++ swift/megatron/train/trainers/trainer.py | 260 +++++++++++++++++++ 5 files changed, 451 insertions(+), 406 deletions(-) create mode 100644 swift/megatron/train/trainers/__init__.py create mode 100644 swift/megatron/train/trainers/dpo_trainer.py create mode 100644 swift/megatron/train/trainers/trainer.py diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index f63a9f351d..e78588dc64 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -1,182 +1,31 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from collections import namedtuple -from functools import partial from typing import List, Union -import torch -from megatron.core import mpu -from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank -from megatron.training import get_args, get_model, training -from megatron.training.checkpointing import load_checkpoint -from megatron.training.utils import unwrap_model - -from swift.trainers import DPOTrainer -from swift.utils import get_current_device, get_logger +from swift.utils import get_logger from ..argument import MegatronRLHFArguments from .sft import MegatronSft -from .utils import get_batch +from .trainers import MegatronDPOTrainer logger = get_logger() -class DummyDPOTrainer(DPOTrainer): - # For reusing the dpo_loss function in TRL. - def __init__(self, args): - from trl.trainer import FDivergenceConstants - self.accelerator = namedtuple('Accelerator', ['device'])(device=get_current_device()) - self.f_alpha_divergence_coef = 1. - self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: self.f_alpha_divergence_coef} - self.reference_free = args.reference_free - self.label_smoothing = args.label_smoothing - self.f_divergence_type = args.f_divergence_type - self.loss_type = args.loss_type - self.beta = args.beta - - class MegatronRLHF(MegatronSft): args_class = MegatronRLHFArguments args: args_class + trainer_cls = MegatronDPOTrainer - def __init__(self, args: Union[List[str], MegatronRLHFArguments, None] = None) -> None: - super().__init__(args) - self._patch_setup_model_and_optimizer() - self.dummy_dpo_trainer = DummyDPOTrainer(self.args) + def prepare_trainer(self): + args = self.args + if args.rlhf_type == 'dpo': + trainer_cls = MegatronDPOTrainer + else: + raise ValueError(f'The current Megatron-SWIFT does not support rlhf_type: {args.rlhf_type}.') + return trainer_cls() def _prepare_template(self) -> None: super()._prepare_template() self.template.set_mode('rlhf') - def _patch_setup_model_and_optimizer(self): - origin_setup_model_and_optimizer = training.setup_model_and_optimizer - - def setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs): - args = get_args() - ref_model = get_model(model_provider_func, model_type) - if args.ref_load is None: - args.ref_load = args.load - args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( - ref_model, None, None, load_arg='ref_load') - self.ref_model = ref_model[0] - self.ref_model.eval() - return origin_setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) - - training.setup_model_and_optimizer = setup_model_and_optimizer - - @staticmethod - def _forward_step_helper(model, inputs): - args = get_args() - if mpu.is_pipeline_first_stage(): - micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] - if args.sequence_parallel: - seq_length //= mpu.get_tensor_model_parallel_world_size() - recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], - device=torch.cuda.current_device(), - dtype=torch.int64) - else: - recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) - recv_from_prev_pipeline_rank_(recv_shape_buffer) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(recv_shape_buffer) - shape = recv_shape_buffer.tolist() - - if not mpu.is_pipeline_first_stage(): - recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) - recv_from_prev_pipeline_rank_(recv_buffer) - model.set_input_tensor(recv_buffer) - output_tensor = model(**inputs) - if not mpu.is_pipeline_last_stage(): - send_to_next_pipeline_rank(output_tensor) - output_tensor = None - - return output_tensor - - def ref_forward(self, data_iterator): - ref_model = unwrap_model(self.ref_model) - with self.stimer(bdata=True): - data = get_batch(data_iterator) - if not data: - raise StopIteration - labels = data.get('labels') - with torch.no_grad(): - output_tensor = self._forward_step_helper(ref_model, data) - data['logps'] = None if labels is None else self.get_logps(output_tensor, labels, data['packed_seq_params']) - return data - - @staticmethod - def get_logps(output_tensor, labels, packed_seq_params): - args = get_args() - per_token_logps = -output_tensor - loss_mask = labels != -100 - per_token_logps = per_token_logps * loss_mask - cu_seqlens = packed_seq_params.cu_seqlens_q[:args.micro_batch_size * 2 + 1] // args.context_parallel_size - all_logps = per_token_logps.new_zeros((args.micro_batch_size * 2, )) - for i in range(args.micro_batch_size * 2): - start, end = cu_seqlens[i], cu_seqlens[i + 1] - all_logps[i] = per_token_logps[:, start:end].sum() - if args.context_parallel_size > 1: - torch.distributed.all_reduce(all_logps, group=mpu.get_context_parallel_group()) - return all_logps - - def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, - packed_seq_params): - from swift.trainers import DPOTrainer - args = get_args() - loss_mask = labels != -100 - num_tokens = packed_seq_params.cu_seqlens_q[args.micro_batch_size] // args.context_parallel_size - loss_mask[:, num_tokens:] = 0 - nll_loss = torch.concat([torch.sum(output_tensor * loss_mask)[None], loss_mask.sum()[None]]) - if args.context_parallel_size > 1: - torch.distributed.all_reduce(nll_loss, group=mpu.get_context_parallel_group()) - nll_loss = nll_loss[0] / nll_loss[1] - - logps = self.get_logps(output_tensor, labels, packed_seq_params) - loss, chosen_rewards, rejected_rewards = self.dummy_dpo_trainer.dpo_loss( - logps[:args.micro_batch_size], - logps[args.micro_batch_size:], - ref_logps[:args.micro_batch_size], - ref_logps[args.micro_batch_size:], - ) - reward_accuracies = (chosen_rewards > rejected_rewards).float() - if args.rpo_alpha > 0: - loss = loss + args.rpo_alpha * nll_loss - loss = loss.mean() - reporting_loss = loss.clone().detach() - torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) - - return loss, { - 'loss': reporting_loss, - 'rewards/chosen': chosen_rewards.mean(), - 'rewards/rejected': rejected_rewards.mean(), - 'rewards/accuracies': reward_accuracies.mean(), - 'rewards/margins': (chosen_rewards - rejected_rewards).mean(), - 'logps/chosen': logps[:args.micro_batch_size].mean(), - 'logps/rejected': logps[args.micro_batch_size:].mean(), - 'nll_loss': nll_loss - } - - def _replace_data_iterator(self, data_iterator): - args = get_args() - num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) - res = [] - for i in range(num_iters_per_step): - with torch.no_grad(): - res.append(self.ref_forward(data_iterator)) - return iter(res) - - def forward_step(self, data_iterator, model): - with torch.no_grad(): - data = next(data_iterator) - - ref_logps = data.pop('logps') - with self.stimer: - output_tensor = model(**data) - return output_tensor, partial( - self.loss_func, - ref_logps=ref_logps, - labels=data.get('labels'), - packed_seq_params=data.get('packed_seq_params')) - def megatron_rlhf_main(args: Union[List[str], MegatronRLHFArguments, None] = None): return MegatronRLHF(args).main() diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 3a822f098f..100d09c93b 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -1,25 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -import time -from contextlib import contextmanager from functools import partial from typing import List, Union -import torch -from megatron.core import mpu -from megatron.core.enums import ModelType -from megatron.core.num_microbatches_calculator import get_num_microbatches -from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine -from megatron.core.utils import StragglerDetector from megatron.training import ft_integration, get_args, get_timers, is_last_rank, pretrain, print_rank_0, training from swift.llm.train import SwiftSft from swift.utils import get_logger, is_master, plot_images from ..argument import MegatronTrainArguments from ..utils import patch_megatron_tokenizer -from .patcher import patch_megatron_data_collator -from .utils import build_streaming_dataloader, get_batch, get_swift_datasets_provider +from .trainers import MegatronTrainer +from .utils import build_streaming_dataloader, get_batch logger = get_logger() @@ -28,6 +19,9 @@ class MegatronSft(SwiftSft): args_class = MegatronTrainArguments args: args_class + def prepare_trainer(self): + return MegatronTrainer() + def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) -> None: self.train_msg = {} super(SwiftSft, self).__init__(args) @@ -38,232 +32,10 @@ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) self._prepare_template() self.template.use_megatron = True args.save_args(args.save) - self.stimer = StragglerDetector() - - @contextmanager - def _get_iters(self, train_dataset, val_dataset): - from megatron.training import training - origin_initialize_megatron = training.initialize_megatron - - def initialize_megatron(*_args, **kwargs): - res = origin_initialize_megatron(*_args, **kwargs) - args = get_args() - data_parallel_size = mpu.get_data_parallel_world_size() - step_batch_size = args.micro_batch_size * data_parallel_size - if args.train_iters is None: - if hasattr(train_dataset, '__len__'): - dataset_sample = len(train_dataset) // step_batch_size * step_batch_size - args.train_iters = (dataset_sample * args.max_epochs // args.global_batch_size) + 1 - else: - raise ValueError( - 'You are using a streaming training dataset. Please explicitly specify `--train_iters`.') - if val_dataset is not None and args.eval_iters < 0: - if hasattr(val_dataset, '__len__'): - dataset_sample = len(val_dataset) // step_batch_size * step_batch_size - args.eval_iters = max(dataset_sample // args.global_batch_size, 1) - else: - raise ValueError( - 'You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') - return res - - training.initialize_megatron = initialize_megatron - try: - yield - finally: - training.initialize_megatron = origin_initialize_megatron - - @staticmethod - def new_cyclic_iter(iter): - args = get_args() - max_epochs = args.max_epochs - i = 0 - while True: - if getattr(args, 'is_training', False): - if max_epochs and i >= max_epochs: - logger.info(f'Training of {i} epochs has been completed, the training has finished.') - break - logger.info(f'The training of Epoch {i} starts...') - for x in iter: - yield x - i += 1 - - @staticmethod - @contextmanager - def _training_context(): - args = get_args() - args.is_training = True - try: - yield - finally: - args.is_training = False - - def _replace_data_iterator(self, data_iterator): - return data_iterator - - def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): - with self._training_context(): - try: - data_iterator = self._replace_data_iterator(data_iterator) - return self._origin_train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, - config) - except StopIteration: - return {}, True, True, True, 0, None, None - - def evaluate(self, - forward_step_func, - data_iterator, - model, - process_non_loss_data_func, - config, - verbose=False, - non_loss_data_func=None): - """Evaluation.""" - args = get_args() - timers = get_timers() - - timers('evaluate', log_level=0).start(barrier=True) - - if args.vision_pretraining and args.vision_pretraining_type == 'dino': - from megatron.legacy.model.vision.knn_monitor import compute_feature_bank - compute_feature_bank(model) - - # Turn on evaluation mode which disables dropout. - for model_module in model: - model_module.eval() - - # Disable result validation during evaluation - rerun_state_machine = get_rerun_state_machine() - rerun_mode = rerun_state_machine.get_mode() - rerun_state_machine.set_mode(RerunMode.DISABLED) - - total_loss_dict = {} - - # make validation batch size independent from training batch size - eval_batch_size = args.global_batch_size - eval_num_microbatches = eval_batch_size // \ - (args.micro_batch_size * args.data_parallel_size) - - with torch.no_grad(): - iteration = 0 - if verbose: - print_rank_0(f'Evaluating on {args.eval_iters * eval_batch_size} samples') - while iteration < args.eval_iters: - iteration += 1 - if verbose: - print_rank_0(f'Evaluating iter {iteration}/{args.eval_iters}') - - forward_backward_func = get_forward_backward_func() - # Don't care about timing during evaluation - config.timers = None - ft_integration.on_eval_step_start() - data_iterator = self._replace_data_iterator(data_iterator) - loss_dicts = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iterator, - model=model, - num_microbatches=eval_num_microbatches, - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, - forward_only=True) - ft_integration.on_eval_step_end() - config.timers = get_timers() - - # Empty unused memory - if args.empty_unused_memory_level >= 1: - torch.cuda.empty_cache() - - if mpu.is_pipeline_last_stage(ignore_virtual=True): - # Reduce across processes. - for loss_dict in loss_dicts: - for key in loss_dict: - if key not in total_loss_dict: - total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda() - val = loss_dict[key] - if isinstance(val, tuple) or isinstance(val, list): - total_loss_dict[key][0] += val[0] - total_loss_dict[key][1] += val[1] - else: - total_loss_dict[key][0] += val - total_loss_dict[key][1] += 1 - - args.consumed_valid_samples += eval_batch_size - - if args.exit_duration_in_mins: - train_time = (time.time() - training._TRAIN_START_TIME) / 60.0 - done_cuda = torch.tensor([train_time > args.exit_duration_in_mins], dtype=torch.int, device='cuda') - torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) - done = done_cuda.item() - if done: - rerun_state_machine.set_mode(rerun_mode) - print_rank_0('Exiting during evaluation, timelimit reached') - return None, None, True - - collected_non_loss_data = None - if non_loss_data_func is not None: - collected_non_loss_data = non_loss_data_func(model) - elif process_non_loss_data_func is not None and is_last_rank(): - collected_non_loss_data = forward_backward_func( - forward_step_func=forward_step_func, - data_iterator=data_iterator, - model=model, - num_microbatches=get_num_microbatches(), - seq_length=args.seq_length, - micro_batch_size=args.micro_batch_size, - decoder_seq_length=args.decoder_seq_length, - forward_only=True, - collect_non_loss_data=True) - - # Move model back to the train mode. - for model_module in model: - model_module.train() - - for key in total_loss_dict: - numerator, denominator = total_loss_dict[key] - total_loss_dict[key] = numerator / denominator - - timers('evaluate').stop() - timers.log(['evaluate']) - - rerun_state_machine.set_mode(rerun_mode) - - rerun_state_machine.set_mode(rerun_mode) - - return total_loss_dict, collected_non_loss_data, False - - def _patch_megatron(self): - # support max_epochs - self._origin_train_step = training.train_step - training.train_step = self.train_step - training.cyclic_iter = MegatronSft.new_cyclic_iter - # patch training_log - self._origin_training_log = training.training_log - # patch evaluate - self._origin_evaluate = training.evaluate - training.evaluate = self.evaluate - - def forward_step(self, data_iterator, model): - from pretrain_gpt import loss_func - - timers = get_timers() - - # Get the batch. - timers('batch-generator', log_level=2).start() - with self.stimer(bdata=True): - data = get_batch(data_iterator) - if not data: - raise StopIteration - timers('batch-generator').stop() - - with self.stimer: - output_tensor = model(**data) - labels = data.get('labels') - loss_mask = None if labels is None else (labels != -100).float() - return output_tensor, partial(loss_func, loss_mask) + self.trainer = self.prepare_trainer() def run(self): args = self.args - self._patch_megatron() train_dataset, val_dataset = self._get_dataset() train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) @@ -272,21 +44,11 @@ def run(self): train_dataset = build_streaming_dataloader(args, train_dataset, data_collator) if val_dataset is not None: val_dataset = build_streaming_dataloader(args, val_dataset, data_collator) - datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset) - datasets_provider.is_distributed = True logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'The logging file will be saved in: {logging_path}') try: - with patch_megatron_data_collator(data_collator), self._get_iters(train_dataset, val_dataset): - extra_args_provider = args.megatron_model_meta.extra_args_provider - pretrain( - datasets_provider, - args.megatron_model_meta.model_provider, - ModelType.encoder_or_decoder, - self.forward_step, - extra_args_provider=extra_args_provider, - args_defaults=args.extra_args) + self.trainer.train(train_dataset, val_dataset, data_collator) finally: # Visualization if is_master(): diff --git a/swift/megatron/train/trainers/__init__.py b/swift/megatron/train/trainers/__init__.py new file mode 100644 index 0000000000..c891081541 --- /dev/null +++ b/swift/megatron/train/trainers/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .dpo_trainer import MegatronDPOTrainer +from .trainer import MegatronTrainer diff --git a/swift/megatron/train/trainers/dpo_trainer.py b/swift/megatron/train/trainers/dpo_trainer.py new file mode 100644 index 0000000000..7c2208e213 --- /dev/null +++ b/swift/megatron/train/trainers/dpo_trainer.py @@ -0,0 +1,171 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from collections import namedtuple +from functools import partial + +import torch +from megatron.core import mpu +from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank +from megatron.training import get_args, get_model, training +from megatron.training.checkpointing import load_checkpoint +from megatron.training.utils import unwrap_model + +from swift.trainers import DPOTrainer +from swift.utils import get_current_device, get_logger +from ..utils import get_batch +from .trainer import MegatronTrainer + +logger = get_logger() + + +class DummyDPOTrainer(DPOTrainer): + # For reusing the dpo_loss function in TRL. + def __init__(self): + args = get_args() + from trl.trainer import FDivergenceConstants + self.accelerator = namedtuple('Accelerator', ['device'])(device=get_current_device()) + self.f_alpha_divergence_coef = 1. + self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: self.f_alpha_divergence_coef} + self.reference_free = args.reference_free + self.label_smoothing = args.label_smoothing + self.f_divergence_type = args.f_divergence_type + self.loss_type = args.loss_type + self.beta = args.beta + + +class MegatronDPOTrainer(MegatronTrainer): + + def __init__(self): + super().__init__() + self._patch_setup_model_and_optimizer() + self.dummy_dpo_trainer = DummyDPOTrainer() + + def _patch_setup_model_and_optimizer(self): + origin_setup_model_and_optimizer = training.setup_model_and_optimizer + + def setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs): + args = get_args() + ref_model = get_model(model_provider_func, model_type) + if args.ref_load is None: + args.ref_load = args.load + args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( + ref_model, None, None, load_arg='ref_load') + self.ref_model = ref_model[0] + self.ref_model.eval() + return origin_setup_model_and_optimizer(model_provider_func, model_type, *_args, **kwargs) + + training.setup_model_and_optimizer = setup_model_and_optimizer + + @staticmethod + def _forward_step_helper(model, inputs): + args = get_args() + if mpu.is_pipeline_first_stage(): + micro_batch_size = 1 # use qkv_format 'thd' + seq_length = inputs['input_ids'].shape[1] + if args.sequence_parallel: + seq_length //= mpu.get_tensor_model_parallel_world_size() + recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], + device=torch.cuda.current_device(), + dtype=torch.int64) + else: + recv_shape_buffer = torch.empty((3, ), device=torch.cuda.current_device(), dtype=torch.int64) + recv_from_prev_pipeline_rank_(recv_shape_buffer) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(recv_shape_buffer) + shape = recv_shape_buffer.tolist() + + if not mpu.is_pipeline_first_stage(): + recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=args.params_dtype) + recv_from_prev_pipeline_rank_(recv_buffer) + model.set_input_tensor(recv_buffer) + output_tensor = model(**inputs) + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + output_tensor = None + + return output_tensor + + def ref_forward(self, data_iterator): + ref_model = unwrap_model(self.ref_model) + with self.stimer(bdata=True): + data = get_batch(data_iterator) + if not data: + raise StopIteration + labels = data.get('labels') + with torch.no_grad(): + output_tensor = self._forward_step_helper(ref_model, data) + data['logps'] = None if labels is None else self.get_logps(output_tensor, labels, data['packed_seq_params']) + return data + + @staticmethod + def get_logps(output_tensor, labels, packed_seq_params): + args = get_args() + per_token_logps = -output_tensor + loss_mask = labels != -100 + per_token_logps = per_token_logps * loss_mask + cu_seqlens = packed_seq_params.cu_seqlens_q[:args.micro_batch_size * 2 + 1] // args.context_parallel_size + all_logps = per_token_logps.new_zeros((args.micro_batch_size * 2, )) + for i in range(args.micro_batch_size * 2): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + all_logps[i] = per_token_logps[:, start:end].sum() + if args.context_parallel_size > 1: + torch.distributed.all_reduce(all_logps, group=mpu.get_context_parallel_group()) + return all_logps + + def loss_func(self, output_tensor: torch.Tensor, *, ref_logps: torch.Tensor, labels: torch.Tensor, + packed_seq_params): + from swift.trainers import DPOTrainer + args = get_args() + loss_mask = labels != -100 + num_tokens = packed_seq_params.cu_seqlens_q[args.micro_batch_size] // args.context_parallel_size + loss_mask[:, num_tokens:] = 0 + nll_loss = torch.concat([torch.sum(output_tensor * loss_mask)[None], loss_mask.sum()[None]]) + if args.context_parallel_size > 1: + torch.distributed.all_reduce(nll_loss, group=mpu.get_context_parallel_group()) + nll_loss = nll_loss[0] / nll_loss[1] + + logps = self.get_logps(output_tensor, labels, packed_seq_params) + loss, chosen_rewards, rejected_rewards = self.dummy_dpo_trainer.dpo_loss( + logps[:args.micro_batch_size], + logps[args.micro_batch_size:], + ref_logps[:args.micro_batch_size], + ref_logps[args.micro_batch_size:], + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + if args.rpo_alpha > 0: + loss = loss + args.rpo_alpha * nll_loss + loss = loss.mean() + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + return loss, { + 'loss': reporting_loss, + 'rewards/chosen': chosen_rewards.mean(), + 'rewards/rejected': rejected_rewards.mean(), + 'rewards/accuracies': reward_accuracies.mean(), + 'rewards/margins': (chosen_rewards - rejected_rewards).mean(), + 'logps/chosen': logps[:args.micro_batch_size].mean(), + 'logps/rejected': logps[args.micro_batch_size:].mean(), + 'nll_loss': nll_loss + } + + def _replace_data_iterator(self, data_iterator): + args = get_args() + num_iters_per_step = args.global_batch_size // (args.micro_batch_size * mpu.get_data_parallel_world_size()) + res = [] + for i in range(num_iters_per_step): + with torch.no_grad(): + res.append(self.ref_forward(data_iterator)) + return iter(res) + + def forward_step(self, data_iterator, model): + with torch.no_grad(): + data = next(data_iterator) + + ref_logps = data.pop('logps') + with self.stimer: + output_tensor = model(**data) + return output_tensor, partial( + self.loss_func, + ref_logps=ref_logps, + labels=data.get('labels'), + packed_seq_params=data.get('packed_seq_params')) diff --git a/swift/megatron/train/trainers/trainer.py b/swift/megatron/train/trainers/trainer.py new file mode 100644 index 0000000000..00027f1356 --- /dev/null +++ b/swift/megatron/train/trainers/trainer.py @@ -0,0 +1,260 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import time +from contextlib import contextmanager +from functools import partial + +import torch +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine +from megatron.core.utils import StragglerDetector +from megatron.training import ft_integration, get_args, get_timers, is_last_rank, pretrain, print_rank_0, training + +from swift.utils import get_logger +from ..patcher import patch_megatron_data_collator +from ..utils import get_batch, get_swift_datasets_provider + +logger = get_logger() + + +class MegatronTrainer: + + def __init__(self): + self.stimer = StragglerDetector() + self._patch_megatron() + + @contextmanager + def _get_iters(self, train_dataset, val_dataset): + origin_initialize_megatron = training.initialize_megatron + + def initialize_megatron(*_args, **kwargs): + res = origin_initialize_megatron(*_args, **kwargs) + args = get_args() + data_parallel_size = mpu.get_data_parallel_world_size() + step_batch_size = args.micro_batch_size * data_parallel_size + if args.train_iters is None: + if hasattr(train_dataset, '__len__'): + dataset_sample = len(train_dataset) // step_batch_size * step_batch_size + args.train_iters = (dataset_sample * args.max_epochs // args.global_batch_size) + 1 + else: + raise ValueError( + 'You are using a streaming training dataset. Please explicitly specify `--train_iters`.') + if val_dataset is not None and args.eval_iters < 0: + if hasattr(val_dataset, '__len__'): + dataset_sample = len(val_dataset) // step_batch_size * step_batch_size + args.eval_iters = max(dataset_sample // args.global_batch_size, 1) + else: + raise ValueError( + 'You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.') + return res + + training.initialize_megatron = initialize_megatron + try: + yield + finally: + training.initialize_megatron = origin_initialize_megatron + + @staticmethod + def new_cyclic_iter(iter): + args = get_args() + max_epochs = args.max_epochs + i = 0 + while True: + if getattr(args, 'is_training', False): + if max_epochs and i >= max_epochs: + logger.info(f'Training of {i} epochs has been completed, the training has finished.') + break + logger.info(f'The training of Epoch {i} starts...') + for x in iter: + yield x + i += 1 + + @staticmethod + @contextmanager + def _training_context(): + args = get_args() + args.is_training = True + try: + yield + finally: + args.is_training = False + + def _replace_data_iterator(self, data_iterator): + return data_iterator + + def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config): + with self._training_context(): + try: + data_iterator = self._replace_data_iterator(data_iterator) + return self._origin_train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, + config) + except StopIteration: + return {}, True, True, True, 0, None, None + + def evaluate(self, + forward_step_func, + data_iterator, + model, + process_non_loss_data_func, + config, + verbose=False, + non_loss_data_func=None): + """Evaluation.""" + args = get_args() + timers = get_timers() + + timers('evaluate', log_level=0).start(barrier=True) + + if args.vision_pretraining and args.vision_pretraining_type == 'dino': + from megatron.legacy.model.vision.knn_monitor import compute_feature_bank + compute_feature_bank(model) + + # Turn on evaluation mode which disables dropout. + for model_module in model: + model_module.eval() + + # Disable result validation during evaluation + rerun_state_machine = get_rerun_state_machine() + rerun_mode = rerun_state_machine.get_mode() + rerun_state_machine.set_mode(RerunMode.DISABLED) + + total_loss_dict = {} + + # make validation batch size independent from training batch size + eval_batch_size = args.global_batch_size + eval_num_microbatches = eval_batch_size // (args.micro_batch_size * args.data_parallel_size) + + with torch.no_grad(): + iteration = 0 + if verbose: + print_rank_0(f'Evaluating on {args.eval_iters * eval_batch_size} samples') + while iteration < args.eval_iters: + iteration += 1 + if verbose: + print_rank_0(f'Evaluating iter {iteration}/{args.eval_iters}') + + forward_backward_func = get_forward_backward_func() + # Don't care about timing during evaluation + config.timers = None + ft_integration.on_eval_step_start() + data_iterator = self._replace_data_iterator(data_iterator) + loss_dicts = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=eval_num_microbatches, + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=True) + ft_integration.on_eval_step_end() + config.timers = get_timers() + + # Empty unused memory + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Reduce across processes. + for loss_dict in loss_dicts: + for key in loss_dict: + if key not in total_loss_dict: + total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda() + val = loss_dict[key] + if isinstance(val, tuple) or isinstance(val, list): + total_loss_dict[key][0] += val[0] + total_loss_dict[key][1] += val[1] + else: + total_loss_dict[key][0] += val + total_loss_dict[key][1] += 1 + + args.consumed_valid_samples += eval_batch_size + + if args.exit_duration_in_mins: + train_time = (time.time() - training._TRAIN_START_TIME) / 60.0 + done_cuda = torch.tensor([train_time > args.exit_duration_in_mins], dtype=torch.int, device='cuda') + torch.distributed.all_reduce(done_cuda, op=torch.distributed.ReduceOp.MAX) + done = done_cuda.item() + if done: + rerun_state_machine.set_mode(rerun_mode) + print_rank_0('Exiting during evaluation, timelimit reached') + return None, None, True + + collected_non_loss_data = None + if non_loss_data_func is not None: + collected_non_loss_data = non_loss_data_func(model) + elif process_non_loss_data_func is not None and is_last_rank(): + collected_non_loss_data = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=model, + num_microbatches=get_num_microbatches(), + seq_length=args.seq_length, + micro_batch_size=args.micro_batch_size, + decoder_seq_length=args.decoder_seq_length, + forward_only=True, + collect_non_loss_data=True) + + # Move model back to the train mode. + for model_module in model: + model_module.train() + + for key in total_loss_dict: + numerator, denominator = total_loss_dict[key] + total_loss_dict[key] = numerator / denominator + + timers('evaluate').stop() + timers.log(['evaluate']) + + rerun_state_machine.set_mode(rerun_mode) + + rerun_state_machine.set_mode(rerun_mode) + + return total_loss_dict, collected_non_loss_data, False + + def _patch_megatron(self): + # support max_epochs + self._origin_train_step = training.train_step + training.train_step = self.train_step + training.cyclic_iter = self.new_cyclic_iter + # patch training_log + self._origin_training_log = training.training_log + # patch evaluate + self._origin_evaluate = training.evaluate + training.evaluate = self.evaluate + + def forward_step(self, data_iterator, model): + from pretrain_gpt import loss_func + + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + with self.stimer(bdata=True): + data = get_batch(data_iterator) + if not data: + raise StopIteration + timers('batch-generator').stop() + + with self.stimer: + output_tensor = model(**data) + labels = data.get('labels') + loss_mask = None if labels is None else (labels != -100).float() + return output_tensor, partial(loss_func, loss_mask) + + def train(self, train_dataset, val_dataset, data_collator): + args = get_args() + datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset) + datasets_provider.is_distributed = True + with patch_megatron_data_collator(data_collator), self._get_iters(train_dataset, val_dataset): + extra_args_provider = args.megatron_model_meta.extra_args_provider + pretrain( + datasets_provider, + args.megatron_model_meta.model_provider, + ModelType.encoder_or_decoder, + self.forward_step, + extra_args_provider=extra_args_provider, + args_defaults=args.extra_args) From c3308518a0550b811b875dfc3cf4ea224e71f9ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 11 Jun 2025 10:57:21 +0800 Subject: [PATCH 33/38] update --- examples/train/megatron/rlhf/dpo.sh | 2 +- swift/megatron/train/rlhf.py | 2 +- swift/megatron/train/trainers/dpo_trainer.py | 9 ++++----- swift/megatron/train/trainers/trainer.py | 5 +++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/train/megatron/rlhf/dpo.sh b/examples/train/megatron/rlhf/dpo.sh index 6f6eabded0..c2b9a546aa 100644 --- a/examples/train/megatron/rlhf/dpo.sh +++ b/examples/train/megatron/rlhf/dpo.sh @@ -7,7 +7,7 @@ megatron rlhf \ --load Qwen3-8B-Base-mcore \ --dataset 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#20000' \ --tensor_model_parallel_size 4 \ - --micro_batch_size 16 \ + --micro_batch_size 8 \ --global_batch_size 16 \ --recompute_granularity full \ --recompute_method uniform \ diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index e78588dc64..9b57d59655 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -20,7 +20,7 @@ def prepare_trainer(self): trainer_cls = MegatronDPOTrainer else: raise ValueError(f'The current Megatron-SWIFT does not support rlhf_type: {args.rlhf_type}.') - return trainer_cls() + return trainer_cls(args) def _prepare_template(self) -> None: super()._prepare_template() diff --git a/swift/megatron/train/trainers/dpo_trainer.py b/swift/megatron/train/trainers/dpo_trainer.py index 7c2208e213..7c57b7bb62 100644 --- a/swift/megatron/train/trainers/dpo_trainer.py +++ b/swift/megatron/train/trainers/dpo_trainer.py @@ -19,8 +19,7 @@ class DummyDPOTrainer(DPOTrainer): # For reusing the dpo_loss function in TRL. - def __init__(self): - args = get_args() + def __init__(self, args): from trl.trainer import FDivergenceConstants self.accelerator = namedtuple('Accelerator', ['device'])(device=get_current_device()) self.f_alpha_divergence_coef = 1. @@ -34,10 +33,10 @@ def __init__(self): class MegatronDPOTrainer(MegatronTrainer): - def __init__(self): - super().__init__() + def __init__(self, args): + super().__init__(args) self._patch_setup_model_and_optimizer() - self.dummy_dpo_trainer = DummyDPOTrainer() + self.dummy_dpo_trainer = DummyDPOTrainer(args) def _patch_setup_model_and_optimizer(self): origin_setup_model_and_optimizer = training.setup_model_and_optimizer diff --git a/swift/megatron/train/trainers/trainer.py b/swift/megatron/train/trainers/trainer.py index 00027f1356..6bd239ae43 100644 --- a/swift/megatron/train/trainers/trainer.py +++ b/swift/megatron/train/trainers/trainer.py @@ -22,7 +22,8 @@ class MegatronTrainer: - def __init__(self): + def __init__(self, args): + self.args = args self.stimer = StragglerDetector() self._patch_megatron() @@ -246,7 +247,7 @@ def forward_step(self, data_iterator, model): return output_tensor, partial(loss_func, loss_mask) def train(self, train_dataset, val_dataset, data_collator): - args = get_args() + args = self.args datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset) datasets_provider.is_distributed = True with patch_megatron_data_collator(data_collator), self._get_iters(train_dataset, val_dataset): From f3b5003ae3fc13ca2668f330a4b431a9ff40608e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 11 Jun 2025 11:34:27 +0800 Subject: [PATCH 34/38] update --- swift/megatron/train/sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 100d09c93b..11d07b66b0 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -20,7 +20,7 @@ class MegatronSft(SwiftSft): args: args_class def prepare_trainer(self): - return MegatronTrainer() + return MegatronTrainer(self.args) def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) -> None: self.train_msg = {} From d110ed3ee4fbc6acbbed106eef0850af2ce6e395 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 11 Jun 2025 14:12:51 +0800 Subject: [PATCH 35/38] fix --- swift/megatron/train/sft.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 11d07b66b0..f2a95c6d2a 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -1,16 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -from functools import partial from typing import List, Union -from megatron.training import ft_integration, get_args, get_timers, is_last_rank, pretrain, print_rank_0, training - from swift.llm.train import SwiftSft from swift.utils import get_logger, is_master, plot_images from ..argument import MegatronTrainArguments from ..utils import patch_megatron_tokenizer from .trainers import MegatronTrainer -from .utils import build_streaming_dataloader, get_batch +from .utils import build_streaming_dataloader logger = get_logger() From e92f79cd1ca1967bc076226dd79196d232df6ea6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 11 Jun 2025 14:16:10 +0800 Subject: [PATCH 36/38] update --- examples/train/megatron/{rlhf/dpo.sh => dpo/dense.sh} | 0 examples/train/megatron/{rlhf => dpo}/moe.sh | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename examples/train/megatron/{rlhf/dpo.sh => dpo/dense.sh} (100%) rename examples/train/megatron/{rlhf => dpo}/moe.sh (100%) diff --git a/examples/train/megatron/rlhf/dpo.sh b/examples/train/megatron/dpo/dense.sh similarity index 100% rename from examples/train/megatron/rlhf/dpo.sh rename to examples/train/megatron/dpo/dense.sh diff --git a/examples/train/megatron/rlhf/moe.sh b/examples/train/megatron/dpo/moe.sh similarity index 100% rename from examples/train/megatron/rlhf/moe.sh rename to examples/train/megatron/dpo/moe.sh From 4fe4e4ef93869c9566d5480b858e75a9346f63cb Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 11 Jun 2025 14:21:59 +0800 Subject: [PATCH 37/38] update --- .../Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" | 2 +- docs/source_en/Instruction/Megatron-SWIFT-Training.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index 59048f925e..303cbe231c 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -110,7 +110,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I ``` - 若要进行预训练,你可以使用`megatron pt`替代`megatron sft`,这将会使用生成式的template进行训练。 -- **更多案例**:包括packing、多机、32K上下文、MoE模型、预训练,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron)。 +- **更多案例**:包括packing、多机、32K上下文、DPO、MoE模型、预训练,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron)。 ## Benchmark diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index 2c3a0382ee..11fef42846 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -114,7 +114,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I ``` - For pretraining, you can use `megatron pt` instead of `megatron sft`, which will use a generative template for training. -- **More examples**: Including packing, multi-node training, 32K context, MoE models, and pre-training, can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron). +- **More examples**: Including packing, multi-node training, 32K context, DPO, MoE models, and pre-training, can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron). ## Benchmark The speed comparison of full-parameter training for Dense/MoE models using `megatron sft` and `swift sft` on a single machine with eight A800 GPUs is shown below. The corresponding scripts can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/benchmark). From 7e8df5039e09f57975a388443eb29f29ec2fccbc Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 11 Jun 2025 14:26:05 +0800 Subject: [PATCH 38/38] update --- README.md | 1 + README_CN.md | 3 ++- examples/train/megatron/{ => rlhf}/dpo/dense.sh | 0 examples/train/megatron/{ => rlhf}/dpo/moe.sh | 0 4 files changed, 3 insertions(+), 1 deletion(-) rename examples/train/megatron/{ => rlhf}/dpo/dense.sh (100%) rename examples/train/megatron/{ => rlhf}/dpo/moe.sh (100%) diff --git a/README.md b/README.md index d32c297651..52d018058d 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ You can contact us and communicate with us by adding our group: ## 🎉 News +- 🎁 2025.06.11: Support for using Megatron parallelism techniques for RLHF training. The training script can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/rlhf). - 🎁 2025.05.29: Support sequence parallel in pt, sft, dpo and grpo, check script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text). - 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models). - 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383). diff --git a/README_CN.md b/README_CN.md index 9db95346e0..3502cd0fe9 100644 --- a/README_CN.md +++ b/README_CN.md @@ -70,7 +70,8 @@ - **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。 ## 🎉 新闻 -- 🎁 2025.05.29: 支持pt、sft、dpo、grpo的序列并行,具体请查看[脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text). +- 🎁 2025.06.11: 支持使用Megatron并行技术进行RLHF训练,训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/rlhf)。 +- 🎁 2025.05.29: 支持pt、sft、dpo、grpo的序列并行,具体请查看[脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text)。 - 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)。 - 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)。 - 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/vllm_multi_round.sh)。 diff --git a/examples/train/megatron/dpo/dense.sh b/examples/train/megatron/rlhf/dpo/dense.sh similarity index 100% rename from examples/train/megatron/dpo/dense.sh rename to examples/train/megatron/rlhf/dpo/dense.sh diff --git a/examples/train/megatron/dpo/moe.sh b/examples/train/megatron/rlhf/dpo/moe.sh similarity index 100% rename from examples/train/megatron/dpo/moe.sh rename to examples/train/megatron/rlhf/dpo/moe.sh