Skip to content

Commit 935cd01

Browse files
authored
simple failing test for graph_rewrite children [pr] (tinygrad#9489)
* simple failing test for graph_rewrite children [pr] * lint * update too
1 parent d20494e commit 935cd01

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

test/test_rewrite_tracked_childen.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from tinygrad import Tensor
3-
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
3+
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp, merge_views
4+
from tinygrad.engine.schedule import sym
45

56
class TestRewriteTrackedChildren(unittest.TestCase):
67
def test_children_in_context(self):
@@ -47,5 +48,15 @@ def test_simple_child(self):
4748
print([x.arg for x in sink.get_children_map()[view_w_child]])
4849
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4)))
4950

51+
@unittest.expectedFailure
52+
def test_child_after_parent_update(self):
53+
def print_children(ctx, r):
54+
ctx.update_children()
55+
print(ctx.children[r])
56+
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
57+
a = Tensor.empty(3, 3)
58+
r = (a+0).sum()
59+
graph_rewrite(r.lazydata, merge_views+sym+extra, track_children=True)
60+
5061
if __name__ == '__main__':
5162
unittest.main()

0 commit comments

Comments
 (0)