Skip to content

Commit 86e9c8d

Browse files
LucasWilkinsonmgoindivakar-amdtlrmchlsmth
authored
[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)
Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent ee5f34b commit 86e9c8d

27 files changed

+1005
-246
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
223223
"csrc/quantization/gguf/gguf_kernel.cu"
224224
"csrc/quantization/fp8/fp8_marlin.cu"
225225
"csrc/custom_all_reduce.cu"
226+
"csrc/permute_cols.cu"
226227
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
227228
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
228229
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")

benchmarks/kernels/benchmark_machete.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import math
55
import pickle as pkl
66
import time
7-
from typing import Callable, Iterable, List, Tuple
7+
from itertools import product
8+
from typing import Callable, Iterable, List, Optional, Tuple
89

10+
import pandas as pd
911
import torch
1012
import torch.utils.benchmark as TBenchmark
1113
from torch.utils.benchmark import Measurement as TMeasurement
@@ -84,6 +86,10 @@ def loop_over_weights(
8486
fn(a, w_ref, w_q, w_s)
8587

8688

89+
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
90+
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
91+
92+
8793
def bench(atype: torch.dtype,
8894
wtype: ScalarType,
8995
group_size: int,
@@ -94,6 +100,8 @@ def bench(atype: torch.dtype,
94100
sub_label: str,
95101
benchmark_marlinv1: bool = True,
96102
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
103+
global _SWEEP_SCHEDULES_RESULTS
104+
97105
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
98106
sub_label += f", L={len(weights)}"
99107

@@ -163,6 +171,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
163171
best_schedule = None
164172
schedules = ops.machete_supported_schedules(wtype)
165173
for schedule in reversed(schedules):
174+
schedule_M = int(schedule.split("_")[0].split("x")[1])
175+
176+
# Prune known bad schedules
177+
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
178+
continue
166179

167180
def run(a, _, w_q, w_s, schedule=schedule):
168181
ops.machete_gemm(a,
@@ -175,6 +188,20 @@ def run(a, _, w_q, w_s, schedule=schedule):
175188
res = bench_fn(label, sub_label, "machete_best",
176189
lambda: loop_over_weights(a, weights_machete, run))
177190

191+
results_row = {
192+
"M": m,
193+
"K": k,
194+
"N": n,
195+
"group_size": group_size,
196+
"schedule": schedule,
197+
"median": res.median,
198+
}
199+
if _SWEEP_SCHEDULES_RESULTS is None:
200+
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
201+
columns=results_row.keys())
202+
_SWEEP_SCHEDULES_RESULTS.\
203+
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
204+
178205
print(f" {res.median:5.5} ", schedule)
179206
if not best or res.median < best.median:
180207
best = res
@@ -235,18 +262,22 @@ def run_square_bench(args):
235262
dim_sizes = list(
236263
range(args.dim_start, args.dim_end + 1, args.dim_increment))
237264
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
265+
238266
data = run(args.dtype, args.sweep_schedules, MKNs)
239267

240268
make_output(data, MKNs, f"square_bench-{args.dtype}")
241269

242270

243271
def run_range_bench(args):
244-
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
245-
n = len(dim_sizes)
246-
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
247-
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
248-
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
249-
MKNs = list(zip(Ms, Ks, Ns))
272+
m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")]
273+
m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")]
274+
m_increment, k_increment, n_increment = \
275+
[int(x) for x in args.dim_increment.split(",")]
276+
Ms = list(range(m_start, m_end + 1, m_increment))
277+
Ks = list(range(k_start, k_end + 1, k_increment))
278+
Ns = list(range(n_start, n_end + 1, n_increment))
279+
MKNs = list(product(Ms, Ks, Ns))
280+
250281
data = run(args.dtype, args.sweep_schedules, MKNs)
251282

252283
make_output(data, MKNs, f"range_bench-{args.dtype}")
@@ -333,6 +364,9 @@ def to_torch_dtype(dt):
333364
action="store_true",
334365
help="Run a sweep over all supported schedules",
335366
)
367+
parser.add_argument("--sweep-csv-out",
368+
help="CSV to store sweep results",
369+
default="sch_sweep_results.csv")
336370
subparsers = parser.add_subparsers(dest="cmd", required=True)
337371

338372
square_parser = subparsers.add_parser("square_bench")
@@ -342,12 +376,21 @@ def to_torch_dtype(dt):
342376
square_parser.set_defaults(func=run_square_bench)
343377

344378
range_parser = subparsers.add_parser("range_bench")
345-
range_parser.add_argument("--dim-start", type=int, required=True)
346-
range_parser.add_argument("--dim-end", type=int, required=True)
347-
range_parser.add_argument("--dim-increment", type=int, required=True)
348-
range_parser.add_argument("--m-constant", type=int, default=None)
349-
range_parser.add_argument("--n-constant", type=int, default=None)
350-
range_parser.add_argument("--k-constant", type=int, default=None)
379+
range_parser.add_argument(
380+
"--dim-start",
381+
type=str,
382+
required=True,
383+
help="Start value for M,K,N as common separated list")
384+
range_parser.add_argument(
385+
"--dim-end",
386+
type=str,
387+
required=True,
388+
help="End value (inclusive) for M,K,N as common separated list")
389+
range_parser.add_argument(
390+
"--dim-increment",
391+
type=str,
392+
required=True,
393+
help="Increment value for M,K,N as common separated list")
351394
range_parser.set_defaults(func=run_range_bench)
352395

353396
model_parser = subparsers.add_parser("model_bench")
@@ -369,4 +412,9 @@ def to_torch_dtype(dt):
369412
model_parser.set_defaults(func=run_model_bench)
370413

371414
args = parser.parse_args()
415+
416+
_SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
372417
args.func(args)
418+
419+
if _SWEEP_SCHEDULES_RESULTS is not None:
420+
_SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)

benchmarks/kernels/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pandas

csrc/cutlass_extensions/torch_utils.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
6868
name, ".stride(", idx, ") to be ", StrideEle::value);
6969
return StrideEle{};
7070
} else {
71-
return tensor.stride(idx);
71+
if (tensor.size(idx) == 1) {
72+
// use 0 stride for dim with size 1, this is easier for
73+
// cute/cutlass to optimize (helps the TMA code flatten dims)
74+
return StrideEle{0};
75+
} else {
76+
return tensor.stride(idx);
77+
}
7278
}
7379
} else {
7480
// Extra strides are assumed to be 0 or 1

csrc/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,
113113

114114
}; // namespace machete
115115

116+
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
117+
116118
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
117119
torch::Tensor& b_meta,
118120
torch::Tensor& b_scales,

csrc/permute_cols.cu

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include <torch/all.h>
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <c10/cuda/CUDAGuard.h>
5+
6+
#include <cuda_fp16.h>
7+
8+
static constexpr int default_threads = 256;
9+
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
10+
11+
// For a given "a" of size [M,K] performs a permutation of the K columns based
12+
// on the given "perm" indices.
13+
// Currently only supports 16bit types (since we permute half types)
14+
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
15+
int const* __restrict__ perm_int_ptr,
16+
int4* __restrict__ out_int4_ptr, int size_m,
17+
int size_k, int block_rows) {
18+
int start_row = block_rows * blockIdx.x;
19+
int finish_row = start_row + block_rows;
20+
if (finish_row > size_m) {
21+
finish_row = size_m;
22+
}
23+
int cur_block_rows = std::max(finish_row - start_row, 0);
24+
25+
int row_stride = size_k * sizeof(half) / 16;
26+
27+
auto permute_row = [&](int row) {
28+
int iters = size_k / default_threads;
29+
int rest = size_k % default_threads;
30+
31+
int offset = row * row_stride;
32+
33+
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
34+
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
35+
36+
int base_k = 0;
37+
38+
for (int i = 0; i < iters; i++) {
39+
int cur_k = base_k + threadIdx.x;
40+
int src_pos = perm_int_ptr[cur_k];
41+
42+
out_half[cur_k] = a_row_half[src_pos];
43+
44+
base_k += default_threads;
45+
}
46+
47+
if (rest) {
48+
if (threadIdx.x < rest) {
49+
int cur_k = base_k + threadIdx.x;
50+
int src_pos = perm_int_ptr[cur_k];
51+
52+
out_half[cur_k] = a_row_half[src_pos];
53+
}
54+
}
55+
};
56+
57+
for (int i = 0; i < cur_block_rows; i++) {
58+
int cur_row = start_row + i;
59+
if (cur_row < size_m) {
60+
permute_row(cur_row);
61+
}
62+
}
63+
}
64+
65+
// More efficient version of A[..., perm]
66+
// taken from gptq_marlin.cu
67+
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
68+
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
69+
auto dev = A.get_device();
70+
auto stream = at::cuda::getCurrentCUDAStream(dev);
71+
72+
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
73+
"Currently only 16bit types are supported");
74+
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
75+
TORCH_CHECK(A.size(-1) % 8 == 0,
76+
"A columns must be a multiple of 8 (128bits)");
77+
auto A_2d = A.view({-1, A.size(-1)});
78+
79+
torch::Tensor D = torch::empty_like(A);
80+
int sms;
81+
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
82+
int block_rows = div_ceil(A_2d.size(0), sms);
83+
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
84+
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
85+
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
86+
A_2d.size(0), A_2d.size(1), block_rows);
87+
return D;
88+
}

0 commit comments

Comments
 (0)