Skip to content

Commit cb7a7f6

Browse files
authored
quantization preprocessor from DSP, should be universal (tinygrad#9437)
* quantization preprocessor from DSP, should be universal * touchups * fix tests
1 parent ca5064a commit cb7a7f6

File tree

7 files changed

+106
-20
lines changed

7 files changed

+106
-20
lines changed

.github/workflows/test.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@ jobs:
423423
run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
424424
- name: Test Additional ONNX Ops (CPU)
425425
run: CPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_ops.py
426+
- name: Test Quantize ONNX
427+
run: CPU=1 PYTHONPATH=. python3 test/test_quantize_onnx.py
426428
- name: Run CLOUD=1 Test
427429
run: |
428430
CLOUDDEV=CPU CLOUD=1 python3 test/test_tiny.py
@@ -467,7 +469,7 @@ jobs:
467469
testdsp:
468470
name: Linux (DSP)
469471
runs-on: ubuntu-24.04
470-
timeout-minutes: 10
472+
timeout-minutes: 15
471473
steps:
472474
- name: Checkout Code
473475
uses: actions/checkout@v4

extra/replay_pkl.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
2727
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
2828
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
29+
elif knum == 3:
30+
k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=128))
2931
else:
3032
k.hand_coded_optimizations()
3133
p2 = k.to_program()

test/test_quantize_onnx.py

+36-13
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import unittest
33
from dataclasses import replace
44
from tinygrad import Tensor, Context, Device, dtypes
5+
from tinygrad.ops import Ops
56
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
6-
from tinygrad.engine.realize import CompiledRunner, ExecItem
7+
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item
78

89
N = 512
910

@@ -44,24 +45,46 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
4445
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
4546
for _ in range(run_count): ei.run(wait=True)
4647

48+
def get_quantized_model(sz):
49+
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
50+
class FakeDataReader(CalibrationDataReader):
51+
def __init__(self): self.cnt = 0
52+
def get_next(self) -> dict:
53+
self.cnt += 1
54+
if self.cnt == 100: return None
55+
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
56+
out_file = "/tmp/test_out.onnx"
57+
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
58+
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
59+
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
60+
extra_options={"ActivationSymmetric": False})
61+
return out_file
62+
63+
@unittest.skipIf(Device.DEFAULT != "CPU", "only tests for CPU")
64+
class TestQuantizeOnnxCPU(unittest.TestCase):
65+
def test_quant_128(self, sz=128):
66+
try:
67+
import onnx
68+
except ImportError:
69+
raise unittest.SkipTest()
70+
from extra.onnx import OnnxRunner
71+
out_file = get_quantized_model(sz)
72+
onnx_model = onnx.load(out_file)
73+
run_onnx = OnnxRunner(onnx_model)
74+
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
75+
with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1):
76+
sched = run_onnx({"input":inp})["output"].schedule()
77+
ei = lower_schedule_item(sched[-2])
78+
daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_ACC]
79+
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
80+
4781
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
4882
class TestQuantizeOnnx(unittest.TestCase):
4983
def test_quant_128(self): self.test_quant(128)
5084
def test_quant(self, sz=512):
51-
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
5285
from examples.benchmark_onnx import load_onnx_model
53-
class FakeDataReader(CalibrationDataReader):
54-
def __init__(self): self.cnt = 0
55-
def get_next(self) -> dict:
56-
self.cnt += 1
57-
if self.cnt == 100: return None
58-
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
59-
out_file = "/tmp/test_out.onnx"
6086
# divide is ~1500-2000 without reduce_range, 750-900 with it
61-
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
62-
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
63-
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
64-
extra_options={"ActivationSymmetric": False})
87+
out_file = get_quantized_model(sz)
6588
run_onnx_jit, _ = load_onnx_model(out_file)
6689
with Context(DONT_REALIZE_EXPAND=1):
6790
run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)))

tinygrad/codegen/lowerer.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import functools, itertools, operator, math
33
from dataclasses import dataclass
44
from typing import cast
5-
from tinygrad.dtype import dtypes, PtrDType
6-
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
5+
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
6+
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop, GroupOp
77
from tinygrad.renderer import Renderer
8-
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
8+
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
99
from tinygrad.codegen.expander import expand_rewrite
10+
from tinygrad.codegen.symbolic import symbolic
1011

1112
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
1213
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
@@ -156,9 +157,65 @@ def lower_const(x:UOp):
156157
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
157158
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
158159
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
160+
(UPat(Ops.IGNORE, name="x"), lambda x: x.src[0]),
161+
])
162+
163+
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
164+
165+
def view_to_mask(x:UOp):
166+
from tinygrad.shape.shapetracker import ShapeTracker, View
167+
st = cast(ShapeTracker, x.st)
168+
if len(st.views) > 1: return None
169+
if st.views[-1].mask is None: return None
170+
return ShapeTracker((View(st.shape, (0,)*len(st.shape), 0, st.views[-1].mask, False),))
171+
172+
FP = (1 << 16)
173+
pm_quant = symbolic+PatternMatcher([
174+
# cast after add/mul
175+
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
176+
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
177+
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
178+
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
179+
# MUL after reduce
180+
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c),
181+
# CAST after reduce (doesn't work if it's a size change)
182+
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
183+
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
184+
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
185+
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
186+
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
187+
# mul 0 * c1 is 0
188+
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
189+
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
190+
# mul (with plus) 0 * c1 is 0
191+
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
192+
(UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \
193+
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
194+
lambda ld,v,c1: ld*c1),
195+
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
196+
((UPat.var("x").cast(dtypes.float)*UPat.cvar("c1")+UPat.cvar("c2")).cast(dtypes.int),
197+
lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP),
198+
# where move
199+
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
200+
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
201+
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
202+
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
203+
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
204+
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
205+
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
206+
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
207+
# don't care
208+
(UPat(Ops.STORE, name="x"), lambda x:
209+
x.replace(src=(x.src[0], UOp(Ops.IGNORE, src=(x.src[1],), arg=mm), UOp(Ops.IGNORE, x.src[2].dtype, src=(x.src[2],), arg=mm),)) \
210+
if x.src[1].op is not Ops.IGNORE and (mm:=view_to_mask(x.src[1])) is not None else None),
211+
(UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST), name="alu"),), name="ig"),
212+
lambda ig,alu: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x,), ig.arg) for x in alu.src))),
213+
(UPat(Ops.IGNORE, src=(UPat.cvar("c"),), name="ig"), lambda ig, c: c),
214+
(UPat(Ops.IGNORE, src=(UPat(Ops.VALID, name="v"),), name="ig"), lambda ig, v: UOp.const(dtypes.bool, True) if v.src[0].arg == ig.arg else None),
159215
])
160216

161217
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
218+
if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize")
162219
sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
163220
# expand_rewrite turns this into a vectorized program
164221
return expand_rewrite(sink)

tinygrad/helpers.py

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __lt__(self, x): return self.value < x
113113
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
114114
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
115115
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
116+
QUANTIZE = ContextVar("QUANTIZE", 0)
116117

117118
@dataclass(frozen=True)
118119
class Metadata:

tinygrad/ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ class Ops(FastEnum):
154154

155155
# CUSTOMI is inline
156156
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
157+
IGNORE = auto()
157158

158159
class GroupOp:
159160
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}

tinygrad/runtime/ops_dsp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
])
2121

2222
dsp_pm_late = PatternMatcher([
23-
(UPat.var("x")+UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")),
24-
(UPat.var("x")*UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")),
25-
(UPat.var("x")//UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")),
23+
(UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
24+
(UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
25+
(UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
2626
(UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
2727
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
2828
])

0 commit comments

Comments
 (0)