Skip to content

Commit cfd2851

Browse files
authored
move pow folding tests to test_schedule [pr] (tinygrad#8955)
not really belongs to test_const_folding
1 parent c2b4c43 commit cfd2851

File tree

2 files changed

+26
-35
lines changed

2 files changed

+26
-35
lines changed

test/test_const_folding.py

+1-29
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest, math
22
from tinygrad import Tensor, Device, dtypes
3-
from tinygrad.ops import Ops, GroupOp
3+
from tinygrad.ops import Ops
44
from tinygrad.helpers import CI
55
import numpy as np
66
from tinygrad.device import is_dtype_supported
@@ -97,34 +97,6 @@ def test_literal_one_pow(self):
9797
def test_tensor_one_pow(self):
9898
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))
9999

100-
def test_2_pow_is_exp2(self):
101-
t = 2.0 ** Tensor([1.0, 2.0, 3.0])
102-
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
103-
self.assertEqual(len(s), 1)
104-
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
105-
self.assertEqual(alu, [Ops.EXP2])
106-
107-
def test_pow_05_is_sqrt(self):
108-
t = Tensor([1.0, 2.0, 3.0]) ** 0.5
109-
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
110-
self.assertEqual(len(s), 1)
111-
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
112-
self.assertEqual(alu, [Ops.SQRT])
113-
114-
def test_pow_neg_05_is_rsqrt(self):
115-
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
116-
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
117-
self.assertEqual(len(s), 1)
118-
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
119-
self.assertEqual(alu, [Ops.RECIP, Ops.SQRT])
120-
121-
def test_pow_8_has_3_muls(self):
122-
t = Tensor([1.0, 2.0, 3.0]) ** 8
123-
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
124-
self.assertEqual(len(s), 1)
125-
alu = [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
126-
self.assertEqual(alu, [Ops.MUL, Ops.MUL, Ops.MUL])
127-
128100
# folds advance indexing into basic indexing
129101
class TestIndexingConstFolding(unittest.TestCase):
130102
def test_scalar_index(self):

test/test_schedule.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tinygrad.dtype import DType, ImageDType
1414
from tinygrad.shape.shapetracker import ShapeTracker
1515
from tinygrad.shape.view import View
16-
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
16+
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views, GroupOp
1717
from tinygrad.spec import type_verify, shape_spec
1818
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
1919
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
@@ -559,11 +559,30 @@ def test_double_from(self):
559559
out = x.to('python')
560560
check_schedule(out, 0, filter_sink=False)
561561

562-
def test_pow_const_tensor_simplified(self):
563-
x = Tensor([1,2,3,4])
564-
# NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5)
565-
out = x ** Tensor(2.0)
566-
check_schedule(out, 1)
562+
def _alu_from_tensor(self, t:Tensor):
563+
s = [s for s in t.schedule() if s.ast.op is Ops.SINK]
564+
self.assertEqual(len(s), 1)
565+
return [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU]
566+
567+
def test_2_pow_is_exp2(self):
568+
t = 2.0 ** Tensor([1.0, 2.0, 3.0])
569+
self.assertEqual(self._alu_from_tensor(t), [Ops.EXP2])
570+
571+
def test_pow_05_is_sqrt(self):
572+
t = Tensor([1.0, 2.0, 3.0]) ** 0.5
573+
self.assertEqual(self._alu_from_tensor(t), [Ops.SQRT])
574+
575+
def test_pow_neg_05_is_rsqrt(self):
576+
t = Tensor([1.0, 2.0, 3.0]) ** -0.5
577+
self.assertEqual(self._alu_from_tensor(t), [Ops.RECIP, Ops.SQRT])
578+
579+
def test_pow_2_has_1_mul(self):
580+
t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0)
581+
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL])
582+
583+
def test_pow_8_has_3_muls(self):
584+
t = Tensor([1.0, 2.0, 3.0]) ** 8
585+
self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL])
567586

568587
def test_pow_const_tensor_to_zero(self):
569588
x = Tensor([1,2,3,4])

0 commit comments

Comments
 (0)