Skip to content

Commit 5ec321f

Browse files
committed
fix
1 parent 7f32257 commit 5ec321f

File tree

4 files changed

+292
-5
lines changed

4 files changed

+292
-5
lines changed

Diff for: kernels/layernorm/non_pc/layer_norm.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ std::tuple<torch::Tensor, torch::Tensor> fused_layernorm(
235235
TORCH_CHECK(norm_weight.size(0) == d, "norm_weight is d_model?");
236236
TORCH_CHECK(norm_bias.size(0) == d, "norm_bias is d_model?");
237237

238-
TORCH_CHECK(x.size(1) % kittens::TILE_DIM == 0, "sequence length is divisible by 16?");
239-
TORCH_CHECK(residual.size(1) % kittens::TILE_DIM == 0, "sequence length is divisible by 16?");
238+
TORCH_CHECK(x.size(1) % kittens::TILE_ROW_DIM<bf16> == 0, "sequence length is divisible by 16?");
239+
TORCH_CHECK(residual.size(1) % kittens::TILE_ROW_DIM<bf16> == 0, "sequence length is divisible by 16?");
240240

241241
torch::Tensor out = torch::empty({b, n, d}, x.options());
242242
torch::Tensor out_resid = torch::empty({b, n, d}, x.options());

Diff for: kernels/matmul/baselines/cublas_fp8/Makefile

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
GPU_TARGET=H100
2+
3+
# Compiler
4+
NVCC=nvcc
5+
6+
NVCCFLAGS=-DNDEBUG -Xcompiler=-fPIE -Xcompiler -fopenmp --expt-extended-lambda --expt-relaxed-constexpr -Xcompiler=-Wno-psabi -Xcompiler=-fno-strict-aliasing --use_fast_math -forward-unknown-to-host-compiler -O3 -Xnvlink=--verbose -Xptxas=--verbose -Xptxas=--warn-on-spills -std=c++20 -MD -MT -MF -x cu -lrt -lpthread -ldl -DKITTENS_HOPPER -arch=sm_90a -lcuda -lcudadevrt -lcudart_static -lcublas -lcublasLt -lgomp -I${THUNDERKITTENS_ROOT}/include -I${THUNDERKITTENS_ROOT}/prototype
7+
TARGET=matmul
8+
SRC=matmul.cu
9+
10+
# Default target
11+
all: $(TARGET)
12+
13+
$(TARGET): $(SRC)
14+
$(NVCC) $(SRC) $(NVCCFLAGS) -o $(TARGET)
15+
16+
# Clean target
17+
clean:
18+
rm -f $(TARGET)

Diff for: kernels/matmul/baselines/cublas_fp8/matmul.cu

+264
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
#include <iostream>
2+
#include <vector>
3+
#include <chrono>
4+
#include <random>
5+
#include <cuda_runtime.h>
6+
#include <cublasLt.h>
7+
#include <cuda_fp8.h>
8+
#include <iomanip>
9+
10+
// Error checking macros
11+
#define CHECK_CUDA(call) \
12+
do { \
13+
cudaError_t err = call; \
14+
if (err != cudaSuccess) { \
15+
std::cerr << "CUDA error in " << __FILE__ << " line " << __LINE__ << ": " \
16+
<< cudaGetErrorString(err) << std::endl; \
17+
exit(EXIT_FAILURE); \
18+
} \
19+
} while(0)
20+
21+
#define CHECK_CUBLAS(call) \
22+
do { \
23+
cublasStatus_t status = call; \
24+
if (status != CUBLAS_STATUS_SUCCESS) { \
25+
std::cerr << "cuBLAS error in " << __FILE__ << " line " << __LINE__ << ": " \
26+
<< "code " << status << std::endl; \
27+
exit(EXIT_FAILURE); \
28+
} \
29+
} while(0)
30+
31+
void check(cudaError_t error) {
32+
if (error != cudaSuccess) {
33+
std::cerr << "CUDA error: " << cudaGetErrorString(error) << std::endl;
34+
exit(EXIT_FAILURE);
35+
}
36+
}
37+
38+
void checkCublas(cublasStatus_t status) {
39+
if (status != CUBLAS_STATUS_SUCCESS) {
40+
std::cerr << "cuBLAS error: " << status << std::endl;
41+
exit(EXIT_FAILURE);
42+
}
43+
}
44+
45+
void cpu_gemm(float* a, float* b, float* c, int M, int N, int K) {
46+
#pragma omp parallel for collapse(2) // otherwise the CPU version takes for everrrrrr
47+
for (int i = 0; i < M; i++) {
48+
for (int j = 0; j < N; j++) {
49+
float sum = 0.0f;
50+
for (int k = 0; k < K; k++) {
51+
// sum += a[i * K + k] * b[j * K + k]; // mma_ABt
52+
sum += a[i * K + k] * b[k * N + j]; // mma_AB
53+
}
54+
c[i * N + j] = sum;
55+
}
56+
}
57+
}
58+
59+
void check_result(float* h_C, float* h_C_ref, int M, int N) {
60+
float max_error = 0.0f;
61+
int error_count = 0;
62+
63+
// Same tolerance and error reporting as your code
64+
for (int i = 0; i < M * N; ++i) {
65+
float error = std::abs(h_C[i] - h_C_ref[i]);
66+
if(1) { // large tolerance because of fp8 vs fp32 numerics
67+
if(error_count < 25) {
68+
std::cout << "Error at row " << i / N << " col " << i % N
69+
<< ": " << h_C[i] << " != " << h_C_ref[i]
70+
<< " (ref)" << std::endl;
71+
}
72+
else if(error_count == 25) {
73+
std::cout << "Too many errors to show them all.\n";
74+
}
75+
error_count++;
76+
}
77+
max_error = std::max(max_error, error);
78+
}
79+
80+
std::cout << "Max error: " << max_error << std::endl;
81+
std::cout << "Error count: " << error_count << std::endl;
82+
}
83+
84+
void benchmark(int m, int n, int k) {
85+
// Align dimensions
86+
m = (m + 15) & ~15;
87+
n = (n + 15) & ~15;
88+
k = (k + 15) & ~15;
89+
90+
// Initialize host memory with same layout as your code
91+
std::vector<float> h_A(m * k); // A[M,K]
92+
std::vector<float> h_B(n * k); // B[N,K]
93+
std::vector<float> h_D(m * n);
94+
std::vector<float> h_D_ref(m * n);
95+
96+
// Initialize with random values just like your code
97+
std::mt19937 gen(42);
98+
std::uniform_real_distribution<> dis(-0.5, 0.5);
99+
for (int i = 0; i < m * k; ++i) h_A[i] = 1; //dis(gen) * 0.5f;
100+
for (int i = 0; i < n * k; ++i) h_B[i] = dis(gen) * 0.5f;
101+
102+
// Convert to FP8
103+
std::vector<__nv_fp8_e4m3> h_A_fp8(m * k);
104+
std::vector<__nv_fp8_e4m3> h_B_fp8(n * k);
105+
for (int i = 0; i < m * k; ++i) h_A_fp8[i] = __nv_fp8_e4m3(h_A[i]);
106+
for (int i = 0; i < n * k; ++i) h_B_fp8[i] = __nv_fp8_e4m3(h_B[i]);
107+
108+
// Allocate device memory
109+
__nv_fp8_e4m3 *d_A, *d_B, *d_D;
110+
__nv_bfloat16 *d_C;
111+
CHECK_CUDA(cudaMalloc(&d_A, m * k * sizeof(__nv_fp8_e4m3)));
112+
CHECK_CUDA(cudaMalloc(&d_B, n * k * sizeof(__nv_fp8_e4m3)));
113+
CHECK_CUDA(cudaMalloc(&d_C, m * n * sizeof(__nv_bfloat16)));
114+
CHECK_CUDA(cudaMalloc(&d_D, m * n * sizeof(__nv_fp8_e4m3)));
115+
116+
// Copy to device with same layout
117+
CHECK_CUDA(cudaMemcpy(d_A, h_A_fp8.data(), m * k * sizeof(__nv_fp8_e4m3), cudaMemcpyHostToDevice));
118+
CHECK_CUDA(cudaMemcpy(d_B, h_B_fp8.data(), n * k * sizeof(__nv_fp8_e4m3), cudaMemcpyHostToDevice));
119+
120+
// Create cuBLAS handle
121+
cublasLtHandle_t handle;
122+
CHECK_CUBLAS(cublasLtCreate(&handle));
123+
124+
// Create matrix descriptors
125+
cublasLtMatrixLayout_t matA, matB, matC, matD;
126+
CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&matA, CUDA_R_8F_E4M3, k, m, k)); // A[K,M]
127+
CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&matB, CUDA_R_8F_E4M3, k, n, k)); // B[K,N]
128+
CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&matC, CUDA_R_16BF, m, n, m)); // C[M,N] in BF16
129+
CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&matD, CUDA_R_8F_E4M3, m, n, m)); // D[M,N] in FP8
130+
131+
132+
// Create operation descriptor
133+
cublasLtMatmulDesc_t operationDesc;
134+
CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
135+
136+
// Set operation attributes - "TN" format required for FP8
137+
const int32_t transa = CUBLAS_OP_T;
138+
const int32_t transb = CUBLAS_OP_N;
139+
CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(int32_t)));
140+
CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(int32_t)));
141+
142+
const float alpha = 1.0f;
143+
const float beta = 0.0f;
144+
145+
// Allocate workspace
146+
size_t workspaceSize = 32 * 1024 * 1024; // 32MB workspace
147+
void* workspace = nullptr;
148+
CHECK_CUDA(cudaMalloc(&workspace, workspaceSize));
149+
150+
// Query the best algorithm
151+
// Create preference descriptor
152+
cublasLtMatmulPreference_t preference;
153+
checkCublas(cublasLtMatmulPreferenceCreate(&preference));
154+
checkCublas(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
155+
int returnedResults = 0;
156+
cublasLtMatmulHeuristicResult_t heuristicResult;
157+
checkCublas(cublasLtMatmulAlgoGetHeuristic(
158+
handle, operationDesc, matA, matB, matC, matD, preference, 1, &heuristicResult, &returnedResults
159+
));
160+
std::cout << "Returned results: " << returnedResults << std::endl;
161+
162+
// Warmup runs
163+
for (int i = 0; i < 10; ++i) {
164+
CHECK_CUBLAS(cublasLtMatmul(
165+
handle,
166+
operationDesc,
167+
&alpha,
168+
d_A, matA,
169+
d_B, matB,
170+
&beta,
171+
d_C, matC,
172+
d_D, matD,
173+
&heuristicResult.algo,
174+
workspace,
175+
workspaceSize,
176+
0 // Default stream
177+
));
178+
}
179+
180+
CHECK_CUDA(cudaDeviceSynchronize());
181+
182+
// Benchmark runs
183+
const int NUM_ITERATIONS = 100;
184+
cudaEvent_t start, stop;
185+
CHECK_CUDA(cudaEventCreate(&start));
186+
CHECK_CUDA(cudaEventCreate(&stop));
187+
188+
CHECK_CUDA(cudaEventRecord(start));
189+
for (int i = 0; i < NUM_ITERATIONS; ++i) {
190+
CHECK_CUBLAS(cublasLtMatmul(
191+
handle,
192+
operationDesc,
193+
&alpha,
194+
d_A, matA,
195+
d_B, matB,
196+
&beta,
197+
d_C, matC,
198+
d_D, matD,
199+
&heuristicResult.algo,
200+
workspace,
201+
workspaceSize,
202+
0
203+
));
204+
}
205+
CHECK_CUDA(cudaEventRecord(stop));
206+
CHECK_CUDA(cudaEventSynchronize(stop));
207+
208+
float milliseconds = 0;
209+
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
210+
float avg_time = milliseconds / NUM_ITERATIONS;
211+
212+
// Calculate TFLOPS
213+
double num_ops = 2.0 * static_cast<double>(m) * static_cast<double>(n) * static_cast<double>(k); // multiply-add counts as 2
214+
double seconds = avg_time / 1000.0; // convert ms to seconds
215+
double tflops = (num_ops / seconds) / 1e12;
216+
217+
std::cout << "Matrix size: " << m << "x" << n << "x" << k << std::endl;
218+
std::cout << std::fixed << std::setprecision(3);
219+
std::cout << "Average time: " << avg_time << " ms" << std::endl;
220+
std::cout << std::setprecision(2);
221+
std::cout << "Performance: " << tflops << " TFLOPS" << std::endl << std::endl;
222+
223+
// Get cuBLAS result
224+
cpu_gemm(h_A.data(), h_B.data(), h_D_ref.data(), m, n, k);
225+
226+
// Allocate FP8 host buffer
227+
std::vector<__nv_fp8_e4m3> h_D_fp8(m * n);
228+
CHECK_CUDA(cudaMemcpy(h_D_fp8.data(), d_D, m * n * sizeof(__nv_fp8_e4m3), cudaMemcpyDeviceToHost));
229+
230+
// Convert FP8 to float for comparison
231+
for (int i = 0; i < m * n; i++) {
232+
h_D[i] = float(h_D_fp8[i]); // Convert FP8 to float
233+
}
234+
235+
// Now compare the float values
236+
check_result(h_D.data(), h_D_ref.data(), m, n);
237+
238+
// Cleanup
239+
CHECK_CUDA(cudaFree(workspace));
240+
CHECK_CUDA(cudaFree(d_A));
241+
CHECK_CUDA(cudaFree(d_B));
242+
CHECK_CUDA(cudaFree(d_C));
243+
CHECK_CUDA(cudaFree(d_D));
244+
CHECK_CUDA(cudaEventDestroy(start));
245+
CHECK_CUDA(cudaEventDestroy(stop));
246+
cublasLtMatrixLayoutDestroy(matA);
247+
cublasLtMatrixLayoutDestroy(matB);
248+
cublasLtMatrixLayoutDestroy(matC);
249+
cublasLtMatrixLayoutDestroy(matD);
250+
cublasLtMatmulDescDestroy(operationDesc);
251+
cublasLtDestroy(handle);
252+
}
253+
254+
int main() {
255+
// Benchmark different matrix sizes
256+
// std::vector<int> sizes = {3072, 4096, 6144, 8192, 12288, 16384};
257+
std::vector<int> sizes = {2048};
258+
259+
for (int size : sizes) {
260+
benchmark(size, size, size);
261+
}
262+
263+
return 0;
264+
}

Diff for: tests/python/layernorm/test_correctness.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def run_tk(x, residual, drop_path, dropout, norm, residual_in_fp32=False):
2929

3030

3131
def run_flash(x, residual, drop_path, dropout, norm, residual_in_fp32=True):
32-
from layer_norm_triton import layer_norm_fn, RMSNorm
32+
from baselines.layer_norm_triton import layer_norm_fn, RMSNorm
3333

3434
rowscale = torch.ones(x.shape[:-1], device=x.device, dtype=x.dtype, )
3535
# drop_path()
@@ -102,7 +102,7 @@ def run_naive(x, residual, drop_path, dropout, norm, residual_in_fp32=False):
102102
norm = nn.LayerNorm(d).cuda()
103103
dropout = nn.Dropout(p)
104104
drop_path = None #StochasticDepth(p_path, mode="row")
105-
run_naive(x, residual, drop_path, dropout, norm)
105+
# run_naive(x, residual, drop_path, dropout, norm)
106106

107107
torch.manual_seed(0)
108108
torch.cuda.manual_seed_all(0)
@@ -113,7 +113,8 @@ def run_naive(x, residual, drop_path, dropout, norm, residual_in_fp32=False):
113113

114114
outs = []
115115
resids = []
116-
for fn in [run_tk, run_flash]:
116+
for _name, fn in [('tk', run_tk), ('triton', run_flash)]:
117+
print(f"Running {_name}")
117118
torch.manual_seed(0)
118119
torch.cuda.manual_seed_all(0)
119120
norm = nn.LayerNorm(d).cuda()
@@ -132,4 +133,8 @@ def run_naive(x, residual, drop_path, dropout, norm, residual_in_fp32=False):
132133
print(fn_resid[4,2,:8])
133134
print(f"Resid Diff: {diff}")
134135

136+
print("----"*10)
137+
138+
# breakpoint()
139+
135140

0 commit comments

Comments
 (0)