Skip to content

Commit 910ae26

Browse files
authored
dsp float4 fold + revectorize [pr] (tinygrad#8995)
* dsp float4 fold [pr] * revectorize * fix reg issue * no bool vectorize * cleanups * no need for that
1 parent e618efc commit 910ae26

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

tinygrad/codegen/rewriter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Any, Callable
22
import functools, itertools, operator
33
from collections import defaultdict
4+
from tinygrad.device import Device
45
from tinygrad.dtype import dtypes, ImageDType, PtrDType
56
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve
67
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
@@ -11,11 +12,18 @@
1112
# ***** float4/image store handling *****
1213

1314
def fold_expanded(ex, buf):
14-
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
1515
new_srcs = dedup(list(ex.src))
1616
old_new_srcs = new_srcs[:]
1717
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
1818

19+
# TODO: get the device from the buffer somehow
20+
if Device.DEFAULT == "DSP":
21+
if buf.dtype.base == dtypes.bool: return None
22+
lengths = [128,4]
23+
else:
24+
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
25+
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
26+
1927
# first, extract all the relevant offsets
2028
offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
2129
for i,s in enumerate(new_srcs):
@@ -30,7 +38,6 @@ def fold_expanded(ex, buf):
3038
offsets_rootsrc[root_src][arg] = i
3139

3240
# then rewrite everything we can
33-
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
3441
used: set[tuple[UOp, UOp]] = set()
3542
for rootsrc, offsets in offsets_rootsrc.items():
3643
for o in offsets:

tinygrad/renderer/cstyle.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
1818
(UPat(Ops.VECTORIZE, name="x"),
1919
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
20-
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
20+
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CLANG', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
2121
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
2222
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
2323
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
@@ -49,7 +49,8 @@
4949
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
5050
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
5151
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
52-
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
52+
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CLANG', 'DSP'} else \
53+
f".{'xyzwabcd'[x.arg[0]]}")),
5354
])
5455

5556
extra_pm = PatternMatcher([
@@ -104,7 +105,8 @@ def render_dtype(self, dt:DType, mutable=True) -> str:
104105
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
105106
if isinstance(dt, PtrDType):
106107
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
107-
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
108+
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
109+
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
108110

109111
def __getitem__(self, key): return self.r[key] # hacky helper
110112
def render(self, name:str, uops:list[UOp]) -> str:

tinygrad/runtime/ops_dsp.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,30 @@
99
from tinygrad.runtime.autogen import libc, qcom_dsp
1010
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
1111

12+
from tinygrad.helpers import all_same
13+
from tinygrad.ops import PatternMatcher, UPat, GroupOp
14+
15+
def revectorize(v:UOp):
16+
if not all_same([x.op for x in v.src]) or any(dtypes.is_bool(x.dtype) for x in v.src[0].src): return None
17+
new_srcs = [UOp(Ops.VECTORIZE, v.src[0].src[i].dtype.vec(v.dtype.count), tuple(x.src[i] for x in v.src)) for i in range(len(v.src[0].src))]
18+
return UOp(v.src[0].op, v.dtype, tuple(new_srcs), v.src[0].arg)
19+
20+
revectorize_pm = PatternMatcher([
21+
(UPat(Ops.VECTORIZE, src=UPat((*GroupOp.ALU, Ops.ASSIGN, Ops.CAST)), name="v"), revectorize),
22+
# vectorize DEFINE_ACC (similar to expander)
23+
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC), name="v"),
24+
lambda v: UOp(Ops.DEFINE_ACC, v.dtype,
25+
(UOp.broadcast(UOp.const(v.dtype.scalar(), v.src[0].src[0].arg), v.dtype.count),)+v.src[0].src[1:], v.src[0].arg)),
26+
# vectorize increasing GEPs = nothing (wrong if dtypes don't match!)
27+
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP), name="v"),
28+
lambda v: v.src[0].src[0] if all_same([x.src for x in v.src]) and \
29+
[x.arg[0] if len(x.arg) == 1 else None for x in v.src] == list(range(v.dtype.count)) else None),
30+
])
31+
1232
class DSPRenderer(ClangRenderer):
1333
device = "DSP"
14-
supports_float4 = False
34+
supports_float4 = True
35+
extra_matcher = revectorize_pm+ClangRenderer.extra_matcher
1536
buffer_suffix = " restrict __attribute__((align_value(128)))"
1637
kernel_prefix = "__attribute__((noinline)) "
1738
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
@@ -233,8 +254,8 @@ def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str
233254
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
234255
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
235256
msrc = ['''static long syscall(long r0, long r1, long r2, long r3, long r4, long r5, long r6) {
236-
long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = #%7; trap0(#1); %0 = r0" : "=r" (retval)
237-
: "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "i" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; }
257+
long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = %7; trap0(#1); %0 = r0" : "=r" (retval)
258+
: "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "r" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; }
238259
static int read(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 63); }}
239260
static int write(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 64); }}
240261
static int exit(int ret) {{ return syscall(ret, 0, 0, 0, 0, 0, 93); }}

0 commit comments

Comments
 (0)