You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
407
+
}
404
408
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
405
409
} else {
406
410
out = torch::empty_like(q_padded);
@@ -494,12 +498,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
494
498
495
499
std::vector<at::Tensor>
496
500
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
497
-
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
498
-
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
501
+
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
502
+
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
499
503
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
500
504
const at::Tensor &cu_seqlens_q, // b+1
501
505
const at::Tensor &cu_seqlens_k, // b+1
502
506
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
507
+
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
503
508
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
504
509
int max_seqlen_q,
505
510
constint max_seqlen_k,
@@ -535,6 +540,15 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
535
540
CHECK_DEVICE(cu_seqlens_q);
536
541
CHECK_DEVICE(cu_seqlens_k);
537
542
543
+
at::Tensor block_table;
544
+
constbool paged_KV = block_table_.has_value();
545
+
if (paged_KV) {
546
+
block_table = block_table_.value();
547
+
CHECK_DEVICE(block_table);
548
+
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
549
+
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
550
+
}
551
+
538
552
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
539
553
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
540
554
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
@@ -546,8 +560,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
0 commit comments