Skip to content

[pt/sft] Feature channel loss #4405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 4, 2025
Merged

Conversation

kevssim
Copy link
Contributor

@kevssim kevssim commented May 29, 2025

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Monitoring the model's performance on different tasks(channels).

Experiment results

20250529114018_2

@@ -323,6 +328,11 @@ def __call__(
if isinstance(dataset_mapped, HfDataset) and len(dataset) != len(dataset_mapped):
logger.info(
f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}')
if channel:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为什么呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为了读取数据的‘channel’字段,把它放到inputs中。
数据示例:{
"channel": "chat",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What color do you like?"},
{"role": "assistant", "content": "I like blue."}]
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的map应该是不需要的。dataset_mapped中应该已经包含channel了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
这一步会把channel删掉,是否改成把channel从remove_columns中移除?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理,见c12f566

@@ -88,6 +88,8 @@ def __init__(self, *args, **kwargs):
self.infer_engine = PtEngine.from_model_template(
self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size)
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl'))
self.channel_cid = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个写个注释吧

channel_cid和cid_channel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改为通过参数方式传入,不需要channel_cid和cid_channel了

@kevssim kevssim marked this pull request as draft May 29, 2025 04:30
@kevssim kevssim requested a review from Jintao-Huang May 30, 2025 04:23
@kevssim kevssim closed this May 30, 2025
@kevssim kevssim reopened this May 30, 2025
@kevssim kevssim marked this pull request as ready for review May 30, 2025 04:26
@Jintao-Huang Jintao-Huang changed the title Feature channel loss [pt/sft] Feature channel loss May 30, 2025
@@ -351,6 +351,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
- check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。如果是断网环境,请设置为False。
- 🔥create_checkpoint_symlink: 额外创建checkpoint软链接,方便书写自动化训练脚本。best_model和last_model的软链接路径分别为f'{output_dir}/best'和f'{output_dir}/last'。
- loss_type: loss类型。默认为None,使用模型自带损失函数。
- channel_list : 数据集包含的channel列表。默认为None。结合`--loss_type channel_loss`使用,可参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh)。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

名字请改成channels

这有可能在训练的时候自动记录channel的实现么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理,见48b4aec

没有在训练的时候自动记录的原因是:多卡情况下,需要保持channels一致(否则可能死锁),自动记录可能会导致不一致,需要通信来同步,增加通信开销

@@ -360,6 +360,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine
- check_model: Check local model files for corruption or modification and give a prompt, default is True. If in an offline environment, please set to False.
- 🔥create_checkpoint_symlink: Creates additional checkpoint symlinks to facilitate writing automated training scripts. The symlink paths for `best_model` and `last_model` are `f'{output_dir}/best'` and `f'{output_dir}/last'` respectively.
- loss_type: Type of loss. Defaults to None, which uses the model's built-in loss function.
- channel_list:List of channels included in the dataset. Defaults to None. Used in conjunction with `--loss_type channel_loss`. Refer to [this example](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh) for more details.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@@ -116,6 +116,7 @@ def run(self):
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks=self.callbacks,
channel_list=self.args.channel_list,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个写在TrainArgumentsMixin吧,不在trainer的init方法传入了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已处理,见48b4aec

@@ -323,6 +328,11 @@ def __call__(
if isinstance(dataset_mapped, HfDataset) and len(dataset) != len(dataset_mapped):
logger.info(
f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}')
if channel:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的map应该是不需要的。dataset_mapped中应该已经包含channel了

@Jintao-Huang
Copy link
Collaborator

还有更新不,打算merge了

@kevssim
Copy link
Contributor Author

kevssim commented Jun 4, 2025

没有更新了,merge吧,谢谢

@Jintao-Huang Jintao-Huang merged commit 40ccffc into modelscope:main Jun 4, 2025
2 of 3 checks passed
tastelikefeet added a commit to tastelikefeet/swift that referenced this pull request Jun 5, 2025
* commit '58b45225287864b4981ac9c0f761a421b4a3f8c0':
  [vlm] fix llm_lora vlm_full (modelscope#4482)
  fix infer in client (modelscope#4480)
  [pt/sft] Feature channel loss (modelscope#4405)
  [megatron] fix val_dataset (modelscope#4478)
  fix (modelscope#4475)
  [seq_parallel] fix sp compute_acc (modelscope#4474)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants