16
16
17
17
# **** schedule simplifier
18
18
19
- def simplify_reduceop (reduce :UOp , x :UOp ) -> UOp | None :
20
- if not all_int (x .shape ): return None
21
- # remove reduce on unmasked const
22
- prshape = prod (unwrap (x .st ).shape [i ] for i in reduce .arg [1 ])
23
- ret = x .const_arg
19
+ def simplify_stride0_reduce (reduce :UOp , x :UOp ):
20
+ # must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
21
+ if any (v .mask is not None for v in unwrap (x .st ).views ): return None
22
+ # must have all stride 0 in the relevant axis (NOTE: can do partial)
23
+ if not all (unwrap (x .st ).views [- 1 ].strides [axis ] == 0 for axis in reduce .arg [1 ]) or not all_int (x .shape ): return None
24
+ prshape = prod (x .shape [i ] for i in reduce .arg [1 ])
25
+ ret = x .shrink (tuple ((0 ,s ) if i not in reduce .arg [1 ] else (0 ,1 ) for i ,s in enumerate (x .shape )))
24
26
match reduce .arg [0 ]:
25
- case Ops .ADD : ret *= prshape
26
- case Ops .MUL : ret **= prshape
27
- case Ops .MAX : pass # NOTE: Ops.MAX is passthrough
28
- case _: return None
29
- return reduce .const_like (ret )
27
+ case Ops .ADD : return ret * prshape
28
+ case Ops .MUL : return ret .pow (prshape )
29
+ case Ops .MAX : return ret # NOTE: Ops.MAX is passthrough
30
30
31
31
def found_contiguous (ctx :dict [UOp , UOp ], contig :UOp , src :UOp ):
32
32
if (sti := unwrap (src .st ).invert (src .base .shape )) is not None : ctx [src .base ] = contig .view (sti )
@@ -45,8 +45,8 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
45
45
# reduce of size 0 is the identity element
46
46
(UPat (Ops .REDUCE_AXIS , name = "reduce" , src = (UPat .var ("x" ),)),
47
47
lambda reduce ,x : reduce .const_like (identity_element (reduce .arg [0 ], reduce .dtype )) if x .size == 0 and reduce .size != 0 else None ),
48
- # reduce of const is collapsed (TODO: make this a generic rule for stride0)
49
- (UPat (Ops .REDUCE_AXIS , name = "reduce" , src = (UPat .cvar ("x" ),)), simplify_reduceop ),
48
+ # reduce on stride 0 is collapsed
49
+ (UPat (Ops .REDUCE_AXIS , name = "reduce" , src = (UPat .var ("x" ),)), simplify_stride0_reduce ),
50
50
# COPY(CONST) creates a new CONST on the destination device
51
51
(UPat (Ops .COPY , name = "root" , src = (UPat (), UPat .cvar ("x" ),)), lambda root ,x : root .const_like (x .const_arg )),
52
52
# no COPY to same device, except clone (arg is True)
0 commit comments