|
| 1 | +// Copyright (c) Microsoft Corporation. |
| 2 | +// Licensed under the MIT license. |
| 3 | + |
| 4 | +#include <assert.h> |
| 5 | + |
| 6 | +#include <cuda.h> |
| 7 | + |
| 8 | +#include <torch/all.h> |
| 9 | + |
| 10 | +__device__ int64_t save_blocks(int* block_offset, int64_t range_start, |
| 11 | + int64_t range_end, int64_t block_size, |
| 12 | + int64_t input_block_count, int64_t kv_seqlen) { |
| 13 | + if (range_start >= kv_seqlen) { |
| 14 | + return input_block_count; |
| 15 | + } |
| 16 | + if (range_end > kv_seqlen) { |
| 17 | + range_end = kv_seqlen; |
| 18 | + } |
| 19 | + int64_t current_block_count = input_block_count; |
| 20 | + for (int idx = range_start; idx < range_end; idx += block_size) { |
| 21 | + block_offset[current_block_count++] = idx; |
| 22 | + } |
| 23 | + return current_block_count; |
| 24 | +} |
| 25 | + |
| 26 | +__global__ void convert_vertical_slash_indexes_kernel( |
| 27 | + const int* q_seqlens, // [BATCH, ] |
| 28 | + const int* kv_seqlens, // [BATCH, ] |
| 29 | + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
| 30 | + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
| 31 | + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 32 | + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] |
| 33 | + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 34 | + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] |
| 35 | + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, |
| 36 | + int64_t NNZ_V, int64_t NNZ_S, |
| 37 | + bool causal // True for intra, False for succ |
| 38 | +) { |
| 39 | + const int batch_idx = blockIdx.y; |
| 40 | + const int head_idx = blockIdx.x; |
| 41 | + const int group_idx = blockIdx.z; |
| 42 | + |
| 43 | + int64_t q_seqlen = q_seqlens[batch_idx]; |
| 44 | + int64_t kv_seqlen = kv_seqlens[batch_idx]; |
| 45 | + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; |
| 46 | + int64_t start_m = block_idx_m * BLOCK_SIZE_M; |
| 47 | + if (start_m >= q_seqlen) { |
| 48 | + return; |
| 49 | + } |
| 50 | + int64_t end_m = start_m + BLOCK_SIZE_M; |
| 51 | + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; |
| 52 | + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; |
| 53 | + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; |
| 54 | + block_count += row_offset; |
| 55 | + block_offset += row_offset * NNZ_S; |
| 56 | + column_count += row_offset; |
| 57 | + column_index += row_offset * NNZ_V; |
| 58 | + |
| 59 | + bool has_slash = true; |
| 60 | + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; |
| 61 | + int64_t s = 0, v = 0; |
| 62 | + int64_t v_idx = vertical_indexes[v++]; |
| 63 | + int64_t s_idx = slash_indexes[s++]; |
| 64 | + if (causal) { |
| 65 | + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { |
| 66 | + s_idx = slash_indexes[s++]; |
| 67 | + } |
| 68 | + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; |
| 69 | + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); |
| 70 | + } else { |
| 71 | + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { |
| 72 | + s_idx = slash_indexes[s++]; |
| 73 | + } |
| 74 | + if (s_idx > end_m + kv_seqlen) has_slash = false; |
| 75 | + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); |
| 76 | + } |
| 77 | + |
| 78 | + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; |
| 79 | + if (!has_slash) { |
| 80 | + if (causal) { |
| 81 | + range_start = (kv_seqlen - q_seqlen) + end_m; |
| 82 | + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; |
| 83 | + } else { |
| 84 | + range_start = kv_seqlen; |
| 85 | + range_end = kv_seqlen + BLOCK_SIZE_N; |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + bool slash_finished = false; |
| 90 | + while (1) { |
| 91 | + if (v_idx < range_end) { |
| 92 | + if (v_idx < range_start) { |
| 93 | + column_index[tmp_col_cnt++] = v_idx; |
| 94 | + } |
| 95 | + if (v < NNZ_V) { |
| 96 | + v_idx = vertical_indexes[v++]; |
| 97 | + } else { |
| 98 | + if (causal) |
| 99 | + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); |
| 100 | + else |
| 101 | + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; |
| 102 | + } |
| 103 | + } else { |
| 104 | + if ((s < NNZ_S && causal) || |
| 105 | + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { |
| 106 | + if (causal) |
| 107 | + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], |
| 108 | + BLOCK_SIZE_M); |
| 109 | + else |
| 110 | + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); |
| 111 | + } else { |
| 112 | + if (v == NNZ_V || (v_idx > range_start && causal)) { |
| 113 | + // add the last vertical if no more slash |
| 114 | + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { |
| 115 | + column_index[tmp_col_cnt++] = v_idx; |
| 116 | + } |
| 117 | + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, |
| 118 | + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); |
| 119 | + break; |
| 120 | + } else { |
| 121 | + if (causal) { |
| 122 | + range_start = (kv_seqlen - q_seqlen) + end_m; |
| 123 | + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; |
| 124 | + } else { |
| 125 | + // if slash_finished but there are vertical left, save current |
| 126 | + // blocks |
| 127 | + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, |
| 128 | + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); |
| 129 | + range_start = kv_seqlen; |
| 130 | + range_end = kv_seqlen + BLOCK_SIZE_N; |
| 131 | + } |
| 132 | + slash_finished = true; |
| 133 | + } |
| 134 | + } |
| 135 | + if (!slash_finished) { |
| 136 | + if (s_idx > range_end + BLOCK_SIZE_M) { |
| 137 | + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, |
| 138 | + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); |
| 139 | + range_start = s_idx - BLOCK_SIZE_M; |
| 140 | + range_end = s_idx; |
| 141 | + } else if (s_idx > range_end) { |
| 142 | + range_end += BLOCK_SIZE_M; |
| 143 | + } |
| 144 | + } |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + block_count[0] = tmp_blk_cnt; |
| 149 | + column_count[0] = tmp_col_cnt; |
| 150 | +} |
| 151 | + |
| 152 | +void convert_vertical_slash_indexes_64x64( |
| 153 | + const int* q_seqlens, // [BATCH, ] |
| 154 | + const int* kv_seqlens, // [BATCH, ] |
| 155 | + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
| 156 | + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
| 157 | + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 158 | + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] |
| 159 | + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 160 | + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] |
| 161 | + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, |
| 162 | + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { |
| 163 | + const int N_THREADS = 64; |
| 164 | + const dim3 dimBlock(N_THREADS); |
| 165 | + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); |
| 166 | + convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>( |
| 167 | + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count, |
| 168 | + block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, |
| 169 | + BLOCK_SIZE_N, NNZ_V, NNZ_S, causal); |
| 170 | +} |
| 171 | + |
| 172 | +void convert_vertical_slash_indexes( |
| 173 | + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] |
| 174 | + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] |
| 175 | + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] |
| 176 | + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] |
| 177 | + torch::Tensor q_seqlens, // [BATCH, ] |
| 178 | + torch::Tensor kv_seqlens, // [BATCH, ] |
| 179 | + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
| 180 | + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
| 181 | + int64_t context_size, int64_t block_size_M, int64_t block_size_N, |
| 182 | + bool causal) { |
| 183 | + cudaSetDevice(q_seqlens.get_device()); |
| 184 | + |
| 185 | + int batch_size = slash_indexes.size(0); |
| 186 | + int num_heads = slash_indexes.size(1); |
| 187 | + int nnz_slash = slash_indexes.size(2); |
| 188 | + int nnz_vertical = vertical_indexes.size(2); |
| 189 | + int num_rows = (context_size + block_size_M - 1) / block_size_M; |
| 190 | + |
| 191 | + convert_vertical_slash_indexes_64x64( |
| 192 | + q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(), |
| 193 | + vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(), |
| 194 | + block_count.data_ptr<int>(), block_offset.data_ptr<int>(), |
| 195 | + column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size, |
| 196 | + num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash, |
| 197 | + causal); |
| 198 | +} |
| 199 | + |
| 200 | +__global__ void convert_vertical_slash_indexes_kernel_mergehead( |
| 201 | + const int* q_seqlens, // [BATCH, ] |
| 202 | + const int* kv_seqlens, // [BATCH, ] |
| 203 | + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
| 204 | + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
| 205 | + const int* per_head_vertical_topkv, const int* per_head_slash_topkv, |
| 206 | + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 207 | + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] |
| 208 | + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 209 | + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] |
| 210 | + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, |
| 211 | + int64_t NNZ_V, int64_t NNZ_S, |
| 212 | + bool causal // True for intra, False for succ |
| 213 | +) { |
| 214 | + const int batch_idx = blockIdx.y; |
| 215 | + const int head_idx = blockIdx.x; |
| 216 | + const int group_idx = blockIdx.z; |
| 217 | + |
| 218 | + int64_t q_seqlen = q_seqlens[batch_idx]; |
| 219 | + int64_t kv_seqlen = kv_seqlens[batch_idx]; |
| 220 | + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; |
| 221 | + int64_t start_m = block_idx_m * BLOCK_SIZE_M; |
| 222 | + if (start_m >= q_seqlen) { |
| 223 | + return; |
| 224 | + } |
| 225 | + int64_t end_m = start_m + BLOCK_SIZE_M; |
| 226 | + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; |
| 227 | + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; |
| 228 | + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; |
| 229 | + block_count += row_offset; |
| 230 | + block_offset += row_offset * NNZ_S; |
| 231 | + column_count += row_offset; |
| 232 | + column_index += row_offset * NNZ_V; |
| 233 | + |
| 234 | + // MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S |
| 235 | + // above is buffer size, use to compute offset) |
| 236 | + NNZ_S = per_head_slash_topkv[head_idx]; |
| 237 | + NNZ_V = per_head_vertical_topkv[head_idx]; |
| 238 | + |
| 239 | + bool has_slash = true; |
| 240 | + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; |
| 241 | + int64_t s = 0, v = 0; |
| 242 | + int64_t v_idx = vertical_indexes[v++]; |
| 243 | + int64_t s_idx = slash_indexes[s++]; |
| 244 | + if (causal) { |
| 245 | + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { |
| 246 | + s_idx = slash_indexes[s++]; |
| 247 | + } |
| 248 | + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; |
| 249 | + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); |
| 250 | + } else { |
| 251 | + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { |
| 252 | + s_idx = slash_indexes[s++]; |
| 253 | + } |
| 254 | + if (s_idx > end_m + kv_seqlen) has_slash = false; |
| 255 | + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); |
| 256 | + } |
| 257 | + |
| 258 | + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; |
| 259 | + if (!has_slash) { |
| 260 | + if (causal) { |
| 261 | + range_start = (kv_seqlen - q_seqlen) + end_m; |
| 262 | + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; |
| 263 | + } else { |
| 264 | + range_start = kv_seqlen; |
| 265 | + range_end = kv_seqlen + BLOCK_SIZE_N; |
| 266 | + } |
| 267 | + } |
| 268 | + |
| 269 | + bool slash_finished = false; |
| 270 | + while (1) { |
| 271 | + if (v_idx < range_end) { |
| 272 | + if (v_idx < range_start) { |
| 273 | + column_index[tmp_col_cnt++] = v_idx; |
| 274 | + } |
| 275 | + if (v < NNZ_V) { |
| 276 | + v_idx = vertical_indexes[v++]; |
| 277 | + } else { |
| 278 | + if (causal) |
| 279 | + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); |
| 280 | + else |
| 281 | + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; |
| 282 | + } |
| 283 | + } else { |
| 284 | + if ((s < NNZ_S && causal) || |
| 285 | + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { |
| 286 | + if (causal) |
| 287 | + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], |
| 288 | + BLOCK_SIZE_M); |
| 289 | + else |
| 290 | + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); |
| 291 | + } else { |
| 292 | + if (v == NNZ_V || (v_idx > range_start && causal)) { |
| 293 | + // add the last vertical if no more slash |
| 294 | + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { |
| 295 | + column_index[tmp_col_cnt++] = v_idx; |
| 296 | + } |
| 297 | + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, |
| 298 | + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); |
| 299 | + break; |
| 300 | + } else { |
| 301 | + if (causal) { |
| 302 | + range_start = (kv_seqlen - q_seqlen) + end_m; |
| 303 | + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; |
| 304 | + } else { |
| 305 | + // if slash_finished but there are vertical left, save current |
| 306 | + // blocks |
| 307 | + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, |
| 308 | + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); |
| 309 | + range_start = kv_seqlen; |
| 310 | + range_end = kv_seqlen + BLOCK_SIZE_N; |
| 311 | + } |
| 312 | + slash_finished = true; |
| 313 | + } |
| 314 | + } |
| 315 | + if (!slash_finished) { |
| 316 | + if (s_idx > range_end + BLOCK_SIZE_M) { |
| 317 | + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, |
| 318 | + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); |
| 319 | + range_start = s_idx - BLOCK_SIZE_M; |
| 320 | + range_end = s_idx; |
| 321 | + } else if (s_idx > range_end) { |
| 322 | + range_end += BLOCK_SIZE_M; |
| 323 | + } |
| 324 | + } |
| 325 | + } |
| 326 | + } |
| 327 | + |
| 328 | + block_count[0] = tmp_blk_cnt; |
| 329 | + column_count[0] = tmp_col_cnt; |
| 330 | +} |
| 331 | + |
| 332 | +void convert_vertical_slash_indexes_64x64_mergehead( |
| 333 | + const int* q_seqlens, // [BATCH, ] |
| 334 | + const int* kv_seqlens, // [BATCH, ] |
| 335 | + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
| 336 | + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
| 337 | + int* per_head_vertical_topkv, int* per_head_slash_topkv, |
| 338 | + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 339 | + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] |
| 340 | + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] |
| 341 | + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] |
| 342 | + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, |
| 343 | + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { |
| 344 | + const int N_THREADS = 64; |
| 345 | + const dim3 dimBlock(N_THREADS); |
| 346 | + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); |
| 347 | + convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>( |
| 348 | + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, |
| 349 | + per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset, |
| 350 | + column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, |
| 351 | + NNZ_V, NNZ_S, causal); |
| 352 | +} |
| 353 | + |
| 354 | +void convert_vertical_slash_indexes_mergehead( |
| 355 | + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] |
| 356 | + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] |
| 357 | + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] |
| 358 | + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] |
| 359 | + torch::Tensor q_seqlens, // [BATCH, ] |
| 360 | + torch::Tensor kv_seqlens, // [BATCH, ] |
| 361 | + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] |
| 362 | + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] |
| 363 | + torch::Tensor vertical_indices_count, // [N_HEADS, ] |
| 364 | + torch::Tensor slash_indices_count, int64_t context_size, |
| 365 | + int64_t block_size_M, int64_t block_size_N, bool causal) { |
| 366 | + cudaSetDevice(q_seqlens.get_device()); |
| 367 | + |
| 368 | + int batch_size = slash_indexes.size(0); |
| 369 | + int num_heads = slash_indexes.size(1); |
| 370 | + int nnz_slash = slash_indexes.size(2); |
| 371 | + int nnz_vertical = vertical_indexes.size(2); |
| 372 | + int num_rows = (context_size + block_size_M - 1) / block_size_M; |
| 373 | + |
| 374 | + convert_vertical_slash_indexes_64x64_mergehead( |
| 375 | + q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(), |
| 376 | + vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(), |
| 377 | + vertical_indices_count.data_ptr<int>(), |
| 378 | + slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(), |
| 379 | + block_offset.data_ptr<int>(), column_count.data_ptr<int>(), |
| 380 | + column_index.data_ptr<int>(), batch_size, num_heads, num_rows, |
| 381 | + block_size_M, block_size_N, nnz_vertical, nnz_slash, causal); |
| 382 | +} |
0 commit comments