Skip to content

[loss] fix vlm channel loss #4497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
- 🔥report_to: 默认值为`tensorboard`。你也可以指定`--report_to tensorboard wandb swanlab`、`--report_to all`。
- logging_first_step: 是否记录第一个step的日志,默认为True。
- logging_steps: 日志打印间隔,默认为5。
- logging_dir: tensorboard日志路径。默认为None,即设置为`f'{self.output_dir}/runs'`。
- predict_with_generate: 验证时使用生成式的方式,默认为False。
- metric_for_best_model: 默认为None,即当`predict_with_generate`设置为False时,设置为'loss',否则设置为'rouge-l'(在PPO训练时,不进行默认值设置;GRPO训练设置为'reward')。
- greater_is_better: 默认为None,即当`metric_for_best_model`含'loss'时,设置为False,否则设置为True。
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ This parameter list inherits from transformers `Seq2SeqTrainingArguments`, with
- 🔥report_to: Default value is `tensorboard`. You can also specify `--report_to tensorboard wandb swanlab` or `--report_to all`.
- logging_first_step: Whether to log the first step, defaults to True.
- logging_steps: Interval for logging, defaults to 5.
- logging_dir: The path for TensorBoard logs. Defaults to None, which means it is set to `f'{self.output_dir}/runs'`.
- predict_with_generate: Whether to use generative method during validation, default is False.
- metric_for_best_model: Default is None, which means that when predict_with_generate is set to False, it is set to 'loss'; otherwise, it is set to 'rouge-l' (during PPO training, the default value is not set; in GRPO training, it is set to 'reward').
- greater_is_better: Defaults to None, which sets it to False when `metric_for_best_model` contains 'loss', otherwise sets to True.
Expand Down
3 changes: 1 addition & 2 deletions swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ def __call__(
dataset = self._cast_pil_image(dataset)

ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False
keep_columns = ['channel']
with self._patch_arrow_writer(), safe_ddp_context(None, True):
try:
dataset_mapped = dataset.map(
Expand All @@ -319,7 +318,7 @@ def __call__(
'strict': strict,
'ignore_max_length_error': ignore_max_length_error
},
remove_columns=[col for col in dataset.features.keys() if col not in keep_columns],
remove_columns=list(dataset.features.keys()),
**map_kwargs)
except NotImplementedError:
pass
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ def encode(self,
encoded = self._kto_encode(inputs)
elif self.mode == 'embedding':
encoded = self._embedding_encode(inputs)
if inputs.channel is not None:
encoded['channel'] = inputs.channel
for key in list(encoded.keys()):
if encoded[key] is None:
encoded.pop(key)
Expand Down
3 changes: 2 additions & 1 deletion swift/llm/template/template_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class StdTemplateInputs:

rejected_response: Optional[str] = None
label: Optional[int] = None
channel: Optional[str] = None

images: List[Union[str, Image.Image]] = field(default_factory=list)
audios: List[str] = field(default_factory=list)
Expand Down Expand Up @@ -133,7 +134,7 @@ def is_multimodal(self):
@classmethod
def from_dict(cls, inputs: Dict[str, Any]) -> 'StdTemplateInputs':
kwargs = {}
for key in ['rejected_response', 'label']:
for key in ['rejected_response', 'label', 'channel']:
if key in inputs:
kwargs[key] = inputs[key]
messages = inputs['messages']
Expand Down
2 changes: 1 addition & 1 deletion swift/plugin/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,9 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch

@register_loss_func(LossType.channel_loss)
def channel_loss_func(outputs, labels, num_items_in_batch=None, sample_channels=None, trainer=None) -> torch.Tensor:
assert sample_channels is not None, 'Data does not have channel field.'
channels = trainer.args.channels
assert channels is not None, 'Please pass --channels as a hyperparameter.'
assert sample_channels is not None, 'Data does not have channel field.'
logits = outputs.logits

# compute token loss
Expand Down
2 changes: 1 addition & 1 deletion swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
compute_loss_func = get_loss_func('loss_scale')

sample_channels = inputs.pop('channel', None)
if sample_channels is not None:
if sample_channels is not None and self.args.channels is not None:
state = self.state
setattr(state, 'local_step', getattr(state, 'local_step', 0))
setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {}))
Expand Down
17 changes: 17 additions & 0 deletions tests/train/test_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def test_channel():
from swift.llm import sft_main, TrainArguments
sft_main(
TrainArguments(
model='Qwen/Qwen2.5-VL-7B-Instruct',
dataset=['channel.jsonl#1000'],
channels=['aaa', 'abc'],
loss_type='channel_loss'))


if __name__ == '__main__':
test_channel()
Loading