From 72e1f41f8e10553ca6eea7aec58ebef6409c4319 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 1 Feb 2025 11:25:44 -0500 Subject: [PATCH 1/5] add unbind_vars pattern matcher (#8851) * add unbind_vars pattern matcher [pr] * this can be cvar * this is empty --- .../external/process_replay/process_replay.py | 2 +- tinygrad/engine/schedule.py | 25 +++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 69787e689879d..c6351e4166fdb 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -32,7 +32,7 @@ class ProcessReplayWarning(Warning): pass def recreate_sched(ast:UOp) -> UOp: # NOTE: process replay isn't meant to actually schedule anything - return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast + return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list)), {}).ast def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str: k = Kernel(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 54ebb844b3e4a..1083240ea787d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -36,7 +36,6 @@ def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in sel @dataclass(frozen=True) class ScheduleContext: tensor_uops: dict[UOp, list[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop - var_vals: dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op @@ -165,11 +164,16 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), ]) -def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem: - # apply swizzles (pushing views from the middle of the AST to BUFFER ops edges) - sink = graph_rewrite(graph_rewrite(pre, view_left), view_right) +def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): + ctx[var.replace(src=())] = val.arg + return var +unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) + +def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> ScheduleItem: + # unbind_vars + push views to edges + sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right) # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL - ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals)) + ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals)) # deal with ASSIGN if len(ctx.assigns) != 0: assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer] @@ -381,11 +385,6 @@ def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): # **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp -def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp): - assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}" - ctx.var_vals[var.replace(src=())] = val.const_arg - return var - def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop())) @@ -397,7 +396,6 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) break_sched = PatternMatcher([ - (UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.var("val"))), unbind_variable), # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), @@ -452,9 +450,10 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # create schedule items + map buffers to realized tensors prescheduled: list[ScheduleItem] = [] + var_vals: dict[Variable, int] = {} for buf_uop,store in ctx.realizes.items(): assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}" - prescheduled.append(schedule_uop(store.sink(), ctx)) + prescheduled.append(schedule_uop(store.sink(), ctx, var_vals)) # can only schedule once for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) # increment refcount for this buffer @@ -487,4 +486,4 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # confirm everything was scheduled correctly if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}") if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") - return schedule, ctx.var_vals, becomes_map + return schedule, var_vals, becomes_map From 73ee2d74c080b7aa095187b403fb2be64abaee7a Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 1 Feb 2025 12:11:57 -0500 Subject: [PATCH 2/5] raise RuntimeError for int base pow (#8852) current implementation is not precise and blocking other simplification change --- test/external/external_test_onnx_backend.py | 4 ++++ test/test_ops.py | 1 + tinygrad/tensor.py | 2 ++ 3 files changed, 7 insertions(+) diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index b9a61b40f11df..1c227d22f8dcd 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -95,6 +95,10 @@ def supports_device(cls, device: str) -> bool: # we don't support indexes backend_test.exclude('test_nonzero_*') +# no support for int pow +backend_test.exclude('test_pow_types_int32_int32_cpu') +backend_test.exclude('test_pow_types_int64_int64_cpu') + # no support for fmod backend_test.exclude('test_mod_int64_fmod_cpu') backend_test.exclude('test_mod_mixed_sign_float16_cpu') diff --git a/test/test_ops.py b/test/test_ops.py index f512992fe2ed5..85f3e31f480a6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -619,6 +619,7 @@ def test_pow_const(self): # TODO: fix backward, should be nan helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True) + @unittest.skip("not supported") def test_pow_int(self): def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 846929ac9aed6..f32e5139cdfc8 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3310,6 +3310,8 @@ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp() base, exponent = self._broadcasted(x, reverse=reverse) + # TODO: int pow + if not base.is_floating_point(): raise RuntimeError("base needs to be float") # start with b ** e = exp(e * log(b)) ret = base.abs().log().mul(exponent).exp() # correct sign of negative base with odd exponent (cos has a period of 2pi so we use it here to get the oddness of the exponent) From 5b1fc4dcb285368b74dd10cb7a4a900a8b5278d1 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 1 Feb 2025 13:55:24 -0500 Subject: [PATCH 3/5] push cast to branches in UOp where (#8850) --- test/unit/test_uop_symbolic.py | 14 ++++++++++++++ tinygrad/codegen/rewriter.py | 3 +++ 2 files changed, 17 insertions(+) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 9ff40ba7ab1ac..d5fb8a5655bd4 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -531,6 +531,20 @@ def test_where_combine(self): # not combining # TODO: can combine if one is identity element const self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))") + def test_where_cast(self): + s = Variable("s", 0, 3) + cond = s < 2 + a = Variable("a", 0, 3) + b = Variable("b", 0, 3) + expr = cond.where(a, b).cast(dtypes.half) + + # TODO: copied from render, render does not support cast + glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) + uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())) + rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1] + + self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half))) + def test_symbolic_div(self): # from symbolic arange a = Variable("a", 1, 10) diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index e17d582af6ae9..03cb681d970e2 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -294,6 +294,9 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): (UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP # x!=0 -> (bool)x (UPat.var("x")!=0, lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), + # ** where ** + # push cast to branches + (UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))), # ** load/store folding ** (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)), (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))), From dc34a4146f9651d02abf2d12c453246067f9a959 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 1 Feb 2025 14:50:23 -0500 Subject: [PATCH 4/5] better process_replay context print [pr] (#8856) * better process_replay context print [pr] * test: revert push cast * Revert "test: revert push cast" This reverts commit 38a2aef6f8f0b7b68c89c50eb42a20e87c52ae9b. --- test/external/process_replay/process_replay.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index c6351e4166fdb..b7d530c45bf94 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -61,7 +61,8 @@ def diff(offset:int, name:str, fxn:Callable) -> None: continue # try recreate try: - with Context(**{k:v.value for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2]) + ctx_vars = {k:v.value for k,v in args[-2].items() if k != "DEBUG" and (var:=ContextVar._cache.get(k)) is not None and var.value != v.value} + with Context(**ctx_vars): good = fxn(*args[:-2]) if good is None: continue except Exception as e: changed += 1 @@ -72,7 +73,8 @@ def diff(offset:int, name:str, fxn:Callable) -> None: try: assert str(args[-1]) == str(good) except AssertionError: changed += 1 - for x in args[:-1]: logging.info(x) + if ctx_vars: logging.info(ctx_vars) + for x in args[:-2]: logging.info(x) changes = list(difflib.unified_diff(str(good).splitlines(), str(args[-1]).splitlines())) logging.info("\n".join(colored(line, "red" if line.startswith("-") else "green" if line.startswith("+") else None) for line in changes)) warnings.warn("PROCESS REPLAY DETECTED CHANGE", ProcessReplayWarning) From 784185287055b796ccb32a33a0de9d71f92fb29c Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 1 Feb 2025 23:42:27 +0300 Subject: [PATCH 5/5] hcq pci signal fuzzer (#8854) * hcq pci signal fuzzer * kk * correct --- test/external/external_fuzz_hcq_signals.py | 31 ++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 test/external/external_fuzz_hcq_signals.py diff --git a/test/external/external_fuzz_hcq_signals.py b/test/external/external_fuzz_hcq_signals.py new file mode 100644 index 0000000000000..2b88cbc00e2ca --- /dev/null +++ b/test/external/external_fuzz_hcq_signals.py @@ -0,0 +1,31 @@ +import random +from tinygrad import Device +from tinygrad.helpers import getenv, DEBUG + +def main(): + seed = getenv("SEED", 1337) + n_gpus = getenv("GPUS", 3) + iters = getenv("ITERS", 10000000) + only_compute = bool(getenv("ONLY_COMPUTE", 0)) + + print(f"{n_gpus} GPUs for {iters} iterations, {only_compute=}, seed {seed}") + devs = tuple([Device[f"{Device.DEFAULT}:{x}"] for x in range(n_gpus)]) + + for i in range(iters): + dev = random.choice(devs) + q_t = random.choice([dev.hw_copy_queue_t, dev.hw_compute_queue_t] if not only_compute else [dev.hw_compute_queue_t]) + + deps_sigs = random.randint(0, len(devs)) + wait_devs = random.sample(devs, deps_sigs) + + q = q_t() + for d in wait_devs: q.wait(d.timeline_signal, d.timeline_value - 1) + q.wait(dev.timeline_signal, dev.timeline_value - 1).signal(dev.timeline_signal, dev.timeline_value).submit(dev) + dev.timeline_value += 1 + + if sync:=random.randint(0, 10) < 3: dev.synchronize() + if DEBUG >= 2: print(f"{i}: {q_t.__name__} {dev.device_id} timeline {dev.timeline_value}, wait for {[d.device_id for d in wait_devs]}, {sync=}") + elif i % 100 == 0: print(f"\rCompleted {i} iterations", end='') + +if __name__ == "__main__": + main()