Skip to content

Commit cad44f5

Browse files
ignaciosicaQazalin
andauthored
add Half-Precision Accumulation Support for Tensor Cores in NV, CUDA, and PTX (tinygrad#8680)
* ptx and nv rendering refactor to work with half acc * ptx fix! * use same reg for acc and out * fix comment * another fix * minor change in commet * fix --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
1 parent 17f9b1c commit cad44f5

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

tinygrad/renderer/cstyle.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,10 +300,11 @@ class CUDARenderer(CStyleLanguage):
300300
local_max = (1024, 1024, 64)
301301
shared_max = 49152
302302
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
303-
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts,
304-
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)]]
305-
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.half, dtype_out=dtypes.float, opts=cuda_tc_opts,
306-
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5))))]
303+
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
304+
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
305+
(dtypes.half,dtypes.half)]]
306+
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
307+
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
307308
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
308309
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
309310

@@ -344,7 +345,8 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
344345
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
345346
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
346347

347-
dt_map = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
348+
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
349+
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
348350
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
349351
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
350352
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
@@ -353,10 +355,11 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
353355

354356
# mma operands => {c}, {a}, {b}, {c}
355357
prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
356-
int *a_pk = (int *)(&a), *b_pk = (int *)(&b);\n asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32"
358+
int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c);
359+
asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}"
357360
"{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
358361
"{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
359-
: {", ".join([f'"+f"(c.{_nms[i]})' for i in range(n_operands[2])])}
362+
: {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])}
360363
: {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
361364
return c;\n}}""")
362365

tinygrad/renderer/ptx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def render_wmma(ctx: "PTXRenderer", wmma: UOp):
6565
if (elems_per_reg := 4 // src.dtype.scalar().itemsize) == 1: yield f"mov.b32 {reg}, {ctx.r[src][i]};"
6666
else: yield f"mov.b32 {reg}, {{{', '.join(ctx.r[src][i * elems_per_reg : (i+1) * elems_per_reg])}}};"
6767

68-
dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32"}
68+
dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32", dtypes.half: "f16"}
6969
yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}{" "*12}'+\
7070
f'{{{", ".join(ctx.wmma_r[2])}}}, {{{", ".join(ctx.wmma_r[0])}}}, {{{", ".join(ctx.wmma_r[1])}}}, {{{", ".join(ctx.wmma_r[2])}}};'
7171

0 commit comments

Comments
 (0)