Skip to content

Commit a5dfdc2

Browse files
authored
[megatron] support DPO (#4193)
1 parent 19b34bc commit a5dfdc2

File tree

22 files changed

+899
-125
lines changed

22 files changed

+899
-125
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ You can contact us and communicate with us by adding our group:
7474

7575

7676
## 🎉 News
77+
- 🎁 2025.06.11: Support for using Megatron parallelism techniques for RLHF training. The training script can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/rlhf).
7778
- 🎁 2025.05.29: Support sequence parallel in pt, sft, dpo and grpo, check script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text).
7879
- 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models).
7980
- 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383).

README_CN.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@
7070
- **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。
7171

7272
## 🎉 新闻
73-
- 🎁 2025.05.29: 支持pt、sft、dpo、grpo的序列并行,具体请查看[脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text).
73+
- 🎁 2025.06.11: 支持使用Megatron并行技术进行RLHF训练,训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/rlhf)
74+
- 🎁 2025.05.29: 支持pt、sft、dpo、grpo的序列并行,具体请查看[脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/long_text)
7475
- 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)
7576
- 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)
7677
- 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/vllm_multi_round.sh)

docs/source/Instruction/Megatron-SWIFT训练.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
110110
```
111111

112112
- 若要进行预训练,你可以使用`megatron pt`替代`megatron sft`,这将会使用生成式的template进行训练。
113-
- **更多案例**:包括packing、多机、32K上下文、MoE模型、预训练,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron)
113+
- **更多案例**:包括packing、多机、32K上下文、DPO、MoE模型、预训练,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron)
114114

115115
## Benchmark
116116

@@ -290,8 +290,17 @@ I am a language model developed by swift, you can call me swift-robot. How can I
290290
- moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。
291291
- moe_shared_expert_overlap: 启用共享专家计算与调度器通信之间的重叠。如果不启用此选项,共享专家将在路由专家之后执行。仅在设置了`moe_shared_expert_intermediate_size`时有效。默认为False。
292292

293+
**DPO参数**:
294+
- ref_load: ref_model的加载路径。默认为None,即设置为`load`
295+
- beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。
296+
- rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失)。`loss = dpo_loss + rpo_alpha * nll_loss`。默认为1。
297+
- reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。
298+
- label_smoothing: 默认为0.。
299+
- f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)
300+
- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)
293301

294-
### Megatron训练参数
302+
303+
### 训练参数
295304

296305
Megatron训练参数继承自Megatron参数和基本参数。基本参数的内容可以参考[这里](./命令行参数.md#基本参数)。此外还包括以下参数:
297306

@@ -302,3 +311,10 @@ Megatron训练参数继承自Megatron参数和基本参数。基本参数的内
302311
- lazy_tokenize: 默认为False。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(这可以避免在训练中出现报错);设置为True,则在训练中对数据集进行tokenize(这可以节约内存)。
303312
- max_epochs: 训练到`max_epochs`时强制退出训练,并对权重进行验证和保存。该参数在使用流式数据集时很有用。默认为None。
304313
- 注意:如果你使用非流式数据集,该参数会为你自动计算train_iters,你不需要手动传入`train_iters`
314+
315+
316+
### RLHF参数
317+
除了继承训练参数外,还支持以下参数:
318+
- rlhf_type: 默认为'dpo'。目前可选择为'dpo'。
319+
- loss_scale: 覆盖[基本参数](./命令行参数.md)中的loss_scale。默认为'last_round'。
320+
- calculate_per_token_loss: 覆盖Megatron参数,默认为False。

docs/source_en/Instruction/Megatron-SWIFT-Training.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
114114
```
115115

116116
- For pretraining, you can use `megatron pt` instead of `megatron sft`, which will use a generative template for training.
117-
- **More examples**: Including packing, multi-node training, 32K context, MoE models, and pre-training, can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron).
117+
- **More examples**: Including packing, multi-node training, 32K context, DPO, MoE models, and pre-training, can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron).
118118

119119
## Benchmark
120120
The speed comparison of full-parameter training for Dense/MoE models using `megatron sft` and `swift sft` on a single machine with eight A800 GPUs is shown below. The corresponding scripts can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/benchmark).
@@ -302,7 +302,17 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the
302302
- moe_expert_capacity_factor: Capacity factor for each expert, None means no tokens will be dropped. Default is None.
303303
- moe_shared_expert_overlap: Enable overlapping of shared expert computation with scheduler communication. If this option is not enabled, shared experts will execute after the routing experts. Only effective when `moe_shared_expert_intermediate_size` is set. Default is False.
304304

305-
### Megatron Training Parameters
305+
**DPO Parameters**
306+
- ref_load: The path to load the reference model. Defaults to `None`, which means it will be set to `load`.
307+
- beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1.
308+
- rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) used to control the weight of the NLL term (i.e., SFT loss) in the loss function. The total loss is calculated as `loss = dpo_loss + rpo_alpha * nll_loss`. Default is 1.
309+
- reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`.
310+
- label_smoothing: Default is 0.
311+
- f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values.
312+
- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values.
313+
314+
315+
### Training Parameters
306316

307317
Megatron training parameters inherit from Megatron parameters and basic parameters. For information on basic parameters, see [here](./Command-line-parameters.md#base-arguments). Additionally, the following parameters are included:
308318

@@ -313,3 +323,12 @@ Megatron training parameters inherit from Megatron parameters and basic paramete
313323
- lazy_tokenize: Default is False. If this parameter is set to False, all dataset samples are tokenized before training (this avoids errors during training); if set to True, tokenization occurs during training (this saves memory).
314324
- max_epochs: Forces the training to exit after reaching `max_epochs`, and performs validation and saving of the model weights. This parameter is especially useful when using a streaming dataset. Default is None.
315325
- Note: If you use a non-streaming dataset, this parameter will automatically calculate train_iters for you, so there is no need to pass `train_iters` manually.
326+
327+
328+
### RLHF Parameters
329+
330+
In addition to inheriting the training parameters, the following parameters are also supported:
331+
332+
- rlhf_type: Default is 'dpo'. Currently, only 'dpo' is available.
333+
- loss_scale: Overrides the `loss_scale` in [basic parameters](https://idealab.alibaba-inc.com/command_line_arguments.md). Default is 'last_round'.
334+
- calculate_per_token_loss: Overrides the Megatron parameter. Default is False.
File renamed without changes.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# 4 * 60GiB
2+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
3+
NPROC_PER_NODE=4 \
4+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
5+
megatron rlhf \
6+
--rlhf_type dpo \
7+
--load Qwen3-8B-Base-mcore \
8+
--dataset 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#20000' \
9+
--tensor_model_parallel_size 4 \
10+
--micro_batch_size 8 \
11+
--global_batch_size 16 \
12+
--recompute_granularity full \
13+
--recompute_method uniform \
14+
--recompute_num_layers 1 \
15+
--max_epochs 1 \
16+
--finetune true \
17+
--cross_entropy_loss_fusion true \
18+
--lr 1e-5 \
19+
--lr_warmup_iters 50 \
20+
--min_lr 1e-6 \
21+
--save megatron_output/Qwen3-8B-Base \
22+
--eval_interval 200 \
23+
--save_interval 200 \
24+
--max_length 8192 \
25+
--num_workers 8 \
26+
--dataset_num_proc 8 \
27+
--no_save_optim true \
28+
--no_save_rng true \
29+
--sequence_parallel true \
30+
--attention_backend flash \
31+
--beta 0.1 \
32+
--rpo_alpha 1 \
33+
--loss_type sigmoid
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# 8 * 64GiB
2+
NPROC_PER_NODE=8 \
3+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
4+
megatron rlhf \
5+
--rlhf_type dpo \
6+
--load Qwen1.5-MoE-A2.7B-mcore \
7+
--dataset 'hjh0119/shareAI-Llama3-DPO-zh-en-emoji#20000' \
8+
--tensor_model_parallel_size 2 \
9+
--expert_model_parallel_size 4 \
10+
--moe_grouped_gemm true \
11+
--moe_shared_expert_overlap true \
12+
--moe_aux_loss_coeff 0.01 \
13+
--micro_batch_size 4 \
14+
--global_batch_size 16 \
15+
--recompute_granularity full \
16+
--recompute_method uniform \
17+
--recompute_num_layers 1 \
18+
--max_epochs 1 \
19+
--finetune true \
20+
--cross_entropy_loss_fusion true \
21+
--lr 1e-5 \
22+
--lr_warmup_iters 100 \
23+
--min_lr 1e-6 \
24+
--save megatron_output/Qwen1.5-MoE-A2.7B \
25+
--eval_interval 200 \
26+
--save_interval 200 \
27+
--max_length 8192 \
28+
--num_workers 8 \
29+
--dataset_num_proc 8 \
30+
--no_save_optim true \
31+
--no_save_rng true \
32+
--sequence_parallel true \
33+
--attention_backend flash \
34+
--beta 0.1 \
35+
--rpo_alpha 1 \
36+
--loss_type sigmoid

swift/cli/_megatron/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
ROUTE_MAPPING: Dict[str, str] = {
1010
'pt': 'swift.cli._megatron.pt',
1111
'sft': 'swift.cli._megatron.sft',
12+
'rlhf': 'swift.cli._megatron.rlhf',
1213
}
1314

1415

swift/cli/_megatron/rlhf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from swift.megatron import megatron_rlhf_main
3+
4+
if __name__ == '__main__':
5+
megatron_rlhf_main()

swift/megatron/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
from swift.utils.import_utils import _LazyModule
1313

1414
if TYPE_CHECKING:
15-
from .train import megatron_sft_main, megatron_pt_main
15+
from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main
1616
from .utils import convert_hf2mcore, convert_mcore2hf
17-
from .argument import MegatronTrainArguments
17+
from .argument import MegatronTrainArguments, MegatronRLHFArguments
1818
from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model
1919
else:
2020
_import_structure = {
21-
'train': ['megatron_sft_main', 'megatron_pt_main'],
21+
'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'],
2222
'utils': ['convert_hf2mcore', 'convert_mcore2hf'],
23-
'argument': ['MegatronTrainArguments'],
23+
'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments'],
2424
'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model']
2525
}
2626

swift/megatron/argument/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from .megatron_args import MegatronArguments
3+
from .rlhf_args import MegatronRLHFArguments
34
from .train_args import MegatronTrainArguments

swift/megatron/argument/megatron_args.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,19 @@
1414

1515

1616
@dataclass
17-
class ExtraMegatronArguments:
17+
class RLHFMegatronArgumentsMixin:
18+
ref_load: Optional[str] = None
19+
20+
beta: float = 0.1
21+
rpo_alpha: float = 1.
22+
reference_free: bool = False
23+
label_smoothing: float = 0.
24+
f_divergence_type: str = 'reverse_kl'
25+
loss_type: str = 'sigmoid'
26+
27+
28+
@dataclass
29+
class ExtraMegatronArguments(RLHFMegatronArgumentsMixin):
1830
padded_vocab_size: Optional[int] = None
1931
rope_scaling: Optional[Union[dict, str]] = None
2032
torch_dtype: Optional[torch.dtype] = None

swift/megatron/argument/rlhf_args.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from dataclasses import dataclass
3+
from typing import Literal
4+
5+
from .train_args import MegatronTrainArguments
6+
7+
8+
@dataclass
9+
class MegatronRLHFArguments(MegatronTrainArguments):
10+
rlhf_type: Literal['dpo'] = 'dpo'
11+
loss_scale: str = 'last_round'
12+
13+
calculate_per_token_loss: bool = False

0 commit comments

Comments
 (0)