@@ -230,6 +230,17 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
230
230
# ** mod **
231
231
# mod folding
232
232
(UPat .var ("x" ) % UPat .var ("y" ), lambda x ,y : div_and_mod_folding (x ,y ,Ops .MOD )),
233
+ # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
234
+ (UPat (Ops .GEP , src = (UPat (Ops .GEP , name = 'g2' ),), name = 'g1' ),
235
+ lambda g1 , g2 : g2 .src [0 ].gep (tuple (g2 .arg [g1 .arg [i ]] for i in range (g1 .dtype .count )))),
236
+ (UPat (Ops .GEP , src = (UPat (Ops .VECTORIZE , name = "vec" ),), name = "gep" ),
237
+ lambda gep , vec : UOp (Ops .VECTORIZE , gep .dtype , tuple (vec .src [i ] for i in gep .arg )) if len (gep .arg ) > 1 else vec .src [gep .arg [0 ]]),
238
+ (UPat (Ops .GEP , src = (UPat .cvar ("c" , vec = False ),), name = "gep" ), lambda gep , c : gep .const_like (c .arg )),
239
+ (UPat (Ops .GEP , src = (UPat (Ops .VCONST , name = "c" ),), name = "gep" ), lambda gep , c : gep .const_like (tuple (c .arg [x ] for x in gep .arg ))),
240
+ # push all GEPs through ALUs (fix arange stuff)
241
+ (UPat (Ops .GEP , src = (UPat ((* GroupOp .ALU , Ops .CAST , Ops .BITCAST ), name = 'alu' ),), name = 'gep' ),
242
+ lambda gep ,alu : UOp (alu .op , alu .dtype .scalar ().vec (gep .dtype .count ), tuple (x .gep (gep .arg ) for x in alu .src ), alu .arg ) \
243
+ if not isinstance (gep .dtype , PtrDType ) else None ),
233
244
])
234
245
235
246
symbolic_flat = symbolic + PatternMatcher ([
@@ -399,17 +410,6 @@ def gep_through_wmma(gep:UOp, wmma:UOp):
399
410
# VECTORIZE void is SINK
400
411
(UPat (Ops .VECTORIZE , dtype = dtypes .void , src = UPat (Ops .BARRIER , name = 'b' )), lambda b : b ),
401
412
(UPat (Ops .VECTORIZE , dtype = dtypes .void , name = 'x' ), lambda x : UOp (Ops .SINK , dtypes .void , x .src )),
402
- # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
403
- (UPat (Ops .GEP , src = (UPat (Ops .GEP , name = 'g2' ),), name = 'g1' ),
404
- lambda g1 , g2 : g2 .src [0 ].gep (tuple (g2 .arg [g1 .arg [i ]] for i in range (g1 .dtype .count )))),
405
- (UPat (Ops .GEP , src = (UPat (Ops .VECTORIZE , name = "vec" ),), name = "gep" ),
406
- lambda gep , vec : UOp (Ops .VECTORIZE , gep .dtype , tuple (vec .src [i ] for i in gep .arg )) if len (gep .arg ) > 1 else vec .src [gep .arg [0 ]]),
407
- (UPat (Ops .GEP , src = (UPat .cvar ("c" , vec = False ),), name = "gep" ), lambda gep , c : gep .const_like (c .arg )),
408
- (UPat (Ops .GEP , src = (UPat (Ops .VCONST , name = "c" ),), name = "gep" ), lambda gep , c : gep .const_like (tuple (c .arg [x ] for x in gep .arg ))),
409
- # push all GEPs through ALUs (fix arange stuff)
410
- (UPat (Ops .GEP , src = (UPat ((* GroupOp .ALU , Ops .CAST , Ops .BITCAST ), name = 'alu' ),), name = 'gep' ),
411
- lambda gep ,alu : UOp (alu .op , alu .dtype .scalar ().vec (gep .dtype .count ), tuple (x .gep (gep .arg ) for x in alu .src ), alu .arg ) \
412
- if not isinstance (gep .dtype , PtrDType ) else None ),
413
413
# push some GEPs through WMMAs
414
414
(UPat (Ops .GEP , src = (UPat (Ops .WMMA , name = "wmma" ),), name = "gep" ), gep_through_wmma ),
415
415
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
0 commit comments