Skip to content

Commit 117b7a1

Browse files
authored
VALIDATE_WITH_CPU [pr] (tinygrad#9488)
* VALIDATE_WITH_CPU [pr] * fix test
1 parent 935cd01 commit 117b7a1

8 files changed

+28
-13
lines changed

extra/reduce_speed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
GlobalCounters.reset()
114114
out = a.sum()
115115
sis = out.schedule()
116-
for i,ei in enumerate(lower_schedule(sis)):
116+
for i,(_,ei) in enumerate(lower_schedule(sis)):
117117
if i == 0:
118118
# change the source code
119119
prg_spec = ei.prg.p

test/test_image_dtype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def test_lil_model(self):
123123
loss = x.image_dot(w1).image_dot(w2).float().max()
124124
loss.backward()
125125
sched = unwrap(w1.grad).schedule()
126-
for s,ei in zip(sched, lower_schedule(sched[:])):
126+
for s,(_,ei) in zip(sched, lower_schedule(sched[:])):
127127
ei.run()
128128
if s.bufs[0].dtype == dtypes.float:
129129
lst = s.bufs[0].as_buffer().cast("f").tolist()

test/test_linearizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_arg_dedup(self):
7171
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
7272
np_a, np_b = a.numpy(), b.numpy()
7373
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
74-
lowered = list(lower_schedule(c.schedule()))
74+
lowered = [x[1] for x in lower_schedule(c.schedule())]
7575
for ei in lowered: ei.run()
7676
rawbufs = lowered[-1].bufs
7777
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}

test/test_multitensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_shard_no_recompile(self):
8181
out = (X + X)
8282
sched = out.schedule()
8383
names = []
84-
for si, ei in zip(sched[:], lower_schedule(sched)):
84+
for si, ei in lower_schedule(sched):
8585
if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name)
8686
ei.run()
8787
self.assertEqual(len(set(names)), 3), "function was relinearized"

test/test_randomness.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_threefry_against_reference(self):
9999

100100
@unittest.skipIf(getenv("PTX"), "fails with PTX")
101101
def test_threefry_doesnt_use_long(self):
102-
for ei in lower_schedule(Tensor.rand(20).schedule()):
102+
for (_,ei) in lower_schedule(Tensor.rand(20).schedule()):
103103
if isinstance(ei.prg, CompiledRunner):
104104
for u in ei.prg.p.uops:
105105
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}")

test/test_schedule.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
3131
assert isinstance(t, UOp), f"can't schedule {t}"
3232
sched, _, __ = create_schedule_with_vars(t.sink())
3333
# test lowering all the ScheduleItems to ExecItems
34-
lowered = list(lower_schedule(sched.copy()))
34+
lowered = [x[1] for x in lower_schedule(sched.copy())]
3535
if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)]
3636
if len(sched) != allowed:
3737
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
@@ -1614,7 +1614,7 @@ def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
16141614
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
16151615
lst = [xt] if isinstance(xt, Tensor) else xt
16161616
s = Tensor.schedule(*lst)
1617-
lowered = list(lower_schedule(s.copy()))
1617+
lowered = [x[1] for x in lower_schedule(s.copy())]
16181618
kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)]
16191619
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
16201620
for ei in lowered: ei.run(do_update_stats=True)

tinygrad/engine/realize.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time, pprint
33
from dataclasses import dataclass, replace
44
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
5-
from tinygrad.helpers import DEVECTORIZE, time_to_str
5+
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU
66
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
77
from tinygrad.device import Device, Buffer
88
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
@@ -150,10 +150,10 @@ def run(self, _var_vals:Optional[dict[Variable, int]]=None, wait=False, jit=Fals
150150
])
151151
def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
152152

153-
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
153+
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
154154
while len(schedule):
155155
si = schedule.pop(0)
156-
try: yield lower_schedule_item(si)
156+
try: yield (si, lower_schedule_item(si))
157157
except Exception as e:
158158
if DEBUG >= 2:
159159
print(f"error lowering {si.ast.op}")
@@ -166,6 +166,21 @@ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, Non
166166
capturing: list = [] # put classes with an add method in here
167167

168168
def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, int]]=None, do_update_stats=True):
169-
for ei in lower_schedule(schedule):
169+
for si, ei in lower_schedule(schedule):
170170
if len(capturing) and CAPTURING: capturing[0].add(ei)
171-
ei.run(var_vals, do_update_stats=do_update_stats)
171+
if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
172+
# copy in allocated buffers from the GPU
173+
nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs)
174+
for cpu_b, gpu_b in zip(nb, si.bufs):
175+
if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
176+
177+
# run on GPU
178+
ei.run(var_vals, do_update_stats=do_update_stats)
179+
180+
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
181+
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
182+
import numpy as np
183+
np.testing.assert_allclose(nb[0].numpy(), si.bufs[0].numpy(), rtol=1e-3, atol=1e-3)
184+
else:
185+
ei.run(var_vals, do_update_stats=do_update_stats)
186+

tinygrad/helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __lt__(self, x): return self.value < x
113113
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
114114
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
115115
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
116-
QUANTIZE = ContextVar("QUANTIZE", 0)
116+
QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0)
117117

118118
@dataclass(frozen=True)
119119
class Metadata:

0 commit comments

Comments
 (0)