Skip to content

Commit 9c9e960

Browse files
authored
[loss] fix vlm channel loss (#4497)
1 parent 9120a17 commit 9c9e960

File tree

8 files changed

+26
-5
lines changed

8 files changed

+26
-5
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
- 🔥report_to: 默认值为`tensorboard`。你也可以指定`--report_to tensorboard wandb swanlab``--report_to all`
151151
- logging_first_step: 是否记录第一个step的日志,默认为True。
152152
- logging_steps: 日志打印间隔,默认为5。
153+
- logging_dir: tensorboard日志路径。默认为None,即设置为`f'{self.output_dir}/runs'`
153154
- predict_with_generate: 验证时使用生成式的方式,默认为False。
154155
- metric_for_best_model: 默认为None,即当`predict_with_generate`设置为False时,设置为'loss',否则设置为'rouge-l'(在PPO训练时,不进行默认值设置;GRPO训练设置为'reward')。
155156
- greater_is_better: 默认为None,即当`metric_for_best_model`含'loss'时,设置为False,否则设置为True。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ This parameter list inherits from transformers `Seq2SeqTrainingArguments`, with
153153
- 🔥report_to: Default value is `tensorboard`. You can also specify `--report_to tensorboard wandb swanlab` or `--report_to all`.
154154
- logging_first_step: Whether to log the first step, defaults to True.
155155
- logging_steps: Interval for logging, defaults to 5.
156+
- logging_dir: The path for TensorBoard logs. Defaults to None, which means it is set to `f'{self.output_dir}/runs'`.
156157
- predict_with_generate: Whether to use generative method during validation, default is False.
157158
- metric_for_best_model: Default is None, which means that when predict_with_generate is set to False, it is set to 'loss'; otherwise, it is set to 'rouge-l' (during PPO training, the default value is not set; in GRPO training, it is set to 'reward').
158159
- greater_is_better: Defaults to None, which sets it to False when `metric_for_best_model` contains 'loss', otherwise sets to True.

swift/llm/dataset/preprocessor/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,6 @@ def __call__(
310310
dataset = self._cast_pil_image(dataset)
311311

312312
ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False
313-
keep_columns = ['channel']
314313
with self._patch_arrow_writer(), safe_ddp_context(None, True):
315314
try:
316315
dataset_mapped = dataset.map(
@@ -319,7 +318,7 @@ def __call__(
319318
'strict': strict,
320319
'ignore_max_length_error': ignore_max_length_error
321320
},
322-
remove_columns=[col for col in dataset.features.keys() if col not in keep_columns],
321+
remove_columns=list(dataset.features.keys()),
323322
**map_kwargs)
324323
except NotImplementedError:
325324
pass

swift/llm/template/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,8 @@ def encode(self,
432432
encoded = self._kto_encode(inputs)
433433
elif self.mode == 'embedding':
434434
encoded = self._embedding_encode(inputs)
435+
if inputs.channel is not None:
436+
encoded['channel'] = inputs.channel
435437
for key in list(encoded.keys()):
436438
if encoded[key] is None:
437439
encoded.pop(key)

swift/llm/template/template_inputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class StdTemplateInputs:
102102

103103
rejected_response: Optional[str] = None
104104
label: Optional[int] = None
105+
channel: Optional[str] = None
105106

106107
images: List[Union[str, Image.Image]] = field(default_factory=list)
107108
audios: List[str] = field(default_factory=list)
@@ -133,7 +134,7 @@ def is_multimodal(self):
133134
@classmethod
134135
def from_dict(cls, inputs: Dict[str, Any]) -> 'StdTemplateInputs':
135136
kwargs = {}
136-
for key in ['rejected_response', 'label']:
137+
for key in ['rejected_response', 'label', 'channel']:
137138
if key in inputs:
138139
kwargs[key] = inputs[key]
139140
messages = inputs['messages']

swift/plugin/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,9 @@ def online_contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch
388388

389389
@register_loss_func(LossType.channel_loss)
390390
def channel_loss_func(outputs, labels, num_items_in_batch=None, sample_channels=None, trainer=None) -> torch.Tensor:
391-
assert sample_channels is not None, 'Data does not have channel field.'
392391
channels = trainer.args.channels
393392
assert channels is not None, 'Please pass --channels as a hyperparameter.'
393+
assert sample_channels is not None, 'Data does not have channel field.'
394394
logits = outputs.logits
395395

396396
# compute token loss

swift/trainers/trainers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
166166
compute_loss_func = get_loss_func('loss_scale')
167167

168168
sample_channels = inputs.pop('channel', None)
169-
if sample_channels is not None:
169+
if sample_channels is not None and self.args.channels is not None:
170170
state = self.state
171171
setattr(state, 'local_step', getattr(state, 'local_step', 0))
172172
setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {}))

tests/train/test_channel.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import os
2+
3+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4+
5+
6+
def test_channel():
7+
from swift.llm import sft_main, TrainArguments
8+
sft_main(
9+
TrainArguments(
10+
model='Qwen/Qwen2.5-VL-7B-Instruct',
11+
dataset=['channel.jsonl#1000'],
12+
channels=['aaa', 'abc'],
13+
loss_type='channel_loss'))
14+
15+
16+
if __name__ == '__main__':
17+
test_channel()

0 commit comments

Comments
 (0)