Skip to content

Commit f72a87f

Browse files
authored
add proper support for Ops.IGNORE to remove store masks (tinygrad#9692)
* add proper support for Ops.IGNORE to remove store masks * remove useless NHWC * revert that
1 parent 3b8d923 commit f72a87f

File tree

3 files changed

+31
-10
lines changed

3 files changed

+31
-10
lines changed

extra/onnx.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -724,8 +724,6 @@ def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int
724724
ret = _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype)
725725
else:
726726
ret = _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype)
727-
# you need both NHWC=1 DONT_GROUP_REDUCES=1 for this to work
728-
if getenv("NHWC") and len(ret.shape) == 4: return ret.permute(0,2,3,1).contiguous().permute(0,3,1,2)
729727
return ret.contiguous()
730728

731729
def DynamicQuantizeLinear(x: Tensor):
@@ -737,10 +735,6 @@ def DynamicQuantizeLinear(x: Tensor):
737735
return y, scale, zero_point
738736

739737
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
740-
WEIGHT_SHIFT = 4
741-
if getenv("NHWC") and len(x.shape) == 4 and x.shape[2:] == (1,1) and x.shape[1]%WEIGHT_SHIFT == 0:
742-
# DSP swizzle memory
743-
x = x.reshape(x.shape[0], x.shape[1]//WEIGHT_SHIFT, WEIGHT_SHIFT).permute(1,0,2).contiguous().permute(1,0,2).reshape(x.shape)
744738
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
745739
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
746740

tinygrad/codegen/expander.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,37 @@ def _gate_srcs(u:UOp, gate:UOp) -> UOp:
116116
(UPat(Ops.STORE, name="root"), create_gate),
117117
])
118118

119+
# **** IGNORE support ****
120+
121+
pm_store_ignore = PatternMatcher([
122+
(UPat().index(UPat(), UPat(name="mask")).store(UPat()).named("store"),
123+
lambda store,mask: store.replace(src=(store.src[0], UOp(Ops.IGNORE, src=(store.src[1], mask)))) if store.src[1].op is not Ops.IGNORE else None),
124+
])
125+
126+
pm_move_ignore = PatternMatcher([
127+
# IGNORE on SELF is nothing
128+
(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat(name="x"))), lambda x: x.const_like(True)),
129+
# IGNORE on a CONST is nothing
130+
(UPat(Ops.IGNORE, src=(UPat((Ops.CONST, Ops.VCONST), name="c"), UPat())), lambda c: c),
131+
# move the IGNOREs
132+
(UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST, Ops.VECTORIZE), name="alu"), UPat.var("mask")), name="ig"),
133+
lambda ig,alu,mask: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x, mask)) for x in alu.src))),
134+
])
135+
136+
pm_delete_ignore = PatternMatcher([
137+
# IGNORE on SELF is nothing
138+
(UPat(Ops.IGNORE, src=(UPat(name="x"), UPat())), lambda x: x),
139+
])
140+
119141
def expand_rewrite(sink:UOp) -> UOp:
120142
# initial symbolic + migrate indexing (remove this)
121143
sink = graph_rewrite(sink, sym+migrate_indexing)
122144

123-
# expand
124-
return graph_rewrite(sink, sym+expander)
145+
# store IGNORE
146+
sink = graph_rewrite(sink, pm_store_ignore, name="store_ignore")
147+
148+
# move IGNORE
149+
sink = graph_rewrite(sink, pm_move_ignore, name="move_ignore")
150+
151+
# expand + remove surviving ignores
152+
return graph_rewrite(sink, pm_delete_ignore+sym+expander)

tinygrad/renderer/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ def __post_init__(self):
111111
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
112112
if u.arg[0][0] == 'i': self.local_size = None
113113
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
114-
assert special_size is not None
115-
special_size[int(u.arg[0][-1])] = u.arg[1]
114+
if special_size is not None: special_size[int(u.arg[0][-1])] = u.arg[1]
116115
self.vars = sorted(self.vars, key=lambda v: v.arg)
117116
self.outs = sorted(dedup(self.outs))
118117
self.ins = sorted(dedup(self.ins))

0 commit comments

Comments
 (0)