-
Notifications
You must be signed in to change notification settings - Fork 704
[pt/sft] Feature channel loss #4405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -323,6 +328,11 @@ def __call__( | |||
if isinstance(dataset_mapped, HfDataset) and len(dataset) != len(dataset_mapped): | |||
logger.info( | |||
f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}') | |||
if channel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是为什么呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是为了读取数据的‘channel’字段,把它放到inputs中。
数据示例:{
"channel": "chat",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What color do you like?"},
{"role": "assistant", "content": "I like blue."}]
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的map应该是不需要的。dataset_mapped中应该已经包含channel了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已处理,见c12f566
swift/trainers/trainers.py
Outdated
@@ -88,6 +88,8 @@ def __init__(self, *args, **kwargs): | |||
self.infer_engine = PtEngine.from_model_template( | |||
self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size) | |||
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl')) | |||
self.channel_cid = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个写个注释吧
channel_cid和cid_channel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为通过参数方式传入,不需要channel_cid和cid_channel了
# Conflicts: # swift/trainers/trainers.py
docs/source/Instruction/命令行参数.md
Outdated
@@ -351,6 +351,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数. | |||
- check_model: 检查本地模型文件有损坏或修改并给出提示,默认为True。如果是断网环境,请设置为False。 | |||
- 🔥create_checkpoint_symlink: 额外创建checkpoint软链接,方便书写自动化训练脚本。best_model和last_model的软链接路径分别为f'{output_dir}/best'和f'{output_dir}/last'。 | |||
- loss_type: loss类型。默认为None,使用模型自带损失函数。 | |||
- channel_list : 数据集包含的channel列表。默认为None。结合`--loss_type channel_loss`使用,可参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh)。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
名字请改成channels
这有可能在训练的时候自动记录channel的实现么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已处理,见48b4aec
没有在训练的时候自动记录的原因是:多卡情况下,需要保持channels一致(否则可能死锁),自动记录可能会导致不一致,需要通信来同步,增加通信开销
@@ -360,6 +360,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine | |||
- check_model: Check local model files for corruption or modification and give a prompt, default is True. If in an offline environment, please set to False. | |||
- 🔥create_checkpoint_symlink: Creates additional checkpoint symlinks to facilitate writing automated training scripts. The symlink paths for `best_model` and `last_model` are `f'{output_dir}/best'` and `f'{output_dir}/last'` respectively. | |||
- loss_type: Type of loss. Defaults to None, which uses the model's built-in loss function. | |||
- channel_list:List of channels included in the dataset. Defaults to None. Used in conjunction with `--loss_type channel_loss`. Refer to [this example](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/channel_loss.sh) for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
swift/llm/train/sft.py
Outdated
@@ -116,6 +116,7 @@ def run(self): | |||
train_dataset=train_dataset, | |||
eval_dataset=val_dataset, | |||
callbacks=self.callbacks, | |||
channel_list=self.args.channel_list, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个写在TrainArgumentsMixin吧,不在trainer的init方法传入了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已处理,见48b4aec
@@ -323,6 +328,11 @@ def __call__( | |||
if isinstance(dataset_mapped, HfDataset) and len(dataset) != len(dataset_mapped): | |||
logger.info( | |||
f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}') | |||
if channel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的map应该是不需要的。dataset_mapped中应该已经包含channel了
还有更新不,打算merge了 |
没有更新了,merge吧,谢谢 |
* commit '58b45225287864b4981ac9c0f761a421b4a3f8c0': [vlm] fix llm_lora vlm_full (modelscope#4482) fix infer in client (modelscope#4480) [pt/sft] Feature channel loss (modelscope#4405) [megatron] fix val_dataset (modelscope#4478) fix (modelscope#4475) [seq_parallel] fix sp compute_acc (modelscope#4474)
PR type
PR information
Monitoring the model's performance on different tasks(channels).
Experiment results