Skip to content

Commit dbb7aee

Browse files
authored
Split constant in div with negative x (tinygrad#10088)
* add rule * change test * lower complexity limit * remove offset in fold_unrolled_divs * remove import * add one more condition
1 parent 610ee79 commit dbb7aee

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

test/test_arange.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(
4242

4343
if Device.default.renderer.has_local:
4444
# TODO: fix limit
45-
def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=100000)
45+
def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=81920)
4646
def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496)
4747

4848
def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0)

test/unit/test_uop_symbolic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def test_add_div(self):
333333
self.helper_test_variable((Variable("a", 0, 5)+5)//4, 1, 2, "(((a+1)//4)+1)")
334334

335335
def test_div_neg_rem(self):
336-
self.helper_test_variable((-Variable("a", 0, 255)+256)//2, 0, 128, "(((a*-1)+256)//2)")
336+
self.helper_test_variable((-Variable("a", 0, 255)+256)//2, 0, 128, "((((a+1)//2)*-1)+128)")
337337

338338
def test_mul_div_factor_mul(self):
339339
self.helper_test_variable((Variable("a", 0, 10)*8)//4, 0, 20, "(a*2)")

tinygrad/codegen/symbolic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections import defaultdict
55
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
66
from tinygrad.dtype import ConstType, dtypes, PtrDType
7-
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod
7+
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element
88
from tinygrad.codegen.transcendental import xpow
99

1010
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
@@ -78,7 +78,7 @@ def split_uop(x:UOp, sep:Ops):
7878
def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
7979
# div pattern in unrolled arange
8080
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
81-
seen_const, ans, offset = [], None, 0
81+
seen_const, ans = [], None
8282
for u in split_uop(divs, Ops.ADD):
8383
if fac!=1:
8484
if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None
@@ -88,9 +88,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
8888
if (s0:=u.src[0]).vmin < 0: return None
8989
# assumed CONST is the last of an ADD
9090
if s0.op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
91-
const = s0.src[1].arg
92-
offset += cdiv(const, denominator)
93-
seen_const.append(cmod(const, denominator))
91+
seen_const.append(s0.src[1].arg)
9492
s0 = s0.src[0]
9593
else: seen_const.append(0)
9694
if ans is None: ans = s0
@@ -100,7 +98,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
10098
for i in range(denominator-len(seen_const)):
10199
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
102100
if sorted(seen_const)==list(range(denominator)):
103-
return fac*(ans + offset)
101+
return fac*ans
104102
return None
105103

106104
def lt_folding(x:UOp, c:int) -> UOp|None:
@@ -283,6 +281,8 @@ def gep_through_wmma(gep:UOp, wmma:UOp):
283281
# div folding
284282
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)), # (x//c+a)//d -> (x+a*c)//(c*d)
285283
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
284+
((UPat.var("x", dtypes.sints)+UPat.cvar("c")).named("n")//UPat.cvar("d"),
285+
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
286286
# ** mod **
287287
# mod folding
288288
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),

0 commit comments

Comments
 (0)