Skip to content

Commit 84fd873

Browse files
authored
Merge pull request #123 from JohannesGaessler/cuda-fa-no-tc-14
CUDA: faster large batch FA without tensor cores
2 parents dec7622 + cc0332d commit 84fd873

7 files changed

+823
-15
lines changed

ggml-cuda/fattn-tile-f16.cu

+395
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
#include "common.cuh"
2+
#include "fattn-common.cuh"
3+
#include "fattn-tile-f16.cuh"
4+
5+
#define FATTN_KQ_STRIDE_TILE_F16 64
6+
7+
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
8+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9+
__launch_bounds__(nwarps*WARP_SIZE, 1)
10+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
11+
static __global__ void flash_attn_tile_ext_f16(
12+
const char * __restrict__ Q,
13+
const char * __restrict__ K,
14+
const char * __restrict__ V,
15+
const char * __restrict__ mask,
16+
float * __restrict__ dst,
17+
float2 * __restrict__ dst_meta,
18+
const float scale,
19+
const float max_bias,
20+
const float m0,
21+
const float m1,
22+
const uint32_t n_head_log2,
23+
const int ne00,
24+
const int ne01,
25+
const int ne02,
26+
const int ne03,
27+
const int ne10,
28+
const int ne11,
29+
const int ne12,
30+
const int ne13,
31+
const int ne31,
32+
const int nb31,
33+
const int nb01,
34+
const int nb02,
35+
const int nb03,
36+
const int nb11,
37+
const int nb12,
38+
const int nb13,
39+
const int ne0,
40+
const int ne1,
41+
const int ne2,
42+
const int ne3) {
43+
#if FP16_AVAILABLE
44+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
45+
46+
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
47+
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
48+
49+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
50+
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
51+
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
52+
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
53+
const half * maskh = (const half *) mask + ne11*ic0;
54+
55+
const int stride_KV2 = nb11 / sizeof(half2);
56+
57+
half slopeh = __float2half(1.0f);
58+
59+
// ALiBi
60+
if (max_bias > 0.0f) {
61+
const uint32_t h = blockIdx.y;
62+
63+
const float base = h < n_head_log2 ? m0 : m1;
64+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
65+
66+
slopeh = __float2half(powf(base, exph));
67+
}
68+
69+
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
70+
71+
__shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
72+
half2 * KQ2 = (half2 *) KQ;
73+
74+
__shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
75+
76+
half kqmax[ncols/nwarps];
77+
#pragma unroll
78+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
79+
kqmax[j0/nwarps] = -HALF_MAX_HALF;
80+
}
81+
half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
82+
83+
half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
84+
85+
// Convert Q to half2 and store in registers:
86+
__shared__ half2 Q_h2[ncols][D/2];
87+
#pragma unroll
88+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
89+
const int j = j0 + threadIdx.y;
90+
91+
#pragma unroll
92+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
93+
const int i = i0 + threadIdx.x;
94+
95+
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
96+
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
97+
}
98+
}
99+
100+
__syncthreads();
101+
102+
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
103+
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
104+
// Calculate KQ tile and keep track of new maximum KQ values:
105+
106+
half kqmax_new[ncols/nwarps];
107+
#pragma unroll
108+
for (int j = 0; j < ncols/nwarps; ++j) {
109+
kqmax_new[j] = kqmax[j];
110+
}
111+
112+
#pragma unroll
113+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
114+
const int i_KQ = i_KQ_0 + threadIdx.y;
115+
116+
#pragma unroll
117+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
118+
const int k_KQ = k_KQ_0 + threadIdx.x;
119+
120+
KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
121+
}
122+
}
123+
124+
__syncthreads();
125+
126+
half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
127+
128+
#pragma unroll
129+
for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
130+
half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
131+
half2 Q_k[ncols/nwarps];
132+
133+
#pragma unroll
134+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
135+
const int i_KQ = i_KQ_0 + threadIdx.x;
136+
137+
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
138+
}
139+
#pragma unroll
140+
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
141+
const int j_KQ = j_KQ_0 + threadIdx.y;
142+
143+
Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
144+
}
145+
146+
#pragma unroll
147+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
148+
#pragma unroll
149+
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
150+
sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
151+
}
152+
}
153+
}
154+
155+
#pragma unroll
156+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
157+
const int i_KQ = i_KQ_0 + threadIdx.x;
158+
159+
#pragma unroll
160+
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
161+
const int j_KQ = j_KQ_0 + threadIdx.y;
162+
163+
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
164+
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
165+
166+
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
167+
168+
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
169+
}
170+
}
171+
172+
__syncthreads();
173+
174+
#pragma unroll
175+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
176+
const int j = j0 + threadIdx.y;
177+
178+
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
179+
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
180+
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
181+
182+
#pragma unroll
183+
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
184+
const int i = i0 + threadIdx.x;
185+
186+
const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
187+
const half2 val = h2exp(diff);
188+
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
189+
KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
190+
}
191+
192+
#pragma unroll
193+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
194+
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
195+
}
196+
}
197+
198+
__syncthreads();
199+
200+
#pragma unroll
201+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
202+
const int k = k0 + threadIdx.y;
203+
204+
#pragma unroll
205+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
206+
const int i = i0 + threadIdx.x;
207+
208+
KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
209+
}
210+
}
211+
212+
__syncthreads();
213+
214+
#pragma unroll
215+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
216+
half2 V_k[(D/2)/WARP_SIZE][2];
217+
half2 KQ_k[ncols/nwarps];
218+
219+
#pragma unroll
220+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
221+
const int i = i0 + threadIdx.x;
222+
223+
V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
224+
V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
225+
}
226+
#pragma unroll
227+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
228+
const int j = j0 + threadIdx.y;
229+
230+
KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
231+
}
232+
233+
#pragma unroll
234+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
235+
#pragma unroll
236+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
237+
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
238+
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
239+
}
240+
}
241+
}
242+
243+
__syncthreads();
244+
}
245+
246+
#pragma unroll
247+
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
248+
const int j_VKQ = j_VKQ_0 + threadIdx.y;
249+
250+
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
251+
kqsum_j = warp_reduce_sum(kqsum_j);
252+
253+
#pragma unroll
254+
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
255+
const int i0 = i00 + 2*threadIdx.x;
256+
257+
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
258+
if (parallel_blocks == 1) {
259+
dst_val /= __half2half2(kqsum_j);
260+
}
261+
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
262+
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
263+
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
264+
}
265+
266+
if (parallel_blocks != 1 && threadIdx.x == 0) {
267+
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
268+
}
269+
}
270+
#else
271+
NO_DEVICE_CODE;
272+
#endif // FP16_AVAILABLE
273+
}
274+
275+
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16(
276+
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
277+
ggml_cuda_pool & pool, cudaStream_t main_stream
278+
) {
279+
ggml_cuda_pool_alloc<float> dst_tmp(pool);
280+
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
281+
282+
if (parallel_blocks > 1) {
283+
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
284+
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
285+
}
286+
287+
constexpr int nwarps = 8;
288+
const dim3 block_dim(WARP_SIZE, nwarps, 1);
289+
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
290+
const int shmem = 0;
291+
292+
float scale = 1.0f;
293+
float max_bias = 0.0f;
294+
295+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
296+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
297+
298+
const uint32_t n_head = Q->ne[2];
299+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
300+
301+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
302+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
303+
304+
flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
305+
<<<blocks_num, block_dim, shmem, main_stream>>> (
306+
(const char *) Q->data,
307+
(const char *) K->data,
308+
(const char *) V->data,
309+
mask ? ((const char *) mask->data) : nullptr,
310+
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
311+
scale, max_bias, m0, m1, n_head_log2,
312+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
313+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
314+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
315+
Q->nb[1], Q->nb[2], Q->nb[3],
316+
K->nb[1], K->nb[2], K->nb[3],
317+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
318+
);
319+
CUDA_CHECK(cudaGetLastError());
320+
321+
if (parallel_blocks == 1) {
322+
return;
323+
}
324+
325+
const dim3 block_dim_combine(D, 1, 1);
326+
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
327+
const int shmem_combine = 0;
328+
329+
flash_attn_combine_results<D, parallel_blocks>
330+
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
331+
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
332+
CUDA_CHECK(cudaGetLastError());
333+
}
334+
335+
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
336+
const ggml_tensor * Q = dst->src[0];
337+
const ggml_tensor * K = dst->src[1];
338+
const ggml_tensor * V = dst->src[2];
339+
340+
const ggml_tensor * mask = dst->src[3];
341+
342+
ggml_tensor * KQV = dst;
343+
344+
const int32_t precision = KQV->op_params[2];
345+
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
346+
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
347+
348+
if (Q->ne[1] <= 16) {
349+
constexpr int cols_per_block = 16;
350+
constexpr int parallel_blocks = 4;
351+
switch (Q->ne[0]) {
352+
case 64:
353+
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
354+
break;
355+
case 128:
356+
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
357+
break;
358+
default:
359+
GGML_ASSERT(false);
360+
break;
361+
}
362+
return;
363+
}
364+
365+
if (Q->ne[1] <= 32) {
366+
constexpr int cols_per_block = 32;
367+
constexpr int parallel_blocks = 4;
368+
switch (Q->ne[0]) {
369+
case 64:
370+
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
371+
break;
372+
case 128:
373+
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
374+
break;
375+
default:
376+
GGML_ASSERT(false);
377+
break;
378+
}
379+
return;
380+
}
381+
382+
constexpr int cols_per_block = 32;
383+
constexpr int parallel_blocks = 1;
384+
switch (Q->ne[0]) {
385+
case 64:
386+
launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
387+
break;
388+
case 128:
389+
launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
390+
break;
391+
default:
392+
GGML_ASSERT(false);
393+
break;
394+
}
395+
}

ggml-cuda/fattn-tile-f16.cuh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)