Skip to content

Commit de451f9

Browse files
authored
fix: cohere cce scaling wrong tensor (axolotl-ai-cloud#2483)
1 parent 9f824ef commit de451f9

File tree

1 file changed

+3
-3
lines changed
  • src/axolotl/integrations/cut_cross_entropy/monkeypatch

1 file changed

+3
-3
lines changed

src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ def cce_forward(
128128

129129
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
130130
assert labels is not None
131-
# scale weight by logit_scale in-place of logits
131+
# scale hidden_states by logit_scale in-place of logits
132132
loss = apply_lce(
133-
hidden_states[:, slice_indices, :],
134-
self.lm_head.weight * self.logit_scale,
133+
hidden_states[:, slice_indices, :] * self.logit_scale,
134+
self.lm_head.weight,
135135
labels,
136136
_PATCH_OPTS,
137137
**kwargs,

0 commit comments

Comments
 (0)