-
Notifications
You must be signed in to change notification settings - Fork 706
[pt/sft] Feature channel loss #4405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
b94950e
add channel into inputs
kevssim 5744d21
compute channel loss
da264ee
channel loss example
bf8cf77
code lint
c71a3f3
fix
d9bcc0e
fix example
baabc1d
code lint
379cfb9
fix multi-gpu
c9c994a
Merge remote-tracking branch 'origin/main' into feat/channel_loss
ce349ba
trainer
4640b53
update documents
48b4aec
rename channel_list to channels
c12f566
fix map channel
a07bb78
remove unused import
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# use loss_type channel_loss | ||
# data should have 'channel' field | ||
# eg. | ||
# {"channel": "chat", | ||
# "messages": [ | ||
# {"role": "system", "content": "You are a helpful assistant"}, | ||
# {"role": "user", "content": "What color do you like?"}, | ||
# {"role": "assistant", "content": "I like blue."} | ||
# ]} | ||
CUDA_VISIBLE_DEVICES=0 \ | ||
swift sft \ | ||
--model Qwen/Qwen2.5-0.5B-Instruct \ | ||
--dataset '/path/to/channel_dataset' \ | ||
--train_type full \ | ||
--torch_dtype bfloat16 \ | ||
--num_train_epochs 1 \ | ||
--per_device_train_batch_size 1 \ | ||
--per_device_eval_batch_size 1 \ | ||
--learning_rate 1e-5 \ | ||
--gradient_accumulation_steps 8 \ | ||
--eval_steps 100 \ | ||
--save_steps 100 \ | ||
--save_total_limit 2 \ | ||
--logging_steps 1 \ | ||
--max_length 512 \ | ||
--output_dir output \ | ||
--system 'You are a helpful assistant.' \ | ||
--warmup_ratio 0.05 \ | ||
--dataloader_num_workers 4 \ | ||
--loss_type channel_loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,6 +88,8 @@ def __init__(self, *args, **kwargs): | |
self.infer_engine = PtEngine.from_model_template( | ||
self.model, self.template, max_batch_size=self.args.per_device_eval_batch_size) | ||
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'predict.jsonl')) | ||
self.channel_cid = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个写个注释吧 channel_cid和cid_channel There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改为通过参数方式传入,不需要channel_cid和cid_channel了 |
||
self.cid_channel = None | ||
|
||
@staticmethod | ||
def _predict_data_collator(batch): | ||
|
@@ -157,6 +159,24 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs: | ||
labels = inputs.pop('labels') | ||
|
||
channels = inputs.pop('channel', None) | ||
if channels is not None: | ||
self.channel_cid = self.channel_cid or {} | ||
self.cid_channel = self.cid_channel or {} | ||
|
||
for ch in set(channels): | ||
if ch not in self.channel_cid: | ||
cid = len(self.channel_cid) | ||
self.channel_cid[ch] = cid | ||
self.cid_channel[cid] = ch | ||
|
||
state = self.state | ||
setattr(state, 'local_step', getattr(state, 'local_step', 0)) | ||
setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {})) | ||
|
||
loss_kwargs['channels'] = channels | ||
loss_kwargs['trainer'] = self | ||
|
||
loss_scale = inputs.pop('loss_scale', None) | ||
if loss_scale is not None: | ||
loss_kwargs['loss_scale'] = loss_scale | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是为什么呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是为了读取数据的‘channel’字段,把它放到inputs中。
数据示例:{
"channel": "chat",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What color do you like?"},
{"role": "assistant", "content": "I like blue."}]
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的map应该是不需要的。dataset_mapped中应该已经包含channel了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一步会把channel删掉,是否改成把channel从remove_columns中移除?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已处理,见c12f566