Skip to content

Commit c2b4c43

Browse files
geohotQazalin
andauthored
handle stride 0 reduce (tinygrad#8068)
* handle stride 0 reduce [pr] * more test fixups * a few more --------- Co-authored-by: qazal <qazal.software@gmail.com>
1 parent cf21e27 commit c2b4c43

File tree

3 files changed

+17
-15
lines changed

3 files changed

+17
-15
lines changed

test/test_linearizer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1233,9 +1233,11 @@ def test_default_global_reversed(self):
12331233
def test_sum_collapse(self):
12341234
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
12351235
sched = [si for si in t.schedule() if si.ast.op is Ops.SINK]
1236+
# sum_collapse is a full collapse now
12361237
assert len(sched) == 1
1237-
lin = Kernel(sched[0].ast)
1238-
assert not any(u.op is Ops.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
1238+
assert not any(u.op is Ops.REDUCE_AXIS for u in sched[0].ast.toposort), "found reduce in sum collapse"
1239+
#lin = Kernel(sched[0].ast)
1240+
#assert not any(u.op is Ops.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
12391241

12401242
def test_assign_fold(self):
12411243
a = Tensor.ones(4, 4).contiguous().realize()

tinygrad/engine/schedule.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616

1717
# **** schedule simplifier
1818

19-
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
20-
if not all_int(x.shape): return None
21-
# remove reduce on unmasked const
22-
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
23-
ret = x.const_arg
19+
def simplify_stride0_reduce(reduce:UOp, x:UOp):
20+
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
21+
if any(v.mask is not None for v in unwrap(x.st).views): return None
22+
# must have all stride 0 in the relevant axis (NOTE: can do partial)
23+
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
24+
prshape = prod(x.shape[i] for i in reduce.arg[1])
25+
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
2426
match reduce.arg[0]:
25-
case Ops.ADD: ret *= prshape
26-
case Ops.MUL: ret **= prshape
27-
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
28-
case _: return None
29-
return reduce.const_like(ret)
27+
case Ops.ADD: return ret*prshape
28+
case Ops.MUL: return ret.pow(prshape)
29+
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
3030

3131
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
3232
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
@@ -45,8 +45,8 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
4545
# reduce of size 0 is the identity element
4646
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
4747
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
48-
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
49-
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
48+
# reduce on stride 0 is collapsed
49+
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
5050
# COPY(CONST) creates a new CONST on the destination device
5151
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
5252
# no COPY to same device, except clone (arg is True)

tinygrad/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def sqrt(self): return self.alu(Ops.SQRT)
8989
def sin(self): return self.alu(Ops.SIN)
9090
def log2(self): return self.alu(Ops.LOG2)
9191
def exp2(self): return self.alu(Ops.EXP2)
92-
def pow(self, x): return self.alu(Ops.POW, x)
92+
def pow(self, x): return self.alu(Ops.POW, self.ufix(x))
9393

9494
# the order of these Ops controls the order of the toposort
9595
class Ops(FastEnum):

0 commit comments

Comments
 (0)