Skip to content

[pt/sft] Feature channel loss #4405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions examples/train/plugins/channel_loss.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# use loss_type channel_loss
# data should have 'channel' field
# eg.
# {"channel": "chat",
# "messages": [
# {"role": "system", "content": "You are a helpful assistant"},
# {"role": "user", "content": "What color do you like?"},
# {"role": "assistant", "content": "I like blue."}
# ]}
CUDA_VISIBLE_DEVICES=0 \
swift sft \
--model Qwen/Qwen2.5-0.5B-Instruct \
--dataset '/path/to/channel_dataset' \
--train_type full \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 8 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 1 \
--max_length 512 \
--output_dir output \
--system 'You are a helpful assistant.' \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--loss_type channel_loss
12 changes: 11 additions & 1 deletion swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


class RowPreprocessor:
standard_keys = ['messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects']
standard_keys = [
'messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects', 'channel'
]

def __init__(self,
*,
Expand Down Expand Up @@ -303,6 +305,9 @@ def __call__(
if 'solution' in dataset.features:
with safe_ddp_context(None, True):
dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs)
channel = None
if 'channel' in dataset.features:
channel = dataset['channel']
dataset = self._rename_columns(dataset)
dataset = self.prepare_dataset(dataset)
dataset = self._cast_pil_image(dataset)
Expand All @@ -323,6 +328,11 @@ def __call__(
if isinstance(dataset_mapped, HfDataset) and len(dataset) != len(dataset_mapped):
logger.info(
f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}')
if channel:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为什么呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为了读取数据的‘channel’字段,把它放到inputs中。
数据示例:{
"channel": "chat",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What color do you like?"},
{"role": "assistant", "content": "I like blue."}]
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的map应该是不需要的。dataset_mapped中应该已经包含channel了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
这一步会把channel删掉,是否改成把channel从remove_columns中移除?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理,见c12f566

dataset_mapped = dataset_mapped.map(
lambda example, idx: {
**example, 'channel': channel[idx]
}, with_indices=True)

return dataset_mapped

Expand Down
3 changes: 3 additions & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,10 +1372,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
else:
inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None]
input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None]
channel = [b['channel'] for b in batch if b.get('channel') is not None]
if inputs_embeds:
res['inputs_embeds'] = inputs_embeds
if input_ids:
res['input_ids'] = input_ids
if channel:
res['channel'] = channel
for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']:
val = [b[key] for b in batch if b.get(key) is not None]
if val:
Expand Down
60 changes: 60 additions & 0 deletions swift/plugin/loss.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.utils import strtobool
import torch.distributed as dist
from swift.plugin import MeanMetric


class LossType:
Expand All @@ -18,6 +20,7 @@ class LossType:
contrastive = 'contrastive'
online_contrastive = 'online_contrastive'
infonce = 'infonce'
channel_loss = 'channel_loss'


LOSS_MAPPING = {}
Expand Down Expand Up @@ -382,6 +385,63 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch
return loss


@register_loss_func(LossType.channel_loss)
def channel_loss_func(outputs,
labels,
loss_scale=None,
num_items_in_batch=None,
channels=None,
trainer=None) -> torch.Tensor:
assert channels is not None, "channels should not be None"
logits = outputs.logits
channel_cid = trainer.channel_cid
cid_channel = trainer.cid_channel
channels_tensor = torch.tensor([channel_cid[channel] for channel in channels], device=logits.device)

# compute token loss
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(reduction='none')
flat_logits = shift_logits.view(-1, shift_logits.size(-1))
flat_labels = shift_labels.view(-1)
token_loss = loss_fct(flat_logits, flat_labels).view_as(shift_labels) # [batch_size, seq_len-1]

mask = shift_labels != -100

state = trainer.state
state.local_step += 1

for cid in torch.unique(channels_tensor):
idx = (channels_tensor == cid)
ch_loss_step = token_loss[idx].masked_select(mask[idx])
state.ch_loss_steps.setdefault(cid.item(), []).append(ch_loss_step)

# At the end of a global step, compute the mean loss for each channel
if state.local_step % trainer.args.gradient_accumulation_steps == 0:
for cid, ch_name in cid_channel.items():
ch_loss_steps = state.ch_loss_steps.get(cid, [])
if ch_loss_steps:
loss_sum_tensor = torch.tensor([sum(torch.sum(x) for x in ch_loss_steps)], device=logits.device)
num_items_tensor = torch.tensor([sum(x.numel() for x in ch_loss_steps)], device=logits.device)
ch_loss = (loss_sum_tensor / num_items_tensor)
if dist.is_initialized():
gather_loss = [torch.zeros_like(ch_loss) for _ in range(dist.get_world_size())]
dist.all_gather(gather_loss, ch_loss)
ch_loss = torch.cat(gather_loss, dim=0)
ch_loss = ch_loss.mean().item()

metric_key = f'loss_{ch_name}'
trainer._custom_metrics.setdefault(metric_key, MeanMetric(nan_value=None)).update(ch_loss)
# Reset
state.ch_loss_steps[cid] = []

# return loss
total_loss = token_loss.masked_select(mask).sum()
total_tokens = mask.sum()
return total_loss / num_items_in_batch if num_items_in_batch is not None \
else total_loss / (total_tokens.float() + 1e-12)


def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]:
if loss_type is None:
return None
Expand Down
20 changes: 20 additions & 0 deletions swift/trainers/trainers.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(self, *args, **kwargs):
self.infer_engine = PtEngine.from_model_template(
self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size)
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl'))
self.channel_cid = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个写个注释吧

channel_cid和cid_channel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改为通过参数方式传入,不需要channel_cid和cid_channel了

self.cid_channel = None

@staticmethod
def _predict_data_collator(batch):
Expand Down Expand Up @@ -157,6 +159,24 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs:
labels = inputs.pop('labels')

channels = inputs.pop('channel', None)
if channels is not None:
self.channel_cid = self.channel_cid or {}
self.cid_channel = self.cid_channel or {}

for ch in set(channels):
if ch not in self.channel_cid:
cid = len(self.channel_cid)
self.channel_cid[ch] = cid
self.cid_channel[cid] = ch

state = self.state
setattr(state, 'local_step', getattr(state, 'local_step', 0))
setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {}))

loss_kwargs['channels'] = channels
loss_kwargs['trainer'] = self

loss_scale = inputs.pop('loss_scale', None)
if loss_scale is not None:
loss_kwargs['loss_scale'] = loss_scale
Expand Down
Loading