Skip to content

Commit 19b34bc

Browse files
authored
[megatron/dpo] fix megatron packing_cache & update DPOTrainer (#4556)
1 parent cb4ac9b commit 19b34bc

File tree

10 files changed

+56
-56
lines changed

10 files changed

+56
-56
lines changed

docs/source/Instruction/Megatron-SWIFT训练.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
172172
- seq_length: 默认为None,即设置为`max_length`。对数据集长度进行限制请使用基本参数中的`--max_length`控制,无需设置此参数。
173173
- use_cpu_initialization: 在cpu上初始化权重,默认为False。在进行HF和MCore权重转换时会被使用。
174174
- no_create_attention_mask_in_dataloader: 在dataloader中不创建attention mask,默认为True。
175-
- extra_megatron_kwargs: Additional parameters passed to Megatron, provided as a JSON object. Defaults to None.
175+
- extra_megatron_kwargs: 传入megatron的其他参数,使用json传递。默认为None。
176176

177177
**学习率参数**:
178178
- 🔥lr: 初始学习率,最终会根据学习率预热策略和衰减策略决定每个迭代的学习率,默认为1e-5。
@@ -221,7 +221,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
221221

222222
**日志参数**:
223223
- log_params_norm: 记录参数的norm。默认为False。
224-
- log_throughput: 记录每个GPU的吞吐量。默认为True
224+
- log_throughput: 记录每个GPU的吞吐量。默认为False
225225
- 注意:在非packing情况下,log_throughput并不准确,因为`seq_length`并不等于真实序列长度。
226226
- tensorboard_log_interval: 记录到tensorboard的间隔(steps),默认为1。
227227
- tensorboard_queue_size: 队列长度(与磁盘IO相关),类似于写入的间隔。默认为50。
@@ -235,7 +235,8 @@ I am a language model developed by swift, you can call me swift-robot. How can I
235235
- wandb_save_dir: 本地保存 wandb 结果的路径。默认为''。
236236

237237
**评估参数**:
238-
- 🔥eval_iters: 评估的迭代次数,默认为100。
238+
- 🔥eval_iters: 评估的迭代次数,默认为-1,根据验证数据集的数量设置合适的值。
239+
- 注意:若使用流式数据集,该值需要手动设置。
239240
- 🔥eval_interval: 评估的间隔(steps),默认为None,即设置为save_interval。
240241

241242
**混合精度参数**:
@@ -295,7 +296,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
295296
Megatron训练参数继承自Megatron参数和基本参数。基本参数的内容可以参考[这里](./命令行参数.md#基本参数)。此外还包括以下参数:
296297

297298
- add_version: 在`save`上额外增加目录`'<版本号>-<时间戳>'`防止权重覆盖,默认为True。
298-
- 🔥packing: 是否使用序列packing,默认为False。
299+
- 🔥packing: 是否使用序列packing,默认为False。当前支持`megatron pt/sft`
299300
- 🔥packing_cache: 指定 packing 缓存目录。默认值为`None`,表示缓存将存储在环境变量 `$MODELSCOPE_CACHE`所指定的路径下。在跨节点使用 packing 功能时,需确保所有节点的 packing 缓存路径共享且一致。你可以通过设置`MODELSCOPE_CACHE`环境变量,或在命令行中添加 `--packing_cache <shared_path>`参数来实现这一要求。
300301
- 🔥streaming: 流式读取并处理数据集,默认False。通常在处理大型数据集时,设置为True。更多流式的参数查看命令行参数文档。
301302
- lazy_tokenize: 默认为False。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(这可以避免在训练中出现报错);设置为True,则在训练中对数据集进行tokenize(这可以节约内存)。

docs/source_en/Instruction/Megatron-SWIFT-Training.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ The speed comparison of full-parameter training for Dense/MoE models using `mega
175175
seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the dataset length, please use the `--max_length` parameter in the basic arguments; there is no need to set this parameter.
176176
- use_cpu_initialization: Initializes weights on the CPU, default is False. Used during HF and MCore weight conversion.
177177
- no_create_attention_mask_in_dataloader: Does not create an attention mask in the dataloader, default is True.
178-
- extra_megatron_kwargs: 传入megatron的其他参数,使用json传递。默认为None。
178+
- extra_megatron_kwargs: Additional parameters passed to Megatron, provided as a JSON object. Defaults to None.
179179

180180
**Learning Rate Parameters**:
181181

@@ -229,7 +229,7 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the
229229
**Logging Parameters**:
230230

231231
- log_params_norm: Logs the norm of parameters. Default is False.
232-
- log_throughput: Logs throughput per GPU. Default is True.
232+
- log_throughput: Logs throughput per GPU. Default is False.
233233
- Note: In non-packing scenarios, log_throughput is not accurate because `seq_length` does not equal the actual sequence length.
234234
- tensorboard_log_interval: Interval (steps) for logging to TensorBoard, default is 1.
235235
- tensorboard_queue_size: Queue length (related to disk I/O), similar to write intervals. Default is 50.
@@ -244,7 +244,8 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the
244244

245245
**Evaluation Parameters**:
246246

247-
- 🔥eval_iters: Number of evaluation iterations, default is 100.
247+
- 🔥eval_iters: The number of iterations for evaluation. Defaults to -1, and a suitable value will be set based on the size of the validation dataset.
248+
- Note: If using a streaming dataset, this value needs to be set manually.
248249
- 🔥eval_interval: Evaluation interval (steps), default is None, meaning it will be set to save_interval.
249250

250251
**Mixed Precision Parameters**:
@@ -306,7 +307,7 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the
306307
Megatron training parameters inherit from Megatron parameters and basic parameters. For information on basic parameters, see [here](./Command-line-parameters.md#base-arguments). Additionally, the following parameters are included:
307308

308309
- add_version: Adds a directory `<version>-<timestamp>` to `save` to prevent overwriting weights, default is True.
309-
- 🔥packing: Whether to use sequence packing, defaults to False.
310+
- 🔥packing: Whether to use sequence packing, defaults to False. Currently supports `megatron pt/sft`.
310311
- 🔥packing_cache: Specifies the directory for packing cache. The default value is `None`, which means the cache will be stored in the path defined by the environment variable `$MODELSCOPE_CACHE`. When using the packing feature across multiple nodes, ensure that all nodes share the same packing cache directory. You can achieve this by setting the `MODELSCOPE_CACHE` environment variable or by adding the `--packing_cache <shared_path>` argument in the command line.
311312
- 🔥streaming: Stream reading and processing of the dataset, default is False. It is typically set to True when handling large datasets. For more information on streaming parameters, refer to the command-line parameters documentation.
312313
- lazy_tokenize: Default is False. If this parameter is set to False, all dataset samples are tokenized before training (this avoids errors during training); if set to True, tokenization occurs during training (this saves memory).

swift/llm/model/model/qwen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,4 +936,5 @@ def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx
936936
],
937937
TemplateType.qwen3_emb,
938938
get_model_tokenizer_with_flash_attn,
939+
additional_saved_files=['config_sentence_transformers.json', '1_Pooling', 'modules.json'],
939940
architectures=['Qwen3ForCausalLM']))

swift/megatron/argument/megatron_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class MegatronArguments(ExtraMegatronArguments):
150150

151151
# logging
152152
log_params_norm: bool = False
153-
log_throughput: bool = True
153+
log_throughput: bool = False
154154
tensorboard_log_interval: int = 1
155155
tensorboard_queue_size: int = 50
156156
log_timers_to_tensorboard: bool = True
@@ -163,7 +163,7 @@ class MegatronArguments(ExtraMegatronArguments):
163163
wandb_save_dir: Optional[str] = None
164164

165165
# evaluate
166-
eval_iters: int = 100
166+
eval_iters: int = -1
167167
eval_interval: Optional[int] = None
168168

169169
# other

swift/megatron/train/sft.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,28 @@ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None)
3737
args.save_args(args.save)
3838

3939
@contextmanager
40-
def _get_train_iters(self, train_dataset):
41-
from megatron.training import training
40+
def _get_iters(self, train_dataset, val_dataset):
4241
origin_initialize_megatron = training.initialize_megatron
4342

4443
def initialize_megatron(*_args, **kwargs):
4544
res = origin_initialize_megatron(*_args, **kwargs)
4645
args = get_args()
47-
if args.train_iters is None and hasattr(train_dataset, '__len__'):
48-
data_parallel_size = mpu.get_data_parallel_world_size()
49-
step_batch_size = \
50-
args.micro_batch_size * data_parallel_size
51-
dataset_sample = len(train_dataset) // step_batch_size * step_batch_size
52-
args.train_iters = (dataset_sample * args.max_epochs // args.global_batch_size) + 1
46+
data_parallel_size = mpu.get_data_parallel_world_size()
47+
step_batch_size = args.micro_batch_size * data_parallel_size
48+
if args.train_iters is None:
49+
if hasattr(train_dataset, '__len__'):
50+
dataset_sample = len(train_dataset) // step_batch_size * step_batch_size
51+
args.train_iters = (dataset_sample * args.max_epochs // args.global_batch_size) + 1
52+
else:
53+
raise ValueError(
54+
'You are using a streaming training dataset. Please explicitly specify `--train_iters`.')
55+
if val_dataset is not None and args.eval_iters < 0:
56+
if hasattr(val_dataset, '__len__'):
57+
dataset_sample = len(val_dataset) // step_batch_size * step_batch_size
58+
args.eval_iters = max(dataset_sample // args.global_batch_size, 1)
59+
else:
60+
raise ValueError(
61+
'You are using a streaming validation dataset. Please explicitly specify `--eval_iters`.')
5362
return res
5463

5564
training.initialize_megatron = initialize_megatron
@@ -136,7 +145,7 @@ def run(self):
136145
logging_path = os.path.join(args.save, 'logging.jsonl')
137146
logger.info(f'The logging file will be saved in: {logging_path}')
138147
try:
139-
with patch_megatron_data_collator(data_collator), self._get_train_iters(train_dataset):
148+
with patch_megatron_data_collator(data_collator), self._get_iters(train_dataset, val_dataset):
140149
extra_args_provider = args.megatron_model_meta.extra_args_provider
141150
pretrain(
142151
datasets_provider,

swift/megatron/train/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@ def _broadcast(item):
8383
_broadcast(batch['position_ids'])
8484

8585
elif mpu.is_pipeline_first_stage():
86+
batch['labels'] = None
8687
_broadcast(batch['input_ids'])
8788
_broadcast(batch['attention_mask'])
8889
_broadcast(batch['position_ids'])
8990

9091
elif mpu.is_pipeline_last_stage():
92+
batch['input_ids'] = None
9193
_broadcast(batch['labels'])
9294
_broadcast(batch['attention_mask'])
9395
_broadcast(batch['position_ids'])

swift/trainers/mixin.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,6 @@ def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
223223
else:
224224
if self.model.__class__.__name__ != 'SentenceTransformer':
225225
self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
226-
# For embedding models, they should copy extra sentence_transformers files
227-
from swift.utils import copy_files_by_pattern
228-
copy_files_by_pattern(self.model.model_dir, output_dir, 'config_sentence_transformers.json')
229-
copy_files_by_pattern(self.model.model_dir, output_dir, '1_Pooling/config.json')
230-
copy_files_by_pattern(self.model.model_dir, output_dir, 'modules.json')
231226
else:
232227

233228
@contextmanager

swift/trainers/rlhf_trainer/dpo_trainer.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,11 @@ def __init__(self,
3939

4040
super().__init__(model, ref_model, *_args, **kwargs)
4141

42-
def get_nll_loss(self, logits, labels):
43-
# Flatten the tokens
44-
loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_token_id)
45-
logits = logits.view(-1, logits.shape[-1])
46-
labels = labels.view(-1)
47-
# Enable model parallelism
48-
labels = labels.to(logits.device)
49-
return loss_fct(logits, labels)
50-
5142
def concatenated_forward(
52-
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]], **kwargs
43+
self,
44+
model: nn.Module,
45+
batch: Dict[str, Union[List, torch.LongTensor]],
46+
is_ref_model: bool = False
5347
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
5448
batch = batch.copy()
5549
labels = batch.pop('labels', None)
@@ -76,7 +70,9 @@ def concatenated_forward(
7670
if not self.is_encoder_decoder and self.template.sequence_parallel_size == 1:
7771
# Shift so that tokens < n predict n
7872
labels = torch.roll(labels, shifts=-1, dims=1)
79-
per_token_logps, mean_all_logits, loss_mask = self.get_per_token_logps(all_logits, labels)
73+
per_token_logps, mean_all_logits, loss_mask = self.get_per_token_logps(
74+
all_logits, labels, label_pad_token_id=self.label_pad_token_id)
75+
origin_per_token_logps = per_token_logps
8076
if self.loss_type == 'ipo':
8177
size_completion = loss_mask.sum(dim=-1)
8278
per_token_logps = per_token_logps / size_completion
@@ -90,15 +86,17 @@ def concatenated_forward(
9086
all_logps[i] = per_token_logps[:, start:end].sum()
9187
num_examples = all_logps.shape[0] // 2
9288
num_tokens = cu_seqlens[num_examples]
93-
output['nll_loss'] = self.get_nll_loss(all_logits[:, :num_tokens], labels[:, :num_tokens])
89+
if not is_ref_model:
90+
output['nll_loss'] = -origin_per_token_logps[:, :num_tokens][loss_mask[:, :num_tokens]].mean()
9491
output['chosen_logps'] = all_logps[:num_examples]
9592
output['rejected_logps'] = all_logps[num_examples:]
9693
output['mean_chosen_logits'] = mean_all_logits[:, :num_tokens][loss_mask[:, :num_tokens]].mean()
9794
output['mean_rejected_logits'] = mean_all_logits[:, num_tokens:][loss_mask[:, num_tokens:]].mean()
9895
else:
9996
all_logps = per_token_logps.sum(-1)
10097
num_examples = labels.shape[0] // 2
101-
output['nll_loss'] = self.get_nll_loss(all_logits[:num_examples], labels[:num_examples])
98+
if not is_ref_model:
99+
output['nll_loss'] = -origin_per_token_logps[:num_examples][loss_mask[:num_examples]].mean()
102100
output['chosen_logps'] = all_logps[:num_examples]
103101
output['rejected_logps'] = all_logps[num_examples:]
104102
output['mean_chosen_logits'] = mean_all_logits[:num_examples][loss_mask[:num_examples]].mean()
@@ -107,15 +105,16 @@ def concatenated_forward(
107105
output['aux_loss'] = outputs.aux_loss
108106
return output
109107

108+
@staticmethod
110109
def get_per_token_logps(
111-
self,
112110
logits: torch.FloatTensor,
113111
labels: torch.LongTensor,
112+
label_pad_token_id=-100,
114113
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
115114
if logits.shape[:-1] != labels.shape:
116115
raise ValueError(f'Logits (batch and sequence length dim) {logits.shape[:-1]}'
117116
'and labels must have the same shape {labels.shape}')
118-
loss_mask = labels != self.label_pad_token_id
117+
loss_mask = labels != label_pad_token_id
119118
labels = labels.clone()
120119
labels[~loss_mask] = 0
121120
# https://github.com/huggingface/trl/pull/2799

swift/trainers/sequence_parallel/ulysses.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,13 @@ def old_policy(self):
176176

177177

178178
# For DPO
179-
def get_per_token_logps(self,
180-
logits: torch.FloatTensor,
179+
def get_per_token_logps(logits: torch.FloatTensor,
181180
labels: torch.LongTensor,
181+
label_pad_token_id=-100,
182182
ulysses=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
183183
if labels.shape[1] > logits.shape[1]:
184184
_, _, labels, _, _, _ = ulysses.pad_and_split_inputs(None, None, labels, None, None, None)
185-
loss_mask = labels != self.label_pad_token_id
185+
loss_mask = labels != label_pad_token_id
186186
labels = labels.clone() # No need to shift, pad and split has shifted the inputs.
187187
labels[~loss_mask] = 0
188188
labels = labels.to(logits.device)
@@ -840,12 +840,7 @@ def prepare_trainer(self, trainer):
840840
elif trainer.__class__.__name__ == 'DPOTrainer':
841841
trainer._origin_prepare_inputs = trainer._prepare_inputs
842842
trainer._prepare_inputs = MethodType(partial(_prepare_inputs, ulysses=self), trainer)
843-
trainer.get_per_token_logps = MethodType(partial(get_per_token_logps, ulysses=self), trainer)
844-
845-
def rlhf_loss_scale_sp_func(_, *args, **kwargs):
846-
return loss_scale_sp_func(*args, ulysses=self, **kwargs)
847-
848-
trainer.get_nll_loss = MethodType(rlhf_loss_scale_sp_func, trainer)
843+
trainer.get_per_token_logps = partial(get_per_token_logps, ulysses=self)
849844

850845
elif trainer.__class__.__name__ == 'GRPOTrainer':
851846
assert version.parse(trl.__version__) >= version.parse('0.18.0')

swift/utils/torch_utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -413,21 +413,18 @@ def check_shared_disk(error, cache_dir: Optional[str] = None):
413413
os.makedirs(cache_dir, exist_ok=True)
414414
tmp_path = os.path.join(cache_dir, 'check_shared_disk.tmp')
415415
is_shared_disk = True
416-
with safe_ddp_context(None, True):
417-
if os.path.exists(tmp_path):
418-
os.remove(tmp_path)
416+
419417
try:
420418
with safe_ddp_context(None, True):
421419
if is_master():
422420
with open(tmp_path, 'w'):
423421
pass
424-
else:
425-
if not os.path.exists(tmp_path):
426-
is_shared_disk = False
422+
if not os.path.exists(tmp_path):
423+
is_shared_disk = False
424+
shared_state = [None] * dist.get_world_size()
425+
dist.all_gather_object(shared_state, is_shared_disk)
427426
finally:
428427
if is_master() and os.path.exists(tmp_path):
429428
os.remove(tmp_path)
430-
shared_state = [None] * dist.get_world_size()
431-
dist.all_gather_object(shared_state, is_shared_disk)
432429
if not all(shared_state):
433430
raise error

0 commit comments

Comments
 (0)