Skip to content

Commit ce349ba

Browse files
author
weikaiwen
committed
trainer
1 parent c9c994a commit ce349ba

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

swift/trainers/trainers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
164164
loss_kwargs['loss_scale'] = loss_scale
165165
if compute_loss_func is None:
166166
compute_loss_func = get_loss_func('loss_scale')
167+
168+
sample_channels = inputs.pop('channel', None)
169+
if sample_channels is not None:
170+
state = self.state
171+
setattr(state, 'local_step', getattr(state, 'local_step', 0))
172+
setattr(state, 'ch_loss_steps', getattr(state, 'ch_loss_steps', {}))
173+
174+
loss_kwargs['sample_channels'] = sample_channels
175+
loss_kwargs['trainer'] = self
176+
167177
if (self.label_smoother is not None or compute_loss_func is not None) and 'labels' in inputs:
168178
labels = inputs.pop('labels')
169179

0 commit comments

Comments
 (0)