From b94950ee21089128046d03e11e47d8a935a38b1f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Sat, 24 May 2025 00:36:55 +0800 Subject: [PATCH 01/13] add channel into inputs --- swift/llm/dataset/preprocessor/core.py | 9 ++++++++- swift/llm/template/base.py | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/llm/dataset/preprocessor/core.py index 7c5760d951..e5b79f51ed 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/llm/dataset/preprocessor/core.py @@ -20,7 +20,8 @@ class RowPreprocessor: - standard_keys = ['messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects'] + standard_keys = ['messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects', + 'channel'] def __init__(self, *, @@ -303,6 +304,9 @@ def __call__( if 'solution' in dataset.features: with safe_ddp_context(None, True): dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs) + channel = None + if 'channel' in dataset.features: + channel = dataset['channel'] dataset = self._rename_columns(dataset) dataset = self.prepare_dataset(dataset) dataset = self._cast_pil_image(dataset) @@ -323,6 +327,9 @@ def __call__( if isinstance(dataset_mapped, HfDataset) and len(dataset) != len(dataset_mapped): logger.info( f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}') + if channel: + dataset_mapped = dataset_mapped.map(lambda example, idx: {**example, 'channel': channel[idx]}, + with_indices=True) return dataset_mapped diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index de78bf1ada..e1346a9d47 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1372,10 +1372,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in else: inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None] input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None] + channel = [b['channel'] for b in batch if b.get('channel') is not None] if inputs_embeds: res['inputs_embeds'] = inputs_embeds if input_ids: res['input_ids'] = input_ids + if channel: + res['channel'] = channel for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']: val = [b[key] for b in batch if b.get(key) is not None] if val: From 5744d21556e6209c55aff4136cddb0ee00c6124f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 29 May 2025 10:16:11 +0800 Subject: [PATCH 02/13] compute channel loss --- swift/plugin/loss.py | 55 ++++++++++++++++++++++++++++++++++++++ swift/trainers/trainers.py | 20 ++++++++++++++ 2 files changed, 75 insertions(+) mode change 100644 => 100755 swift/plugin/loss.py mode change 100644 => 100755 swift/trainers/trainers.py diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py old mode 100644 new mode 100755 index 6ad82a5dee..3f1e9a21c3 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -10,6 +10,8 @@ from torch import nn from torch.nn import CrossEntropyLoss, MSELoss from transformers.utils import strtobool +import torch.distributed as dist +from swift.plugin import MeanMetric class LossType: @@ -18,6 +20,7 @@ class LossType: contrastive = 'contrastive' online_contrastive = 'online_contrastive' infonce = 'infonce' + channel_loss = 'channel_loss' LOSS_MAPPING = {} @@ -382,6 +385,58 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch return loss +@register_loss_func(LossType.channel_loss) +def channel_loss_func(outputs, labels, loss_scale=None, num_items_in_batch=None, channels=None, + trainer=None) -> torch.Tensor: + logits = outputs.logits + channel_cid = trainer.channel_cid + cid_channel = trainer.cid_channel + channels_tensor = torch.tensor([channel_cid[channel] for channel in channels], device=logits.device) + + # compute token loss + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss(reduction='none') + flat_logits = shift_logits.view(-1, shift_logits.size(-1)) + flat_labels = shift_labels.view(-1) + token_loss = loss_fct(flat_logits, flat_labels).view_as(shift_labels) # [batch_size, seq_len-1] + + mask = shift_labels != -100 + + state = trainer.state + state.local_step += 1 + + for cid in torch.unique(channels_tensor): + idx = (channels_tensor == cid) + ch_loss_step = token_loss[idx].masked_select(mask[idx]) + state.ch_loss_steps.setdefault(cid.item(), []).append(ch_loss_step) + + # At the end of a global step, compute the mean loss for each channel + if state.local_step % trainer.args.gradient_accumulation_steps == 0: + for cid, ch_name in cid_channel.items(): + ch_loss_steps = state.ch_loss_steps.get(cid, []) + if ch_loss_steps: + loss_sum_tensor = torch.tensor([sum(torch.sum(x) for x in ch_loss_steps)], device=logits.device) + num_items_tensor = torch.tensor([sum(x.numel() for x in ch_loss_steps)], device=logits.device) + ch_loss = (loss_sum_tensor / num_items_tensor) + if dist.is_initialized(): + gather_loss = [torch.zeros_like(ch_loss) for _ in range(dist.get_world_size())] + dist.all_gather(gather_loss, ch_loss) + ch_loss = torch.cat(gather_loss, dim=0) + ch_loss = ch_loss.mean().item() + + metric_key = f'loss_{ch_name}' + trainer._custom_metrics.setdefault(metric_key, MeanMetric(nan_value=None)).update(ch_loss) + # Reset + state.ch_loss_steps[cid] = [] + + # return loss + total_loss = token_loss.masked_select(mask).sum() + total_tokens = mask.sum() + return total_loss / num_items_in_batch if num_items_in_batch is not None \ + else total_loss / (total_tokens.float() + 1e-12) + + def get_loss_func(loss_type: Optional[str]) -> Optional[Callable]: if loss_type is None: return None diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py old mode 100644 new mode 100755 index 24bd3e4282..732b1aa988 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -88,6 +88,8 @@ def __init__(self, *args, **kwargs): self.infer_engine = PtEngine.from_model_template( self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size) self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl')) + self.channel_cid=None + self.cid_channel=None @staticmethod def _predict_data_collator(batch): @@ -157,6 +159,24 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs: labels = inputs.pop('labels') + channels = inputs.pop('channel', None) + if channels is not None: + self.channel_cid = self.channel_cid or {} + self.cid_channel = self.cid_channel or {} + + for ch in set(channels): + if ch not in self.channel_cid: + cid = len(self.channel_cid) + self.channel_cid[ch] = cid + self.cid_channel[cid] = ch + + state = self.state + setattr(state, 'local_step', getattr(state, 'local_step', 0)) + setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {})) + + loss_kwargs['channels'] = channels + loss_kwargs['trainer'] = self + loss_scale = inputs.pop('loss_scale', None) if loss_scale is not None: loss_kwargs['loss_scale'] = loss_scale From da264eec57a0a437833a22b01eab5f68e14e4231 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 29 May 2025 10:41:15 +0800 Subject: [PATCH 03/13] channel loss example --- examples/train/plugins/channel_loss.sh | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 examples/train/plugins/channel_loss.sh diff --git a/examples/train/plugins/channel_loss.sh b/examples/train/plugins/channel_loss.sh new file mode 100644 index 0000000000..2fa0dd9996 --- /dev/null +++ b/examples/train/plugins/channel_loss.sh @@ -0,0 +1,31 @@ +# use loss_type channel_loss +# data should have 'channel' field +# eg. +# {"channel": "chat", +# "messages": [ +# {"role": "system", "content": "You are a helpful assistant"}, +# {"role": "user", "content": "What color do you like?"}, +# {"role": "assistant", "content": "I like blue."} +# ]} +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --dataset '/path/to/your_channel_dataset' \ + --train_type full \ + --dataset '/path/to/dataset' \ + --torch_dtype bfloat16 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-5 \ + --gradient_accumulation_steps 8 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 1 \ + --max_length 512 \ + --output_dir output \ + --system 'You are a helpful assistant.' \ + --warmup_ratio 0.05 \ + --dataloader_num_workers 4 \ + --loss_type channel_loss From bf8cf778b83ff5b691169416bf252f66d2425062 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 29 May 2025 11:25:03 +0800 Subject: [PATCH 04/13] code lint --- swift/llm/dataset/preprocessor/core.py | 11 +++++++---- swift/plugin/loss.py | 6 +++++- swift/trainers/trainers.py | 4 ++-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/llm/dataset/preprocessor/core.py index e5b79f51ed..c4a06d225c 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/llm/dataset/preprocessor/core.py @@ -20,8 +20,9 @@ class RowPreprocessor: - standard_keys = ['messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects', - 'channel'] + standard_keys = [ + 'messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects', 'channel' + ] def __init__(self, *, @@ -328,8 +329,10 @@ def __call__( logger.info( f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}') if channel: - dataset_mapped = dataset_mapped.map(lambda example, idx: {**example, 'channel': channel[idx]}, - with_indices=True) + dataset_mapped = dataset_mapped.map( + lambda example, idx: { + **example, 'channel': channel[idx] + }, with_indices=True) return dataset_mapped diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index 3f1e9a21c3..f2c88977fc 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -386,7 +386,11 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch @register_loss_func(LossType.channel_loss) -def channel_loss_func(outputs, labels, loss_scale=None, num_items_in_batch=None, channels=None, +def channel_loss_func(outputs, + labels, + loss_scale=None, + num_items_in_batch=None, + channels=None, trainer=None) -> torch.Tensor: logits = outputs.logits channel_cid = trainer.channel_cid diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 732b1aa988..da4736a0ff 100755 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -88,8 +88,8 @@ def __init__(self, *args, **kwargs): self.infer_engine = PtEngine.from_model_template( self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size) self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl')) - self.channel_cid=None - self.cid_channel=None + self.channel_cid = None + self.cid_channel = None @staticmethod def _predict_data_collator(batch): From c71a3f3959079dc3f7f227cc48ecb5c4a2c326e4 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 29 May 2025 11:32:20 +0800 Subject: [PATCH 05/13] fix --- swift/plugin/loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index f2c88977fc..4c102ba852 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -392,6 +392,7 @@ def channel_loss_func(outputs, num_items_in_batch=None, channels=None, trainer=None) -> torch.Tensor: + assert channels is not None, "channels should not be None" logits = outputs.logits channel_cid = trainer.channel_cid cid_channel = trainer.cid_channel From d9bcc0e1da366a26b8ecc58c4711fe83bcf24c91 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 29 May 2025 11:38:00 +0800 Subject: [PATCH 06/13] fix example --- examples/train/plugins/channel_loss.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/train/plugins/channel_loss.sh b/examples/train/plugins/channel_loss.sh index 2fa0dd9996..16aad4c674 100644 --- a/examples/train/plugins/channel_loss.sh +++ b/examples/train/plugins/channel_loss.sh @@ -9,10 +9,9 @@ # ]} CUDA_VISIBLE_DEVICES=0 \ swift sft \ - --model Qwen/Qwen2.5-7B-Instruct \ - --dataset '/path/to/your_channel_dataset' \ + --model Qwen/Qwen2.5-0.5B-Instruct \ + --dataset '/path/to/channel_dataset' \ --train_type full \ - --dataset '/path/to/dataset' \ --torch_dtype bfloat16 \ --num_train_epochs 1 \ --per_device_train_batch_size 1 \ From baabc1df255b26dafabb65b17286c2730d290300 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Thu, 29 May 2025 11:48:54 +0800 Subject: [PATCH 07/13] code lint --- swift/plugin/loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index 4c102ba852..710e1c3f90 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -5,12 +5,13 @@ import numpy as np import torch +import torch.distributed as dist import torch.nn.functional as F from accelerate.utils import gather_object from torch import nn from torch.nn import CrossEntropyLoss, MSELoss from transformers.utils import strtobool -import torch.distributed as dist + from swift.plugin import MeanMetric @@ -392,7 +393,7 @@ def channel_loss_func(outputs, num_items_in_batch=None, channels=None, trainer=None) -> torch.Tensor: - assert channels is not None, "channels should not be None" + assert channels is not None, 'channels should not be None' logits = outputs.logits channel_cid = trainer.channel_cid cid_channel = trainer.cid_channel From 379cfb9f3ae02703a5ced1a30d9710ed5b06c9f8 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 30 May 2025 11:18:43 +0800 Subject: [PATCH 08/13] fix multi-gpu --- swift/llm/argument/train_args.py | 3 +- swift/llm/train/sft.py | 1 + swift/plugin/loss.py | 52 +++++++++++++++----------------- swift/trainers/mixin.py | 2 ++ swift/trainers/trainers.py | 17 ++--------- 5 files changed, 33 insertions(+), 42 deletions(-) mode change 100755 => 100644 swift/plugin/loss.py mode change 100755 => 100644 swift/trainers/trainers.py diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 7dc30e916f..1959b919a1 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os from dataclasses import dataclass, field -from typing import Literal, Optional +from typing import List, Literal, Optional from transformers import Seq2SeqTrainingArguments from transformers.utils.versions import require_version @@ -117,6 +117,7 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'}) optimizer: Optional[str] = None metric: Optional[str] = None + channel_list: List[str] = None # extra max_new_tokens: int = 64 diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index afbd06436d..e8484914ed 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -123,6 +123,7 @@ def run(self): train_dataset=train_dataset, eval_dataset=val_dataset, callbacks=self.callbacks, + channel_list=self.args.channel_list, template=self.template, **self._get_trainer_kwargs(), ) diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py old mode 100755 new mode 100644 index 710e1c3f90..c945145146 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -387,17 +387,11 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch @register_loss_func(LossType.channel_loss) -def channel_loss_func(outputs, - labels, - loss_scale=None, - num_items_in_batch=None, - channels=None, - trainer=None) -> torch.Tensor: - assert channels is not None, 'channels should not be None' +def channel_loss_func(outputs, labels, num_items_in_batch=None, sample_channels=None, trainer=None) -> torch.Tensor: + assert sample_channels is not None, 'Data does not have channel field.' + channel_list = trainer.channel_list + assert channel_list is not None, 'Please pass --channel_list as a hyperparameter.' logits = outputs.logits - channel_cid = trainer.channel_cid - cid_channel = trainer.cid_channel - channels_tensor = torch.tensor([channel_cid[channel] for channel in channels], device=logits.device) # compute token loss shift_logits = logits[..., :-1, :].contiguous() @@ -412,29 +406,33 @@ def channel_loss_func(outputs, state = trainer.state state.local_step += 1 - for cid in torch.unique(channels_tensor): - idx = (channels_tensor == cid) + for ch in set(sample_channels): + idx = torch.tensor([s == ch for s in sample_channels], device=logits.device) ch_loss_step = token_loss[idx].masked_select(mask[idx]) - state.ch_loss_steps.setdefault(cid.item(), []).append(ch_loss_step) + state.ch_loss_steps.setdefault(ch, []).append(ch_loss_step) # At the end of a global step, compute the mean loss for each channel if state.local_step % trainer.args.gradient_accumulation_steps == 0: - for cid, ch_name in cid_channel.items(): - ch_loss_steps = state.ch_loss_steps.get(cid, []) - if ch_loss_steps: - loss_sum_tensor = torch.tensor([sum(torch.sum(x) for x in ch_loss_steps)], device=logits.device) - num_items_tensor = torch.tensor([sum(x.numel() for x in ch_loss_steps)], device=logits.device) - ch_loss = (loss_sum_tensor / num_items_tensor) - if dist.is_initialized(): - gather_loss = [torch.zeros_like(ch_loss) for _ in range(dist.get_world_size())] - dist.all_gather(gather_loss, ch_loss) - ch_loss = torch.cat(gather_loss, dim=0) - ch_loss = ch_loss.mean().item() - - metric_key = f'loss_{ch_name}' + for ch in channel_list: + ch_loss_steps = state.ch_loss_steps.get(ch, []) + loss_sum_tensor = torch.tensor([sum(torch.sum(x) for x in ch_loss_steps)], + dtype=torch.float32, + device=logits.device) + num_items_tensor = torch.tensor([sum(x.numel() for x in ch_loss_steps)], + dtype=torch.float32, + device=logits.device) + if dist.is_initialized(): + dist.all_reduce(loss_sum_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(num_items_tensor, op=dist.ReduceOp.SUM) + loss_sum = loss_sum_tensor.item() + num_items = num_items_tensor.item() + ch_loss = loss_sum / (num_items + 1e-12) + + if ch_loss > 0.0: + metric_key = f'loss_{ch}' trainer._custom_metrics.setdefault(metric_key, MeanMetric(nan_value=None)).update(ch_loss) # Reset - state.ch_loss_steps[cid] = [] + state.ch_loss_steps[ch] = [] # return loss total_loss = token_loss.masked_select(mask).sum() diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 8c5746b084..805a698aea 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -62,6 +62,7 @@ def __init__(self, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + channel_list: Optional[List[str]] = None, **kwargs) -> None: if not hasattr(train_dataset, '__len__') and args.dataloader_num_workers > 1: args.dataloader_num_workers = 1 @@ -79,6 +80,7 @@ def __init__(self, args.evaluation_strategy = IntervalStrategy.NO args.eval_strategy = IntervalStrategy.NO + self.channel_list = channel_list self._custom_metrics = {} self.template = template self.max_memory = 0 diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py old mode 100755 new mode 100644 index da4736a0ff..f2290770b1 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -88,8 +88,6 @@ def __init__(self, *args, **kwargs): self.infer_engine = PtEngine.from_model_template( self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size) self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl')) - self.channel_cid = None - self.cid_channel = None @staticmethod def _predict_data_collator(batch): @@ -159,22 +157,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs: labels = inputs.pop('labels') - channels = inputs.pop('channel', None) - if channels is not None: - self.channel_cid = self.channel_cid or {} - self.cid_channel = self.cid_channel or {} - - for ch in set(channels): - if ch not in self.channel_cid: - cid = len(self.channel_cid) - self.channel_cid[ch] = cid - self.cid_channel[cid] = ch - + sample_channels = inputs.pop('channel', None) + if sample_channels is not None: state = self.state setattr(state, 'local_step', getattr(state, 'local_step', 0)) setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {})) - loss_kwargs['channels'] = channels + loss_kwargs['sample_channels'] = sample_channels loss_kwargs['trainer'] = self loss_scale = inputs.pop('loss_scale', None) From ce349bac4d2e22356df15f3c56baa3bfddd0dd88 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 30 May 2025 11:39:37 +0800 Subject: [PATCH 09/13] trainer --- swift/trainers/trainers.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 1a91f5b7ea..0c45944b0c 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -164,6 +164,16 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss_kwargs['loss_scale'] = loss_scale if compute_loss_func is None: compute_loss_func = get_loss_func('loss_scale') + + sample_channels = inputs.pop('channel', None) + if sample_channels is not None: + state = self.state + setattr(state, 'local_step', getattr(state, 'local_step', 0)) + setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {})) + + loss_kwargs['sample_channels'] = sample_channels + loss_kwargs['trainer'] = self + if (self.label_smoother is not None or compute_loss_func is not None) and 'labels' in inputs: labels = inputs.pop('labels') From 4640b5305caef6060a3bbd833c5fb233eacc6296 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 30 May 2025 12:23:08 +0800 Subject: [PATCH 10/13] update documents --- ...21\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 1 + docs/source_en/Instruction/Command-line-parameters.md | 1 + examples/train/plugins/channel_loss.sh | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index b61c46f5fe..a57fbd8147 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -351,6 +351,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数. - check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。如果是断网环境,请设置为False。 - 🔥create_checkpoint_symlink: 额外创建checkpoint软链接,方便书写自动化训练脚本。best_model和last_model的软链接路径分别为f'{output_dir}/best'和f'{output_dir}/last'。 - loss_type: loss类型。默认为None,使用模型自带损失函数。 +- channel_list : 数据集包含的channel列表。默认为None。结合`--loss_type channel_loss`使用,可参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh)。 - 🔥packing: 是否使用序列packing提升计算效率,默认为False。当前支持`swift pt/sft`。 - 注意:使用packing请结合`--attn_impl flash_attn`使用且"transformers>=4.44",具体查看[该PR](https://github.com/huggingface/transformers/pull/31629)。 - 支持的多模态模型参考:https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 0a66e82812..53402b57cc 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -360,6 +360,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine - check_model: Check local model files for corruption or modification and give a prompt, default is True. If in an offline environment, please set to False. - 🔥create_checkpoint_symlink: Creates additional checkpoint symlinks to facilitate writing automated training scripts. The symlink paths for `best_model` and `last_model` are `f'{output_dir}/best'` and `f'{output_dir}/last'` respectively. - loss_type: Type of loss. Defaults to None, which uses the model's built-in loss function. +- channel_list:List of channels included in the dataset. Defaults to None. Used in conjunction with `--loss_type channel_loss`. Refer to [this example](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh) for more details. - 🔥packing: Whether to use sequence packing to improve computational efficiency. The default value is False. Currently supports `swift pt/sft`. - Note: When using packing, please combine it with `--attn_impl flash_attn` and ensure "transformers>=4.44". For details, see [this PR](https://github.com/huggingface/transformers/pull/31629). - Supported multimodal models reference: https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh diff --git a/examples/train/plugins/channel_loss.sh b/examples/train/plugins/channel_loss.sh index 16aad4c674..a7c9b5f568 100644 --- a/examples/train/plugins/channel_loss.sh +++ b/examples/train/plugins/channel_loss.sh @@ -1,4 +1,5 @@ # use loss_type channel_loss +# channel_list specifies the channels included in the dataset # data should have 'channel' field # eg. # {"channel": "chat", @@ -27,4 +28,5 @@ swift sft \ --system 'You are a helpful assistant.' \ --warmup_ratio 0.05 \ --dataloader_num_workers 4 \ - --loss_type channel_loss + --loss_type channel_loss \ + --channel_list 'chat' 'math' 'code' From 48b4aecb6684ac938daf69a40afea7f5cfb4f7d0 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 4 Jun 2025 14:56:43 +0800 Subject: [PATCH 11/13] rename channel_list to channels --- ...\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 2 +- docs/source_en/Instruction/Command-line-parameters.md | 2 +- examples/train/plugins/channel_loss.sh | 4 ++-- swift/llm/argument/train_args.py | 1 - swift/llm/train/sft.py | 1 - swift/plugin/loss.py | 6 +++--- swift/trainers/arguments.py | 1 + swift/trainers/mixin.py | 2 -- 8 files changed, 8 insertions(+), 11 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index a57fbd8147..79299eafae 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -351,7 +351,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数. - check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。如果是断网环境,请设置为False。 - 🔥create_checkpoint_symlink: 额外创建checkpoint软链接,方便书写自动化训练脚本。best_model和last_model的软链接路径分别为f'{output_dir}/best'和f'{output_dir}/last'。 - loss_type: loss类型。默认为None,使用模型自带损失函数。 -- channel_list : 数据集包含的channel列表。默认为None。结合`--loss_type channel_loss`使用,可参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh)。 +- channels : 数据集包含的channel集合。默认为None。结合`--loss_type channel_loss`使用,可参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh)。 - 🔥packing: 是否使用序列packing提升计算效率,默认为False。当前支持`swift pt/sft`。 - 注意:使用packing请结合`--attn_impl flash_attn`使用且"transformers>=4.44",具体查看[该PR](https://github.com/huggingface/transformers/pull/31629)。 - 支持的多模态模型参考:https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 53402b57cc..1089185175 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -360,7 +360,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine - check_model: Check local model files for corruption or modification and give a prompt, default is True. If in an offline environment, please set to False. - 🔥create_checkpoint_symlink: Creates additional checkpoint symlinks to facilitate writing automated training scripts. The symlink paths for `best_model` and `last_model` are `f'{output_dir}/best'` and `f'{output_dir}/last'` respectively. - loss_type: Type of loss. Defaults to None, which uses the model's built-in loss function. -- channel_list:List of channels included in the dataset. Defaults to None. Used in conjunction with `--loss_type channel_loss`. Refer to [this example](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh) for more details. +- channels:Set of channels included in the dataset. Defaults to None. Used in conjunction with `--loss_type channel_loss`. Refer to [this example](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh) for more details. - 🔥packing: Whether to use sequence packing to improve computational efficiency. The default value is False. Currently supports `swift pt/sft`. - Note: When using packing, please combine it with `--attn_impl flash_attn` and ensure "transformers>=4.44". For details, see [this PR](https://github.com/huggingface/transformers/pull/31629). - Supported multimodal models reference: https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh diff --git a/examples/train/plugins/channel_loss.sh b/examples/train/plugins/channel_loss.sh index a7c9b5f568..09d9b29641 100644 --- a/examples/train/plugins/channel_loss.sh +++ b/examples/train/plugins/channel_loss.sh @@ -1,5 +1,5 @@ # use loss_type channel_loss -# channel_list specifies the channels included in the dataset +# channels specifies the channels included in the dataset # data should have 'channel' field # eg. # {"channel": "chat", @@ -29,4 +29,4 @@ swift sft \ --warmup_ratio 0.05 \ --dataloader_num_workers 4 \ --loss_type channel_loss \ - --channel_list 'chat' 'math' 'code' + --channels 'chat' 'math' 'code' diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 9a8c13eb46..e32e96e1f9 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -113,7 +113,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra # plugin loss_type: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'}) metric: Optional[str] = None - channel_list: List[str] = None # extra max_new_tokens: int = 64 diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index e9c767a3c0..4d528ab1ff 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -116,7 +116,6 @@ def run(self): train_dataset=train_dataset, eval_dataset=val_dataset, callbacks=self.callbacks, - channel_list=self.args.channel_list, template=self.template, **self._get_trainer_kwargs(), ) diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index c945145146..3c7d3d1b70 100644 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -389,8 +389,8 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch @register_loss_func(LossType.channel_loss) def channel_loss_func(outputs, labels, num_items_in_batch=None, sample_channels=None, trainer=None) -> torch.Tensor: assert sample_channels is not None, 'Data does not have channel field.' - channel_list = trainer.channel_list - assert channel_list is not None, 'Please pass --channel_list as a hyperparameter.' + channels = trainer.args.channels + assert channels is not None, 'Please pass --channels as a hyperparameter.' logits = outputs.logits # compute token loss @@ -413,7 +413,7 @@ def channel_loss_func(outputs, labels, num_items_in_batch=None, sample_channels= # At the end of a global step, compute the mean loss for each channel if state.local_step % trainer.args.gradient_accumulation_steps == 0: - for ch in channel_list: + for ch in channels: ch_loss_steps = state.ch_loss_steps.get(ch, []) loss_sum_tensor = torch.tensor([sum(torch.sum(x) for x in ch_loss_steps)], dtype=torch.float32, diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index ee1e4df47b..2a51cc4042 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -51,6 +51,7 @@ class TrainArgumentsMixin: vit_lr: Optional[float] = None optimizer: Optional[str] = None use_logits_to_keep: Optional[bool] = None + channels: List[str] = None # torchacc metric_warmup_step: Optional[float] = 0 diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 805a698aea..8c5746b084 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -62,7 +62,6 @@ def __init__(self, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - channel_list: Optional[List[str]] = None, **kwargs) -> None: if not hasattr(train_dataset, '__len__') and args.dataloader_num_workers > 1: args.dataloader_num_workers = 1 @@ -80,7 +79,6 @@ def __init__(self, args.evaluation_strategy = IntervalStrategy.NO args.eval_strategy = IntervalStrategy.NO - self.channel_list = channel_list self._custom_metrics = {} self.template = template self.max_memory = 0 From c12f56666b46597355f149febb8e152f349806e3 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 4 Jun 2025 14:57:05 +0800 Subject: [PATCH 12/13] fix map channel --- swift/llm/dataset/preprocessor/core.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/llm/dataset/preprocessor/core.py index c4a06d225c..f259857ef9 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/llm/dataset/preprocessor/core.py @@ -305,14 +305,12 @@ def __call__( if 'solution' in dataset.features: with safe_ddp_context(None, True): dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs) - channel = None - if 'channel' in dataset.features: - channel = dataset['channel'] dataset = self._rename_columns(dataset) dataset = self.prepare_dataset(dataset) dataset = self._cast_pil_image(dataset) ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False + keep_columns = ['channel'] with self._patch_arrow_writer(), safe_ddp_context(None, True): try: dataset_mapped = dataset.map( @@ -321,18 +319,13 @@ def __call__( 'strict': strict, 'ignore_max_length_error': ignore_max_length_error }, - remove_columns=list(dataset.features.keys()), + remove_columns=[col for col in dataset.features.keys() if col not in keep_columns], **map_kwargs) except NotImplementedError: pass if isinstance(dataset_mapped, HfDataset) and len(dataset) != len(dataset_mapped): logger.info( f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}') - if channel: - dataset_mapped = dataset_mapped.map( - lambda example, idx: { - **example, 'channel': channel[idx] - }, with_indices=True) return dataset_mapped From a07bb787f51a9fafa58a074c26a9d199955d4137 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Wed, 4 Jun 2025 15:03:23 +0800 Subject: [PATCH 13/13] remove unused import --- swift/llm/argument/train_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index e32e96e1f9..35b7b3a88a 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os from dataclasses import dataclass, field -from typing import List, Literal, Optional +from typing import Literal, Optional from transformers import Seq2SeqTrainingArguments from transformers.utils.versions import require_version