|
| 1 | +#define FATTN_KQ_STRIDE 256 |
| 2 | +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. |
| 3 | +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. |
| 4 | + |
| 5 | +template<int D, int parallel_blocks> // D == head size |
| 6 | +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) |
| 7 | +__launch_bounds__(D, 1) |
| 8 | +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) |
| 9 | +static __global__ void flash_attn_combine_results( |
| 10 | + const float * __restrict__ VKQ_parts, |
| 11 | + const float2 * __restrict__ VKQ_meta, |
| 12 | + float * __restrict__ dst) { |
| 13 | + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; |
| 14 | + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; |
| 15 | + dst += D * gridDim.y*blockIdx.x; |
| 16 | + |
| 17 | + const int tid = threadIdx.x; |
| 18 | + __builtin_assume(tid < D); |
| 19 | + |
| 20 | + __shared__ float2 meta[parallel_blocks]; |
| 21 | + if (tid < 2*parallel_blocks) { |
| 22 | + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; |
| 23 | + } |
| 24 | + |
| 25 | + __syncthreads(); |
| 26 | + |
| 27 | + float kqmax = meta[0].x; |
| 28 | +#pragma unroll |
| 29 | + for (int l = 1; l < parallel_blocks; ++l) { |
| 30 | + kqmax = max(kqmax, meta[l].x); |
| 31 | + } |
| 32 | + |
| 33 | + float VKQ_numerator = 0.0f; |
| 34 | + float VKQ_denominator = 0.0f; |
| 35 | +#pragma unroll |
| 36 | + for (int l = 0; l < parallel_blocks; ++l) { |
| 37 | + const float diff = meta[l].x - kqmax; |
| 38 | + const float KQ_max_scale = expf(diff); |
| 39 | + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); |
| 40 | + *((uint32_t *) &KQ_max_scale) &= ftz_mask; |
| 41 | + |
| 42 | + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; |
| 43 | + VKQ_denominator += KQ_max_scale * meta[l].y; |
| 44 | + } |
| 45 | + |
| 46 | + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; |
| 47 | +} |
0 commit comments