Skip to content

Commit cc1094a

Browse files
committed
Support Gemma
1 parent a19a2ec commit cc1094a

15 files changed

+170
-74
lines changed

exllamav2/attn.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def scratch_space(self):
267267

268268
def temp_state_size(self):
269269

270-
return self.model.config.max_input_len * self.model.config.max_batch_size * self.model.config.hidden_size * 2 + 128
270+
return self.model.config.max_input_len * self.model.config.max_batch_size * self.model.config.num_attention_heads * self.model.config.head_dim * 2 + 128
271271

272272

273273
def temp_q_size(self):
@@ -465,15 +465,15 @@ def forward(self, hidden_states, cache = None, attn_params = None, past_len = No
465465
v_states = None
466466

467467
attn_output = attn_output.transpose(1, 2)
468-
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
468+
attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim))
469469

470470
# Flash Attention 2
471471

472472
else:
473473

474474
# TODO: Enable flash-attn with input mask
475475
attn_output = flash_attn_func(q_states, k_states, v_states, causal = True)
476-
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
476+
attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim))
477477

478478
# xformers memory_efficient_attention
479479

@@ -661,17 +661,16 @@ def forward_torch(self, hidden_states, cache = None, attn_params = None, past_le
661661
attn_output = torch.matmul(attn_weights, value_states)
662662

663663
attn_output = attn_output.transpose(1, 2)
664-
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
664+
attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim))
665665

666666
# Flash Attention 2
667667

668668
else:
669669

670670
attn_output = flash_attn_func(query_states, key_states, value_states, causal = True)
671-
attn_output = attn_output.reshape((batch_size, q_len, hidden_size))
671+
attn_output = attn_output.reshape((batch_size, q_len, self.model.config.num_attention_heads * self.model.config.head_dim))
672672

673673
# Update 8-bit cache
674-
# TODO: Only update changed positions of the cache
675674

676675
if cache is not None:
677676
cache.store_kv_state(self.layer_idx, batch_size, past_len, q_len)

exllamav2/config.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def prepare(self, no_tensors = False):
9797
expect_keys_llama = [["lm_head"],
9898
["model.norm"],
9999
["model.embed_tokens"]]
100+
expect_keys_gemma = [["model.norm"],
101+
["model.embed_tokens"]]
102+
100103

101104
if "LlamaForCausalLM" in read_config["architectures"]:
102105
self.architecture = "Llama"
@@ -157,6 +160,15 @@ def prepare(self, no_tensors = False):
157160
self.attention_bias_qkv = True
158161
self.attention_bias_o = False
159162

163+
elif "GemmaForCausalLM" in read_config["architectures"]:
164+
self.architecture = "Gemma"
165+
layer_keys += \
166+
layer_keys_llama_norms + \
167+
layer_keys_llama_attn + \
168+
layer_keys_llama_mlp
169+
expect_keys += \
170+
expect_keys_gemma
171+
160172
else:
161173
print(f" !! Warning, unknown architecture: {repr(read_config['architectures'])}")
162174
print(f" !! Loading as LlamaForCausalLM")
@@ -206,7 +218,10 @@ def prepare(self, no_tensors = False):
206218

207219
# Model dimensions
208220

209-
self.head_dim = self.hidden_size // self.num_attention_heads
221+
if "head_dim" in read_config:
222+
self.head_dim = read_config["head_dim"]
223+
else:
224+
self.head_dim = self.hidden_size // self.num_attention_heads
210225

211226
# Create map of model tensors
212227

exllamav2/embedding.py

+4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def forward(self, hidden_states, cache = None, attn_params = None, past_len = No
6868

6969
hidden_states = self.embedding.forward(hidden_states)
7070

71+
# Normalize the input embeddings for Gemma
72+
if self.model.config.architecture == "Gemma":
73+
hidden_states = hidden_states * (self.model.config.hidden_size ** 0.5)
74+
7175
if intermediates:
7276
return {"hidden_states": hidden_states}
7377
else:

exllamav2/exllamav2_ext/cuda/q_attn.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ void QAttn::forward_cuda_2
161161
half* lora_temp
162162
)
163163
{
164-
gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, hidden_size, false, temp_dq);
164+
gemm_half_q_half_cuda(cublas_handle, attn_output, o_proj, hidden_state, q_len * batch_size, o_proj->width, o_proj->height, false, temp_dq);
165165

166166
apply_loras_cuda(cublas_handle, o_proj_lora, loras, o_proj, attn_output, hidden_state, lora_temp, q_len * batch_size);
167167
}

exllamav2/exllamav2_ext/cuda/q_mlp.cu

+50-14
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ __device__ __forceinline__ half2 silu(half2 x)
4444
return result;
4545
}
4646

47+
__device__ __forceinline__ half gelu(half x)
48+
{
49+
float xf = __half2float(x);
50+
const float c = 0.797884560803f; // sqrt(2/Pi)
51+
float tanh_arg = c * (xf + 0.044715f * pow(xf, 3));
52+
xf = 0.5f * xf * (1.0 + tanh(tanh_arg));
53+
return __float2half_rn(xf);
54+
}
55+
56+
__device__ __forceinline__ half2 gelu(half2 x)
57+
{
58+
return __halves2half2(gelu(x.x), gelu(x.y));
59+
}
60+
4761
typedef void (*fp_silu_mul_kernel)
4862
(
4963
half*,
@@ -54,7 +68,7 @@ typedef void (*fp_silu_mul_kernel)
5468
const int
5569
);
5670

57-
template <bool use_half2, bool use_r_weights>
71+
template <bool use_half2, bool use_r_weights, bool act_fn_gelu>
5872
__global__ void silu_mul_kernel
5973
(
6074
half* __restrict__ x,
@@ -90,7 +104,11 @@ __global__ void silu_mul_kernel
90104
half2 x_item = x_.item_half2(row, column);
91105
half2 y_item = y_.item_half2(row, column);
92106

93-
x_item = silu(x_item);
107+
if constexpr (act_fn_gelu)
108+
x_item = gelu(x_item);
109+
else
110+
x_item = silu(x_item);
111+
94112
x_item = __hmul2(x_item, y_item);
95113

96114
x_.set_half2(row, column, x_item);
@@ -100,19 +118,33 @@ __global__ void silu_mul_kernel
100118
half x_item = x_.item(row, column);
101119
half y_item = y_.item(row, column);
102120

103-
x_item = silu(x_item);
121+
if constexpr (act_fn_gelu)
122+
x_item = gelu(x_item);
123+
else
124+
x_item = silu(x_item);
125+
104126
x_item = __hmul(x_item, y_item);
105127

106128
x_.set(row, column, x_item);
107129
}
108130
}
109131

110-
fp_silu_mul_kernel pick_silu_mul_kernel(bool use_half2, bool mul_r_weights)
132+
fp_silu_mul_kernel pick_silu_mul_kernel(bool use_half2, bool mul_r_weights, bool act_fn_gelu)
111133
{
112-
if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false>;
113-
if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true>;
114-
if (!use_half2 && !mul_r_weights) return silu_mul_kernel<false, false>;
115-
if (!use_half2 && mul_r_weights) return silu_mul_kernel<false, true>;
134+
if (act_fn_gelu)
135+
{
136+
if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false, true>;
137+
if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true, true>;
138+
if (!use_half2 && !mul_r_weights) return silu_mul_kernel<false, false, true>;
139+
if (!use_half2 && mul_r_weights) return silu_mul_kernel<false, true, true>;
140+
}
141+
else
142+
{
143+
if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true, false, false>;
144+
if ( use_half2 && mul_r_weights) return silu_mul_kernel< true, true, false>;
145+
if (!use_half2 && !mul_r_weights) return silu_mul_kernel<false, false, false>;
146+
if (!use_half2 && mul_r_weights) return silu_mul_kernel<false, true, false>;
147+
}
116148
return NULL;
117149
};
118150

@@ -129,7 +161,8 @@ QMLP::QMLP
129161
half* _temp_a,
130162
half* _temp_b,
131163
half* _temp_dq,
132-
int _max_rows
164+
int _max_rows,
165+
bool _act_gelu
133166
):
134167
layernorm(_layernorm),
135168
layernorm_bias(_layernorm_bias),
@@ -142,7 +175,8 @@ QMLP::QMLP
142175
temp_a(_temp_a),
143176
temp_b(_temp_b),
144177
temp_dq(_temp_dq),
145-
max_rows(_max_rows)
178+
max_rows(_max_rows),
179+
act_gelu(_act_gelu)
146180
{
147181
}
148182

@@ -179,7 +213,7 @@ void QMLP::forward_
179213
gridDim.x = DIVIDE(up->width, THREADS_X) / (use_half2 ? 2 : 1);
180214
gridDim.y = DIVIDE(rows, THREADS_Y);
181215

182-
fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, false);
216+
fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, false, act_gelu);
183217
kernel<<<gridDim, blockDim>>>(temp_a, temp_b, rows, intermediate_size, NULL, 0);
184218

185219
gemm_half_q_half_cuda(cublas_handle, temp_a, down, x, rows, columns, intermediate_size, false, temp_dq);
@@ -207,7 +241,8 @@ QMoEMLP::QMoEMLP
207241
half* _temp_logits,
208242
half* _temp_dq,
209243
int _max_rows,
210-
int _hidden_dim
244+
int _hidden_dim,
245+
bool _act_gelu
211246
):
212247
layernorm(_layernorm),
213248
layernorm_bias(_layernorm_bias),
@@ -226,7 +261,8 @@ QMoEMLP::QMoEMLP
226261
temp_logits(_temp_logits),
227262
temp_dq(_temp_dq),
228263
max_rows(_max_rows),
229-
hidden_dim(_hidden_dim)
264+
hidden_dim(_hidden_dim),
265+
act_gelu(_act_gelu)
230266
{
231267
// for (int i = 0; i < num_experts; ++i)
232268
// {
@@ -299,7 +335,7 @@ void QMoEMLP::forward_
299335
if (rows <= MAX_Q_GEMM_WEIGHTS)
300336
{
301337
int intermediate_size = w1[0]->width;
302-
fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, true);
338+
fp_silu_mul_kernel kernel = pick_silu_mul_kernel(use_half2, true, act_gelu);
303339

304340
for (int i = 0; i < num_experts; i++)
305341
{

exllamav2/exllamav2_ext/cuda/q_mlp.cuh

+8-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ public:
3434
std::unordered_map<uintptr_t, std::tuple<half*, half*, int>> up_proj_lora;
3535
std::unordered_map<uintptr_t, std::tuple<half*, half*, int>> down_proj_lora;
3636

37+
bool act_gelu;
38+
3739
QMLP
3840
(
3941
half* _layernorm,
@@ -47,7 +49,8 @@ public:
4749
half* _temp_a,
4850
half* _temp_b,
4951
half* _temp_dq,
50-
int _max_rows
52+
int _max_rows,
53+
bool _act_gelu
5154
);
5255

5356
~QMLP();
@@ -94,6 +97,8 @@ public:
9497
int max_rows;
9598
int hidden_dim;
9699

100+
bool act_gelu;
101+
97102
// std::vector<std::unordered_map<uintptr_t, std::tuple<half*, half*, int>>> w1_lora;
98103
// std::vector<std::unordered_map<uintptr_t, std::tuple<half*, half*, int>>> w2_lora;
99104
// std::vector<std::unordered_map<uintptr_t, std::tuple<half*, half*, int>>> w3_lora;
@@ -117,7 +122,8 @@ public:
117122
half* _temp_logits,
118123
half* _temp_dq,
119124
int _max_rows,
120-
int _hidden_dim
125+
int _hidden_dim,
126+
bool _act_gelu
121127
);
122128

123129
~QMoEMLP();

exllamav2/exllamav2_ext/ext_qattn.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ uintptr_t make_q_attn
4848
if (qm_q_proj && !layernorm.is_meta()) TORCH_CHECK(qm_q_proj->height == layernorm.size(0), "q_proj is wrong shape")
4949
if (qm_k_proj && !layernorm.is_meta()) TORCH_CHECK(qm_k_proj->height == layernorm.size(0), "k_proj is wrong shape")
5050
if (qm_v_proj && !layernorm.is_meta()) TORCH_CHECK(qm_v_proj->height == layernorm.size(0), "v_proj is wrong shape")
51-
if (!layernorm.is_meta()) TORCH_CHECK(qm_o_proj->height == layernorm.size(0), "o_proj is wrong shape")
51+
if (!layernorm.is_meta()) TORCH_CHECK(qm_o_proj->width == layernorm.size(0), "o_proj is wrong shape")
5252

5353
QAttn* attn = new QAttn
5454
(

exllamav2/exllamav2_ext/ext_qmlp.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ uintptr_t make_q_mlp
2828
torch::Tensor temp_a,
2929
torch::Tensor temp_b,
3030
torch::Tensor temp_dq,
31-
int max_rows
31+
int max_rows,
32+
bool act_gelu
3233
)
3334
{
3435
QMatrix* qm_gate = reinterpret_cast<QMatrix*> (q_gate);
@@ -52,7 +53,8 @@ uintptr_t make_q_mlp
5253
(half*) temp_a.data_ptr(),
5354
(half*) temp_b.data_ptr(),
5455
(half*) temp_dq.data_ptr(),
55-
max_rows
56+
max_rows,
57+
act_gelu
5658
);
5759

5860
return reinterpret_cast<uintptr_t> (mlp);
@@ -163,7 +165,8 @@ uintptr_t make_q_moe_mlp
163165
torch::Tensor temp_b,
164166
torch::Tensor temp_logits,
165167
torch::Tensor temp_dq,
166-
int max_rows
168+
int max_rows,
169+
bool act_gelu
167170
)
168171
{
169172
std::vector<QMatrix*> qm_w1;
@@ -202,7 +205,8 @@ uintptr_t make_q_moe_mlp
202205
(half*) temp_logits.data_ptr(),
203206
(half*) temp_dq.data_ptr(),
204207
max_rows,
205-
hidden_dim
208+
hidden_dim,
209+
act_gelu
206210
);
207211

208212
return reinterpret_cast<uintptr_t> (moe_mlp);

exllamav2/exllamav2_ext/ext_qmlp.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ uintptr_t make_q_mlp
1212
torch::Tensor temp_a,
1313
torch::Tensor temp_b,
1414
torch::Tensor temp_dq,
15-
int max_rows
15+
int max_rows,
16+
bool act_gelu
1617
);
1718

1819
void free_q_mlp
@@ -57,7 +58,8 @@ uintptr_t make_q_moe_mlp
5758
torch::Tensor temp_b,
5859
torch::Tensor temp_logits,
5960
torch::Tensor temp_dq,
60-
int max_rows
61+
int max_rows,
62+
bool act_gelu
6163
);
6264

6365
void free_q_moe_mlp

exllamav2/mlp.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def load(self):
9191
device_tensors.get_scratch_slice(self.temp_a_size()),
9292
device_tensors.get_scratch_slice(self.temp_b_size()),
9393
device_tensors.get_scratch_slice(self.temp_dq_size()),
94-
self.model.config.max_input_len * self.model.config.max_batch_size)
94+
self.model.config.max_input_len * self.model.config.max_batch_size,
95+
self.model.config.architecture == "Gemma")
9596

9697

9798
def unload(self):
@@ -195,7 +196,7 @@ def forward_torch(self, hidden_states, cache = None, attn_params = None, interme
195196
post_norm = self.post_attention_layernorm.forward(hidden_states)
196197

197198
gate = self.gate_proj.forward(post_norm, loras = loras)
198-
y = F.silu(gate)
199+
y = F.gelu(gate) if self.model.config.architecture == "Gemma" else F.silu(gate)
199200
up = self.up_proj.forward(post_norm, loras = loras)
200201
y *= up
201202
y.clamp_(min = -65504.0, max = 65504.0)

exllamav2/model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,13 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
164164
self.modules_dict[self.modules[-1].key] = self.modules[-1]
165165

166166
self.head_layer_idx = len(self.modules)
167+
167168
self.modules.append(ExLlamaV2Linear(self, "lm_head", self.config.hidden_size, self.config.vocab_size, False))
168169
self.modules_dict[self.modules[-1].key] = self.modules[-1]
170+
if self.config.architecture == "Gemma":
171+
self.modules[-1].alt_key = "model.embed_tokens"
169172

170-
# Find last layer that affects k/v cache
173+
# Find last layer that affects k/v cache
171174

172175
layer_idx = len(self.modules)
173176
while True:

0 commit comments

Comments
 (0)