|
9 | 9 | from tinygrad.runtime.autogen import libc, qcom_dsp
|
10 | 10 | if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import
|
11 | 11 |
|
| 12 | +from tinygrad.helpers import all_same |
| 13 | +from tinygrad.ops import PatternMatcher, UPat, GroupOp |
| 14 | + |
| 15 | +def revectorize(v:UOp): |
| 16 | + if not all_same([x.op for x in v.src]) or any(dtypes.is_bool(x.dtype) for x in v.src[0].src): return None |
| 17 | + new_srcs = [UOp(Ops.VECTORIZE, v.src[0].src[i].dtype.vec(v.dtype.count), tuple(x.src[i] for x in v.src)) for i in range(len(v.src[0].src))] |
| 18 | + return UOp(v.src[0].op, v.dtype, tuple(new_srcs), v.src[0].arg) |
| 19 | + |
| 20 | +revectorize_pm = PatternMatcher([ |
| 21 | + (UPat(Ops.VECTORIZE, src=UPat((*GroupOp.ALU, Ops.ASSIGN, Ops.CAST)), name="v"), revectorize), |
| 22 | + # vectorize DEFINE_ACC (similar to expander) |
| 23 | + (UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC), name="v"), |
| 24 | + lambda v: UOp(Ops.DEFINE_ACC, v.dtype, |
| 25 | + (UOp.broadcast(UOp.const(v.dtype.scalar(), v.src[0].src[0].arg), v.dtype.count),)+v.src[0].src[1:], v.src[0].arg)), |
| 26 | + # vectorize increasing GEPs = nothing (wrong if dtypes don't match!) |
| 27 | + (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP), name="v"), |
| 28 | + lambda v: v.src[0].src[0] if all_same([x.src for x in v.src]) and \ |
| 29 | + [x.arg[0] if len(x.arg) == 1 else None for x in v.src] == list(range(v.dtype.count)) else None), |
| 30 | +]) |
| 31 | + |
12 | 32 | class DSPRenderer(ClangRenderer):
|
13 | 33 | device = "DSP"
|
14 |
| - supports_float4 = False |
| 34 | + supports_float4 = True |
| 35 | + extra_matcher = revectorize_pm+ClangRenderer.extra_matcher |
15 | 36 | buffer_suffix = " restrict __attribute__((align_value(128)))"
|
16 | 37 | kernel_prefix = "__attribute__((noinline)) "
|
17 | 38 | type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
@@ -233,8 +254,8 @@ def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str
|
233 | 254 | # https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
|
234 | 255 | # control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
|
235 | 256 | msrc = ['''static long syscall(long r0, long r1, long r2, long r3, long r4, long r5, long r6) {
|
236 |
| - long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = #%7; trap0(#1); %0 = r0" : "=r" (retval) |
237 |
| - : "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "i" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; } |
| 257 | + long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = %7; trap0(#1); %0 = r0" : "=r" (retval) |
| 258 | + : "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "r" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; } |
238 | 259 | static int read(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 63); }}
|
239 | 260 | static int write(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 64); }}
|
240 | 261 | static int exit(int ret) {{ return syscall(ret, 0, 0, 0, 0, 0, 93); }}
|
|
0 commit comments