Skip to content

[pt/sft] Feature channel loss #4405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
- check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。如果是断网环境,请设置为False。
- 🔥create_checkpoint_symlink: 额外创建checkpoint软链接,方便书写自动化训练脚本。best_model和last_model的软链接路径分别为f'{output_dir}/best'和f'{output_dir}/last'。
- loss_type: loss类型。默认为None,使用模型自带损失函数。
- channels : 数据集包含的channel集合。默认为None。结合`--loss_type channel_loss`使用,可参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh)。
- 🔥packing: 是否使用序列packing提升计算效率,默认为False。当前支持`swift pt/sft`。
- 注意:使用packing请结合`--attn_impl flash_attn`使用且"transformers>=4.44",具体查看[该PR](https://github.com/huggingface/transformers/pull/31629)。
- 支持的多模态模型参考:https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine
- check_model: Check local model files for corruption or modification and give a prompt, default is True. If in an offline environment, please set to False.
- 🔥create_checkpoint_symlink: Creates additional checkpoint symlinks to facilitate writing automated training scripts. The symlink paths for `best_model` and `last_model` are `f'{output_dir}/best'` and `f'{output_dir}/last'` respectively.
- loss_type: Type of loss. Defaults to None, which uses the model's built-in loss function.
- channels:Set of channels included in the dataset. Defaults to None. Used in conjunction with `--loss_type channel_loss`. Refer to [this example](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh) for more details.
- 🔥packing: Whether to use sequence packing to improve computational efficiency. The default value is False. Currently supports `swift pt/sft`.
- Note: When using packing, please combine it with `--attn_impl flash_attn` and ensure "transformers>=4.44". For details, see [this PR](https://github.com/huggingface/transformers/pull/31629).
- Supported multimodal models reference: https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh
Expand Down
32 changes: 32 additions & 0 deletions examples/train/plugins/channel_loss.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# use loss_type channel_loss
# channels specifies the channels included in the dataset
# data should have 'channel' field
# eg.
# {"channel": "chat",
# "messages": [
# {"role": "system", "content": "You are a helpful assistant"},
# {"role": "user", "content": "What color do you like?"},
# {"role": "assistant", "content": "I like blue."}
# ]}
CUDA_VISIBLE_DEVICES=0 \
swift sft \
--model Qwen/Qwen2.5-0.5B-Instruct \
--dataset '/path/to/channel_dataset' \
--train_type full \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 8 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 1 \
--max_length 512 \
--output_dir output \
--system 'You are a helpful assistant.' \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--loss_type channel_loss \
--channels 'chat' 'math' 'code'
7 changes: 5 additions & 2 deletions swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


class RowPreprocessor:
standard_keys = ['messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects']
standard_keys = [
'messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects', 'channel'
]

def __init__(self,
*,
Expand Down Expand Up @@ -308,6 +310,7 @@ def __call__(
dataset = self._cast_pil_image(dataset)

ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False
keep_columns = ['channel']
with self._patch_arrow_writer(), safe_ddp_context(None, True):
try:
dataset_mapped = dataset.map(
Expand All @@ -316,7 +319,7 @@ def __call__(
'strict': strict,
'ignore_max_length_error': ignore_max_length_error
},
remove_columns=list(dataset.features.keys()),
remove_columns=[col for col in dataset.features.keys() if col not in keep_columns],
**map_kwargs)
except NotImplementedError:
pass
Expand Down
3 changes: 3 additions & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,10 +1372,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
else:
inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None]
input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None]
channel = [b['channel'] for b in batch if b.get('channel') is not None]
if inputs_embeds:
res['inputs_embeds'] = inputs_embeds
if input_ids:
res['input_ids'] = input_ids
if channel:
res['channel'] = channel
for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']:
val = [b[key] for b in batch if b.get(key) is not None]
if val:
Expand Down
59 changes: 59 additions & 0 deletions swift/plugin/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from accelerate.utils import gather_object
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.utils import strtobool

from swift.plugin import MeanMetric


class LossType:
loss_scale = 'loss_scale'
cosine_similarity = 'cosine_similarity'
contrastive = 'contrastive'
online_contrastive = 'online_contrastive'
infonce = 'infonce'
channel_loss = 'channel_loss'


LOSS_MAPPING = {}
Expand Down Expand Up @@ -382,6 +386,61 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch
return loss


@register_loss_func(LossType.channel_loss)
def channel_loss_func(outputs, labels, num_items_in_batch=None, sample_channels=None, trainer=None) -> torch.Tensor:
assert sample_channels is not None, 'Data does not have channel field.'
channels = trainer.args.channels
assert channels is not None, 'Please pass --channels as a hyperparameter.'
logits = outputs.logits

# compute token loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(reduction='none')
flat_logits = shift_logits.view(-1, shift_logits.size(-1))
flat_labels = shift_labels.view(-1)
token_loss = loss_fct(flat_logits, flat_labels).view_as(shift_labels) # [batch_size, seq_len-1]

mask = shift_labels != -100

state = trainer.state
state.local_step += 1

for ch in set(sample_channels):
idx = torch.tensor([s == ch for s in sample_channels], device=logits.device)
ch_loss_step = token_loss[idx].masked_select(mask[idx])
state.ch_loss_steps.setdefault(ch, []).append(ch_loss_step)

# At the end of a global step, compute the mean loss for each channel
if state.local_step % trainer.args.gradient_accumulation_steps == 0:
for ch in channels:
ch_loss_steps = state.ch_loss_steps.get(ch, [])
loss_sum_tensor = torch.tensor([sum(torch.sum(x) for x in ch_loss_steps)],
dtype=torch.float32,
device=logits.device)
num_items_tensor = torch.tensor([sum(x.numel() for x in ch_loss_steps)],
dtype=torch.float32,
device=logits.device)
if dist.is_initialized():
dist.all_reduce(loss_sum_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(num_items_tensor, op=dist.ReduceOp.SUM)
loss_sum = loss_sum_tensor.item()
num_items = num_items_tensor.item()
ch_loss = loss_sum / (num_items + 1e-12)

if ch_loss > 0.0:
metric_key = f'loss_{ch}'
trainer._custom_metrics.setdefault(metric_key, MeanMetric(nan_value=None)).update(ch_loss)
# Reset
state.ch_loss_steps[ch] = []

# return loss
total_loss = token_loss.masked_select(mask).sum()
total_tokens = mask.sum()
return total_loss / num_items_in_batch if num_items_in_batch is not None \
else total_loss / (total_tokens.float() + 1e-12)


def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]:
if loss_type is None:
return None
Expand Down
1 change: 1 addition & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TrainArgumentsMixin:
vit_lr: Optional[float] = None
optimizer: Optional[str] = None
use_logits_to_keep: Optional[bool] = None
channels: List[str] = None

# torchacc
metric_warmup_step: Optional[float] = 0
Expand Down
10 changes: 10 additions & 0 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
loss_kwargs['loss_scale'] = loss_scale
if compute_loss_func is None:
compute_loss_func = get_loss_func('loss_scale')

sample_channels = inputs.pop('channel', None)
if sample_channels is not None:
state = self.state
setattr(state, 'local_step', getattr(state, 'local_step', 0))
setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {}))

loss_kwargs['sample_channels'] = sample_channels
loss_kwargs['trainer'] = self

if (self.label_smoother is not None or compute_loss_func is not None) and 'labels' in inputs:
labels = inputs.pop('labels')

Expand Down