Skip to content

Commit 82bb648

Browse files
committed
Fix Granite3 logit scaling
1 parent bee449d commit 82bb648

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

exllamav2/config.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,15 @@ def prepare(self, no_tensors: bool = False):
311311

312312
# Logit/embedding/residual scale
313313

314-
self.logit_scale = read(read_config, float, ["logit_scale", "logits_scaling"], 1)
314+
self.logit_scale = read(read_config, float, "logit_scale", 1)
315315
if self.arch.lm.logit_scale_basedim:
316316
dim_model_base = read(read_config, int, "dim_model_base", self.hidden_size)
317317
self.logit_scale /= (self.hidden_size / dim_model_base)
318318

319+
logit_scaling = read(read_config, float, "logits_scaling", None) # Granite is backwards
320+
if logit_scaling:
321+
self.logit_scale = 1.0 / logit_scaling
322+
319323
self.scale_emb = read(read_config, float, ["scale_emb", "embedding_multiplier"], 1)
320324
residual_multiplier = read(read_config, float, "residual_multiplier", None)
321325
scale_depth = read(read_config, float, "scale_depth", None)

0 commit comments

Comments
 (0)