Skip to content

Commit 9bf7c9e

Browse files
committed
update
1 parent 652b0bc commit 9bf7c9e

File tree

6 files changed

+32
-17
lines changed

6 files changed

+32
-17
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,10 +356,11 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
356356
- 支持的多模态模型参考:https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh
357357
- packing_cache: 指定 packing 缓存目录。默认值为`None`,表示缓存将存储在环境变量 `$MODELSCOPE_CACHE`所指定的路径下。在跨节点使用 packing 功能时,需确保所有节点的 packing 缓存路径共享且一致。你可以通过设置`MODELSCOPE_CACHE`环境变量,或在命令行中添加 `--packing_cache <shared_path>`参数来实现这一要求。
358358
- 🔥lazy_tokenize: 是否使用lazy_tokenize。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(多模态模型则包括从磁盘中读取图片)。该参数在LLM训练中默认设置为False,而MLLM训练默认为True,节约内存。
359+
- use_logits_to_keep: 通过在`forward`中根据labels传入logits_to_keep,减少无效logits的计算与存储,从而减少显存占用并加快训练速度。默认为None,进行自动选择。
359360
- acc_strategy: 训练和验证时计算acc的策略。可选为`seq``token`级别的acc,默认为`token`
360361
- max_new_tokens: 覆盖生成参数。predict_with_generate=True时的最大生成token数量,默认64。
361362
- temperature: 覆盖生成参数。predict_with_generate=True时的temperature,默认0。
362-
- optimizer: plugin的自定义optimizer名称,默认为None。
363+
- optimizer: plugin的自定义optimizer名称,默认为None。可选optimizer参考[这里](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/optimizer.py)
363364
- metric: plugin的自定义metric名称。默认为None,即在predict_with_generate=False的情况下设置为'acc',在predict_with_generate=True的情况下设置为'nlg'。
364365
- eval_use_evalscope: 是否使用evalscope进行训练时评测,需要设置该参数来开启评测,具体使用参考[示例](../Instruction/评测.md#训练中评测)
365366
- eval_datasets: 评测数据集,可设置多个数据集,用空格分割。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,11 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine
365365
- Supported multimodal models reference: https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh
366366
- packing_cache: Specifies the directory for packing cache. The default value is `None`, which means the cache will be stored in the path defined by the environment variable `$MODELSCOPE_CACHE`. When using the packing feature across multiple nodes, ensure that all nodes share the same packing cache directory. You can achieve this by setting the `MODELSCOPE_CACHE` environment variable or by adding the `--packing_cache <shared_path>` argument in the command line.
367367
- 🔥lazy_tokenize: Whether to use lazy tokenization. If set to False, all dataset samples are tokenized before training (for multimodal models, this includes reading images from disk). This parameter defaults to False for LLM training, and True for MLLM training, to save memory.
368+
- use_logits_to_keep: Pass `logits_to_keep` in the `forward` method based on labels to reduce the computation and storage of unnecessary logits, thereby reducing memory usage and accelerating training. The default is `None`, which enables automatic selection.
368369
- acc_strategy: Strategy for calculating accuracy during training and validation. Options are `seq`-level and `token`-level accuracy, with `token` as the default.
369370
- max_new_tokens: Generation parameter override. The maximum number of tokens to generate when `predict_with_generate=True`, defaulting to 64.
370371
- temperature: Generation parameter override. The temperature setting when `predict_with_generate=True`, defaulting to 0.
371-
- optimizer: Custom optimizer name for the plugin, defaults to None.
372+
- optimizer: Custom optimizer name for the plugin, defaults to None. Optional optimizer reference: [here](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/optimizer.py).
372373
- metric: Custom metric name for the plugin. Defaults to None, with the default set to 'acc' when `predict_with_generate=False` and 'nlg' when `predict_with_generate=True`.
373374
- eval_use_evalscope: Whether to use evalscope for evaluation, this parameter needs to be set to enable evaluation, refer to [example](../Instruction/Evaluation.md#evaluation-during-training). Default is False.
374375
- eval_datasets: Evaluation datasets, multiple datasets can be set, separated by spaces

swift/llm/argument/train_args.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra
112112

113113
# plugin
114114
loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
115-
optimizer: Optional[str] = None
116115
metric: Optional[str] = None
117116

118117
# extra

swift/llm/train/sft.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,6 @@ def _get_dataset(self):
7575

7676
return train_dataset, val_dataset
7777

78-
def _get_loss_func(self):
79-
args = self.args
80-
loss_type = args.loss_type
81-
if loss_type is None and args.loss_scale != 'default':
82-
loss_type = 'loss_scale'
83-
return get_loss_func(loss_type)
84-
8578
def _get_data_collator(self):
8679
args = self.args
8780
template = self.template
@@ -141,7 +134,7 @@ def _get_trainer_kwargs(self):
141134
return {
142135
'compute_metrics': compute_metrics,
143136
'preprocess_logits_for_metrics': preprocess_logits_for_metrics,
144-
'compute_loss_func': self._get_loss_func()
137+
'compute_loss_func': get_loss_func(args.loss_type)
145138
}
146139

147140
def _save_trainer_state(self, trainer):

swift/trainers/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class TrainArgumentsMixin:
5050
aligner_lr: Optional[float] = None
5151
vit_lr: Optional[float] = None
5252
optimizer: Optional[str] = None
53+
use_logits_to_keep: Optional[bool] = None
5354

5455
# torchacc
5556
metric_warmup_step: Optional[float] = 0

swift/trainers/trainers.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
# Part of the implementation is borrowed from huggingface/transformers.
3+
import inspect
34
import os
45
from contextlib import contextmanager, nullcontext
56
from functools import wraps
@@ -15,10 +16,12 @@
1516
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
1617
from transformers.utils import is_peft_available
1718

18-
from swift.utils import JsonlWriter, Serializer, gc_collect
19+
from swift.utils import JsonlWriter, Serializer, gc_collect, get_logger, is_mp
1920
from .arguments import Seq2SeqTrainingArguments, TrainingArguments
2021
from .mixin import DataLoaderMixin, SwiftMixin
2122

23+
logger = get_logger()
24+
2225

2326
class Trainer(SwiftMixin, HfTrainer):
2427
args: TrainingArguments
@@ -152,15 +155,32 @@ def prediction_step(
152155
return None, response_list, labels_list
153156

154157
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
158+
from swift.plugin.loss import get_loss_func
155159
loss_kwargs = {}
156160
labels = None
157-
if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs:
158-
labels = inputs.pop('labels')
159-
161+
compute_loss_func = self.compute_loss_func
160162
loss_scale = inputs.pop('loss_scale', None)
161163
if loss_scale is not None:
162164
loss_kwargs['loss_scale'] = loss_scale
165+
if compute_loss_func is None:
166+
compute_loss_func = get_loss_func('loss_scale')
167+
if (self.label_smoother is not None or compute_loss_func is not None) and 'labels' in inputs:
168+
labels = inputs.pop('labels')
163169

170+
base_model = self.template.get_base_model(self.model)
171+
use_logits_to_keep = self.args.use_logits_to_keep
172+
if use_logits_to_keep is None:
173+
use_logits_to_keep = 'labels' in inputs and inputs['labels'].shape[
174+
0] == 1 and 'logits_to_keep' in inspect.signature(base_model.forward).parameters
175+
logger.info_once(f'use_logits_to_keep: {use_logits_to_keep}')
176+
# padding_free or packing
177+
if use_logits_to_keep:
178+
loss_mask = (inputs['labels'] != -100)[0]
179+
inputs['labels'] = inputs['labels'][:, loss_mask]
180+
inputs['labels'] = nn.functional.pad(inputs['labels'], (1, 0), value=-100)
181+
inputs['logits_to_keep'] = nn.functional.pad(loss_mask[1:], (0, 1), value=True)
182+
if is_mp():
183+
inputs['logits_to_keep'] = inputs['logits_to_keep'].cpu()
164184
with self.template.compute_loss_context(self.model, inputs):
165185
outputs = model(**inputs)
166186
# Save past state if it exists
@@ -188,8 +208,8 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
188208
else:
189209
model_name = unwrapped_model._get_name()
190210
# User-defined compute_loss function
191-
if self.compute_loss_func is not None:
192-
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch, **loss_kwargs)
211+
if compute_loss_func is not None:
212+
loss = compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch, **loss_kwargs)
193213
elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
194214
loss = self.label_smoother(outputs, labels, shift_labels=True)
195215
else:

0 commit comments

Comments
 (0)