Skip to content

Commit 40ccffc

Browse files
authored
[pt/sft] Feature channel loss (#4405)
1 parent 691c3d4 commit 40ccffc

File tree

8 files changed

+112
-2
lines changed

8 files changed

+112
-2
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
351351
- check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。如果是断网环境,请设置为False。
352352
- 🔥create_checkpoint_symlink: 额外创建checkpoint软链接,方便书写自动化训练脚本。best_model和last_model的软链接路径分别为f'{output_dir}/best'和f'{output_dir}/last'。
353353
- loss_type: loss类型。默认为None,使用模型自带损失函数。
354+
- channels : 数据集包含的channel集合。默认为None。结合`--loss_type channel_loss`使用,可参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh)
354355
- 🔥packing: 是否使用序列packing提升计算效率,默认为False。当前支持`swift pt/sft`
355356
- 注意:使用packing请结合`--attn_impl flash_attn`使用且"transformers>=4.44",具体查看[该PR](https://github.com/huggingface/transformers/pull/31629)
356357
- 支持的多模态模型参考:https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine
360360
- check_model: Check local model files for corruption or modification and give a prompt, default is True. If in an offline environment, please set to False.
361361
- 🔥create_checkpoint_symlink: Creates additional checkpoint symlinks to facilitate writing automated training scripts. The symlink paths for `best_model` and `last_model` are `f'{output_dir}/best'` and `f'{output_dir}/last'` respectively.
362362
- loss_type: Type of loss. Defaults to None, which uses the model's built-in loss function.
363+
- channels:Set of channels included in the dataset. Defaults to None. Used in conjunction with `--loss_type channel_loss`. Refer to [this example](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh) for more details.
363364
- 🔥packing: Whether to use sequence packing to improve computational efficiency. The default value is False. Currently supports `swift pt/sft`.
364365
- Note: When using packing, please combine it with `--attn_impl flash_attn` and ensure "transformers>=4.44". For details, see [this PR](https://github.com/huggingface/transformers/pull/31629).
365366
- Supported multimodal models reference: https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# use loss_type channel_loss
2+
# channels specifies the channels included in the dataset
3+
# data should have 'channel' field
4+
# eg.
5+
# {"channel": "chat",
6+
# "messages": [
7+
# {"role": "system", "content": "You are a helpful assistant"},
8+
# {"role": "user", "content": "What color do you like?"},
9+
# {"role": "assistant", "content": "I like blue."}
10+
# ]}
11+
CUDA_VISIBLE_DEVICES=0 \
12+
swift sft \
13+
--model Qwen/Qwen2.5-0.5B-Instruct \
14+
--dataset '/path/to/channel_dataset' \
15+
--train_type full \
16+
--torch_dtype bfloat16 \
17+
--num_train_epochs 1 \
18+
--per_device_train_batch_size 1 \
19+
--per_device_eval_batch_size 1 \
20+
--learning_rate 1e-5 \
21+
--gradient_accumulation_steps 8 \
22+
--eval_steps 100 \
23+
--save_steps 100 \
24+
--save_total_limit 2 \
25+
--logging_steps 1 \
26+
--max_length 512 \
27+
--output_dir output \
28+
--system 'You are a helpful assistant.' \
29+
--warmup_ratio 0.05 \
30+
--dataloader_num_workers 4 \
31+
--loss_type channel_loss \
32+
--channels 'chat' 'math' 'code'

swift/llm/dataset/preprocessor/core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121

2222
class RowPreprocessor:
23-
standard_keys = ['messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects']
23+
standard_keys = [
24+
'messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects', 'channel'
25+
]
2426

2527
def __init__(self,
2628
*,
@@ -308,6 +310,7 @@ def __call__(
308310
dataset = self._cast_pil_image(dataset)
309311

310312
ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False
313+
keep_columns = ['channel']
311314
with self._patch_arrow_writer(), safe_ddp_context(None, True):
312315
try:
313316
dataset_mapped = dataset.map(
@@ -316,7 +319,7 @@ def __call__(
316319
'strict': strict,
317320
'ignore_max_length_error': ignore_max_length_error
318321
},
319-
remove_columns=list(dataset.features.keys()),
322+
remove_columns=[col for col in dataset.features.keys() if col not in keep_columns],
320323
**map_kwargs)
321324
except NotImplementedError:
322325
pass

swift/llm/template/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,10 +1374,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
13741374
else:
13751375
inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None]
13761376
input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None]
1377+
channel = [b['channel'] for b in batch if b.get('channel') is not None]
13771378
if inputs_embeds:
13781379
res['inputs_embeds'] = inputs_embeds
13791380
if input_ids:
13801381
res['input_ids'] = input_ids
1382+
if channel:
1383+
res['channel'] = channel
13811384
for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']:
13821385
val = [b[key] for b in batch if b.get(key) is not None]
13831386
if val:

swift/plugin/loss.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55

66
import numpy as np
77
import torch
8+
import torch.distributed as dist
89
import torch.nn.functional as F
910
from accelerate.utils import gather_object
1011
from torch import nn
1112
from torch.nn import CrossEntropyLoss, MSELoss
1213
from transformers.utils import strtobool
1314

15+
from swift.plugin import MeanMetric
16+
1417

1518
class LossType:
1619
loss_scale = 'loss_scale'
1720
cosine_similarity = 'cosine_similarity'
1821
contrastive = 'contrastive'
1922
online_contrastive = 'online_contrastive'
2023
infonce = 'infonce'
24+
channel_loss = 'channel_loss'
2125

2226

2327
LOSS_MAPPING = {}
@@ -382,6 +386,61 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch
382386
return loss
383387

384388

389+
@register_loss_func(LossType.channel_loss)
390+
def channel_loss_func(outputs, labels, num_items_in_batch=None, sample_channels=None, trainer=None) -> torch.Tensor:
391+
assert sample_channels is not None, 'Data does not have channel field.'
392+
channels = trainer.args.channels
393+
assert channels is not None, 'Please pass --channels as a hyperparameter.'
394+
logits = outputs.logits
395+
396+
# compute token loss
397+
shift_logits = logits[..., :-1, :].contiguous()
398+
shift_labels = labels[..., 1:].contiguous()
399+
loss_fct = nn.CrossEntropyLoss(reduction='none')
400+
flat_logits = shift_logits.view(-1, shift_logits.size(-1))
401+
flat_labels = shift_labels.view(-1)
402+
token_loss = loss_fct(flat_logits, flat_labels).view_as(shift_labels) # [batch_size, seq_len-1]
403+
404+
mask = shift_labels != -100
405+
406+
state = trainer.state
407+
state.local_step += 1
408+
409+
for ch in set(sample_channels):
410+
idx = torch.tensor([s == ch for s in sample_channels], device=logits.device)
411+
ch_loss_step = token_loss[idx].masked_select(mask[idx])
412+
state.ch_loss_steps.setdefault(ch, []).append(ch_loss_step)
413+
414+
# At the end of a global step, compute the mean loss for each channel
415+
if state.local_step % trainer.args.gradient_accumulation_steps == 0:
416+
for ch in channels:
417+
ch_loss_steps = state.ch_loss_steps.get(ch, [])
418+
loss_sum_tensor = torch.tensor([sum(torch.sum(x) for x in ch_loss_steps)],
419+
dtype=torch.float32,
420+
device=logits.device)
421+
num_items_tensor = torch.tensor([sum(x.numel() for x in ch_loss_steps)],
422+
dtype=torch.float32,
423+
device=logits.device)
424+
if dist.is_initialized():
425+
dist.all_reduce(loss_sum_tensor, op=dist.ReduceOp.SUM)
426+
dist.all_reduce(num_items_tensor, op=dist.ReduceOp.SUM)
427+
loss_sum = loss_sum_tensor.item()
428+
num_items = num_items_tensor.item()
429+
ch_loss = loss_sum / (num_items + 1e-12)
430+
431+
if ch_loss > 0.0:
432+
metric_key = f'loss_{ch}'
433+
trainer._custom_metrics.setdefault(metric_key, MeanMetric(nan_value=None)).update(ch_loss)
434+
# Reset
435+
state.ch_loss_steps[ch] = []
436+
437+
# return loss
438+
total_loss = token_loss.masked_select(mask).sum()
439+
total_tokens = mask.sum()
440+
return total_loss / num_items_in_batch if num_items_in_batch is not None \
441+
else total_loss / (total_tokens.float() + 1e-12)
442+
443+
385444
def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]:
386445
if loss_type is None:
387446
return None

swift/trainers/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class TrainArgumentsMixin:
5151
vit_lr: Optional[float] = None
5252
optimizer: Optional[str] = None
5353
use_logits_to_keep: Optional[bool] = None
54+
channels: List[str] = None
5455

5556
# torchacc
5657
metric_warmup_step: Optional[float] = 0

swift/trainers/trainers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
164164
loss_kwargs['loss_scale'] = loss_scale
165165
if compute_loss_func is None:
166166
compute_loss_func = get_loss_func('loss_scale')
167+
168+
sample_channels = inputs.pop('channel', None)
169+
if sample_channels is not None:
170+
state = self.state
171+
setattr(state, 'local_step', getattr(state, 'local_step', 0))
172+
setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {}))
173+
174+
loss_kwargs['sample_channels'] = sample_channels
175+
loss_kwargs['trainer'] = self
176+
167177
if (self.label_smoother is not None or compute_loss_func is not None) and 'labels' in inputs:
168178
labels = inputs.pop('labels')
169179

0 commit comments

Comments
 (0)