Skip to content

Commit 1e5d9ad

Browse files
authored
extra/gemm/max_matmul: start of custom kernels for GEMM (tinygrad#6926)
* extra/gemm/max_matmul: start of custom kernels for GEMM * add an unoptimized FP16/FP16 MMA example * add slow 3-stage fp16 acc example * add correct 3-stage pipeline with unswizzled/flat smem input (slow) * add acc fp16 example with 3 stages and swizzle (no bank conflicts) * add max version of NV fp16_fp16_fp16 * fix up comments and removed unused code in max variations * add start of no_xor example * fix to account for UOps to Ops
1 parent 865f23d commit 1e5d9ad

11 files changed

+4418
-0
lines changed

extra/gemm/max_kernels/nv.fp16_fp16_fp16.2_stage.cu

Lines changed: 508 additions & 0 deletions
Large diffs are not rendered by default.

extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage.cu

Lines changed: 465 additions & 0 deletions
Large diffs are not rendered by default.

extra/gemm/max_kernels/nv.fp16_fp16_fp16.3_stage_swizzled.cu

Lines changed: 517 additions & 0 deletions
Large diffs are not rendered by default.

extra/gemm/max_kernels/nv.fp16_fp16_fp16.max.cu

Lines changed: 482 additions & 0 deletions
Large diffs are not rendered by default.

extra/gemm/max_kernels/nv.fp16_fp16_fp16.no_xor.cu

Lines changed: 486 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#define INFINITY (__int_as_float(0x7f800000))
2+
#define NAN (__int_as_float(0x7fffffff))
3+
#include <cuda_fp16.h>
4+
struct __align__(8) half4 { half x, y, z, w; }; __device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
5+
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
6+
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
7+
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
8+
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
9+
return c;}
10+
extern "C" __global__ void __launch_bounds__(128) wmma_example(half* data0, const half* data1, const half* data2) {
11+
int gidx0 = blockIdx.x; /* 32 */
12+
int gidx1 = blockIdx.y; /* 64 */
13+
int lidx0 = threadIdx.x; /* 16 */
14+
int lidx1 = threadIdx.y; /* 2 */
15+
int lidx2 = threadIdx.z; /* 4 */
16+
float4 cast0 = make_float4(0.0f,0.0f,0.0f,0.0f);
17+
int alu0 = (gidx0*128);
18+
int alu1 = (gidx1*262144);
19+
int alu2 = (lidx1*32768);
20+
int alu3 = (lidx2*32);
21+
int alu4 = (lidx0/8);
22+
int alu5 = (alu4*16384);
23+
int alu6 = (lidx0%2);
24+
int alu7 = (alu6*2);
25+
int alu8 = ((lidx0/2)%2);
26+
int alu9 = (alu8*4);
27+
int alu10 = ((lidx0/4)%2);
28+
int alu11 = (alu10*8192);
29+
int alu12 = (alu1+alu0+alu7+alu9+alu11+alu5+alu2+alu3);
30+
int alu13 = (alu1+alu7+alu9+alu11+alu5+alu2);
31+
float4 acc0 = cast0;
32+
float4 acc1 = cast0;
33+
float4 acc2 = cast0;
34+
float4 acc3 = cast0;
35+
float4 acc4 = cast0;
36+
float4 acc5 = cast0;
37+
float4 acc6 = cast0;
38+
float4 acc7 = cast0;
39+
float4 acc8 = cast0;
40+
float4 acc9 = cast0;
41+
float4 acc10 = cast0;
42+
float4 acc11 = cast0;
43+
float4 acc12 = cast0;
44+
float4 acc13 = cast0;
45+
float4 acc14 = cast0;
46+
float4 acc15 = cast0;
47+
for (int ridx0 = 0; ridx0 < 256; ridx0++) {
48+
int alu14 = (ridx0*16);
49+
int alu15 = (alu13+alu14);
50+
int alu16 = (alu14+alu13);
51+
int alu17 = (alu0+(alu6*8192)+(alu8*16384)+alu10+(alu4*2)+(lidx1*4)+alu3+(ridx0*65536));
52+
half val0 = data2[alu17+8];
53+
half val1 = data2[alu17+16];
54+
half val2 = data2[alu17+24];
55+
half val3 = data2[alu17+4096];
56+
half val4 = data2[alu17+4104];
57+
half val5 = data2[alu17+4112];
58+
half val6 = data2[alu17+4120];
59+
half val7 = data2[alu17+32768];
60+
half val8 = data2[alu17+32776];
61+
half val9 = data2[alu17+32784];
62+
half val10 = data2[alu17+32792];
63+
half val11 = data2[alu17+36864];
64+
half val12 = data2[alu17+36872];
65+
half4 cast1 = make_half4(val0,val4,val8,val12);
66+
half val13 = data2[alu17+36880];
67+
half4 cast2 = make_half4(val1,val5,val9,val13);
68+
half val14 = data2[alu17+36888];
69+
half4 cast3 = make_half4(val2,val6,val10,val14);
70+
half val15 = data2[alu17];
71+
half4 cast4 = make_half4(val15,val3,val7,val11);
72+
half2 val16 = *((half2*)(data1+alu15+4096));
73+
half2 val17 = *((half2*)(data1+alu15+65536));
74+
half2 val18 = *((half2*)(data1+alu15+69632));
75+
half2 val19 = *((half2*)(data1+alu15+131072));
76+
half2 val20 = *((half2*)(data1+alu15+135168));
77+
half2 val21 = *((half2*)(data1+alu15+196608));
78+
half2 val22 = *((half2*)(data1+alu15+200704));
79+
half2 val23 = *((half2*)(data1+alu15));
80+
half2 val24 = *((half2*)(data1+alu16+8));
81+
half2 val25 = *((half2*)(data1+alu16+4104));
82+
half8 cast5 = make_half8(val23.x,val23.y,val16.x,val16.y,val24.x,val24.y,val25.x,val25.y);
83+
float4 wmma0 = __WMMA_8_16_16_half_float(cast5, cast1, acc1);
84+
float4 wmma1 = __WMMA_8_16_16_half_float(cast5, cast2, acc2);
85+
float4 wmma2 = __WMMA_8_16_16_half_float(cast5, cast3, acc3);
86+
float4 wmma3 = __WMMA_8_16_16_half_float(cast5, cast4, acc0);
87+
half2 val26 = *((half2*)(data1+alu16+65544));
88+
half2 val27 = *((half2*)(data1+alu16+69640));
89+
half8 cast6 = make_half8(val17.x,val17.y,val18.x,val18.y,val26.x,val26.y,val27.x,val27.y);
90+
float4 wmma4 = __WMMA_8_16_16_half_float(cast6, cast1, acc5);
91+
float4 wmma5 = __WMMA_8_16_16_half_float(cast6, cast2, acc6);
92+
float4 wmma6 = __WMMA_8_16_16_half_float(cast6, cast3, acc7);
93+
float4 wmma7 = __WMMA_8_16_16_half_float(cast6, cast4, acc4);
94+
half2 val28 = *((half2*)(data1+alu16+131080));
95+
half2 val29 = *((half2*)(data1+alu16+135176));
96+
half8 cast7 = make_half8(val19.x,val19.y,val20.x,val20.y,val28.x,val28.y,val29.x,val29.y);
97+
float4 wmma8 = __WMMA_8_16_16_half_float(cast7, cast1, acc9);
98+
float4 wmma9 = __WMMA_8_16_16_half_float(cast7, cast2, acc10);
99+
float4 wmma10 = __WMMA_8_16_16_half_float(cast7, cast3, acc11);
100+
float4 wmma11 = __WMMA_8_16_16_half_float(cast7, cast4, acc8);
101+
half2 val30 = *((half2*)(data1+alu16+196616));
102+
half2 val31 = *((half2*)(data1+alu16+200712));
103+
half8 cast8 = make_half8(val21.x,val21.y,val22.x,val22.y,val30.x,val30.y,val31.x,val31.y);
104+
float4 wmma12 = __WMMA_8_16_16_half_float(cast8, cast1, acc13);
105+
float4 wmma13 = __WMMA_8_16_16_half_float(cast8, cast2, acc14);
106+
float4 wmma14 = __WMMA_8_16_16_half_float(cast8, cast3, acc15);
107+
float4 wmma15 = __WMMA_8_16_16_half_float(cast8, cast4, acc12);
108+
acc0 = wmma3;
109+
acc1 = wmma0;
110+
acc2 = wmma1;
111+
acc3 = wmma2;
112+
acc4 = wmma7;
113+
acc5 = wmma4;
114+
acc6 = wmma5;
115+
acc7 = wmma6;
116+
acc8 = wmma11;
117+
acc9 = wmma8;
118+
acc10 = wmma9;
119+
acc11 = wmma10;
120+
acc12 = wmma15;
121+
acc13 = wmma12;
122+
acc14 = wmma13;
123+
acc15 = wmma14;
124+
}
125+
*((half2*)(data0+alu12+8)) = make_half2((half)(acc1.x),(half)(acc1.y));
126+
*((half2*)(data0+alu12+16)) = make_half2((half)(acc2.x),(half)(acc2.y));
127+
*((half2*)(data0+alu12+24)) = make_half2((half)(acc3.x),(half)(acc3.y));
128+
*((half2*)(data0+alu12+4096)) = make_half2((half)(acc0.z),(half)(acc0.w));
129+
*((half2*)(data0+alu12+4104)) = make_half2((half)(acc1.z),(half)(acc1.w));
130+
*((half2*)(data0+alu12+4112)) = make_half2((half)(acc2.z),(half)(acc2.w));
131+
*((half2*)(data0+alu12+4120)) = make_half2((half)(acc3.z),(half)(acc3.w));
132+
*((half2*)(data0+alu12+65536)) = make_half2((half)(acc4.x),(half)(acc4.y));
133+
*((half2*)(data0+alu12+65544)) = make_half2((half)(acc5.x),(half)(acc5.y));
134+
*((half2*)(data0+alu12+65552)) = make_half2((half)(acc6.x),(half)(acc6.y));
135+
*((half2*)(data0+alu12+65560)) = make_half2((half)(acc7.x),(half)(acc7.y));
136+
*((half2*)(data0+alu12+69632)) = make_half2((half)(acc4.z),(half)(acc4.w));
137+
*((half2*)(data0+alu12+69640)) = make_half2((half)(acc5.z),(half)(acc5.w));
138+
*((half2*)(data0+alu12+69648)) = make_half2((half)(acc6.z),(half)(acc6.w));
139+
*((half2*)(data0+alu12+69656)) = make_half2((half)(acc7.z),(half)(acc7.w));
140+
*((half2*)(data0+alu12+131072)) = make_half2((half)(acc8.x),(half)(acc8.y));
141+
*((half2*)(data0+alu12+131080)) = make_half2((half)(acc9.x),(half)(acc9.y));
142+
*((half2*)(data0+alu12+131088)) = make_half2((half)(acc10.x),(half)(acc10.y));
143+
*((half2*)(data0+alu12+131096)) = make_half2((half)(acc11.x),(half)(acc11.y));
144+
*((half2*)(data0+alu12+135168)) = make_half2((half)(acc8.z),(half)(acc8.w));
145+
*((half2*)(data0+alu12+135176)) = make_half2((half)(acc9.z),(half)(acc9.w));
146+
*((half2*)(data0+alu12+135184)) = make_half2((half)(acc10.z),(half)(acc10.w));
147+
*((half2*)(data0+alu12+135192)) = make_half2((half)(acc11.z),(half)(acc11.w));
148+
*((half2*)(data0+alu12+196608)) = make_half2((half)(acc12.x),(half)(acc12.y));
149+
*((half2*)(data0+alu12+196616)) = make_half2((half)(acc13.x),(half)(acc13.y));
150+
*((half2*)(data0+alu12+196624)) = make_half2((half)(acc14.x),(half)(acc14.y));
151+
*((half2*)(data0+alu12+196632)) = make_half2((half)(acc15.x),(half)(acc15.y));
152+
*((half2*)(data0+alu12+200704)) = make_half2((half)(acc12.z),(half)(acc12.w));
153+
*((half2*)(data0+alu12+200712)) = make_half2((half)(acc13.z),(half)(acc13.w));
154+
*((half2*)(data0+alu12+200720)) = make_half2((half)(acc14.z),(half)(acc14.w));
155+
*((half2*)(data0+alu12+200728)) = make_half2((half)(acc15.z),(half)(acc15.w));
156+
*((half2*)(data0+alu12)) = make_half2((half)(acc0.x),(half)(acc0.y));
157+
}

0 commit comments

Comments
 (0)