Skip to content

Commit ef3b5ba

Browse files
committed
Support Mistral-Nemo
1 parent 5f0b993 commit ef3b5ba

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

vllm/model_executor/models/llama.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class LlamaAttention(nn.Module):
8989

9090
def __init__(
9191
self,
92+
config: LlamaConfig,
9293
hidden_size: int,
9394
num_heads: int,
9495
num_kv_heads: int,
@@ -115,7 +116,9 @@ def __init__(
115116
# the KV heads across multiple tensor parallel GPUs.
116117
assert tp_size % self.total_num_kv_heads == 0
117118
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
118-
self.head_dim = hidden_size // self.total_num_heads
119+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
120+
self.head_dim = getattr(config, "head_dim",
121+
self.hidden_size // self.total_num_heads)
119122
self.q_size = self.num_heads * self.head_dim
120123
self.kv_size = self.num_kv_heads * self.head_dim
121124
self.scaling = self.head_dim**-0.5
@@ -189,6 +192,7 @@ def __init__(
189192
attention_bias = getattr(config, "attention_bias", False) or getattr(
190193
config, "bias", False)
191194
self.self_attn = LlamaAttention(
195+
config=config,
192196
hidden_size=self.hidden_size,
193197
num_heads=config.num_attention_heads,
194198
num_kv_heads=getattr(config, "num_key_value_heads",

0 commit comments

Comments
 (0)