4
4
from collections import defaultdict
5
5
from tinygrad .ops import Ops , PatternMatcher , UPat , UOp , GroupOp , exec_alu
6
6
from tinygrad .dtype import ConstType , dtypes , PtrDType
7
- from tinygrad .helpers import partition , all_same , prod , flatten , get_single_element , cdiv , cmod
7
+ from tinygrad .helpers import partition , all_same , prod , flatten , get_single_element
8
8
from tinygrad .codegen .transcendental import xpow
9
9
10
10
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
@@ -78,7 +78,7 @@ def split_uop(x:UOp, sep:Ops):
78
78
def fold_unrolled_divs (divs :UOp , denominator : int , fac = 1 ) -> UOp | None :
79
79
# div pattern in unrolled arange
80
80
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
81
- seen_const , ans , offset = [], None , 0
81
+ seen_const , ans = [], None
82
82
for u in split_uop (divs , Ops .ADD ):
83
83
if fac != 1 :
84
84
if u .op is not Ops .MUL or u .src [1 ].op is not Ops .CONST or u .src [1 ].arg != fac : return None
@@ -88,9 +88,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
88
88
if (s0 := u .src [0 ]).vmin < 0 : return None
89
89
# assumed CONST is the last of an ADD
90
90
if s0 .op is Ops .ADD and s0 .src [1 ].op is Ops .CONST and s0 .src [1 ].op is Ops .CONST :
91
- const = s0 .src [1 ].arg
92
- offset += cdiv (const , denominator )
93
- seen_const .append (cmod (const , denominator ))
91
+ seen_const .append (s0 .src [1 ].arg )
94
92
s0 = s0 .src [0 ]
95
93
else : seen_const .append (0 )
96
94
if ans is None : ans = s0
@@ -100,7 +98,7 @@ def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
100
98
for i in range (denominator - len (seen_const )):
101
99
if ans is not None and 0 <= ans .vmin and ans .vmax + i < denominator : seen_const .append (i )
102
100
if sorted (seen_const )== list (range (denominator )):
103
- return fac * ( ans + offset )
101
+ return fac * ans
104
102
return None
105
103
106
104
def lt_folding (x :UOp , c :int ) -> UOp | None :
@@ -283,6 +281,8 @@ def gep_through_wmma(gep:UOp, wmma:UOp):
283
281
# div folding
284
282
((UPat .var ("x" )// UPat .cvar ("c" ) + UPat .cvar ("a" ))// UPat .cvar ("d" ), lambda x ,c ,a ,d : (x + a * c )// (c * d )), # (x//c+a)//d -> (x+a*c)//(c*d)
285
283
(UPat .var ("x" , dtypes .sints ) // UPat .var ("y" ), lambda x ,y : div_and_mod_folding (x ,y ,Ops .IDIV )),
284
+ ((UPat .var ("x" , dtypes .sints )+ UPat .cvar ("c" )).named ("n" )// UPat .cvar ("d" ),
285
+ lambda x ,c ,n ,d : (- (- (c .arg % d .arg + x - (d .arg - 1 ))// d ) + c .arg // d .arg ) if x .vmax <= 0 and n .vmin >= 0 and d .arg > 0 else None ),
286
286
# ** mod **
287
287
# mod folding
288
288
(UPat .var ("x" ) % UPat .var ("y" ), lambda x ,y : div_and_mod_folding (x ,y ,Ops .MOD )),
0 commit comments