1
+ #include < cudaTypedefs.h>
2
+
1
3
#include < c10/cuda/CUDAGuard.h>
2
- #include < cuda_runtime.h>
3
4
#include < torch/extension.h>
4
5
5
6
void cutlass_scaled_mm_dq_sm75 (torch::Tensor& c, torch::Tensor const & a,
@@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
17
18
torch::Tensor const & a_scales,
18
19
torch::Tensor const & b_scales);
19
20
21
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
20
22
void cutlass_scaled_mm_dq_sm90 (torch::Tensor& c, torch::Tensor const & a,
21
23
torch::Tensor const & b,
22
24
torch::Tensor const & a_scales,
23
25
torch::Tensor const & b_scales);
26
+ #endif
24
27
25
28
void cutlass_scaled_mm_dq (torch::Tensor& c, torch::Tensor const & a,
26
29
torch::Tensor const & b, torch::Tensor const & a_scales,
@@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
51
54
52
55
if (version_num >= 90 ) {
53
56
// Hopper
57
+
58
+ // Guard against compilation issues for sm90 kernels
59
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
54
60
cutlass_scaled_mm_dq_sm90 (c, a, b, a_scales, b_scales);
61
+ #else
62
+ cutlass_scaled_mm_dq_sm80 (c, a, b, a_scales, b_scales);
63
+ #endif
55
64
} else if (version_num == 89 ) {
56
65
// Ada Lovelace
57
66
cutlass_scaled_mm_dq_sm89 (c, a, b, a_scales, b_scales);
0 commit comments