|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "cutlass/cutlass.h" |
| 4 | +#include "cutlass/numeric_types.h" |
| 5 | + |
| 6 | +#include "cute/tensor.hpp" |
| 7 | +#include "cutlass/tensor_ref.h" |
| 8 | +#include "cutlass/gemm/dispatch_policy.hpp" |
| 9 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 10 | +#include "cutlass/gemm/device/gemm_universal_adapter.h" |
| 11 | +#include "cutlass/gemm/kernel/gemm_universal.hpp" |
| 12 | +#include "cutlass/gemm/kernel/tile_scheduler_params.h" |
| 13 | +#include "cutlass/epilogue/dispatch_policy.hpp" |
| 14 | +#include "cutlass/epilogue/collective/collective_builder.hpp" |
| 15 | + |
| 16 | +#include "cutlass_extensions/gemm/dispatch_policy.hpp" |
| 17 | +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" |
| 18 | + |
| 19 | +#include "cutlass_gemm_caller.cuh" |
| 20 | + |
| 21 | +namespace vllm { |
| 22 | + |
| 23 | +using namespace cute; |
| 24 | + |
| 25 | +template <typename OutType, typename MmaTileShape, typename ScalesPerTile, |
| 26 | + class ClusterShape, typename EpilogueScheduler, |
| 27 | + typename MainloopScheduler> |
| 28 | +struct cutlass_3x_gemm_fp8_blockwise { |
| 29 | + using ElementAB = cutlass::float_e4m3_t; |
| 30 | + |
| 31 | + using ElementA = ElementAB; |
| 32 | + using LayoutA = cutlass::layout::RowMajor; |
| 33 | + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; |
| 34 | + |
| 35 | + using ElementB = ElementAB; |
| 36 | + using LayoutB = cutlass::layout::ColumnMajor; |
| 37 | + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; |
| 38 | + |
| 39 | + using ElementC = void; |
| 40 | + using ElementD = OutType; |
| 41 | + using LayoutD = cutlass::layout::RowMajor; |
| 42 | + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; |
| 43 | + |
| 44 | + using LayoutC = LayoutD; |
| 45 | + static constexpr int AlignmentC = AlignmentD; |
| 46 | + |
| 47 | + using ElementAccumulator = float; |
| 48 | + using ElementCompute = float; |
| 49 | + using ElementBlockScale = float; |
| 50 | + |
| 51 | + // MMA and Cluster Tile Shapes |
| 52 | + // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster |
| 53 | + // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>; |
| 54 | + static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); |
| 55 | + static constexpr int ScaleGranularityM = |
| 56 | + size<0>(MmaTileShape{}) / ScaleMsPerTile; |
| 57 | + static constexpr int ScaleGranularityN = |
| 58 | + size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); |
| 59 | + static constexpr int ScaleGranularityK = |
| 60 | + size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); |
| 61 | + |
| 62 | + // Shape of the threadblocks in a cluster |
| 63 | + using ClusterShape_MNK = ClusterShape; |
| 64 | + |
| 65 | + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< |
| 66 | + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, |
| 67 | + cute::UMMA::Major::MN, cute::UMMA::Major::K>; |
| 68 | + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); |
| 69 | + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); |
| 70 | + |
| 71 | + using ArchTag = cutlass::arch::Sm100; |
| 72 | + using OperatorClass = cutlass::arch::OpClassTensorOp; |
| 73 | + |
| 74 | + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; |
| 75 | + using ElementScalar = float; |
| 76 | + // clang-format off |
| 77 | + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>; |
| 78 | + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< |
| 79 | + ArchTag, |
| 80 | + OperatorClass, |
| 81 | + MmaTileShape, |
| 82 | + ClusterShape, |
| 83 | + cutlass::epilogue::collective::EpilogueTileAuto, |
| 84 | + ElementAccumulator, |
| 85 | + ElementCompute, |
| 86 | + ElementC, |
| 87 | + LayoutC, |
| 88 | + AlignmentC, |
| 89 | + ElementD, |
| 90 | + LayoutD, |
| 91 | + AlignmentD, |
| 92 | + EpilogueScheduler, |
| 93 | + DefaultOperation |
| 94 | + >::CollectiveOp; |
| 95 | + |
| 96 | + using StageCountType = cutlass::gemm::collective::StageCountAuto; |
| 97 | + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< |
| 98 | + ArchTag, |
| 99 | + OperatorClass, |
| 100 | + ElementA, |
| 101 | + cute::tuple<LayoutA, LayoutSFA>, |
| 102 | + AlignmentA, |
| 103 | + ElementB, |
| 104 | + cute::tuple<LayoutB, LayoutSFB>, |
| 105 | + AlignmentB, |
| 106 | + ElementAccumulator, |
| 107 | + MmaTileShape, |
| 108 | + ClusterShape, |
| 109 | + |
| 110 | + cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| 111 | + MainloopScheduler |
| 112 | + >::CollectiveOp; |
| 113 | + // clang-format on |
| 114 | + |
| 115 | + using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal< |
| 116 | + Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>; |
| 117 | + |
| 118 | + struct GemmKernel : public KernelType {}; |
| 119 | +}; |
| 120 | + |
| 121 | +template <typename Gemm> |
| 122 | +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, |
| 123 | + torch::Tensor const& b, |
| 124 | + torch::Tensor const& a_scales, |
| 125 | + torch::Tensor const& b_scales) { |
| 126 | + using GemmKernel = typename Gemm::GemmKernel; |
| 127 | + using StrideA = typename Gemm::GemmKernel::StrideA; |
| 128 | + using StrideB = typename Gemm::GemmKernel::StrideB; |
| 129 | + using StrideD = typename Gemm::GemmKernel::StrideD; |
| 130 | + using StrideC = typename Gemm::GemmKernel::StrideC; |
| 131 | + using LayoutSFA = typename Gemm::LayoutSFA; |
| 132 | + using LayoutSFB = typename Gemm::LayoutSFB; |
| 133 | + using ScaleConfig = typename Gemm::ScaleConfig; |
| 134 | + |
| 135 | + using ElementAB = typename Gemm::ElementAB; |
| 136 | + using ElementD = typename Gemm::ElementD; |
| 137 | + |
| 138 | + int32_t m = a.size(0), n = b.size(1), k = a.size(1); |
| 139 | + auto prob_shape = cute::make_shape(m, n, k, 1); |
| 140 | + |
| 141 | + StrideA a_stride; |
| 142 | + StrideB b_stride; |
| 143 | + StrideC c_stride; |
| 144 | + a_stride = |
| 145 | + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); |
| 146 | + b_stride = |
| 147 | + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); |
| 148 | + c_stride = |
| 149 | + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); |
| 150 | + |
| 151 | + LayoutSFA layout_SFA = |
| 152 | + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); |
| 153 | + LayoutSFB layout_SFB = |
| 154 | + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); |
| 155 | + |
| 156 | + auto a_ptr = static_cast<ElementAB*>(a.data_ptr()); |
| 157 | + auto b_ptr = static_cast<ElementAB*>(b.data_ptr()); |
| 158 | + auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr()); |
| 159 | + auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr()); |
| 160 | + |
| 161 | + typename GemmKernel::MainloopArguments mainloop_args{ |
| 162 | + a_ptr, a_stride, b_ptr, b_stride, |
| 163 | + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB}; |
| 164 | + |
| 165 | + auto c_ptr = static_cast<ElementD*>(out.data_ptr()); |
| 166 | + typename GemmKernel::EpilogueArguments epilogue_args{ |
| 167 | + {}, c_ptr, c_stride, c_ptr, c_stride}; |
| 168 | + c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args, |
| 169 | + epilogue_args); |
| 170 | +} |
| 171 | + |
| 172 | +template <typename OutType> |
| 173 | +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, |
| 174 | + torch::Tensor const& a, |
| 175 | + torch::Tensor const& b, |
| 176 | + torch::Tensor const& a_scales, |
| 177 | + torch::Tensor const& b_scales) { |
| 178 | + auto m = a.size(0); |
| 179 | + auto k = a.size(1); |
| 180 | + auto n = b.size(1); |
| 181 | + int sms; |
| 182 | + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); |
| 183 | + |
| 184 | + auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) { |
| 185 | + return std::ceil(static_cast<float>(m) / tile1SM) * |
| 186 | + std::ceil(static_cast<float>(n) / tile1SM) >= |
| 187 | + sms; |
| 188 | + }; |
| 189 | + bool use_2sm = should_use_2sm(m, n); |
| 190 | + if (use_2sm) { |
| 191 | + cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
| 192 | + OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>, |
| 193 | + Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, |
| 194 | + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( |
| 195 | + out, a, b, a_scales, b_scales); |
| 196 | + } else { |
| 197 | + cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< |
| 198 | + OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>, |
| 199 | + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, |
| 200 | + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( |
| 201 | + out, a, b, a_scales, b_scales); |
| 202 | + } |
| 203 | +} |
| 204 | + |
| 205 | +} // namespace vllm |
0 commit comments