1
+ #define INFINITY (__int_as_float(0x7f800000 ))
2
+ #define NAN (__int_as_float(0x7fffffff ))
3
+ #include < cuda_fp16.h>
4
+ struct __align__ (8 ) half4 { half x, y, z, w; }; __device__ half4 make_half4 (half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
5
+ struct __align__ (16 ) half8 { half x, y, z, w, a, b, c, d; }; __device__ half8 make_half8 (half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
6
+ __device__ float4 __WMMA_8_16_16_half_float (half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
7
+ asm ( " mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
8
+ : " +f" (c.x ), " +f" (c.y ), " +f" (c.z ), " +f" (c.w ) : " r" (a_pk[0 ]), " r" (a_pk[1 ]), " r" (a_pk[2 ]), " r" (a_pk[3 ]), " r" (b_pk[0 ]), " r" (b_pk[1 ]) );
9
+ return c;}
10
+ extern " C" __global__ void __launch_bounds__ (128 ) wmma_example(half* data0, const half* data1, const half* data2) {
11
+ int gidx0 = blockIdx .x ; /* 32 */
12
+ int gidx1 = blockIdx .y ; /* 64 */
13
+ int lidx0 = threadIdx .x ; /* 16 */
14
+ int lidx1 = threadIdx .y ; /* 2 */
15
+ int lidx2 = threadIdx .z ; /* 4 */
16
+ float4 cast0 = make_float4 (0 .0f ,0 .0f ,0 .0f ,0 .0f );
17
+ int alu0 = (gidx0*128 );
18
+ int alu1 = (gidx1*262144 );
19
+ int alu2 = (lidx1*32768 );
20
+ int alu3 = (lidx2*32 );
21
+ int alu4 = (lidx0/8 );
22
+ int alu5 = (alu4*16384 );
23
+ int alu6 = (lidx0%2 );
24
+ int alu7 = (alu6*2 );
25
+ int alu8 = ((lidx0/2 )%2 );
26
+ int alu9 = (alu8*4 );
27
+ int alu10 = ((lidx0/4 )%2 );
28
+ int alu11 = (alu10*8192 );
29
+ int alu12 = (alu1+alu0+alu7+alu9+alu11+alu5+alu2+alu3);
30
+ int alu13 = (alu1+alu7+alu9+alu11+alu5+alu2);
31
+ float4 acc0 = cast0;
32
+ float4 acc1 = cast0;
33
+ float4 acc2 = cast0;
34
+ float4 acc3 = cast0;
35
+ float4 acc4 = cast0;
36
+ float4 acc5 = cast0;
37
+ float4 acc6 = cast0;
38
+ float4 acc7 = cast0;
39
+ float4 acc8 = cast0;
40
+ float4 acc9 = cast0;
41
+ float4 acc10 = cast0;
42
+ float4 acc11 = cast0;
43
+ float4 acc12 = cast0;
44
+ float4 acc13 = cast0;
45
+ float4 acc14 = cast0;
46
+ float4 acc15 = cast0;
47
+ for (int ridx0 = 0 ; ridx0 < 256 ; ridx0++) {
48
+ int alu14 = (ridx0*16 );
49
+ int alu15 = (alu13+alu14);
50
+ int alu16 = (alu14+alu13);
51
+ int alu17 = (alu0+(alu6*8192 )+(alu8*16384 )+alu10+(alu4*2 )+(lidx1*4 )+alu3+(ridx0*65536 ));
52
+ half val0 = data2[alu17+8 ];
53
+ half val1 = data2[alu17+16 ];
54
+ half val2 = data2[alu17+24 ];
55
+ half val3 = data2[alu17+4096 ];
56
+ half val4 = data2[alu17+4104 ];
57
+ half val5 = data2[alu17+4112 ];
58
+ half val6 = data2[alu17+4120 ];
59
+ half val7 = data2[alu17+32768 ];
60
+ half val8 = data2[alu17+32776 ];
61
+ half val9 = data2[alu17+32784 ];
62
+ half val10 = data2[alu17+32792 ];
63
+ half val11 = data2[alu17+36864 ];
64
+ half val12 = data2[alu17+36872 ];
65
+ half4 cast1 = make_half4 (val0,val4,val8,val12);
66
+ half val13 = data2[alu17+36880 ];
67
+ half4 cast2 = make_half4 (val1,val5,val9,val13);
68
+ half val14 = data2[alu17+36888 ];
69
+ half4 cast3 = make_half4 (val2,val6,val10,val14);
70
+ half val15 = data2[alu17];
71
+ half4 cast4 = make_half4 (val15,val3,val7,val11);
72
+ half2 val16 = *((half2*)(data1+alu15+4096 ));
73
+ half2 val17 = *((half2*)(data1+alu15+65536 ));
74
+ half2 val18 = *((half2*)(data1+alu15+69632 ));
75
+ half2 val19 = *((half2*)(data1+alu15+131072 ));
76
+ half2 val20 = *((half2*)(data1+alu15+135168 ));
77
+ half2 val21 = *((half2*)(data1+alu15+196608 ));
78
+ half2 val22 = *((half2*)(data1+alu15+200704 ));
79
+ half2 val23 = *((half2*)(data1+alu15));
80
+ half2 val24 = *((half2*)(data1+alu16+8 ));
81
+ half2 val25 = *((half2*)(data1+alu16+4104 ));
82
+ half8 cast5 = make_half8 (val23.x ,val23.y ,val16.x ,val16.y ,val24.x ,val24.y ,val25.x ,val25.y );
83
+ float4 wmma0 = __WMMA_8_16_16_half_float (cast5, cast1, acc1);
84
+ float4 wmma1 = __WMMA_8_16_16_half_float (cast5, cast2, acc2);
85
+ float4 wmma2 = __WMMA_8_16_16_half_float (cast5, cast3, acc3);
86
+ float4 wmma3 = __WMMA_8_16_16_half_float (cast5, cast4, acc0);
87
+ half2 val26 = *((half2*)(data1+alu16+65544 ));
88
+ half2 val27 = *((half2*)(data1+alu16+69640 ));
89
+ half8 cast6 = make_half8 (val17.x ,val17.y ,val18.x ,val18.y ,val26.x ,val26.y ,val27.x ,val27.y );
90
+ float4 wmma4 = __WMMA_8_16_16_half_float (cast6, cast1, acc5);
91
+ float4 wmma5 = __WMMA_8_16_16_half_float (cast6, cast2, acc6);
92
+ float4 wmma6 = __WMMA_8_16_16_half_float (cast6, cast3, acc7);
93
+ float4 wmma7 = __WMMA_8_16_16_half_float (cast6, cast4, acc4);
94
+ half2 val28 = *((half2*)(data1+alu16+131080 ));
95
+ half2 val29 = *((half2*)(data1+alu16+135176 ));
96
+ half8 cast7 = make_half8 (val19.x ,val19.y ,val20.x ,val20.y ,val28.x ,val28.y ,val29.x ,val29.y );
97
+ float4 wmma8 = __WMMA_8_16_16_half_float (cast7, cast1, acc9);
98
+ float4 wmma9 = __WMMA_8_16_16_half_float (cast7, cast2, acc10);
99
+ float4 wmma10 = __WMMA_8_16_16_half_float (cast7, cast3, acc11);
100
+ float4 wmma11 = __WMMA_8_16_16_half_float (cast7, cast4, acc8);
101
+ half2 val30 = *((half2*)(data1+alu16+196616 ));
102
+ half2 val31 = *((half2*)(data1+alu16+200712 ));
103
+ half8 cast8 = make_half8 (val21.x ,val21.y ,val22.x ,val22.y ,val30.x ,val30.y ,val31.x ,val31.y );
104
+ float4 wmma12 = __WMMA_8_16_16_half_float (cast8, cast1, acc13);
105
+ float4 wmma13 = __WMMA_8_16_16_half_float (cast8, cast2, acc14);
106
+ float4 wmma14 = __WMMA_8_16_16_half_float (cast8, cast3, acc15);
107
+ float4 wmma15 = __WMMA_8_16_16_half_float (cast8, cast4, acc12);
108
+ acc0 = wmma3;
109
+ acc1 = wmma0;
110
+ acc2 = wmma1;
111
+ acc3 = wmma2;
112
+ acc4 = wmma7;
113
+ acc5 = wmma4;
114
+ acc6 = wmma5;
115
+ acc7 = wmma6;
116
+ acc8 = wmma11;
117
+ acc9 = wmma8;
118
+ acc10 = wmma9;
119
+ acc11 = wmma10;
120
+ acc12 = wmma15;
121
+ acc13 = wmma12;
122
+ acc14 = wmma13;
123
+ acc15 = wmma14;
124
+ }
125
+ *((half2*)(data0+alu12+8 )) = make_half2 ((half)(acc1.x ),(half)(acc1.y ));
126
+ *((half2*)(data0+alu12+16 )) = make_half2 ((half)(acc2.x ),(half)(acc2.y ));
127
+ *((half2*)(data0+alu12+24 )) = make_half2 ((half)(acc3.x ),(half)(acc3.y ));
128
+ *((half2*)(data0+alu12+4096 )) = make_half2 ((half)(acc0.z ),(half)(acc0.w ));
129
+ *((half2*)(data0+alu12+4104 )) = make_half2 ((half)(acc1.z ),(half)(acc1.w ));
130
+ *((half2*)(data0+alu12+4112 )) = make_half2 ((half)(acc2.z ),(half)(acc2.w ));
131
+ *((half2*)(data0+alu12+4120 )) = make_half2 ((half)(acc3.z ),(half)(acc3.w ));
132
+ *((half2*)(data0+alu12+65536 )) = make_half2 ((half)(acc4.x ),(half)(acc4.y ));
133
+ *((half2*)(data0+alu12+65544 )) = make_half2 ((half)(acc5.x ),(half)(acc5.y ));
134
+ *((half2*)(data0+alu12+65552 )) = make_half2 ((half)(acc6.x ),(half)(acc6.y ));
135
+ *((half2*)(data0+alu12+65560 )) = make_half2 ((half)(acc7.x ),(half)(acc7.y ));
136
+ *((half2*)(data0+alu12+69632 )) = make_half2 ((half)(acc4.z ),(half)(acc4.w ));
137
+ *((half2*)(data0+alu12+69640 )) = make_half2 ((half)(acc5.z ),(half)(acc5.w ));
138
+ *((half2*)(data0+alu12+69648 )) = make_half2 ((half)(acc6.z ),(half)(acc6.w ));
139
+ *((half2*)(data0+alu12+69656 )) = make_half2 ((half)(acc7.z ),(half)(acc7.w ));
140
+ *((half2*)(data0+alu12+131072 )) = make_half2 ((half)(acc8.x ),(half)(acc8.y ));
141
+ *((half2*)(data0+alu12+131080 )) = make_half2 ((half)(acc9.x ),(half)(acc9.y ));
142
+ *((half2*)(data0+alu12+131088 )) = make_half2 ((half)(acc10.x ),(half)(acc10.y ));
143
+ *((half2*)(data0+alu12+131096 )) = make_half2 ((half)(acc11.x ),(half)(acc11.y ));
144
+ *((half2*)(data0+alu12+135168 )) = make_half2 ((half)(acc8.z ),(half)(acc8.w ));
145
+ *((half2*)(data0+alu12+135176 )) = make_half2 ((half)(acc9.z ),(half)(acc9.w ));
146
+ *((half2*)(data0+alu12+135184 )) = make_half2 ((half)(acc10.z ),(half)(acc10.w ));
147
+ *((half2*)(data0+alu12+135192 )) = make_half2 ((half)(acc11.z ),(half)(acc11.w ));
148
+ *((half2*)(data0+alu12+196608 )) = make_half2 ((half)(acc12.x ),(half)(acc12.y ));
149
+ *((half2*)(data0+alu12+196616 )) = make_half2 ((half)(acc13.x ),(half)(acc13.y ));
150
+ *((half2*)(data0+alu12+196624 )) = make_half2 ((half)(acc14.x ),(half)(acc14.y ));
151
+ *((half2*)(data0+alu12+196632 )) = make_half2 ((half)(acc15.x ),(half)(acc15.y ));
152
+ *((half2*)(data0+alu12+200704 )) = make_half2 ((half)(acc12.z ),(half)(acc12.w ));
153
+ *((half2*)(data0+alu12+200712 )) = make_half2 ((half)(acc13.z ),(half)(acc13.w ));
154
+ *((half2*)(data0+alu12+200720 )) = make_half2 ((half)(acc14.z ),(half)(acc14.w ));
155
+ *((half2*)(data0+alu12+200728 )) = make_half2 ((half)(acc15.z ),(half)(acc15.w ));
156
+ *((half2*)(data0+alu12)) = make_half2 ((half)(acc0.x ),(half)(acc0.y ));
157
+ }
0 commit comments