1
+ #include < iostream>
2
+ #include < vector>
3
+ #include < chrono>
4
+ #include < random>
5
+ #include < cuda_runtime.h>
6
+ #include < cublasLt.h>
7
+ #include < cuda_fp8.h>
8
+ #include < iomanip>
9
+
10
+ // Error checking macros
11
+ #define CHECK_CUDA (call ) \
12
+ do { \
13
+ cudaError_t err = call; \
14
+ if (err != cudaSuccess) { \
15
+ std::cerr << " CUDA error in " << __FILE__ << " line " << __LINE__ << " : " \
16
+ << cudaGetErrorString (err) << std::endl; \
17
+ exit (EXIT_FAILURE); \
18
+ } \
19
+ } while (0 )
20
+
21
+ #define CHECK_CUBLAS (call ) \
22
+ do { \
23
+ cublasStatus_t status = call; \
24
+ if (status != CUBLAS_STATUS_SUCCESS) { \
25
+ std::cerr << " cuBLAS error in " << __FILE__ << " line " << __LINE__ << " : " \
26
+ << " code " << status << std::endl; \
27
+ exit (EXIT_FAILURE); \
28
+ } \
29
+ } while (0 )
30
+
31
+ void check (cudaError_t error) {
32
+ if (error != cudaSuccess) {
33
+ std::cerr << " CUDA error: " << cudaGetErrorString (error) << std::endl;
34
+ exit (EXIT_FAILURE);
35
+ }
36
+ }
37
+
38
+ void checkCublas (cublasStatus_t status) {
39
+ if (status != CUBLAS_STATUS_SUCCESS) {
40
+ std::cerr << " cuBLAS error: " << status << std::endl;
41
+ exit (EXIT_FAILURE);
42
+ }
43
+ }
44
+
45
+ void cpu_gemm (float * a, float * b, float * c, int M, int N, int K) {
46
+ #pragma omp parallel for collapse(2) // otherwise the CPU version takes for everrrrrr
47
+ for (int i = 0 ; i < M; i++) {
48
+ for (int j = 0 ; j < N; j++) {
49
+ float sum = 0 .0f ;
50
+ for (int k = 0 ; k < K; k++) {
51
+ // sum += a[i * K + k] * b[j * K + k]; // mma_ABt
52
+ sum += a[i * K + k] * b[k * N + j]; // mma_AB
53
+ }
54
+ c[i * N + j] = sum;
55
+ }
56
+ }
57
+ }
58
+
59
+ void check_result (float * h_C, float * h_C_ref, int M, int N) {
60
+ float max_error = 0 .0f ;
61
+ int error_count = 0 ;
62
+
63
+ // Same tolerance and error reporting as your code
64
+ for (int i = 0 ; i < M * N; ++i) {
65
+ float error = std::abs (h_C[i] - h_C_ref[i]);
66
+ if (1 ) { // large tolerance because of fp8 vs fp32 numerics
67
+ if (error_count < 25 ) {
68
+ std::cout << " Error at row " << i / N << " col " << i % N
69
+ << " : " << h_C[i] << " != " << h_C_ref[i]
70
+ << " (ref)" << std::endl;
71
+ }
72
+ else if (error_count == 25 ) {
73
+ std::cout << " Too many errors to show them all.\n " ;
74
+ }
75
+ error_count++;
76
+ }
77
+ max_error = std::max (max_error, error);
78
+ }
79
+
80
+ std::cout << " Max error: " << max_error << std::endl;
81
+ std::cout << " Error count: " << error_count << std::endl;
82
+ }
83
+
84
+ void benchmark (int m, int n, int k) {
85
+ // Align dimensions
86
+ m = (m + 15 ) & ~15 ;
87
+ n = (n + 15 ) & ~15 ;
88
+ k = (k + 15 ) & ~15 ;
89
+
90
+ // Initialize host memory with same layout as your code
91
+ std::vector<float > h_A (m * k); // A[M,K]
92
+ std::vector<float > h_B (n * k); // B[N,K]
93
+ std::vector<float > h_D (m * n);
94
+ std::vector<float > h_D_ref (m * n);
95
+
96
+ // Initialize with random values just like your code
97
+ std::mt19937 gen (42 );
98
+ std::uniform_real_distribution<> dis (-0.5 , 0.5 );
99
+ for (int i = 0 ; i < m * k; ++i) h_A[i] = 1 ; // dis(gen) * 0.5f;
100
+ for (int i = 0 ; i < n * k; ++i) h_B[i] = dis (gen) * 0 .5f ;
101
+
102
+ // Convert to FP8
103
+ std::vector<__nv_fp8_e4m3> h_A_fp8 (m * k);
104
+ std::vector<__nv_fp8_e4m3> h_B_fp8 (n * k);
105
+ for (int i = 0 ; i < m * k; ++i) h_A_fp8[i] = __nv_fp8_e4m3 (h_A[i]);
106
+ for (int i = 0 ; i < n * k; ++i) h_B_fp8[i] = __nv_fp8_e4m3 (h_B[i]);
107
+
108
+ // Allocate device memory
109
+ __nv_fp8_e4m3 *d_A, *d_B, *d_D;
110
+ __nv_bfloat16 *d_C;
111
+ CHECK_CUDA (cudaMalloc (&d_A, m * k * sizeof (__nv_fp8_e4m3)));
112
+ CHECK_CUDA (cudaMalloc (&d_B, n * k * sizeof (__nv_fp8_e4m3)));
113
+ CHECK_CUDA (cudaMalloc (&d_C, m * n * sizeof (__nv_bfloat16)));
114
+ CHECK_CUDA (cudaMalloc (&d_D, m * n * sizeof (__nv_fp8_e4m3)));
115
+
116
+ // Copy to device with same layout
117
+ CHECK_CUDA (cudaMemcpy (d_A, h_A_fp8.data (), m * k * sizeof (__nv_fp8_e4m3), cudaMemcpyHostToDevice));
118
+ CHECK_CUDA (cudaMemcpy (d_B, h_B_fp8.data (), n * k * sizeof (__nv_fp8_e4m3), cudaMemcpyHostToDevice));
119
+
120
+ // Create cuBLAS handle
121
+ cublasLtHandle_t handle;
122
+ CHECK_CUBLAS (cublasLtCreate (&handle));
123
+
124
+ // Create matrix descriptors
125
+ cublasLtMatrixLayout_t matA, matB, matC, matD;
126
+ CHECK_CUBLAS (cublasLtMatrixLayoutCreate (&matA, CUDA_R_8F_E4M3, k, m, k)); // A[K,M]
127
+ CHECK_CUBLAS (cublasLtMatrixLayoutCreate (&matB, CUDA_R_8F_E4M3, k, n, k)); // B[K,N]
128
+ CHECK_CUBLAS (cublasLtMatrixLayoutCreate (&matC, CUDA_R_16BF, m, n, m)); // C[M,N] in BF16
129
+ CHECK_CUBLAS (cublasLtMatrixLayoutCreate (&matD, CUDA_R_8F_E4M3, m, n, m)); // D[M,N] in FP8
130
+
131
+
132
+ // Create operation descriptor
133
+ cublasLtMatmulDesc_t operationDesc;
134
+ CHECK_CUBLAS (cublasLtMatmulDescCreate (&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
135
+
136
+ // Set operation attributes - "TN" format required for FP8
137
+ const int32_t transa = CUBLAS_OP_T;
138
+ const int32_t transb = CUBLAS_OP_N;
139
+ CHECK_CUBLAS (cublasLtMatmulDescSetAttribute (operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof (int32_t )));
140
+ CHECK_CUBLAS (cublasLtMatmulDescSetAttribute (operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof (int32_t )));
141
+
142
+ const float alpha = 1 .0f ;
143
+ const float beta = 0 .0f ;
144
+
145
+ // Allocate workspace
146
+ size_t workspaceSize = 32 * 1024 * 1024 ; // 32MB workspace
147
+ void * workspace = nullptr ;
148
+ CHECK_CUDA (cudaMalloc (&workspace, workspaceSize));
149
+
150
+ // Query the best algorithm
151
+ // Create preference descriptor
152
+ cublasLtMatmulPreference_t preference;
153
+ checkCublas (cublasLtMatmulPreferenceCreate (&preference));
154
+ checkCublas (cublasLtMatmulPreferenceSetAttribute (preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof (workspaceSize)));
155
+ int returnedResults = 0 ;
156
+ cublasLtMatmulHeuristicResult_t heuristicResult;
157
+ checkCublas (cublasLtMatmulAlgoGetHeuristic (
158
+ handle, operationDesc, matA, matB, matC, matD, preference, 1 , &heuristicResult, &returnedResults
159
+ ));
160
+ std::cout << " Returned results: " << returnedResults << std::endl;
161
+
162
+ // Warmup runs
163
+ for (int i = 0 ; i < 10 ; ++i) {
164
+ CHECK_CUBLAS (cublasLtMatmul (
165
+ handle,
166
+ operationDesc,
167
+ &alpha,
168
+ d_A, matA,
169
+ d_B, matB,
170
+ &beta,
171
+ d_C, matC,
172
+ d_D, matD,
173
+ &heuristicResult.algo ,
174
+ workspace,
175
+ workspaceSize,
176
+ 0 // Default stream
177
+ ));
178
+ }
179
+
180
+ CHECK_CUDA (cudaDeviceSynchronize ());
181
+
182
+ // Benchmark runs
183
+ const int NUM_ITERATIONS = 100 ;
184
+ cudaEvent_t start, stop;
185
+ CHECK_CUDA (cudaEventCreate (&start));
186
+ CHECK_CUDA (cudaEventCreate (&stop));
187
+
188
+ CHECK_CUDA (cudaEventRecord (start));
189
+ for (int i = 0 ; i < NUM_ITERATIONS; ++i) {
190
+ CHECK_CUBLAS (cublasLtMatmul (
191
+ handle,
192
+ operationDesc,
193
+ &alpha,
194
+ d_A, matA,
195
+ d_B, matB,
196
+ &beta,
197
+ d_C, matC,
198
+ d_D, matD,
199
+ &heuristicResult.algo ,
200
+ workspace,
201
+ workspaceSize,
202
+ 0
203
+ ));
204
+ }
205
+ CHECK_CUDA (cudaEventRecord (stop));
206
+ CHECK_CUDA (cudaEventSynchronize (stop));
207
+
208
+ float milliseconds = 0 ;
209
+ CHECK_CUDA (cudaEventElapsedTime (&milliseconds, start, stop));
210
+ float avg_time = milliseconds / NUM_ITERATIONS;
211
+
212
+ // Calculate TFLOPS
213
+ double num_ops = 2.0 * static_cast <double >(m) * static_cast <double >(n) * static_cast <double >(k); // multiply-add counts as 2
214
+ double seconds = avg_time / 1000.0 ; // convert ms to seconds
215
+ double tflops = (num_ops / seconds) / 1e12 ;
216
+
217
+ std::cout << " Matrix size: " << m << " x" << n << " x" << k << std::endl;
218
+ std::cout << std::fixed << std::setprecision (3 );
219
+ std::cout << " Average time: " << avg_time << " ms" << std::endl;
220
+ std::cout << std::setprecision (2 );
221
+ std::cout << " Performance: " << tflops << " TFLOPS" << std::endl << std::endl;
222
+
223
+ // Get cuBLAS result
224
+ cpu_gemm (h_A.data (), h_B.data (), h_D_ref.data (), m, n, k);
225
+
226
+ // Allocate FP8 host buffer
227
+ std::vector<__nv_fp8_e4m3> h_D_fp8 (m * n);
228
+ CHECK_CUDA (cudaMemcpy (h_D_fp8.data (), d_D, m * n * sizeof (__nv_fp8_e4m3), cudaMemcpyDeviceToHost));
229
+
230
+ // Convert FP8 to float for comparison
231
+ for (int i = 0 ; i < m * n; i++) {
232
+ h_D[i] = float (h_D_fp8[i]); // Convert FP8 to float
233
+ }
234
+
235
+ // Now compare the float values
236
+ check_result (h_D.data (), h_D_ref.data (), m, n);
237
+
238
+ // Cleanup
239
+ CHECK_CUDA (cudaFree (workspace));
240
+ CHECK_CUDA (cudaFree (d_A));
241
+ CHECK_CUDA (cudaFree (d_B));
242
+ CHECK_CUDA (cudaFree (d_C));
243
+ CHECK_CUDA (cudaFree (d_D));
244
+ CHECK_CUDA (cudaEventDestroy (start));
245
+ CHECK_CUDA (cudaEventDestroy (stop));
246
+ cublasLtMatrixLayoutDestroy (matA);
247
+ cublasLtMatrixLayoutDestroy (matB);
248
+ cublasLtMatrixLayoutDestroy (matC);
249
+ cublasLtMatrixLayoutDestroy (matD);
250
+ cublasLtMatmulDescDestroy (operationDesc);
251
+ cublasLtDestroy (handle);
252
+ }
253
+
254
+ int main () {
255
+ // Benchmark different matrix sizes
256
+ // std::vector<int> sizes = {3072, 4096, 6144, 8192, 12288, 16384};
257
+ std::vector<int > sizes = {2048 };
258
+
259
+ for (int size : sizes) {
260
+ benchmark (size, size, size);
261
+ }
262
+
263
+ return 0 ;
264
+ }
0 commit comments