@@ -421,7 +421,8 @@ def _gate_srcs(u:UOp, gate:UOp) -> UOp:
421
421
Ops .VECTORIZE , Ops .IF ), name = "root" , custom_early_reject = set ([Ops .UNROLL ])), do_expand ),
422
422
(UPat (Ops .CONTRACT , name = "con" ), do_contract ),
423
423
# vectorize DEFINE_ACC
424
- (UPat (Ops .VECTORIZE , src = UPat (Ops .DEFINE_ACC , name = "acc" ), name = "v" ), lambda acc ,v : acc .replace (dtype = v .dtype )),
424
+ (UPat (Ops .VECTORIZE , src = UPat (Ops .DEFINE_ACC , name = "acc" ), name = "v" ),
425
+ lambda acc ,v : acc .replace (dtype = v .dtype , src = (acc .src [0 ].broadcast (v .dtype .count ),)+ acc .src [1 :])),
425
426
# BARRIERs aren't actually expanded
426
427
(UPat (Ops .BARRIER , src = (UPat (Ops .UNROLL , name = "ex" ),)),
427
428
lambda ex : UOp (Ops .UNROLL , dtypes .void , (UOp (Ops .BARRIER , dtypes .void , ex .src ),)* len (ex .src ), ex .arg )),
@@ -453,6 +454,12 @@ def no_vectorized_acc(acc:UOp):
453
454
(UPat ((Ops .LOAD , Ops .STORE ), name = "ls" ), no_vectorized_load_store ),
454
455
])
455
456
457
+ devectorize_load_store = PatternMatcher ([
458
+ # TODO: add vectorized support to transcendental
459
+ (UPat ((Ops .INDEX , Ops .EXP2 , Ops .LOG2 , Ops .SIN ), name = "alu" ), no_vectorized_alu ),
460
+ (UPat ((Ops .LOAD , Ops .STORE ), name = "ls" ), no_vectorized_load_store ),
461
+ ])
462
+
456
463
def delete_redundant_gates (buf :UOp , idx :UOp , val :UOp , store_gate :UOp , cast :UOp | None = None ) -> UOp | None :
457
464
if store_gate not in [gate .src [0 ] for gate in val .toposort if gate .op is Ops .IF ]: return None
458
465
# remove the gate from the index
@@ -508,9 +515,13 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
508
515
# expand
509
516
sink = graph_rewrite (sink , sym + expander )
510
517
511
- # devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
512
- sink = graph_rewrite (sink , sym + (devectorize + float4_folding if opts is not None and opts .supports_float4 else devectorize )+ load_store_indexing +
513
- mulacc_unrolled )
518
+ if getenv ("NO_DEVECTORIZE" ):
519
+ # new devectorize for load/store
520
+ sink = graph_rewrite (sink , sym + devectorize_load_store )
521
+ else :
522
+ # devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
523
+ sink = graph_rewrite (sink , sym + (devectorize + float4_folding if opts is not None and opts .supports_float4 else devectorize )+ load_store_indexing +
524
+ mulacc_unrolled )
514
525
515
526
# final rules for the renderer (without sym)
516
527
sink = graph_rewrite (sink , symbolic_simple + get_late_rewrite_patterns (supported_ops , TRANSCENDENTAL >= 2 )+ pm_render + extra_matcher )
0 commit comments