|
2 | 2 | import functools, itertools, operator, math
|
3 | 3 | from dataclasses import dataclass
|
4 | 4 | from typing import cast
|
5 |
| -from tinygrad.dtype import dtypes, PtrDType |
6 |
| -from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop |
| 5 | +from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype |
| 6 | +from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop, GroupOp |
7 | 7 | from tinygrad.renderer import Renderer
|
8 |
| -from tinygrad.helpers import all_int, prod, partition, flatten, unwrap |
| 8 | +from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE |
9 | 9 | from tinygrad.codegen.expander import expand_rewrite
|
| 10 | +from tinygrad.codegen.symbolic import symbolic |
10 | 11 |
|
11 | 12 | # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
12 | 13 | def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
@@ -156,9 +157,65 @@ def lower_const(x:UOp):
|
156 | 157 | # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
157 | 158 | (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
158 | 159 | (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
| 160 | + (UPat(Ops.IGNORE, name="x"), lambda x: x.src[0]), |
| 161 | +]) |
| 162 | + |
| 163 | +# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints **** |
| 164 | + |
| 165 | +def view_to_mask(x:UOp): |
| 166 | + from tinygrad.shape.shapetracker import ShapeTracker, View |
| 167 | + st = cast(ShapeTracker, x.st) |
| 168 | + if len(st.views) > 1: return None |
| 169 | + if st.views[-1].mask is None: return None |
| 170 | + return ShapeTracker((View(st.shape, (0,)*len(st.shape), 0, st.views[-1].mask, False),)) |
| 171 | + |
| 172 | +FP = (1 << 16) |
| 173 | +pm_quant = symbolic+PatternMatcher([ |
| 174 | + # cast after add/mul |
| 175 | + (UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32), |
| 176 | + lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)), |
| 177 | + (UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32), |
| 178 | + lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)), |
| 179 | + # MUL after reduce |
| 180 | + (UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c), |
| 181 | + # CAST after reduce (doesn't work if it's a size change) |
| 182 | + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"), |
| 183 | + lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None), |
| 184 | + # x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats) |
| 185 | + (UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats), |
| 186 | + lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None), |
| 187 | + # mul 0 * c1 is 0 |
| 188 | + (UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) * |
| 189 | + UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1), |
| 190 | + # mul (with plus) 0 * c1 is 0 |
| 191 | + (UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) * |
| 192 | + (UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \ |
| 193 | + UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"), |
| 194 | + lambda ld,v,c1: ld*c1), |
| 195 | + # fixed point mult, replace (x.float()*c1+c2).int() with an int expression |
| 196 | + ((UPat.var("x").cast(dtypes.float)*UPat.cvar("c1")+UPat.cvar("c2")).cast(dtypes.int), |
| 197 | + lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP), |
| 198 | + # where move |
| 199 | + (UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul: |
| 200 | + (yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None), |
| 201 | + ((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c), |
| 202 | + (UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid: |
| 203 | + (x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)), |
| 204 | + ((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) * |
| 205 | + UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2: |
| 206 | + x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))), |
| 207 | + # don't care |
| 208 | + (UPat(Ops.STORE, name="x"), lambda x: |
| 209 | + x.replace(src=(x.src[0], UOp(Ops.IGNORE, src=(x.src[1],), arg=mm), UOp(Ops.IGNORE, x.src[2].dtype, src=(x.src[2],), arg=mm),)) \ |
| 210 | + if x.src[1].op is not Ops.IGNORE and (mm:=view_to_mask(x.src[1])) is not None else None), |
| 211 | + (UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST), name="alu"),), name="ig"), |
| 212 | + lambda ig,alu: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x,), ig.arg) for x in alu.src))), |
| 213 | + (UPat(Ops.IGNORE, src=(UPat.cvar("c"),), name="ig"), lambda ig, c: c), |
| 214 | + (UPat(Ops.IGNORE, src=(UPat(Ops.VALID, name="v"),), name="ig"), lambda ig, v: UOp.const(dtypes.bool, True) if v.src[0].arg == ig.arg else None), |
159 | 215 | ])
|
160 | 216 |
|
161 | 217 | def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
| 218 | + if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize") |
162 | 219 | sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
163 | 220 | # expand_rewrite turns this into a vectorized program
|
164 | 221 | return expand_rewrite(sink)
|
0 commit comments