Skip to content

Commit c672716

Browse files
authored
improve vmin/vmax for IDIV (tinygrad#9678)
1 parent 8dd88ad commit c672716

File tree

4 files changed

+32
-9
lines changed

4 files changed

+32
-9
lines changed

test/external/fuzz_symbolic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def gt(expr, rng=None):
7070
v = [v1,v2,v3]
7171
rn = 0
7272
for t,r in zip(tape, rngs): rn, _ = t(rn, r)
73-
num = eval(expr.render())
73+
num = eval(expr.render(simplify=False))
7474
if num != rn:
7575
unsimplified_num = eval(expr.render(simplify=False))
7676
assert unsimplified_num == rn, "UNSIMPLIFIED MISMATCH!"

test/unit/test_uop_symbolic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,8 @@ def test_mul_div(self):
323323

324324
def test_add_div(self):
325325
# careful about the lower bounds and upper bounds
326-
self.helper_test_variable((Variable("a", 0, 5)-2)//4, -1, 0, "((a+-2)//4)")
327-
self.helper_test_variable((Variable("a", 0, 5)-1)//4, -1, 1, "((a+-1)//4)")
326+
self.helper_test_variable((Variable("a", 0, 5)-2)//4, 0, 0, "0")
327+
self.helper_test_variable((Variable("a", 0, 5)-1)//4, 0, 1, "((a+-1)//4)")
328328
self.helper_test_variable((Variable("a", 0, 5))//4, 0, 1, "(a//4)")
329329
self.helper_test_variable((Variable("a", 0, 5)+1)//4, 0, 1, "((a+1)//4)")
330330
self.helper_test_variable((Variable("a", 0, 5)+2)//4, 0, 1, "((a+2)//4)")

test/unit/test_uop_vmin_vmax.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,32 @@ def test_vmin_vmax_division_positive(self):
120120

121121
def test_vmin_vmax_division_negative(self):
122122
# vmin and vmax for division of a variable by a negative constant
123+
# always positive
123124
x = UOp.variable('x', 10, 20)
124125
uop = x // -2
125126
self.assertEqual(uop.vmin, -10)
126127
self.assertEqual(uop.vmax, -5)
128+
uop = x // -3
129+
self.assertEqual(uop.vmin, -6)
130+
self.assertEqual(uop.vmax, -3)
131+
132+
# always negative
133+
x = UOp.variable('x', -20, -10)
134+
uop = x // -2
135+
self.assertEqual(uop.vmin, 5)
136+
self.assertEqual(uop.vmax, 10)
137+
uop = x // -3
138+
self.assertEqual(uop.vmin, 3)
139+
self.assertEqual(uop.vmax, 6)
140+
141+
# cross 0
142+
x = UOp.variable('x', -10, 10)
143+
uop = x // -2
144+
self.assertEqual(uop.vmin, -5)
145+
self.assertEqual(uop.vmax, 5)
146+
uop = x // -3
147+
self.assertEqual(uop.vmin, -3)
148+
self.assertEqual(uop.vmax, 3)
127149

128150
def test_vmin_vmax_mod_positive(self):
129151
# vmin and vmax for modulo of a variable by a positive constant
@@ -144,7 +166,7 @@ def test_vmin_vmax_division_with_mixed_range(self):
144166
# vmin and vmax for division of a variable with a range crossing zero
145167
x = UOp.variable('x', -10, 10)
146168
uop = x // 3
147-
self.assertEqual(uop.vmin, -4) # -10//3 = -4
169+
self.assertEqual(uop.vmin, -3) # -10//3 = -3 (in C)
148170
self.assertEqual(uop.vmax, 3) # 10//3 = 3
149171

150172
def test_vmin_vmax_mod_with_mixed_range(self):

tinygrad/ops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -616,12 +616,13 @@ def _min_max(self) -> tuple[ConstType, ConstType]:
616616
if self.op is Ops.MOD and s1_vmin > 0:
617617
return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
618618
if self.op is Ops.IDIV:
619-
if s1_vmin == s1_vmax: # min/max are equal in a CONST
620-
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
621-
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
619+
assert isinstance(s0_vmin, int) and isinstance(s0_vmax, int) and isinstance(s1_vmin, int) and isinstance(s1_vmax, int)
620+
if s1_vmin == s1_vmax: # s1 is a const
621+
if s1_vmin > 0: return cdiv(s0_vmin, s1_vmin), cdiv(s0_vmax, s1_vmin)
622+
if s1_vmin < 0: return cdiv(s0_vmax, s1_vmin), cdiv(s0_vmin, s1_vmin)
622623
# don't know exact bounds, but know the sign
623-
if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
624-
if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
624+
if (s0_vmax <= 0 and s1_vmax < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
625+
if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmax < 0): return dtypes.min(self.dtype), 0
625626
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
626627
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
627628
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))

0 commit comments

Comments
 (0)