|
13 | 13 | from tinygrad.dtype import DType, ImageDType
|
14 | 14 | from tinygrad.shape.shapetracker import ShapeTracker
|
15 | 15 | from tinygrad.shape.view import View
|
16 |
| -from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views |
| 16 | +from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views, GroupOp |
17 | 17 | from tinygrad.spec import type_verify, shape_spec
|
18 | 18 | from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
|
19 | 19 | from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
@@ -559,11 +559,30 @@ def test_double_from(self):
|
559 | 559 | out = x.to('python')
|
560 | 560 | check_schedule(out, 0, filter_sink=False)
|
561 | 561 |
|
562 |
| - def test_pow_const_tensor_simplified(self): |
563 |
| - x = Tensor([1,2,3,4]) |
564 |
| - # NOTE: this does not test ** Tensor(2) is simpler in ast than ** Tensor(2.5) |
565 |
| - out = x ** Tensor(2.0) |
566 |
| - check_schedule(out, 1) |
| 562 | + def _alu_from_tensor(self, t:Tensor): |
| 563 | + s = [s for s in t.schedule() if s.ast.op is Ops.SINK] |
| 564 | + self.assertEqual(len(s), 1) |
| 565 | + return [u.op for u in s[0].ast.toposort if u.op in GroupOp.ALU] |
| 566 | + |
| 567 | + def test_2_pow_is_exp2(self): |
| 568 | + t = 2.0 ** Tensor([1.0, 2.0, 3.0]) |
| 569 | + self.assertEqual(self._alu_from_tensor(t), [Ops.EXP2]) |
| 570 | + |
| 571 | + def test_pow_05_is_sqrt(self): |
| 572 | + t = Tensor([1.0, 2.0, 3.0]) ** 0.5 |
| 573 | + self.assertEqual(self._alu_from_tensor(t), [Ops.SQRT]) |
| 574 | + |
| 575 | + def test_pow_neg_05_is_rsqrt(self): |
| 576 | + t = Tensor([1.0, 2.0, 3.0]) ** -0.5 |
| 577 | + self.assertEqual(self._alu_from_tensor(t), [Ops.RECIP, Ops.SQRT]) |
| 578 | + |
| 579 | + def test_pow_2_has_1_mul(self): |
| 580 | + t = Tensor([1.0, 2.0, 3.0]) ** Tensor(2.0) |
| 581 | + self.assertEqual(self._alu_from_tensor(t), [Ops.MUL]) |
| 582 | + |
| 583 | + def test_pow_8_has_3_muls(self): |
| 584 | + t = Tensor([1.0, 2.0, 3.0]) ** 8 |
| 585 | + self.assertEqual(self._alu_from_tensor(t), [Ops.MUL, Ops.MUL, Ops.MUL]) |
567 | 586 |
|
568 | 587 | def test_pow_const_tensor_to_zero(self):
|
569 | 588 | x = Tensor([1,2,3,4])
|
|
0 commit comments