From 15f94ac964e0760639b94fc157643429210a73cd Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Wed, 5 Feb 2025 13:03:46 -0300 Subject: [PATCH 1/9] TC_SEARCH_OVER_SHAPE to search multiple TC shapes (#8793) * squash search over search * refactor assert * init benchmark * cleaner get_kernel_actions * cleaner get_kernel_actions * add comment --- test/test_linearizer.py | 4 +++- test/test_search.py | 18 ++++++++++++++++++ tinygrad/engine/search.py | 13 ++++++++++--- tinygrad/helpers.py | 2 +- 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 3bb3e77ac2b05..fa3df34608d30 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1112,7 +1112,9 @@ def test_tensor_cores_multi_reduce(self): # check that get_kernel_actions produces all 9 options from tinygrad.engine.search import get_kernel_actions tc_actions = [k for i, k in get_kernel_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC] - assert len(tc_actions) == 9, f"get_kernel_actions should contain 9 possible TC actions, only got {len(tc_actions)}" + + available_tc = len([x for x in Device[Device.DEFAULT].renderer.tensor_cores if x.dtype_in == tc.dtype_in and x.dtype_out == tc.dtype_out]) + assert len(tc_actions) == 9 * available_tc, f"should contain 9 possible TC actions for every available TC, got {len(tc_actions)}" @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_unroll_phi(self): diff --git a/test/test_search.py b/test/test_search.py index d0d6cf9114f2a..d6bda9aa5df0f 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -102,6 +102,24 @@ def test_get_kernel_actions(self): if Opt(OptOps.GROUPTOP, 0, 0) in actions: assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, arg=3)]) == 0, "did not de-dup GROUPTOP" + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + def test_search_over_shape(self): + from test.test_linearizer import helper_realized_ast + from tinygrad.engine.search import get_kernel_actions + + dtype_pairs = [(tc.dtype_in, tc.dtype_out) for tc in Device[Device.DEFAULT].renderer.tensor_cores] + multi_shape_dtype_pairs = [dts for dts in dtype_pairs if dtype_pairs.count(dts) > 1] + + if len(multi_shape_dtype_pairs) == 0: raise unittest.SkipTest("only one tc available per dtype pair to search over") + + for (dtype_in, dtype_out) in multi_shape_dtype_pairs: + a = Tensor.rand(16, 16, dtype=dtype_in) + b = Tensor.rand(16, 16, dtype=dtype_in) + realized_ast, _ = helper_realized_ast(a.matmul(b, acc_dtype=dtype_out)) + + lins = get_kernel_actions(Kernel(realized_ast)).values() + assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1 + def test_filter_global_buffer(self): # taken from https://github.com/tinygrad/tinygrad/issues/4612 ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 87ff15aea8833..443d22f0e963b 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -5,7 +5,7 @@ from tinygrad.ops import UOp, Ops, Variable, sym_infer from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name -from tinygrad.helpers import IGNORE_BEAM_CACHE +from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE from tinygrad.dtype import ImageDType, PtrDType from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError from tinygrad.tensor import Tensor @@ -102,8 +102,15 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: # get dictionary of all possible actions def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: - acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) - for i,a in enumerate(actions): + acted_lins, max_up, max_lcl, kernel_actions = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024), actions + + if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first + for i, action in enumerate(kernel_actions): + if action.op == OptOps.TC and (tc_arg := cast(tuple, action.arg))[0] == -1: + # replace every tc_action with default tc with one tc_action for each available tc + kernel_actions[i:i+1] = [Opt(op=OptOps.TC, axis=action.axis, arg=(tc_select, tc_arg[1])) for tc_select,_ in enumerate(lin.opts.tensor_cores)] + + for i,a in enumerate(kernel_actions): if a.axis is not None and a.op is not OptOps.TC: if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue lin2 = lin.copy() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 42fb309546fea..09054b4419b09 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -106,7 +106,7 @@ def __lt__(self, x): return self.value < x JIT = ContextVar("JIT", 2 if platform.system() == 'Darwin' and ('Intel' in platform.processor() or 'i386' in platform.processor()) else 1) WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) USE_TC, TC_SELECT, TC_OPT, AMX = ContextVar("TC", 1), ContextVar("TC_SELECT", -1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0) -TRANSCENDENTAL = ContextVar("TRANSCENDENTAL", 1) +TRANSCENDENTAL, TC_SEARCH_OVER_SHAPE = ContextVar("TRANSCENDENTAL", 1), ContextVar("TC_SEARCH_OVER_SHAPE", 1) FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) From 0f6109ec007518167d0be64f46d09fe4103d6cff Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Wed, 5 Feb 2025 15:10:05 -0300 Subject: [PATCH 2/9] hotfix bug in `get_kernel_actions` after `TC_SEARCH_OVER_SHAPE` was introduced (#8904) * hotfix search bug * copy actions --- tinygrad/engine/search.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 443d22f0e963b..b22f63d459b8e 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -102,7 +102,8 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> list[Buffer]: # get dictionary of all possible actions def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: - acted_lins, max_up, max_lcl, kernel_actions = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024), actions + acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) + kernel_actions = actions.copy() if TC_SEARCH_OVER_SHAPE and len(lin.applied_opts) == 0: # tensor core opts must be first for i, action in enumerate(kernel_actions): @@ -112,7 +113,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: for i,a in enumerate(kernel_actions): if a.axis is not None and a.op is not OptOps.TC: - if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue + if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in kernel_actions): continue lin2 = lin.copy() try: lin2.apply_opt(a) From e71497aabc99be08b4831461143152530375bcaa Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 5 Feb 2025 19:47:20 +0100 Subject: [PATCH 3/9] move assign ShapeTracker check to pattern matcher [pr] (#8906) * move assign ShapeTracker check to pattern matcher [pr] * rename the st uop to view --- tinygrad/engine/schedule.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f8307ad028f93..4f6721e48dec0 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -349,6 +349,16 @@ def _append_buf(ctx:KernelContext, x:UOp) -> UOp: ctx.bufs.append(x) return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1) +def check_load_st(glbl:UOp, view:UOp): + if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return + # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine + if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: return + # if it has a single view and it's equal when you shrink a contig, it's fine + if len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): return + # otherwise, it's not fine + raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" + +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) + to_si = PatternMatcher([ # BUFFER -> DEFINE_GLOBAL (UPat(Ops.BUFFER, name="x"), _append_buf), @@ -365,6 +375,8 @@ def _append_buf(ctx:KernelContext, x:UOp) -> UOp: (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), # once images are loaded they become the base dtype (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), + # if this kernel also assigns to the loaded buffer, ensure we can index it correctly + (UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st), ]) def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): @@ -384,17 +396,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> UOp: # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph") # PRELOAD tells the toposort this kernel should run before ASSIGN - if x.op is Ops.PRELOAD: - assign_preloads[x.buf_uop] = None - # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD - if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous: - # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine - if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass - # if it has a single view and it's equal when you shrink a contig, it's fine - elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass - # otherwise, it's not fine - else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" - +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) + if x.op is Ops.PRELOAD: assign_preloads[x.buf_uop] = None # NOTE: we only add the metadata for fused tensors metadata = tuple(dedup(m for x in pre.toposort if x.op is not Ops.BUFFER and (m:=ctx.ops_metadata.get(x)) is not None)) return UOp(Ops.KERNEL, src=tuple(si_ctx.bufs), arg=Kernel(ast, metadata)) From aec3b8d5158149ae4e70083eac6cfa94e92020db Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Wed, 5 Feb 2025 16:13:01 -0300 Subject: [PATCH 4/9] add regression test: `test_get_kernel_actions_preserves_actions_state` (#8907) * test_get_kernel_actions_preserves_actions_state * simplify * simplify * refactor assert message --- test/test_search.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_search.py b/test/test_search.py index d6bda9aa5df0f..074008118f26a 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -120,6 +120,17 @@ def test_search_over_shape(self): lins = get_kernel_actions(Kernel(realized_ast)).values() assert len(set(lin.tensor_core.dims for lin in lins if lin.tensor_core is not None)) > 1 + def test_get_kernel_actions_preserves_actions_state(self): + from test.test_linearizer import helper_realized_ast + from tinygrad.engine.search import get_kernel_actions + a = Tensor.rand(16, 16) + b = Tensor.rand(16, 16) + realized_ast, _ = helper_realized_ast(a @ b) + actions_before = actions.copy() + get_kernel_actions(Kernel(realized_ast)) + actions_after = actions.copy() + assert actions_after == actions_before, "actions state was not preserved" + def test_filter_global_buffer(self): # taken from https://github.com/tinygrad/tinygrad/issues/4612 ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( From bff7c70eef18d2365eb5d47be4f808eac299f934 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 5 Feb 2025 22:38:59 +0300 Subject: [PATCH 5/9] hcq: better var check (#8908) --- tinygrad/runtime/support/hcq.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index e2136fcd17aa3..1eede775eb81c 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -4,7 +4,7 @@ from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up from tinygrad.renderer import Renderer from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent -from tinygrad.ops import sym_infer, sint, Variable +from tinygrad.ops import sym_infer, sint, Variable, UOp from tinygrad.runtime.autogen import libc class HWInterface: @@ -83,10 +83,10 @@ def q(self, *values): """ for v in values: - if isinstance(v, int): self._q.append(v) - else: + if isinstance(v, UOp): self.q_sints.append((len(self._q), self._new_sym(v))) self._q.append(0xbadc0ded) + else: self._q.append(v) # *** common commands *** From 9307572fe3e8eb8f4aa98f17698300e486179b4e Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 5 Feb 2025 15:15:59 -0500 Subject: [PATCH 6/9] Ops.POW and transcendental (#8911) --- tinygrad/codegen/rewriter.py | 3 ++- tinygrad/codegen/transcendental.py | 10 ++++++++++ tinygrad/gradient.py | 1 + tinygrad/ops.py | 14 ++++++++++---- tinygrad/tensor.py | 9 +++------ 5 files changed, 26 insertions(+), 11 deletions(-) diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index 22da6d9871d60..c49bcf2e4d2d6 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -6,7 +6,7 @@ from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same -from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES +from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.renderer import Renderer # ***** float4/image store handling ***** @@ -124,6 +124,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: def get_late_rewrite_patterns(ops, force_transcendental=False): pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \ ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental] + pat.append((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))) # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 753aed378754f..5611cae1d1f9d 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -254,3 +254,13 @@ def xlog2(d:UOp) -> UOp: r = d.ne(d).where(r.const_like(math.nan), r) # log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal. return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf)) + +def xpow(base:UOp, exponent:UOp) -> UOp: + # start with b ** e = exp2(e * log2(b)) + ret = (base < 0).where(-base, base).log2().mul(exponent).exp2() + # negative base adjustment: nan for non-integer exponent and -1 for odd exponent + adj = (base < 0).where((exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)).where( + ret.const_like(math.nan), + (exponent.cast(dtypes.int32).cast(dtypes.uint32)%2).eq(1).where(ret.const_like(-1), ret.const_like(1))), ret.const_like(1)) + # fix 0 ** 0 = 1 + return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * adj) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 91d0d3a59994b..54e6a3465f96d 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -22,6 +22,7 @@ def reduce_gradient(ctx:UOp, ret:UOp): (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)), (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)), (UPat(Ops.ADD), lambda ctx: (ctx, ctx)), + (UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))), (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)), (ret.src[0] 0 else -math.inf + python_alu: dict[Ops, Callable] = { Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2, Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), - Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, + Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max, Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0, diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 529554f18a3f9..95bb5837bdb78 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3313,12 +3313,9 @@ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: base, exponent = self._broadcasted(x, reverse=reverse) # TODO: int pow if not base.is_floating_point(): raise RuntimeError("base needs to be float") - # start with b ** e = exp(e * log(b)) - ret = base.abs().log().mul(exponent).exp() - # negative base adjustment: nan for non-integer exponent and -1 for odd exponent - adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1) - # fix 0 ** 0 = 1 - ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj) + + # NOTE: pow(int, float) -> int + ret = base._apply_uop(UOp.pow, exponent) return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret def maximum(self, x:Union[Tensor, ConstType]) -> Tensor: From 189bfa164e95a8439a6a7496e90bc88e5b535931 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 5 Feb 2025 15:35:21 -0500 Subject: [PATCH 7/9] enable backward test for pow(neg const ** x) (#8912) backward works now. 0**x still does not work because it's a special case fixed in transcendental --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 091201244f705..4b42ffdedbc37 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -607,6 +607,7 @@ def test_pow(self): helper_test_op([], lambda: b**1.1, lambda: a**1.1) def test_pow_const(self): + helper_test_op([(45,65)], lambda x: x**0.0) helper_test_op([(45,65)], lambda x: x**1.0) helper_test_op([(45,65)], lambda x: x**-1.0) helper_test_op([(45,65)], lambda x: 1.0**x) @@ -616,8 +617,7 @@ def test_pow_const(self): helper_test_op([()], lambda x: 2.0**x) # TODO: fix backward helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True) - # TODO: fix backward, should be nan - helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True) + helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]]) @unittest.skip("not supported") def test_pow_int(self): From 17f9b1cef65c7c51d4ad0dd863d4bff9d6865437 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 6 Feb 2025 00:02:09 +0300 Subject: [PATCH 8/9] am: load fw based on versions (#8913) * am: load fw based on versions * ops * ops2 --- tinygrad/runtime/support/am/amdev.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index bf2f6a5b90d6b..b06a2abcae25c 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -32,11 +32,13 @@ def write(self, value=0, **kwargs): def read(self, **kwargs): return self.adev.rreg(self.reg_off) & self._parse_kwargs(**kwargs)[0] class AMFirmware: - def __init__(self): + def __init__(self, adev): + def fmt_ver(hwip): return f"{adev.ip_versions[hwip]//10000}_{(adev.ip_versions[hwip]//100)%100}_{adev.ip_versions[hwip]%100}" + # Load SOS firmware self.sos_fw = {} - blob, sos_hdr = self.load_fw("psp_13_0_0_sos.bin", am.struct_psp_firmware_header_v2_0) + blob, sos_hdr = self.load_fw(f"psp_{fmt_ver(am.MP0_HWIP)}_sos.bin", am.struct_psp_firmware_header_v2_0) fw_bin = sos_hdr.psp_fw_bin for fw_i in range(sos_hdr.psp_fw_bin_count): @@ -48,17 +50,17 @@ def __init__(self): self.ucode_start: dict[str, int] = {} self.descs: list[tuple[int, memoryview]] = [] - blob, hdr = self.load_fw("smu_13_0_0.bin", am.struct_smc_firmware_header_v1_0) + blob, hdr = self.load_fw(f"smu_{fmt_ver(am.MP1_HWIP)}.bin", am.struct_smc_firmware_header_v1_0) self.smu_psp_desc = self.desc(am.GFX_FW_TYPE_SMU, blob, hdr.header.ucode_array_offset_bytes, hdr.header.ucode_size_bytes) # SDMA firmware - blob, hdr = self.load_fw("sdma_6_0_0.bin", am.struct_sdma_firmware_header_v2_0) + blob, hdr = self.load_fw(f"sdma_{fmt_ver(am.SDMA0_HWIP)}.bin", am.struct_sdma_firmware_header_v2_0) self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH0, blob, hdr.header.ucode_array_offset_bytes, hdr.ctx_ucode_size_bytes)] self.descs += [self.desc(am.GFX_FW_TYPE_SDMA_UCODE_TH1, blob, hdr.ctl_ucode_offset, hdr.ctl_ucode_size_bytes)] # PFP, ME, MEC firmware for (fw_name, fw_cnt) in [('PFP', 2), ('ME', 2), ('MEC', 4)]: - blob, hdr = self.load_fw(f"gc_11_0_0_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0) + blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_{fw_name.lower()}.bin", am.struct_gfx_firmware_header_v2_0) # Code part self.descs += [self.desc(getattr(am, f'GFX_FW_TYPE_RS64_{fw_name}'), blob, hdr.header.ucode_array_offset_bytes, hdr.ucode_size_bytes)] @@ -69,12 +71,12 @@ def __init__(self): self.ucode_start[fw_name] = hdr.ucode_start_addr_lo | (hdr.ucode_start_addr_hi << 32) # IMU firmware - blob, hdr = self.load_fw("gc_11_0_0_imu.bin", am.struct_imu_firmware_header_v1_0) + blob, hdr = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_imu.bin", am.struct_imu_firmware_header_v1_0) imu_i_off, imu_i_sz, imu_d_sz = hdr.header.ucode_array_offset_bytes, hdr.imu_iram_ucode_size_bytes, hdr.imu_dram_ucode_size_bytes self.descs += [self.desc(am.GFX_FW_TYPE_IMU_I, blob, imu_i_off, imu_i_sz), self.desc(am.GFX_FW_TYPE_IMU_D, blob, imu_i_off + imu_i_sz, imu_d_sz)] # RLC firmware - blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw("gc_11_0_0_rlc.bin", am.struct_rlc_firmware_header_v2_0, + blob, hdr0, hdr1, hdr2, hdr3 = self.load_fw(f"gc_{fmt_ver(am.GC_HWIP)}_rlc.bin", am.struct_rlc_firmware_header_v2_0, am.struct_rlc_firmware_header_v2_1, am.struct_rlc_firmware_header_v2_2, am.struct_rlc_firmware_header_v2_3) for mem in ['GPM', 'SRM']: @@ -263,7 +265,7 @@ def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_ba # Memory manager & firmware self.mm = AMMemoryManager(self, self.vram_size) - self.fw = AMFirmware() + self.fw = AMFirmware(self) # Initialize IP blocks self.soc21:AM_SOC21 = AM_SOC21(self) From cad44f5f4270a4bf19c90184881a96140030e281 Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Wed, 5 Feb 2025 18:56:37 -0300 Subject: [PATCH 9/9] add Half-Precision Accumulation Support for Tensor Cores in NV, CUDA, and PTX (#8680) * ptx and nv rendering refactor to work with half acc * ptx fix! * use same reg for acc and out * fix comment * another fix * minor change in commet * fix --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> --- tinygrad/renderer/cstyle.py | 17 ++++++++++------- tinygrad/renderer/ptx.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 23a8e7375900c..8c7131d0760c1 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -300,10 +300,11 @@ class CUDARenderer(CStyleLanguage): local_max = (1024, 1024, 64) shared_max = 49152 # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions - tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts, - swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)]] - tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.half, dtype_out=dtypes.float, opts=cuda_tc_opts, - swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5))))] + tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts, + swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float), + (dtypes.half,dtypes.half)]] + tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts, + swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]] tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts, swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))] @@ -344,7 +345,8 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include ") prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}] - dt_map = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" } + dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16" } + dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" } for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes] wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] @@ -353,10 +355,11 @@ def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): # mma operands => {c}, {a}, {b}, {c} prefix.append(f"""__device__ {wmma_dtypes[2]} __{name}({wmma_dtypes[0]} a, {wmma_dtypes[1]} b, {wmma_dtypes[2]} c){{ - int *a_pk = (int *)(&a), *b_pk = (int *)(&b);\n asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32" + int *a_pk = (int *)(&a), *b_pk = (int *)(&b), *c_pk = (int *)(&c); + asm("mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}" "{{{", ".join(operands[:n_operands[2]])}}}, {{{", ".join(operands[n_operands[2]:n_operands[2]+n_operands[0]])}}}," "{{{", ".join(operands[-n_operands[1]:])}}}, {{{", ".join(operands[:n_operands[2]])}}};" - : {", ".join([f'"+f"(c.{_nms[i]})' for i in range(n_operands[2])])} + : {", ".join([f'"+r"(c_pk[{i}])' for i in range(n_operands[2])])} : {", ".join([f'"r"(a_pk[{i}])' for i in range(n_operands[0])])}, {", ".join([f'"r"(b_pk[{i}])' for i in range(n_operands[1])])}); return c;\n}}""") diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 5cbc2face4bff..3f01673e9c1f2 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -65,7 +65,7 @@ def render_wmma(ctx: "PTXRenderer", wmma: UOp): if (elems_per_reg := 4 // src.dtype.scalar().itemsize) == 1: yield f"mov.b32 {reg}, {ctx.r[src][i]};" else: yield f"mov.b32 {reg}, {{{', '.join(ctx.r[src][i * elems_per_reg : (i+1) * elems_per_reg])}}};" - dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32"} + dt_map_in, dt_map_out = {dtypes.float: "tf32", dtypes.half: "f16"}, {dtypes.float: "f32", dtypes.half: "f16"} yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.{dt_map_out[dtype_out]}.{dt_map_in[dtype_in]}.{dt_map_in[dtype_in]}.{dt_map_out[dtype_out]}{" "*12}'+\ f'{{{", ".join(ctx.wmma_r[2])}}}, {{{", ".join(ctx.wmma_r[0])}}}, {{{", ".join(ctx.wmma_r[1])}}}, {{{", ".join(ctx.wmma_r[2])}}};'