Skip to content

Commit f454870

Browse files
authored
Merge pull request #113 from JohannesGaessler/cuda-fa-no-tc-11
CUDA: add FP32 FlashAttention vector kernel
2 parents becc8d7 + e0d1184 commit f454870

9 files changed

+849
-433
lines changed

ggml-cuda.cu

+10-1
Original file line numberDiff line numberDiff line change
@@ -2714,6 +2714,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
27142714
}
27152715

27162716
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
2717+
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
27172718
switch (op->op) {
27182719
case GGML_OP_UNARY:
27192720
switch (ggml_get_unary_op(op)) {
@@ -2841,8 +2842,16 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28412842
case GGML_OP_ARANGE:
28422843
case GGML_OP_TIMESTEP_EMBEDDING:
28432844
case GGML_OP_LEAKY_RELU:
2844-
case GGML_OP_FLASH_ATTN_EXT:
28452845
return true;
2846+
case GGML_OP_FLASH_ATTN_EXT:
2847+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2848+
return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
2849+
#else
2850+
if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
2851+
return true;
2852+
}
2853+
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
2854+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
28462855
default:
28472856
return false;
28482857
}

ggml-cuda/common.cuh

+4
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
321321

322322
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
323323

324+
static bool fast_fp16_available(const int cc) {
325+
return cc >= CC_PASCAL && cc != 610;
326+
}
327+
324328
static bool fp16_mma_available(const int cc) {
325329
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
326330
}

ggml-cuda/fattn-common.cuh

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#define FATTN_KQ_STRIDE 256
2+
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
3+
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
4+
5+
template<int D, int parallel_blocks> // D == head size
6+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
7+
__launch_bounds__(D, 1)
8+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9+
static __global__ void flash_attn_combine_results(
10+
const float * __restrict__ VKQ_parts,
11+
const float2 * __restrict__ VKQ_meta,
12+
float * __restrict__ dst) {
13+
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
14+
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
15+
dst += D * gridDim.y*blockIdx.x;
16+
17+
const int tid = threadIdx.x;
18+
__builtin_assume(tid < D);
19+
20+
__shared__ float2 meta[parallel_blocks];
21+
if (tid < 2*parallel_blocks) {
22+
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
23+
}
24+
25+
__syncthreads();
26+
27+
float kqmax = meta[0].x;
28+
#pragma unroll
29+
for (int l = 1; l < parallel_blocks; ++l) {
30+
kqmax = max(kqmax, meta[l].x);
31+
}
32+
33+
float VKQ_numerator = 0.0f;
34+
float VKQ_denominator = 0.0f;
35+
#pragma unroll
36+
for (int l = 0; l < parallel_blocks; ++l) {
37+
const float diff = meta[l].x - kqmax;
38+
const float KQ_max_scale = expf(diff);
39+
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
40+
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
41+
42+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
43+
VKQ_denominator += KQ_max_scale * meta[l].y;
44+
}
45+
46+
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
47+
}

0 commit comments

Comments
 (0)