@@ -300,10 +300,11 @@ class CUDARenderer(CStyleLanguage):
300
300
local_max = (1024 , 1024 , 64 )
301
301
shared_max = 49152
302
302
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
303
- tc_81616 = [TensorCore (dims = (8 ,16 ,16 ), threads = 32 , elements_per_thread = (8 ,4 ,4 ), dtype_in = di ,dtype_out = do , opts = cuda_tc_opts ,
304
- swizzle = (((6 ,7 ,2 ,3 ,4 ),(0 ,1 ,9 ,5 ,10 ,8 )), ((6 ,7 ,9 ,0 ,1 ),(2 ,3 ,4 ,10 ,5 ,8 )))) for di ,do in [(dtypes .half ,dtypes .float ), (dtypes .bfloat16 ,dtypes .float )]]
305
- tc_8168_f16 = [TensorCore (dims = (8 ,16 ,8 ), threads = 32 , elements_per_thread = (4 ,2 ,4 ), dtype_in = dtypes .half , dtype_out = dtypes .float , opts = cuda_tc_opts ,
306
- swizzle = (((6 ,7 ,2 ,3 ,4 ),(0 ,1 ,8 ,5 ,9 )), ((6 ,7 ,8 ,0 ,1 ),(2 ,3 ,4 ,9 ,5 ))))]
303
+ tc_81616 = [TensorCore (dims = (8 ,16 ,16 ), threads = 32 , elements_per_thread = (8 ,4 ,4 ), dtype_in = di , dtype_out = do , opts = cuda_tc_opts ,
304
+ swizzle = (((6 ,7 ,2 ,3 ,4 ),(0 ,1 ,9 ,5 ,10 ,8 )), ((6 ,7 ,9 ,0 ,1 ),(2 ,3 ,4 ,10 ,5 ,8 )))) for di ,do in [(dtypes .half ,dtypes .float ), (dtypes .bfloat16 ,dtypes .float ),
305
+ (dtypes .half ,dtypes .half )]]
306
+ tc_8168_f16 = [TensorCore (dims = (8 ,16 ,8 ), threads = 32 , elements_per_thread = (4 ,2 ,4 ), dtype_in = di , dtype_out = do , opts = cuda_tc_opts ,
307
+ swizzle = (((6 ,7 ,2 ,3 ,4 ),(0 ,1 ,8 ,5 ,9 )), ((6 ,7 ,8 ,0 ,1 ),(2 ,3 ,4 ,9 ,5 )))) for di ,do in [(dtypes .half ,dtypes .float ), (dtypes .half ,dtypes .half )]]
307
308
tc_8168_tf32 = [TensorCore (dims = (8 ,16 ,8 ), threads = 32 , elements_per_thread = (4 ,2 ,4 ), dtype_in = dtypes .float , dtype_out = dtypes .float , opts = cuda_tc_opts ,
308
309
swizzle = (((5 ,6 ,2 ,3 ,4 ),(0 ,1 ,8 ,9 ,7 )), ((5 ,6 ,8 ,0 ,1 ),(2 ,3 ,4 ,9 ,7 ))))]
309
310
@@ -344,7 +345,8 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
344
345
if any (dt .scalar () == dtypes .bfloat16 for dt in used_dtypes ): prefix .append ("#include <cuda_bf16.h>" )
345
346
prefix += [self .render_vector_prefix (dt ) for dt in used_dtypes if dt .count in (4 ,8 ) and dt .scalar () in {dtypes .half , dtypes .bfloat16 }]
346
347
347
- dt_map = { dtypes .float : "tf32" , dtypes .half : "f16" , dtypes .bfloat16 : "bf16" }
348
+ dt_map_in = { dtypes .float : "tf32" , dtypes .half : "f16" , dtypes .bfloat16 : "bf16" }
349
+ dt_map_out = { dtypes .float : "f32" , dtypes .half : "f16" }
348
350
for name , (N , M , K ), dtype_in , dtype_out , _ , _ , upcast_axes , _ in dedup ([uop .arg for uop in uops if uop .op is Ops .WMMA ]):
349
351
upcast_sizes = [prod (size for _ , size in upcast ) for upcast in upcast_axes ]
350
352
wmma_dtypes = [self .render_dtype (dtype .vec (size )) for dtype , size in zip ([dtype_in , dtype_in , dtype_out ], upcast_sizes )]
@@ -353,10 +355,11 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
353
355
354
356
# mma operands => {c}, {a}, {b}, {c}
355
357
prefix .append (f"""__device__ { wmma_dtypes [2 ]} __{ name } ({ wmma_dtypes [0 ]} a, { wmma_dtypes [1 ]} b, { wmma_dtypes [2 ]} c){{
356
- int *a_pk = (int *)(&a), *b_pk = (int *)(&b);\n asm("mma.sync.aligned.m{ M } n{ N } k{ K } .row.col.f32.{ dt_map [dtype_in ]} .{ dt_map [dtype_in ]} .f32"
358
+ int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c);
359
+ asm("mma.sync.aligned.m{ M } n{ N } k{ K } .row.col.{ dt_map_out [dtype_out ]} .{ dt_map_in [dtype_in ]} .{ dt_map_in [dtype_in ]} .{ dt_map_out [dtype_out ]} "
357
360
"{{{ ", " .join (operands [:n_operands [2 ]])} }}, {{{ ", " .join (operands [n_operands [2 ]:n_operands [2 ]+ n_operands [0 ]])} }},"
358
361
"{{{ ", " .join (operands [- n_operands [1 ]:])} }}, {{{ ", " .join (operands [:n_operands [2 ]])} }};"
359
- : { ", " .join ([f'"+f"(c. { _nms [ i ] } )' for i in range (n_operands [2 ])])}
362
+ : { ", " .join ([f'"+r"(c_pk[ { i } ] )' for i in range (n_operands [2 ])])}
360
363
: { ", " .join ([f'"r"(a_pk[{ i } ])' for i in range (n_operands [0 ])])} , { ", " .join ([f'"r"(b_pk[{ i } ])' for i in range (n_operands [1 ])])} );
361
364
return c;\n }}""" )
362
365
0 commit comments