Skip to content

Commit bee449d

Browse files
committed
Support Granite 3.x arch
1 parent ab4d9e1 commit bee449d

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

exllamav2/architecture.py

+11
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,17 @@ class Params:
681681
self.lm.expect_keys += \
682682
expect_keys_llama
683683

684+
# Granite (v3)
685+
686+
if arch_string == "GraniteForCausalLM":
687+
arch_recognized = True
688+
self.lm.layer_keys += \
689+
layer_keys_llama_norms + \
690+
layer_keys_llama_attn + \
691+
layer_keys_llama_mlp
692+
self.lm.expect_keys += \
693+
expect_keys_llama
694+
684695
# Llama (default + fallback)
685696

686697
if arch_string != "LlamaForCausalLM" and not arch_recognized:

exllamav2/attn.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def __init__(
211211
if cfg.use_qk_norm:
212212
self.submodules += [self.q_norm, self.k_norm]
213213

214-
if cfg.query_pre_attn_scalar:
214+
if cfg.attention_multiplier:
215+
self.scaling = cfg.attention_multiplier
216+
elif cfg.query_pre_attn_scalar:
215217
self.scaling = cfg.query_pre_attn_scalar ** (-0.5)
216218
else:
217219
self.scaling = 1 / math.sqrt(self.head_dim)

exllamav2/config.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class ExLlamaV2Config:
126126
checkpoint_fused_mlp: bool
127127
checkpoint_offset_qzeros: bool
128128
mrope_section: list | None
129+
attention_multiplier: float | None
129130

130131
vision_model_type: str | None
131132
vision_head_dim: int | None
@@ -289,6 +290,7 @@ def prepare(self, no_tensors: bool = False):
289290
self.use_qk_norm = read(read_config, bool, ["use_qk_norm"], False)
290291

291292
self.query_pre_attn_scalar = read(read_config, float, "query_pre_attn_scalar", None)
293+
self.attention_multiplier = read(read_config, float, "attention_multiplier", None)
292294

293295
# MLP params
294296

@@ -309,16 +311,18 @@ def prepare(self, no_tensors: bool = False):
309311

310312
# Logit/embedding/residual scale
311313

312-
self.logit_scale = read(read_config, float, "logit_scale", 1)
314+
self.logit_scale = read(read_config, float, ["logit_scale", "logits_scaling"], 1)
313315
if self.arch.lm.logit_scale_basedim:
314316
dim_model_base = read(read_config, int, "dim_model_base", self.hidden_size)
315317
self.logit_scale /= (self.hidden_size / dim_model_base)
316318

317-
self.scale_emb = read(read_config, float, "scale_emb", 1)
319+
self.scale_emb = read(read_config, float, ["scale_emb", "embedding_multiplier"], 1)
320+
residual_multiplier = read(read_config, float, "residual_multiplier", None)
318321
scale_depth = read(read_config, float, "scale_depth", None)
319-
if scale_depth is None:
320-
self.scale_depth = 1
321-
else:
322+
self.scale_depth = 1
323+
if residual_multiplier:
324+
self.scale_depth = residual_multiplier
325+
elif scale_depth:
322326
self.scale_depth = scale_depth / math.sqrt(self.num_hidden_layers)
323327

324328
self.attn_logit_softcapping = read(read_config, float, "attn_logit_softcapping", None)

0 commit comments

Comments
 (0)