We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9f824ef commit de451f9Copy full SHA for de451f9
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
@@ -128,10 +128,10 @@ def cce_forward(
128
129
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
130
assert labels is not None
131
- # scale weight by logit_scale in-place of logits
+ # scale hidden_states by logit_scale in-place of logits
132
loss = apply_lce(
133
- hidden_states[:, slice_indices, :],
134
- self.lm_head.weight * self.logit_scale,
+ hidden_states[:, slice_indices, :] * self.logit_scale,
+ self.lm_head.weight,
135
labels,
136
_PATCH_OPTS,
137
**kwargs,
0 commit comments