-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathless_slow_sm70.ptx
232 lines (196 loc) · 8.5 KB
/
less_slow_sm70.ptx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
/**
* less_slow_sm70.ptx
*
* Micro-kernels for building a performance-first mindset for CUDA-capable
* GPUs using Parallel Thread eXecution (PTX) Intermediate Representation (IR)
* for different generations of Streaming Multiprocessors (SMs) and Tensor
* Cores (TCs) for Volta-generation Nvidia GPUs.
*
* ? You should start at `less_slow.cu` before reading this file.
* ? Also read intro to PTX: https://docs.nvidia.com/cuda/parallel-thread-execution/
* ? Check the PTX ISA: https://docs.nvidia.com/cuda/pdf/ptx_isa_8.5.pdf
*
* ! PTX is higher-level than SASS, but still very similar to typical CPU
* ! assembly languages. It has slightly different syntax and semantics for
* ! predicates and memory access, but still don't have `for` loops :(
* ! As for emojis, we can only use ASCII... CUDA can't JIT-compile UTF-8.
*
* You can validate this file by asking the Nvidia PTX Assembler to compile it
* to `.cubin` for some target architecture:
*
* $ ptxas -o less_slow_sm70_from_ptx.cubin -arch=sm_70 less_slow_sm70.ptx
* $ cuobjdump -sass less_slow_sm70_from_ptx.cubin | grep -i mma
*
* Assuming how aggressively NVCC unrolls loops and the number of kernels in
* this file, you may want to deduplicate them:
*
* $ cuobjdump -sass less_slow_sm70_from_ptx.cubin | grep -i mma | \
* $ sed -r 's/\/\*[^*]+\*\///g' | \
* $ sed -r 's/^[[:space:]]+//; s/[[:space:]]+$//' | \
* $ sort -u
*
* @section Register File
*
* GPU code manages registers differently. Hopper tuning guide suggests that
* the register file size is 64K 32-bit registers per Streaming Multiprocessor.
* The maximum number of registers per thread is 255.
*
* PTX provides read-only variables visible a special registers like `%tid` for
* thread ID, `%ctaid` for block ID, and `%aggr_smem_size` for shared memory,
* or `%current_graph_exec` to access the ID of the current graph execution.
* To read from them, simple use the `mov` instruction.
*/
.version 6.5 // PTX version 6.5 is enough for Volta GPUs
.target sm_70 // Target architecture (SM 7.0 - Volta GPUs)
.address_size 64 // 64-bit addressing
.visible .entry tops_f16f16_sm70mma_8x8x4_loop128_ptx_kernel()
{
// Accumulator registers used for both input and output of the MMA operation
.reg .b32 accum_0, accum_1, accum_2, accum_3;
// Registers to hold packed pairs of 16-bit data for matrix a (2 registers)
.reg .b32 matrix_a_0, matrix_a_1;
// Registers to hold packed pairs of 16-bit data for matrix b (2 registers)
.reg .b32 matrix_b_0, matrix_b_1;
// General-purpose registers for loop control and constant values
.reg .b32 loop_counter, loop_limit, packed_const;
// Predicate register for conditional branching (loop exit)
.reg .pred exit_predicate;
// Set up loop counter and loop limit
mov.u32 loop_counter, 0;
mov.u32 loop_limit, 128;
// Zero-initialize the accumulator registers
mov.f32 accum_0, 0.0;
mov.f32 accum_1, 0.0;
mov.f32 accum_2, 0.0;
mov.f32 accum_3, 0.0;
// Initialize constant for packed matrix data (placeholder)
mov.b32 packed_const, 0x00010001;
// Initialize matrix a registers with the packed constant
mov.b32 matrix_a_0, packed_const;
mov.b32 matrix_a_1, packed_const;
// Initialize matrix b registers with the packed constant
mov.b32 matrix_b_0, packed_const;
mov.b32 matrix_b_1, packed_const;
// The main loop will repeat for 128 iterations
loop_start:
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
@exit_predicate bra loop_end;
mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16
{ accum_0, accum_1, accum_2, accum_3 },
{ matrix_a_0, matrix_a_1 },
{ matrix_b_0, matrix_b_1 },
{ accum_0, accum_1, accum_2, accum_3 };
// Increment the loop counter
add.u32 loop_counter, loop_counter, 1;
// Branch back to the beginning of the loop
bra loop_start;
loop_end:
// If we simply exit, the computation will be optimized out!
// Instead, let's check for an impossible condition, like if the thread ID
// is equal to `UINT_MAX`, and if so - write accumulators to global memory
// NULL address.
.reg .u32 tid;
.reg .pred impossible_predicate;
mov.u32 tid, %tid.x; //? Special system registers start with `%`
setp.ne.u32 impossible_predicate, tid, 0xFFFFFFFF;
@impossible_predicate bra loop_exit;
// Write into memory:
.reg .u64 store_ptr;
mov.u64 store_ptr, 0;
st.global.f32 [store_ptr], accum_0;
st.global.f32 [store_ptr+4], accum_1;
st.global.f32 [store_ptr+8], accum_2;
st.global.f32 [store_ptr+12], accum_3;
loop_exit:
ret;
}
.visible .entry tops_f16f32_sm70mma_8x8x4_loop128_ptx_kernel()
{
// Accumulator registers used for both input and output of the MMA operation
.reg .b32 accum_0, accum_1, accum_2, accum_3,
accum_4, accum_5, accum_6, accum_7;
// Registers to hold packed 16-bit data for matrix a (4 registers)
.reg .b32 matrix_a_0, matrix_a_1, matrix_a_2, matrix_a_3;
// Registers to hold packed 16-bit data for matrix b (4 registers)
.reg .b32 matrix_b_0, matrix_b_1, matrix_b_2, matrix_b_3;
// General-purpose registers for loop control and constant values
.reg .b32 loop_counter, loop_limit, packed_const;
// Predicate register for conditional branching (loop exit)
.reg .pred exit_predicate;
// Set up loop counter and loop limit
mov.u32 loop_counter, 0;
mov.u32 loop_limit, 128;
// Zero-initialize the accumulator registers
mov.f32 accum_0, 0.0;
mov.f32 accum_1, 0.0;
mov.f32 accum_2, 0.0;
mov.f32 accum_3, 0.0;
// Initialize constant for packed matrix data (placeholder)
mov.b32 packed_const, 0x00010001;
// Initialize matrix a registers with the packed constant
mov.b32 matrix_a_0, packed_const;
mov.b32 matrix_a_1, packed_const;
mov.b32 matrix_a_2, packed_const;
mov.b32 matrix_a_3, packed_const;
// Initialize matrix b registers with the packed constant
mov.b32 matrix_b_0, packed_const;
mov.b32 matrix_b_1, packed_const;
mov.b32 matrix_b_2, packed_const;
mov.b32 matrix_b_3, packed_const;
// The main loop will repeat for 128 iterations
loop_start:
setp.ge.u32 exit_predicate, loop_counter, loop_limit;
@exit_predicate bra loop_end;
mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32
{ accum_0, accum_1, accum_2, accum_3,
accum_4, accum_5, accum_6, accum_7 },
{ matrix_a_0, matrix_a_1 },
{ matrix_b_0, matrix_b_1 },
{ accum_0, accum_1, accum_2, accum_3,
accum_4, accum_5, accum_6, accum_7 };
// Increment the loop counter
add.u32 loop_counter, loop_counter, 1;
// Branch back to the beginning of the loop
bra loop_start;
loop_end:
// If we simply exit, the computation will be optimized out!
// Instead, let's check for an impossible condition, like if the thread ID
// is equal to `UINT_MAX`, and if so - write accumulators to global memory
// NULL address.
.reg .u32 tid;
.reg .pred impossible_predicate;
mov.u32 tid, %tid.x; //? Special system registers start with `%`
setp.ne.u32 impossible_predicate, tid, 0xFFFFFFFF;
@impossible_predicate bra loop_exit;
// Write into memory:
.reg .u64 store_ptr;
mov.u64 store_ptr, 0;
st.global.f32 [store_ptr], accum_0;
st.global.f32 [store_ptr+4], accum_1;
st.global.f32 [store_ptr+8], accum_2;
st.global.f32 [store_ptr+12], accum_3;
loop_exit:
ret;
}
/**
* Here are some potentially counterintuitive facts about PTX and SASS:
*
* - NVCC unrolls loops more aggressively than any other mainstream compiler.
*
* - If you are coming from CPU side, you shouldn't expect instructions to be
* forward-compatible or to have better or equal performance on the next
* generation! Entire instruction families may live for just one generation
* and be completely abandoned in a couple of years.
*
* - Some instructions, like Tensor Core Gen 4 and 5 operations can't work
* with both multiplication operands in GPU registers. At least one of them
* has to be in the shared memory. Moreover, they may work up to 10% faster
* with both arguments in shared memory!
*
* Because only one `.version` directive can be placed in each file, for newer
* kernels, go to `less_slow_sm80.ptx` for Ampere and `less_slow_sm90a.ptx`
* for Hopper.
*
* @see PTX module-level directives:
* https://docs.nvidia.com/cuda/parallel-thread-execution/#ptx-module-directives
*/