Skip to content

[pull] master from tinygrad:master #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 5, 2025
4 changes: 3 additions & 1 deletion test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,9 @@ def test_tensor_cores_multi_reduce(self):
# check that get_kernel_actions produces all 9 options
from tinygrad.engine.search import get_kernel_actions
tc_actions = [k for i, k in get_kernel_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
assert len(tc_actions) == 9, f"get_kernel_actions should contain 9 possible TC actions, only got {len(tc_actions)}"

available_tc = len([x for x in Device[Device.DEFAULT].renderer.tensor_cores if x.dtype_in == tc.dtype_in and x.dtype_out == tc.dtype_out])
assert len(tc_actions) == 9 * available_tc, f"should contain 9 possible TC actions for every available TC, got {len(tc_actions)}"

@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_unroll_phi(self):
Expand Down
4 changes: 2 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def test_pow(self):
helper_test_op([], lambda: b**1.1, lambda: a**1.1)

def test_pow_const(self):
helper_test_op([(45,65)], lambda x: x**0.0)
helper_test_op([(45,65)], lambda x: x**1.0)
helper_test_op([(45,65)], lambda x: x**-1.0)
helper_test_op([(45,65)], lambda x: 1.0**x)
Expand All @@ -616,8 +617,7 @@ def test_pow_const(self):
helper_test_op([()], lambda x: 2.0**x)
# TODO: fix backward
helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
# TODO: fix backward, should be nan
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]])

@unittest.skip("not supported")
def test_pow_int(self):
Expand Down
29 changes: 29 additions & 0 deletions test/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,35 @@ def test_get_kernel_actions(self):
if Opt(OptOps.GROUPTOP, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, arg=3)]) == 0, "did not de-dup GROUPTOP"

@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_search_over_shape(self):
from test.test_linearizer import helper_realized_ast
from tinygrad.engine.search import get_kernel_actions

dtype_pairs = [(tc.dtype_in, tc.dtype_out) for tc in Device[Device.DEFAULT].renderer.tensor_cores]
multi_shape_dtype_pairs = [dts for dts in dtype_pairs if dtype_pairs.count(dts) > 1]

if len(multi_shape_dtype_pairs) == 0: raise unittest.SkipTest("only one tc available per dtype pair to search over")

for (dtype_in, dtype_out) in multi_shape_dtype_pairs:
a = Tensor.rand(16, 16, dtype=dtype_in)
b = Tensor.rand(16, 16, dtype=dtype_in)
realized_ast, _ = helper_realized_ast(a.matmul(b, acc_dtype=dtype_out))

lins = get_kernel_actions(Kernel(realized_ast)).values()
assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1

def test_get_kernel_actions_preserves_actions_state(self):
from test.test_linearizer import helper_realized_ast
from tinygrad.engine.search import get_kernel_actions
a = Tensor.rand(16, 16)
b = Tensor.rand(16, 16)
realized_ast, _ = helper_realized_ast(a @ b)
actions_before = actions.copy()
get_kernel_actions(Kernel(realized_ast))
actions_after = actions.copy()
assert actions_after == actions_before, "actions state was not preserved"

def test_filter_global_buffer(self):
# taken from https://github.com/tinygrad/tinygrad/issues/4612
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/codegen/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer

# ***** float4/image store handling *****
Expand Down Expand Up @@ -124,6 +124,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
pat.append((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src)))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops:
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
Expand Down
10 changes: 10 additions & 0 deletions tinygrad/codegen/transcendental.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,13 @@ def xlog2(d:UOp) -> UOp:
r = d.ne(d).where(r.const_like(math.nan), r)
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))

def xpow(base:UOp, exponent:UOp) -> UOp:
# start with b ** e = exp2(e * log2(b))
ret = (base < 0).where(-base, base).log2().mul(exponent).exp2()
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
adj = (base < 0).where((exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)).where(
ret.const_like(math.nan),
(exponent.cast(dtypes.int32).cast(dtypes.uint32)%2).eq(1).where(ret.const_like(-1), ret.const_like(1))), ret.const_like(1))
# fix 0 ** 0 = 1
return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * adj)
24 changes: 13 additions & 11 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ def _append_buf(ctx:KernelContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)

def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return
# otherwise, it's not fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))

to_si = PatternMatcher([
# BUFFER -> DEFINE_GLOBAL
(UPat(Ops.BUFFER, name="x"), _append_buf),
Expand All @@ -365,6 +375,8 @@ def _append_buf(ctx:KernelContext, x:UOp) -> UOp:
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
# once images are loaded they become the base dtype
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
])

def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
Expand All @@ -384,17 +396,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp:
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
# PRELOAD tells the toposort this kernel should run before ASSIGN
if x.op is Ops.PRELOAD:
assign_preloads[x.buf_uop] = None
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
# if it has a single view and it's equal when you shrink a contig, it's fine
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
# otherwise, it's not fine
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
if x.op is Ops.PRELOAD: assign_preloads[x.buf_uop] = None
# NOTE: we only add the metadata for fused tensors
metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None))
return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata))
Expand Down
14 changes: 11 additions & 3 deletions tinygrad/engine/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tinygrad.ops import UOp, Ops, Variable, sym_infer
from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.helpers import IGNORE_BEAM_CACHE
from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
from tinygrad.dtype import ImageDType, PtrDType
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
from tinygrad.tensor import Tensor
Expand Down Expand Up @@ -103,9 +103,17 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]:
# get dictionary of all possible actions
def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
for i,a in enumerate(actions):
kernel_actions = actions.copy()

if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first
for i, action in enumerate(kernel_actions):
if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1:
# replace every tc_action with default tc with one tc_action for each available tc
kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)]

for i,a in enumerate(kernel_actions):
if a.axis is not None and a.op is not OptOps.TC:
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue
lin2 = lin.copy()
try:
lin2.apply_opt(a)
Expand Down
1 change: 1 addition & 0 deletions tinygrad/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def reduce_gradient(ctx:UOp, ret:UOp):
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
(UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))),
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __lt__(self, x): return self.value < x
JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1)
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0)
TRANSCENDENTAL = ContextVar("TRANSCENDENTAL", 1)
TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1)
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
Expand Down
14 changes: 10 additions & 4 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def sqrt(self): return self.alu(Ops.SQRT)
def sin(self): return self.alu(Ops.SIN)
def log2(self): return self.alu(Ops.LOG2)
def exp2(self): return self.alu(Ops.EXP2)
def pow(self, x): return self.alu(Ops.POW, x)

# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
Expand Down Expand Up @@ -133,7 +134,7 @@ class Ops(FastEnum):

# BinaryOps
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702

# TernaryOps
WHERE = auto(); MULACC = auto() # noqa: E702
Expand All @@ -155,7 +156,7 @@ class Ops(FastEnum):
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
Ops.SUB, Ops.FDIV}
Ops.SUB, Ops.FDIV, Ops.POW}
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)

Expand All @@ -175,7 +176,7 @@ class GroupOp:
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}

# do not preserve f(0) = 0
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}

All = set(Ops)

Expand Down Expand Up @@ -675,10 +676,15 @@ def safe_exp2(x):
try: return 2 ** x
except OverflowError: return math.inf

def safe_pow(x, y):
try: return math.nan if isinstance(p:=pow(x, y), complex) else p
except ZeroDivisionError: return math.inf
except ValueError: return math.inf if x > 0 else -math.inf

python_alu: dict[Ops, Callable] = {
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
Expand Down
17 changes: 10 additions & 7 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,11 @@ class CUDARenderer(CStyleLanguage):
local_max = (1024, 1024, 64)
shared_max = 49152
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)]]
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.half, dtype_out=dtypes.float, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5))))]
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
(dtypes.half,dtypes.half)]]
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]

Expand Down Expand Up @@ -344,7 +345,8 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]

dt_map = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" }
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
Expand All @@ -353,10 +355,11 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):

# mma operands => {c}, {a}, {b}, {c}
prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{
int *a_pk = (int *)(&a), *b_pk = (int *)(&b);\n asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32"
int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c);
asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}"
"{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}},"
"{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};"
: {", ".join([f'"+f"(c.{_nms[i]})' for i in range(n_operands[2])])}
: {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])}
: {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])});
return c;\n}}""")

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/renderer/ptx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def render_wmma(ctx: "PTXRenderer", wmma: UOp):
if (elems_per_reg := 4 // src.dtype.scalar().itemsize) == 1: yield f"mov.b32 {reg}, {ctx.r[src][i]};"
else: yield f"mov.b32 {reg}, {{{', '.join(ctx.r[src][i * elems_per_reg : (i+1) * elems_per_reg])}}};"

dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32"}
dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32", dtypes.half: "f16"}
yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}{" "*12}'+\
f'{{{", ".join(ctx.wmma_r[2])}}}, {{{", ".join(ctx.wmma_r[0])}}}, {{{", ".join(ctx.wmma_r[1])}}}, {{{", ".join(ctx.wmma_r[2])}}};'

Expand Down
Loading
Loading