File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -89,6 +89,7 @@ class LlamaAttention(nn.Module):
89
89
90
90
def __init__ (
91
91
self ,
92
+ config : LlamaConfig ,
92
93
hidden_size : int ,
93
94
num_heads : int ,
94
95
num_kv_heads : int ,
@@ -115,7 +116,9 @@ def __init__(
115
116
# the KV heads across multiple tensor parallel GPUs.
116
117
assert tp_size % self .total_num_kv_heads == 0
117
118
self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
118
- self .head_dim = hidden_size // self .total_num_heads
119
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
120
+ self .head_dim = getattr (config , "head_dim" ,
121
+ self .hidden_size // self .total_num_heads )
119
122
self .q_size = self .num_heads * self .head_dim
120
123
self .kv_size = self .num_kv_heads * self .head_dim
121
124
self .scaling = self .head_dim ** - 0.5
@@ -189,6 +192,7 @@ def __init__(
189
192
attention_bias = getattr (config , "attention_bias" , False ) or getattr (
190
193
config , "bias" , False )
191
194
self .self_attn = LlamaAttention (
195
+ config = config ,
192
196
hidden_size = self .hidden_size ,
193
197
num_heads = config .num_attention_heads ,
194
198
num_kv_heads = getattr (config , "num_key_value_heads" ,
You can’t perform that action at this time.
0 commit comments