@@ -180,6 +180,10 @@ def div_and_mod_folding(x: UOp, y: UOp, which: Literal[Ops.MOD, Ops.IDIV], split
180
180
lambda gep , vec : UOp (Ops .VECTORIZE , gep .dtype , tuple (vec .src [i ] for i in gep .arg )) if len (gep .arg ) > 1 else vec .src [gep .arg [0 ]]),
181
181
(UPat (Ops .GEP , src = (UPat .cvar ("c" , vec = False ),), name = "gep" ), lambda gep , c : gep .const_like (c .arg )),
182
182
(UPat (Ops .GEP , src = (UPat (Ops .VCONST , name = "c" ),), name = "gep" ), lambda gep , c : gep .const_like (tuple (c .arg [x ] for x in gep .arg ))),
183
+ # GEP on void is skipped
184
+ (UPat (Ops .GEP , src = (UPat (dtype = dtypes .void , name = "x" ),)), lambda x : x ),
185
+ # GEP in order is removed
186
+ (UPat (Ops .GEP , name = "g" ), lambda g : g .src [0 ] if not isinstance (g .dtype , PtrDType ) and g .arg == tuple (range (g .src [0 ].dtype .count )) else None ),
183
187
# push all GEPs through ALUs (fix arange stuff)
184
188
(UPat (Ops .GEP , src = (UPat ((* GroupOp .ALU , Ops .CAST , Ops .BITCAST ), name = 'alu' ),), name = 'gep' ),
185
189
lambda gep ,alu : UOp (alu .op , alu .dtype .scalar ().vec (gep .dtype .count ), tuple (x .gep (gep .arg ) for x in alu .src ), alu .arg ) \
0 commit comments