@@ -156,10 +156,6 @@ inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bid
156
156
// PREDICATES
157
157
//
158
158
159
- // // Allocate predicate tensors for m and n
160
- // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
161
- // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
162
-
163
159
// Construct identity layout for sQ and sK
164
160
Tensor cQ = make_identity_tensor (make_shape (size<0 >(sQ ), size<1 >(sQ ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
165
161
Tensor cKV = make_identity_tensor (make_shape (size<0 >(sK ), size<1 >(sK ))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
@@ -434,9 +430,6 @@ inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bid
434
430
if (num_cols > 0 ) {
435
431
auto * cols_ptr = params.column_index + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_V ;
436
432
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
437
- // tKgKBlock.data() = tKgKBlockData + blks_ptr[0] * int64_t(params.k_row_stride);
438
- // flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV,
439
- // binfo.actual_seqlen_k - blks_ptr[0]);
440
433
#pragma unroll
441
434
for (int m = 0 ; m < size<1 >(tKgKToken); ++m) {
442
435
if (Is_even_MN || get<0 >(tKVcKV (0 , m, 0 )) < num_cols) { // Is_even_MN
@@ -445,7 +438,7 @@ inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bid
445
438
for (int k = 0 ; k < size<2 >(tKgKToken); ++k) {
446
439
if (Is_even_K || tKVpKV (k)) {
447
440
cute::copy (gmem_tiled_copy_QKV, tKgKToken (_, m, k), tKsK (_, m, k));
448
- } else if ( true ) { // Clear_OOB_K
441
+ } else { // Clear_OOB_K
449
442
cute::clear (tKsK (_, m, k));
450
443
}
451
444
}
@@ -463,7 +456,6 @@ inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bid
463
456
464
457
// Advance gV
465
458
if (n < num_cols_block - 1 ) {
466
- // flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV);
467
459
#pragma unroll
468
460
for (int m = 0 ; m < size<1 >(tVgVToken); ++m) {
469
461
if (true ) { // Is_even_MN
@@ -480,9 +472,6 @@ inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bid
480
472
}
481
473
} else {
482
474
// Clear the smem tiles to account for predicated off loads
483
- // flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
484
- // gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - start_n
485
- // );
486
475
#pragma unroll
487
476
for (int m = 0 ; m < size<1 >(tVgVToken); ++m) {
488
477
if (Is_even_MN || n * kBlockN + get<0 >(tKVcKV (0 , m, 0 )) < num_cols) { // Is_even_MN
@@ -491,11 +480,11 @@ inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bid
491
480
for (int k = 0 ; k < size<2 >(tVgVToken); ++k) {
492
481
if (Is_even_K || tKVpKV (k)) {
493
482
cute::copy (gmem_tiled_copy_QKV, tVgVToken (_, m, k), tVsV (_, m, k));
494
- } else if ( true ) { // Clear_OOB_K
483
+ } else { // Clear_OOB_K
495
484
cute::clear (tVsV (_, m, k));
496
485
}
497
486
}
498
- } else if ( true ) { // Clear_OOB_MN
487
+ } else { // Clear_OOB_MN
499
488
cute::clear (tVsV (_, m, _));
500
489
}
501
490
}
@@ -511,9 +500,6 @@ inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bid
511
500
flash::apply_softcap (acc_s, params.softcap );
512
501
}
513
502
514
- // mask.template apply_mask<Is_causal, Is_even_MN>(
515
- // acc_s, cols_ptr[n * kBlockN], m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
516
- // );
517
503
if (n >= num_cols_block - n_masking_steps) {
518
504
Tensor tensor = make_tensor (acc_s.data (), flash::convert_layout_acc_rowcol (acc_s.layout ()));
519
505
const int lane_id = threadIdx.x % 32 ;
@@ -546,8 +532,6 @@ inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bid
546
532
flash::cp_async_wait<0 >();
547
533
__syncthreads ();
548
534
if (n < num_cols_block - 2 ) {
549
- // tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index + 1] * int64_t(params.k_row_stride);
550
- // flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV);
551
535
#pragma unroll
552
536
for (int m = 0 ; m < size<1 >(tKgKToken); ++m) {
553
537
if (true ) { // Is_even_MN
@@ -567,9 +551,6 @@ inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bid
567
551
// isn't right and we get race conditions.
568
552
cute::cp_async_fence ();
569
553
} else if (n == num_cols_block - 2 ) {
570
- // tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index + 1] * int64_t(params.k_row_stride);
571
- // flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV,
572
- // binfo.actual_seqlen_k - blks_ptr[block_index + 1]);
573
554
#pragma unroll
574
555
for (int m = 0 ; m < size<1 >(tKgKToken); ++m) {
575
556
if (Is_even_MN || (n + 1 ) * kBlockN + get<0 >(tKVcKV (0 , m, 0 )) < num_cols) { // Is_even_MN
0 commit comments