Skip to content

[megatron] support DPO #4193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2200de1
support megatron dpo
Jintao-Huang May 13, 2025
3a21dca
update
Jintao-Huang May 13, 2025
21d044a
update
Jintao-Huang May 13, 2025
29714f1
update
Jintao-Huang May 13, 2025
48bdef9
update
Jintao-Huang May 13, 2025
6806521
update
Jintao-Huang May 13, 2025
cc6b0f6
Merge branch 'main' into support_megatron_dpo
Jintao-Huang May 13, 2025
ebfe9d1
update
Jintao-Huang May 13, 2025
5802d8e
update
Jintao-Huang May 13, 2025
fef4b9b
update
Jintao-Huang May 13, 2025
0a69513
Merge branch 'main' into support_megatron_dpo
Jintao-Huang May 14, 2025
7a595d3
update
Jintao-Huang May 14, 2025
9dde75a
update
Jintao-Huang May 15, 2025
2552bda
Merge branch 'main' into support_megatron_dpo
Jintao-Huang May 28, 2025
007c9ed
update
Jintao-Huang May 28, 2025
b56b601
Merge branch 'main' into support_megatron_dpo
Jintao-Huang Jun 1, 2025
96e5f3a
Merge branch 'main' into support_megatron_dpo
Jintao-Huang Jun 1, 2025
54244e5
Merge branch 'main' into support_megatron_dpo
Jintao-Huang Jun 3, 2025
3382b0c
Merge branch 'main' into support_megatron_dpo
Jintao-Huang Jun 5, 2025
1f0f411
update
Jintao-Huang Jun 5, 2025
615befc
update
Jintao-Huang Jun 5, 2025
ab5bdfa
update
Jintao-Huang Jun 5, 2025
515d476
update
Jintao-Huang Jun 5, 2025
15db9af
update
Jintao-Huang Jun 6, 2025
f8d29a7
update shell
Jintao-Huang Jun 6, 2025
c96cef7
update
Jintao-Huang Jun 6, 2025
b6fc6d9
update
Jintao-Huang Jun 6, 2025
95f74fa
update
Jintao-Huang Jun 6, 2025
ac1c33b
update
Jintao-Huang Jun 6, 2025
7a75d93
update
Jintao-Huang Jun 6, 2025
d981bc5
update
Jintao-Huang Jun 6, 2025
9b089b5
fix dpo emoji dataset
Jintao-Huang Jun 7, 2025
93a7486
Merge branch 'fix_emoji_dpo_dataset' into support_megatron_dpo
Jintao-Huang Jun 7, 2025
db6fc3a
Merge branch 'main' into support_megatron_dpo
Jintao-Huang Jun 9, 2025
a6067bd
update
Jintao-Huang Jun 9, 2025
bd46a59
update
Jintao-Huang Jun 9, 2025
f8c8dcf
Merge branch 'main' into support_megatron_dpo
Jintao-Huang Jun 9, 2025
568b3aa
update
Jintao-Huang Jun 9, 2025
0dff938
update
Jintao-Huang Jun 9, 2025
c57a29e
update
Jintao-Huang Jun 10, 2025
79eca73
update
Jintao-Huang Jun 10, 2025
2ad6fd2
update
Jintao-Huang Jun 11, 2025
c330851
update
Jintao-Huang Jun 11, 2025
f3b5003
update
Jintao-Huang Jun 11, 2025
13c8696
Merge branch 'main' into support_megatron_dpo
Jintao-Huang Jun 11, 2025
d110ed3
fix
Jintao-Huang Jun 11, 2025
e92f79c
update
Jintao-Huang Jun 11, 2025
4fe4e4e
update
Jintao-Huang Jun 11, 2025
7e8df50
update
Jintao-Huang Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ You can contact us and communicate with us by adding our group:


## 🎉 News
- 🎁 2025.06.11: Support for using Megatron parallelism techniques for RLHF training. The training script can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/rlhf).
- 🎁 2025.05.29: Support sequence parallel in pt, sft, dpo and grpo, check script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text).
- 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models).
- 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383).
Expand Down
3 changes: 2 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
- **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。

## 🎉 新闻
- 🎁 2025.05.29: 支持pt、sft、dpo、grpo的序列并行,具体请查看[脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text).
- 🎁 2025.06.11: 支持使用Megatron并行技术进行RLHF训练,训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/rlhf)。
- 🎁 2025.05.29: 支持pt、sft、dpo、grpo的序列并行,具体请查看[脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text)。
- 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)。
- 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)。
- 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/vllm_multi_round.sh)。
Expand Down
20 changes: 18 additions & 2 deletions docs/source/Instruction/Megatron-SWIFT训练.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
```

- 若要进行预训练,你可以使用`megatron pt`替代`megatron sft`,这将会使用生成式的template进行训练。
- **更多案例**:包括packing、多机、32K上下文、MoE模型、预训练,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron)。
- **更多案例**:包括packing、多机、32K上下文、DPO、MoE模型、预训练,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron)。

## Benchmark

Expand Down Expand Up @@ -290,8 +290,17 @@ I am a language model developed by swift, you can call me swift-robot. How can I
- moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。
- moe_shared_expert_overlap: 启用共享专家计算与调度器通信之间的重叠。如果不启用此选项,共享专家将在路由专家之后执行。仅在设置了`moe_shared_expert_intermediate_size`时有效。默认为False。

**DPO参数**:
- ref_load: ref_model的加载路径。默认为None,即设置为`load`。
- beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。
- rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失)。`loss = dpo_loss + rpo_alpha * nll_loss`。默认为1。
- reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。
- label_smoothing: 默认为0.。
- f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。
- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。

### Megatron训练参数

### 训练参数

Megatron训练参数继承自Megatron参数和基本参数。基本参数的内容可以参考[这里](./命令行参数.md#基本参数)。此外还包括以下参数:

Expand All @@ -302,3 +311,10 @@ Megatron训练参数继承自Megatron参数和基本参数。基本参数的内
- lazy_tokenize: 默认为False。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(这可以避免在训练中出现报错);设置为True,则在训练中对数据集进行tokenize(这可以节约内存)。
- max_epochs: 训练到`max_epochs`时强制退出训练,并对权重进行验证和保存。该参数在使用流式数据集时很有用。默认为None。
- 注意:如果你使用非流式数据集,该参数会为你自动计算train_iters,你不需要手动传入`train_iters`。


### RLHF参数
除了继承训练参数外,还支持以下参数:
- rlhf_type: 默认为'dpo'。目前可选择为'dpo'。
- loss_scale: 覆盖[基本参数](./命令行参数.md)中的loss_scale。默认为'last_round'。
- calculate_per_token_loss: 覆盖Megatron参数,默认为False。
23 changes: 21 additions & 2 deletions docs/source_en/Instruction/Megatron-SWIFT-Training.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
```

- For pretraining, you can use `megatron pt` instead of `megatron sft`, which will use a generative template for training.
- **More examples**: Including packing, multi-node training, 32K context, MoE models, and pre-training, can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron).
- **More examples**: Including packing, multi-node training, 32K context, DPO, MoE models, and pre-training, can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron).

## Benchmark
The speed comparison of full-parameter training for Dense/MoE models using `megatron sft` and `swift sft` on a single machine with eight A800 GPUs is shown below. The corresponding scripts can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/benchmark).
Expand Down Expand Up @@ -302,7 +302,17 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the
- moe_expert_capacity_factor: Capacity factor for each expert, None means no tokens will be dropped. Default is None.
- moe_shared_expert_overlap: Enable overlapping of shared expert computation with scheduler communication. If this option is not enabled, shared experts will execute after the routing experts. Only effective when `moe_shared_expert_intermediate_size` is set. Default is False.

### Megatron Training Parameters
**DPO Parameters**
- ref_load: The path to load the reference model. Defaults to `None`, which means it will be set to `load`.
- beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1.
- rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) used to control the weight of the NLL term (i.e., SFT loss) in the loss function. The total loss is calculated as `loss = dpo_loss + rpo_alpha * nll_loss`. Default is 1.
- reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`.
- label_smoothing: Default is 0.
- f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values.
- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values.


### Training Parameters

Megatron training parameters inherit from Megatron parameters and basic parameters. For information on basic parameters, see [here](./Command-line-parameters.md#base-arguments). Additionally, the following parameters are included:

Expand All @@ -313,3 +323,12 @@ Megatron training parameters inherit from Megatron parameters and basic paramete
- lazy_tokenize: Default is False. If this parameter is set to False, all dataset samples are tokenized before training (this avoids errors during training); if set to True, tokenization occurs during training (this saves memory).
- max_epochs: Forces the training to exit after reaching `max_epochs`, and performs validation and saving of the model weights. This parameter is especially useful when using a streaming dataset. Default is None.
- Note: If you use a non-streaming dataset, this parameter will automatically calculate train_iters for you, so there is no need to pass `train_iters` manually.


### RLHF Parameters

In addition to inheriting the training parameters, the following parameters are also supported:

- rlhf_type: Default is 'dpo'. Currently, only 'dpo' is available.
- loss_scale: Overrides the `loss_scale` in [basic parameters](https://idealab.alibaba-inc.com/command_line_arguments.md). Default is 'last_round'.
- calculate_per_token_loss: Overrides the Megatron parameter. Default is False.
File renamed without changes.
33 changes: 33 additions & 0 deletions examples/train/megatron/rlhf/dpo/dense.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# 4 * 60GiB
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NPROC_PER_NODE=4 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
megatron rlhf \
--rlhf_type dpo \
--load Qwen3-8B-Base-mcore \
--dataset 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#20000' \
--tensor_model_parallel_size 4 \
--micro_batch_size 8 \
--global_batch_size 16 \
--recompute_granularity full \
--recompute_method uniform \
--recompute_num_layers 1 \
--max_epochs 1 \
--finetune true \
--cross_entropy_loss_fusion true \
--lr 1e-5 \
--lr_warmup_iters 50 \
--min_lr 1e-6 \
--save megatron_output/Qwen3-8B-Base \
--eval_interval 200 \
--save_interval 200 \
--max_length 8192 \
--num_workers 8 \
--dataset_num_proc 8 \
--no_save_optim true \
--no_save_rng true \
--sequence_parallel true \
--attention_backend flash \
--beta 0.1 \
--rpo_alpha 1 \
--loss_type sigmoid
36 changes: 36 additions & 0 deletions examples/train/megatron/rlhf/dpo/moe.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# 8 * 64GiB
NPROC_PER_NODE=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
megatron rlhf \
--rlhf_type dpo \
--load Qwen1.5-MoE-A2.7B-mcore \
--dataset 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#20000' \
--tensor_model_parallel_size 2 \
--expert_model_parallel_size 4 \
--moe_grouped_gemm true \
--moe_shared_expert_overlap true \
--moe_aux_loss_coeff 0.01 \
--micro_batch_size 4 \
--global_batch_size 16 \
--recompute_granularity full \
--recompute_method uniform \
--recompute_num_layers 1 \
--max_epochs 1 \
--finetune true \
--cross_entropy_loss_fusion true \
--lr 1e-5 \
--lr_warmup_iters 100 \
--min_lr 1e-6 \
--save megatron_output/Qwen1.5-MoE-A2.7B \
--eval_interval 200 \
--save_interval 200 \
--max_length 8192 \
--num_workers 8 \
--dataset_num_proc 8 \
--no_save_optim true \
--no_save_rng true \
--sequence_parallel true \
--attention_backend flash \
--beta 0.1 \
--rpo_alpha 1 \
--loss_type sigmoid
1 change: 1 addition & 0 deletions swift/cli/_megatron/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ROUTE_MAPPING: Dict[str, str] = {
'pt': 'swift.cli._megatron.pt',
'sft': 'swift.cli._megatron.sft',
'rlhf': 'swift.cli._megatron.rlhf',
}


Expand Down
5 changes: 5 additions & 0 deletions swift/cli/_megatron/rlhf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.megatron import megatron_rlhf_main

if __name__ == '__main__':
megatron_rlhf_main()
8 changes: 4 additions & 4 deletions swift/megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from swift.utils.import_utils import _LazyModule

if TYPE_CHECKING:
from .train import megatron_sft_main, megatron_pt_main
from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main
from .utils import convert_hf2mcore, convert_mcore2hf
from .argument import MegatronTrainArguments
from .argument import MegatronTrainArguments, MegatronRLHFArguments
from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model
else:
_import_structure = {
'train': ['megatron_sft_main', 'megatron_pt_main'],
'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'],
'utils': ['convert_hf2mcore', 'convert_mcore2hf'],
'argument': ['MegatronTrainArguments'],
'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments'],
'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model']
}

Expand Down
1 change: 1 addition & 0 deletions swift/megatron/argument/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .megatron_args import MegatronArguments
from .rlhf_args import MegatronRLHFArguments
from .train_args import MegatronTrainArguments
14 changes: 13 additions & 1 deletion swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,19 @@


@dataclass
class ExtraMegatronArguments:
class RLHFMegatronArgumentsMixin:
ref_load: Optional[str] = None

beta: float = 0.1
rpo_alpha: float = 1.
reference_free: bool = False
label_smoothing: float = 0.
f_divergence_type: str = 'reverse_kl'
loss_type: str = 'sigmoid'


@dataclass
class ExtraMegatronArguments(RLHFMegatronArgumentsMixin):
padded_vocab_size: Optional[int] = None
rope_scaling: Optional[Union[dict, str]] = None
torch_dtype: Optional[torch.dtype] = None
Expand Down
13 changes: 13 additions & 0 deletions swift/megatron/argument/rlhf_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass
from typing import Literal

from .train_args import MegatronTrainArguments


@dataclass
class MegatronRLHFArguments(MegatronTrainArguments):
rlhf_type: Literal['dpo'] = 'dpo'
loss_scale: str = 'last_round'

calculate_per_token_loss: bool = False
Loading
Loading