Skip to content

Commit 3cc0508

Browse files
authored
llvm no devectorize, the right way (tinygrad#8901)
* closer * env flag + transcendental issue
1 parent 8b16c65 commit 3cc0508

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

tinygrad/codegen/rewriter.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,8 @@ def _gate_srcs(u:UOp, gate:UOp) -> UOp:
421421
Ops.VECTORIZE, Ops.IF), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
422422
(UPat(Ops.CONTRACT, name="con"), do_contract),
423423
# vectorize DEFINE_ACC
424-
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
424+
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"),
425+
lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
425426
# BARRIERs aren't actually expanded
426427
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
427428
lambda ex: UOp(Ops.UNROLL, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
@@ -453,6 +454,12 @@ def no_vectorized_acc(acc:UOp):
453454
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
454455
])
455456

457+
devectorize_load_store = PatternMatcher([
458+
# TODO: add vectorized support to transcendental
459+
(UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
460+
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
461+
])
462+
456463
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
457464
if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
458465
# remove the gate from the index
@@ -508,9 +515,13 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
508515
# expand
509516
sink = graph_rewrite(sink, sym+expander)
510517

511-
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
512-
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
513-
mulacc_unrolled)
518+
if getenv("NO_DEVECTORIZE"):
519+
# new devectorize for load/store
520+
sink = graph_rewrite(sink, sym+devectorize_load_store)
521+
else:
522+
# devectorize + load_store_indexing + mulacc_unrolled, mulacc_unrolled must be last because it can break loop_collapse
523+
sink = graph_rewrite(sink, sym+(devectorize+float4_folding if opts is not None and opts.supports_float4 else devectorize)+load_store_indexing+
524+
mulacc_unrolled)
514525

515526
# final rules for the renderer (without sym)
516527
sink = graph_rewrite(sink, symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)+pm_render+extra_matcher)

0 commit comments

Comments
 (0)