Skip to content

Commit 5725ea8

Browse files
wenscarlMu Huai
authored and
Mu Huai
committed
Add cutlass support for blackwell fp8 blockwise gemm (vllm-project#14383)
Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 3ab9328 commit 5725ea8

File tree

11 files changed

+332
-64
lines changed

11 files changed

+332
-64
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
418418
set(SRCS
419419
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
420420
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
421+
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
421422
)
422423
set_gencode_flags_for_srcs(
423424
SRCS "${SRCS}"

csrc/cutlass_extensions/common.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel {
5959
#endif
6060
}
6161
};
62+
63+
template <typename Kernel>
64+
struct enable_sm100_only : Kernel {
65+
template <typename... Args>
66+
CUTLASS_DEVICE void operator()(Args&&... args) {
67+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
68+
Kernel::operator()(std::forward<Args>(args)...);
69+
#endif
70+
}
71+
};
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "scaled_mm_kernels.hpp"
2+
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
3+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
4+
5+
namespace vllm {
6+
7+
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
8+
torch::Tensor const& a,
9+
torch::Tensor const& b,
10+
torch::Tensor const& a_scales,
11+
torch::Tensor const& b_scales) {
12+
TORCH_CHECK(
13+
a.size(0) % 4 == 0,
14+
"Input tensor must have a number of rows that is a multiple of 4. ",
15+
"but got: ", a.size(0), " rows.");
16+
if (out.dtype() == torch::kBFloat16) {
17+
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
18+
out, a, b, a_scales, b_scales);
19+
20+
} else {
21+
TORCH_CHECK(out.dtype() == torch::kFloat16);
22+
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
23+
out, a, b, a_scales, b_scales);
24+
}
25+
}
26+
27+
} // namespace vllm
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
#pragma once
2+
3+
#include "cutlass/cutlass.h"
4+
#include "cutlass/numeric_types.h"
5+
6+
#include "cute/tensor.hpp"
7+
#include "cutlass/tensor_ref.h"
8+
#include "cutlass/gemm/dispatch_policy.hpp"
9+
#include "cutlass/gemm/collective/collective_builder.hpp"
10+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
11+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
12+
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
13+
#include "cutlass/epilogue/dispatch_policy.hpp"
14+
#include "cutlass/epilogue/collective/collective_builder.hpp"
15+
16+
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
17+
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
18+
19+
#include "cutlass_gemm_caller.cuh"
20+
21+
namespace vllm {
22+
23+
using namespace cute;
24+
25+
template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
26+
class ClusterShape, typename EpilogueScheduler,
27+
typename MainloopScheduler>
28+
struct cutlass_3x_gemm_fp8_blockwise {
29+
using ElementAB = cutlass::float_e4m3_t;
30+
31+
using ElementA = ElementAB;
32+
using LayoutA = cutlass::layout::RowMajor;
33+
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
34+
35+
using ElementB = ElementAB;
36+
using LayoutB = cutlass::layout::ColumnMajor;
37+
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
38+
39+
using ElementC = void;
40+
using ElementD = OutType;
41+
using LayoutD = cutlass::layout::RowMajor;
42+
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
43+
44+
using LayoutC = LayoutD;
45+
static constexpr int AlignmentC = AlignmentD;
46+
47+
using ElementAccumulator = float;
48+
using ElementCompute = float;
49+
using ElementBlockScale = float;
50+
51+
// MMA and Cluster Tile Shapes
52+
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
53+
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
54+
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
55+
static constexpr int ScaleGranularityM =
56+
size<0>(MmaTileShape{}) / ScaleMsPerTile;
57+
static constexpr int ScaleGranularityN =
58+
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
59+
static constexpr int ScaleGranularityK =
60+
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
61+
62+
// Shape of the threadblocks in a cluster
63+
using ClusterShape_MNK = ClusterShape;
64+
65+
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
66+
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
67+
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
68+
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
69+
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
70+
71+
using ArchTag = cutlass::arch::Sm100;
72+
using OperatorClass = cutlass::arch::OpClassTensorOp;
73+
74+
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
75+
using ElementScalar = float;
76+
// clang-format off
77+
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
78+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
79+
ArchTag,
80+
OperatorClass,
81+
MmaTileShape,
82+
ClusterShape,
83+
cutlass::epilogue::collective::EpilogueTileAuto,
84+
ElementAccumulator,
85+
ElementCompute,
86+
ElementC,
87+
LayoutC,
88+
AlignmentC,
89+
ElementD,
90+
LayoutD,
91+
AlignmentD,
92+
EpilogueScheduler,
93+
DefaultOperation
94+
>::CollectiveOp;
95+
96+
using StageCountType = cutlass::gemm::collective::StageCountAuto;
97+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
98+
ArchTag,
99+
OperatorClass,
100+
ElementA,
101+
cute::tuple<LayoutA, LayoutSFA>,
102+
AlignmentA,
103+
ElementB,
104+
cute::tuple<LayoutB, LayoutSFB>,
105+
AlignmentB,
106+
ElementAccumulator,
107+
MmaTileShape,
108+
ClusterShape,
109+
110+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
111+
MainloopScheduler
112+
>::CollectiveOp;
113+
// clang-format on
114+
115+
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
116+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
117+
118+
struct GemmKernel : public KernelType {};
119+
};
120+
121+
template <typename Gemm>
122+
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
123+
torch::Tensor const& b,
124+
torch::Tensor const& a_scales,
125+
torch::Tensor const& b_scales) {
126+
using GemmKernel = typename Gemm::GemmKernel;
127+
using StrideA = typename Gemm::GemmKernel::StrideA;
128+
using StrideB = typename Gemm::GemmKernel::StrideB;
129+
using StrideD = typename Gemm::GemmKernel::StrideD;
130+
using StrideC = typename Gemm::GemmKernel::StrideC;
131+
using LayoutSFA = typename Gemm::LayoutSFA;
132+
using LayoutSFB = typename Gemm::LayoutSFB;
133+
using ScaleConfig = typename Gemm::ScaleConfig;
134+
135+
using ElementAB = typename Gemm::ElementAB;
136+
using ElementD = typename Gemm::ElementD;
137+
138+
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
139+
auto prob_shape = cute::make_shape(m, n, k, 1);
140+
141+
StrideA a_stride;
142+
StrideB b_stride;
143+
StrideC c_stride;
144+
a_stride =
145+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
146+
b_stride =
147+
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
148+
c_stride =
149+
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
150+
151+
LayoutSFA layout_SFA =
152+
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
153+
LayoutSFB layout_SFB =
154+
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
155+
156+
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
157+
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
158+
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
159+
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
160+
161+
typename GemmKernel::MainloopArguments mainloop_args{
162+
a_ptr, a_stride, b_ptr, b_stride,
163+
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
164+
165+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
166+
typename GemmKernel::EpilogueArguments epilogue_args{
167+
{}, c_ptr, c_stride, c_ptr, c_stride};
168+
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
169+
epilogue_args);
170+
}
171+
172+
template <typename OutType>
173+
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
174+
torch::Tensor const& a,
175+
torch::Tensor const& b,
176+
torch::Tensor const& a_scales,
177+
torch::Tensor const& b_scales) {
178+
auto m = a.size(0);
179+
auto k = a.size(1);
180+
auto n = b.size(1);
181+
int sms;
182+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
183+
184+
auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
185+
return std::ceil(static_cast<float>(m) / tile1SM) *
186+
std::ceil(static_cast<float>(n) / tile1SM) >=
187+
sms;
188+
};
189+
bool use_2sm = should_use_2sm(m, n);
190+
if (use_2sm) {
191+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
192+
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
193+
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
194+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
195+
out, a, b, a_scales, b_scales);
196+
} else {
197+
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
198+
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
199+
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
200+
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
201+
out, a, b, a_scales, b_scales);
202+
}
203+
}
204+
205+
} // namespace vllm
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include <torch/all.h>
2+
#include "cuda_utils.h"
3+
4+
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
5+
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
6+
torch::Tensor const& b, torch::Tensor const& a_scales,
7+
torch::Tensor const& b_scales,
8+
std::optional<torch::Tensor> const& bias,
9+
Fp8Func fp8_func, Int8Func int8_func,
10+
BlockwiseFunc blockwise_func) {
11+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
12+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
13+
14+
int M = a.size(0), N = b.size(1), K = a.size(1);
15+
16+
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
17+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
18+
// Standard per-tensor/per-token/per-channel scaling
19+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
20+
if (a.dtype() == torch::kFloat8_e4m3fn) {
21+
fp8_func(c, a, b, a_scales, b_scales, bias);
22+
} else {
23+
TORCH_CHECK(a.dtype() == torch::kInt8);
24+
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
25+
int8_func(c, a, b, a_scales, b_scales, bias);
26+
} else {
27+
TORCH_CHECK(false, "Int8 not supported for this architecture");
28+
}
29+
}
30+
} else {
31+
using GroupShape = std::array<int64_t, 2>;
32+
auto make_group_shape = [](torch::Tensor const& x,
33+
torch::Tensor const& s) -> GroupShape {
34+
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
35+
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
36+
cuda_utils::ceil_div(x.size(1), s.size(1))};
37+
};
38+
39+
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
40+
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
41+
42+
// 1x128 per-token group scales for activations
43+
// 128x128 blockwise scales for weights
44+
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
45+
b_scale_group_shape == GroupShape{128, 128} &&
46+
a.dtype() == torch::kFloat8_e4m3fn &&
47+
b.dtype() == torch::kFloat8_e4m3fn),
48+
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
49+
"a_scale_group_shape must be [1, 128]. Got: [",
50+
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
51+
"]\n"
52+
"b_scale_group_shape must be [128, 128]. Got: [",
53+
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
54+
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
55+
blockwise_func(c, a, b, a_scales, b_scales);
56+
}
57+
}

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
3636
torch::Tensor const& b_scales,
3737
std::optional<torch::Tensor> const& bias);
3838

39+
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
40+
torch::Tensor const& a,
41+
torch::Tensor const& b,
42+
torch::Tensor const& a_scales,
43+
torch::Tensor const& b_scales);
3944
} // namespace vllm
Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
#include <cudaTypedefs.h>
1+
#include "c3x/scaled_mm_helper.hpp"
22
#include "c3x/scaled_mm_kernels.hpp"
33

4-
#include "cuda_utils.h"
5-
64
/*
75
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
86
NVIDIA GPUs with sm100 (Blackwell).
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
1513
torch::Tensor const& a_scales,
1614
torch::Tensor const& b_scales,
1715
std::optional<torch::Tensor> const& bias) {
18-
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
19-
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
20-
21-
int M = a.size(0), N = b.size(1), K = a.size(1);
22-
TORCH_CHECK(
23-
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24-
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
25-
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
26-
27-
// Standard per-tensor/per-token/per-channel scaling
28-
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
29-
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
30-
"Currently, only fp8 gemm is implemented for Blackwell");
31-
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
16+
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
17+
vllm::cutlass_scaled_mm_sm100_fp8,
18+
nullptr, // int8 not supported on SM100
19+
vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
3220
}
3321

3422
#endif

0 commit comments

Comments
 (0)