Skip to content

Commit c9493e4

Browse files
geohotQazalin
andauthored
reorder expand (tinygrad#9051)
* reorder expand * symbolic ops needs resolve here * s/arg/st + whitespace * viz --------- Co-authored-by: qazal <qazal.software@gmail.com>
1 parent 14aa239 commit c9493e4

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from tinygrad import Tensor, GlobalCounters, Context
2+
3+
if __name__ == "__main__":
4+
with Context(TRACK_MATCH_STATS=0): test = Tensor.ones(32, 10).contiguous().realize()
5+
GlobalCounters.reset()
6+
7+
# this is the softmax from scaled_dot_product_attention
8+
# it becomes 3 kernels
9+
print("*** softmax ***")
10+
with Context(NOOPT=1):
11+
out = test.softmax(-1)
12+
out.realize()

test/test_schedule.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2527,7 +2527,6 @@ def test_new_flat_buffer(self):
25272527

25282528
# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer
25292529

2530-
@unittest.expectedFailure
25312530
def test_reorder_expand(self):
25322531
a = Tensor.empty(4, 1)
25332532
b = a.expand(4, 4).reciprocal()

tinygrad/engine/schedule.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
9292
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
9393
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
9494
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
95+
# put UnaryOps before EXPANDs
96+
(UPat(GroupOp.Unary, src=UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"), name="alu"),
97+
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
9598
])
9699

97100
remove_movement_ops = merge_views+PatternMatcher([

0 commit comments

Comments
 (0)