Skip to content

Commit 52ae9af

Browse files
authored
Fast DSP for MobileNetV2 (try 2) (tinygrad#9467)
* Fast DSP for MobileNetV2 (try 2) * enable fast path on uchar * fix tests
1 parent 15ee742 commit 52ae9af

File tree

6 files changed

+51
-16
lines changed

6 files changed

+51
-16
lines changed

extra/onnx.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,11 @@ def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
728728
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
729729
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
730730
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
731-
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
731+
if out_dtype == dtypes.uchar:
732+
# this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff
733+
return _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype).contiguous()
734+
else:
735+
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
732736

733737
def DynamicQuantizeLinear(x: Tensor):
734738
# only support uint8

extra/replay_pkl.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import pickle, sys
22
from dataclasses import replace
3-
from tinygrad import Device
3+
from tinygrad import Device, Context
4+
from tinygrad.device import Buffer
45
from tinygrad.helpers import getenv
56
from tinygrad.engine.jit import TinyJit
67
from tinygrad.engine.realize import CompiledRunner
78
from tinygrad.renderer import ProgramSpec
89
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
910

1011
if __name__ == "__main__":
11-
with open(sys.argv[1], "rb") as f:
12-
fxn: TinyJit = pickle.load(f)
13-
print(f"{f.tell()/1e6:.2f}M loaded")
14-
print(type(fxn))
12+
with Context(DEBUG=0):
13+
with open(sys.argv[1], "rb") as f:
14+
fxn: TinyJit = pickle.load(f)
15+
print(f"{f.tell()/1e6:.2f}M loaded")
16+
print(type(fxn))
1517

1618
knum = 1
1719
for ei in fxn.captured.jit_cache:
@@ -21,17 +23,33 @@
2123
p: ProgramSpec = ei.prg.p
2224
k = Kernel(p.ast, Device["DSP"].renderer)
2325
if not getenv("NOOPT"):
24-
if knum == 2:
26+
if knum in [6,7,9,11]:
27+
k.apply_opt(Opt(OptOps.PADTO, 1, 128))
28+
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
29+
elif knum in [5,8]:
2530
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
2631
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
2732
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
2833
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
34+
elif knum == 2:
35+
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
36+
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
37+
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
38+
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
39+
#k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=4))
40+
elif knum == 1:
41+
k.apply_opt(Opt(op=OptOps.UNROLL, axis=2, arg=0))
42+
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
43+
#k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
44+
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
45+
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
2946
elif knum == 3:
30-
k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=128))
47+
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=4))
48+
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
3149
else:
3250
k.hand_coded_optimizations()
51+
#if knum in [5]: k.apply_opt(Opt(OptOps.UPCAST, 1, 2))
3352
p2 = k.to_program()
34-
new_ei = replace(ei, prg=CompiledRunner(p2))
53+
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=[Buffer("DSP", 128+b.size*2, b.dtype).view(b.size, b.dtype, 128) for b in ei.bufs])
3554
new_ei.run()
3655
knum += 1
37-

tinygrad/codegen/devectorizer.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
4545
global_offset += len(grp)
4646
assert None not in idxs, f"some idxs are missing {idxs}"
4747
# this base thing is for image, we want the CAT to be a normal pointer
48-
return UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)).gep(tuple(cast(list[int], idxs)))
48+
post_cat = UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)) if len(ret) > 1 else ret[0]
49+
return post_cat.gep(tuple(cast(list[int], idxs)))
4950

5051
def cat_after_store(cat:UOp, data:UOp):
5152
# TODO: this is written in many places
@@ -143,7 +144,11 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
143144
if (sz:=ls.src[0].dtype.count) == 1: return None
144145
lengths = []
145146
buf = idx.src[0]
146-
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
147+
must_divide = True
148+
if ctx is not None and ctx.device == "DSP":
149+
lengths = [128,64,32,16,8,4]
150+
must_divide = False
151+
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
147152
pass
148153
elif isinstance(buf.dtype, ImageDType):
149154
lengths = [4]
@@ -158,7 +163,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
158163
for fold_length in lengths:
159164
if global_offset+fold_length > sz: continue
160165
oidx = idx.src[1] + global_offset
161-
if oidx.simplify().divides(fold_length) is None: continue
166+
if must_divide and oidx.simplify().divides(fold_length) is None: continue
162167
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
163168
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
164169
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))

tinygrad/codegen/kernel.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,12 @@ def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve
501501
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
502502

503503
# potentially do more upcasts of non reduce axes based on a heuristic
504+
is_dsp = self.opts is not None and self.opts.device == "DSP"
504505
upcasted_axis: set[int] = set()
505506
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
506507
xb_choices = []
507-
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
508+
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
509+
for axis, upcast_amount in itertools.product(range(self.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
508510
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
509511
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
510512
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501

tinygrad/renderer/cstyle.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ class ClangRenderer(CStyleLanguage):
197197
if sys.platform == 'win32':
198198
kernel_prefix = "__attribute__((ms_abi)) "
199199
def render_vector_prefix(self, dt:DType) -> str:
200-
# round (down) to power of two
201-
alignment = 2**int(math.log2(dt.itemsize))
200+
# round (down) to power of two (this is actually the default clang behavior)
201+
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) else 1
202202
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
203203

204204
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:

tinygrad/runtime/ops_dsp.py

+6
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,19 @@
2727
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
2828
])
2929

30+
# NOTE: this just increases readability of the generated code
31+
dsp_string = PatternMatcher([
32+
(UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)),
33+
])
34+
3035
class DSPRenderer(ClangRenderer):
3136
device = "DSP"
3237
supports_float4 = True
3338
buffer_suffix = " restrict __attribute__((align_value(128)))"
3439
kernel_prefix = "__attribute__((noinline)) "
3540
pre_matcher = dsp_pm
3641
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
42+
string_rewrite = dsp_string+ClangRenderer.string_rewrite
3743
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
3844
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
3945
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",

0 commit comments

Comments
 (0)