@@ -44,6 +44,20 @@ __device__ __forceinline__ half2 silu(half2 x)
44
44
return result;
45
45
}
46
46
47
+ __device__ __forceinline__ half gelu (half x)
48
+ {
49
+ float xf = __half2float (x);
50
+ const float c = 0 .797884560803f ; // sqrt(2/Pi)
51
+ float tanh_arg = c * (xf + 0 .044715f * pow (xf, 3 ));
52
+ xf = 0 .5f * xf * (1.0 + tanh (tanh_arg));
53
+ return __float2half_rn (xf);
54
+ }
55
+
56
+ __device__ __forceinline__ half2 gelu (half2 x)
57
+ {
58
+ return __halves2half2 (gelu (x.x ), gelu (x.y ));
59
+ }
60
+
47
61
typedef void (*fp_silu_mul_kernel)
48
62
(
49
63
half*,
@@ -54,7 +68,7 @@ typedef void (*fp_silu_mul_kernel)
54
68
const int
55
69
);
56
70
57
- template <bool use_half2, bool use_r_weights>
71
+ template <bool use_half2, bool use_r_weights, bool act_fn_gelu >
58
72
__global__ void silu_mul_kernel
59
73
(
60
74
half* __restrict__ x,
@@ -90,7 +104,11 @@ __global__ void silu_mul_kernel
90
104
half2 x_item = x_.item_half2 (row, column);
91
105
half2 y_item = y_.item_half2 (row, column);
92
106
93
- x_item = silu (x_item);
107
+ if constexpr (act_fn_gelu)
108
+ x_item = gelu (x_item);
109
+ else
110
+ x_item = silu (x_item);
111
+
94
112
x_item = __hmul2 (x_item, y_item);
95
113
96
114
x_.set_half2 (row, column, x_item);
@@ -100,19 +118,33 @@ __global__ void silu_mul_kernel
100
118
half x_item = x_.item (row, column);
101
119
half y_item = y_.item (row, column);
102
120
103
- x_item = silu (x_item);
121
+ if constexpr (act_fn_gelu)
122
+ x_item = gelu (x_item);
123
+ else
124
+ x_item = silu (x_item);
125
+
104
126
x_item = __hmul (x_item, y_item);
105
127
106
128
x_.set (row, column, x_item);
107
129
}
108
130
}
109
131
110
- fp_silu_mul_kernel pick_silu_mul_kernel (bool use_half2, bool mul_r_weights)
132
+ fp_silu_mul_kernel pick_silu_mul_kernel (bool use_half2, bool mul_r_weights, bool act_fn_gelu )
111
133
{
112
- if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true , false >;
113
- if ( use_half2 && mul_r_weights) return silu_mul_kernel< true , true >;
114
- if (!use_half2 && !mul_r_weights) return silu_mul_kernel<false , false >;
115
- if (!use_half2 && mul_r_weights) return silu_mul_kernel<false , true >;
134
+ if (act_fn_gelu)
135
+ {
136
+ if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true , false , true >;
137
+ if ( use_half2 && mul_r_weights) return silu_mul_kernel< true , true , true >;
138
+ if (!use_half2 && !mul_r_weights) return silu_mul_kernel<false , false , true >;
139
+ if (!use_half2 && mul_r_weights) return silu_mul_kernel<false , true , true >;
140
+ }
141
+ else
142
+ {
143
+ if ( use_half2 && !mul_r_weights) return silu_mul_kernel< true , false , false >;
144
+ if ( use_half2 && mul_r_weights) return silu_mul_kernel< true , true , false >;
145
+ if (!use_half2 && !mul_r_weights) return silu_mul_kernel<false , false , false >;
146
+ if (!use_half2 && mul_r_weights) return silu_mul_kernel<false , true , false >;
147
+ }
116
148
return NULL ;
117
149
};
118
150
@@ -129,7 +161,8 @@ QMLP::QMLP
129
161
half* _temp_a,
130
162
half* _temp_b,
131
163
half* _temp_dq,
132
- int _max_rows
164
+ int _max_rows,
165
+ bool _act_gelu
133
166
):
134
167
layernorm (_layernorm),
135
168
layernorm_bias(_layernorm_bias),
@@ -142,7 +175,8 @@ QMLP::QMLP
142
175
temp_a(_temp_a),
143
176
temp_b(_temp_b),
144
177
temp_dq(_temp_dq),
145
- max_rows(_max_rows)
178
+ max_rows(_max_rows),
179
+ act_gelu(_act_gelu)
146
180
{
147
181
}
148
182
@@ -179,7 +213,7 @@ void QMLP::forward_
179
213
gridDim .x = DIVIDE (up->width , THREADS_X) / (use_half2 ? 2 : 1 );
180
214
gridDim .y = DIVIDE (rows, THREADS_Y);
181
215
182
- fp_silu_mul_kernel kernel = pick_silu_mul_kernel (use_half2, false );
216
+ fp_silu_mul_kernel kernel = pick_silu_mul_kernel (use_half2, false , act_gelu );
183
217
kernel<<<gridDim , blockDim >>> (temp_a, temp_b, rows, intermediate_size, NULL , 0 );
184
218
185
219
gemm_half_q_half_cuda (cublas_handle, temp_a, down, x, rows, columns, intermediate_size, false , temp_dq);
@@ -207,7 +241,8 @@ QMoEMLP::QMoEMLP
207
241
half* _temp_logits,
208
242
half* _temp_dq,
209
243
int _max_rows,
210
- int _hidden_dim
244
+ int _hidden_dim,
245
+ bool _act_gelu
211
246
):
212
247
layernorm (_layernorm),
213
248
layernorm_bias(_layernorm_bias),
@@ -226,7 +261,8 @@ QMoEMLP::QMoEMLP
226
261
temp_logits(_temp_logits),
227
262
temp_dq(_temp_dq),
228
263
max_rows(_max_rows),
229
- hidden_dim(_hidden_dim)
264
+ hidden_dim(_hidden_dim),
265
+ act_gelu(_act_gelu)
230
266
{
231
267
// for (int i = 0; i < num_experts; ++i)
232
268
// {
@@ -299,7 +335,7 @@ void QMoEMLP::forward_
299
335
if (rows <= MAX_Q_GEMM_WEIGHTS)
300
336
{
301
337
int intermediate_size = w1[0 ]->width ;
302
- fp_silu_mul_kernel kernel = pick_silu_mul_kernel (use_half2, true );
338
+ fp_silu_mul_kernel kernel = pick_silu_mul_kernel (use_half2, true , act_gelu );
303
339
304
340
for (int i = 0 ; i < num_experts; i++)
305
341
{
0 commit comments