From 9119716761efbbac26178f61a30bf09f137edc3e Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 9 Feb 2025 21:26:27 -0500 Subject: [PATCH 1/4] update Tensor.maximum (#8992) now it's just broadcast and UOp.maximum --- tinygrad/tensor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index eed97122c000c..2cc5bdbaf2554 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3307,9 +3307,7 @@ def maximum(self, x:Union[Tensor, ConstType]) -> Tensor: print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy()) ``` """ - # NOTE: the mid-point is for backward, revisit after new gradient API - if self.is_floating_point(): return (self Tensor: """ From 29832853159a63c0784aa4faf20b501a76f67a78 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 10 Feb 2025 11:07:35 +0800 Subject: [PATCH 2/4] use HEX_REG_QEMU_INSN_CNT from qemu as a DSP timer [pr] (#8993) * use HEX_REG_QEMU_INSN_CNT from qemu as a DSP timer [pr] * add quantize test to dsp * fix tests * older onnx * debug, let's see what's happening --- .github/actions/setup-tinygrad/action.yml | 8 ++- .github/workflows/test.yml | 8 ++- test/test_quantize_onnx.py | 88 +++++++++++++++++++++++ tinygrad/codegen/kernel.py | 2 +- tinygrad/runtime/ops_dsp.py | 12 ++-- 5 files changed, 108 insertions(+), 10 deletions(-) create mode 100644 test/test_quantize_onnx.py diff --git a/.github/actions/setup-tinygrad/action.yml b/.github/actions/setup-tinygrad/action.yml index 272de00df0304..7487461de936a 100644 --- a/.github/actions/setup-tinygrad/action.yml +++ b/.github/actions/setup-tinygrad/action.yml @@ -13,6 +13,10 @@ inputs: description: 'Extra dependency groups (comma separated)' required: false default: '' + pydeps: + description: 'Extra Python dependency groups (space separated)' + required: false + default: '' opencl: description: "Install OpenCL?" required: false @@ -83,11 +87,11 @@ runs: - name: Install dependencies (with extra) if: inputs.deps != '' shell: bash - run: pip install ${{ (runner.os == 'macOS' && '--user') || (runner.os != 'macOS' && '') }} -e ".[${{ inputs.deps }}]" --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ + run: pip install ${{ (runner.os == 'macOS' && '--user') || (runner.os != 'macOS' && '') }} -e ".[${{ inputs.deps }}]" ${{ inputs.pydeps }} --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ - name: Install dependencies (without extra) if: inputs.deps == '' shell: bash - run: pip install ${{ (runner.os == 'macOS' && '--user') || (runner.os != 'macOS' && '') }} -e . + run: pip install ${{ (runner.os == 'macOS' && '--user') || (runner.os != 'macOS' && '') }} -e . ${{ inputs.pydeps }} # **** OpenCL **** diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6496121013aaf..c101165221222 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,8 +42,7 @@ jobs: uses: ./.github/actions/setup-tinygrad with: deps: docs - - name: Install capstone for CLANG disassembly - run: pip install capstone + pydeps: "capstone" - name: Use as an external package run: | mkdir $HOME/test_external_dir @@ -403,6 +402,9 @@ jobs: uses: actions/checkout@v4 - name: Setup Environment uses: ./.github/actions/setup-tinygrad + with: + key: dsp + pydeps: "onnx==1.16.0 onnxruntime" - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Build QEMU Docker with cache @@ -416,6 +418,8 @@ jobs: cache-to: type=gha,mode=min - name: Run test_tiny on DSP run: DEBUG=2 DSP=1 python test/test_tiny.py + - name: Test quantize onnx + run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py testwebgpu: name: Linux (WebGPU) diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py new file mode 100644 index 0000000000000..96e9e226133a3 --- /dev/null +++ b/test/test_quantize_onnx.py @@ -0,0 +1,88 @@ +import numpy as np +import unittest +from tinygrad import Tensor, Context, Device +from tinygrad.codegen.kernel import Kernel, Opt, OptOps +from tinygrad.engine.realize import CompiledRunner, ExecItem + +N = 1024 + +def create_gemm_model(model_path:str, in_size=N, out_size=N): + import onnx + from onnx import helper, numpy_helper, TensorProto + # Define input and output + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, in_size]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, out_size]) + + # Create random weights and bias + W_data = np.random.randn(in_size, out_size).astype(np.float32) + B_data = np.random.randn(out_size).astype(np.float32) + + W_init = numpy_helper.from_array(W_data, name="W") + B_init = numpy_helper.from_array(B_data, name="B") + + gemm_node = helper.make_node("Gemm", inputs=["input", "W", "B"], outputs=["output"], alpha=1.0, beta=1.0, transB=0) + graph_def = helper.make_graph([gemm_node], "SingleGemmGraph", [input_tensor], [output_tensor], initializer=[W_init, B_init]) + + # Create and save the model + model_def = helper.make_model(graph_def, producer_name="single_gemm_example") + onnx.save_model(model_def, model_path) + return model_path + +def sexec(out:Tensor, opts:list[Opt]): + si = out.schedule()[-1] + k = Kernel(si.ast, opts=Device[Device.DEFAULT].renderer) + #opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128)] #, Opt(op=OptOps.UNROLL, axis=0, arg=4)] + for opt in opts: k.apply_opt(opt) + prg = k.to_program() + ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata) + for _ in range(3): ei.run(wait=True) + +@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP") +class TestQuantizeOnnx(unittest.TestCase): + def test_quant(self): + from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader + from examples.benchmark_onnx import load_onnx_model + class FakeDataReader(CalibrationDataReader): + def __init__(self): self.cnt = 0 + def get_next(self) -> dict: + self.cnt += 1 + if self.cnt == 100: return None + return {"input": np.random.uniform(size=(1, N)).astype(np.float32)} + out_file = "/tmp/test_out.onnx" + quantize_static(create_gemm_model("/tmp/test_in.onnx"), out_file, + FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, + activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, + extra_options={"ActivationSymmetric": True}) + run_onnx_jit, _ = load_onnx_model(out_file) + with Context(NOOPT=1): + run_onnx_jit(input=Tensor(np.random.uniform(size=(1, N)).astype(np.float32))) + + def test_prequant_conv2d_1x1(self): + X = Tensor(np.random.uniform(0, 255, size=(1, 32, 128, 128)).astype(np.uint8)) + W = Tensor(np.random.uniform(0, 255, size=(64, 32, 1, 1)).astype(np.uint8)) + out = X.conv2d(W, acc_dtype=X.dtype) + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] + sexec(out, opts) + + def test_prequant_gemm(self): + N = 512 + # ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant + X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)) + W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)) + out = X.matmul(W, acc_dtype=X.dtype) + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] + sexec(out, opts) + + def test_prequant_gemv(self): + N = 2048 + # ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant + X = Tensor(np.random.uniform(0, 255, size=(1,N)).astype(np.uint8)) + W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)) + #out = X.cast(dtypes.int) @ W.cast(dtypes.int) + #out = X @ W + out = X.matmul(W, acc_dtype=X.dtype) + opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] + sexec(out, opts) + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index a9df66b8101fd..66668ae5fc1d5 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -411,7 +411,7 @@ def apply_opt(self, opt:Opt, append_opt:bool=True): elif opt.op is OptOps.UPCAST: # yellow check(axis < self.first_reduce, "upcast is for non-reduce") check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals") - check(amt <= 16, "don't upcast more than 16") + check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16") self.shift_to(axis, amt, insert_before=None) self.upcast() elif opt.op is OptOps.NOLOCALS: diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index 588e6fbf2d26d..e15f9b12699d4 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -1,5 +1,5 @@ from __future__ import annotations -import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, time, struct +import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, struct assert sys.platform != 'win32' from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator from tinygrad.dtype import dtypes, DType, PtrDType @@ -231,12 +231,14 @@ class MockDSPRenderer(DSPRenderer): def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str: ret = ClangRenderer.render_kernel(self, function_name, kernel, bufs, uops, prefix) # https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html + # control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it msrc = ['''static long syscall(long r0, long r1, long r2, long r3, long r4, long r5, long r6) { long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = #%7; trap0(#1); %0 = r0" : "=r" (retval) : "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "i" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; } static int read(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 63); }} static int write(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 64); }} static int exit(int ret) {{ return syscall(ret, 0, 0, 0, 0, 0, 93); }} + static unsigned int inscount(void) {{ unsigned int ret; __asm__ volatile(".word 0x6a15c000; %0 = R0" : "=r" (ret) : : "r0"); return ret; }} static void *mmap2(void *addr, unsigned int length, int prot, int flags, int fd, unsigned long offset) {{ return (void*)syscall((long)addr, length, prot, flags, fd, offset, 222); }}''', 'void _start(void) {'] for i,b in enumerate(bufs): @@ -245,7 +247,9 @@ def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); read(0, buf{i}, {sz});") else: msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);") + msrc.append("unsigned int st = inscount();") msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])});") + msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));") for i,b in enumerate(bufs): if isinstance(b[1][0], PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].size*b[1][0].itemsize});") msrc.append('exit(0); }') @@ -259,13 +263,11 @@ def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False): dsp_lib.flush() os.chmod(dsp_lib.name, 0o0777) # NOTE: this timing includes a docker launch - start = time.perf_counter() proc = subprocess.run(["docker", "run", "--rm", "-i", "-v", f"{os.path.abspath(os.path.dirname(dsp_lib.name))}:/work", "-w", "/work", "qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 3 else ''} /work/"+os.path.basename(dsp_lib.name)], input=b''.join([bytes(x) for x in bufs] + [struct.pack("I", x) for x in vals]), stdout=subprocess.PIPE, check=True) - elapsed = time.perf_counter() - start - offset = 0 + offset = 4 for x in bufs: x[:] = proc.stdout[offset:offset+len(x)] offset += len(x) - return elapsed + return struct.unpack("I", proc.stdout[0:4])[0] / 1e9 # pretend it's 1 Ghz, but this is an inscount, not a time From e618efce226207bc6b1b9d2a1f4b88e9bbe4b9a8 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 10 Feb 2025 12:01:28 +0800 Subject: [PATCH 3/4] COMMUTATIVE flipping is only for ints (#8996) * COMMUTATIVE flipping is only for ints [pr] * no pr * comm fixes this --- test/test_linearizer_failures.py | 4 ++-- tinygrad/ops.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index b1b385029b50f..8bacd0b6203bc 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1368,7 +1368,7 @@ def test_failure_56(self): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=2, arg=32)] - helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"]) + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_57(self): ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( @@ -1409,7 +1409,7 @@ def test_failure_57(self): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32)] - helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"]) + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 503a707bd7d8e..353f51731e000 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1170,8 +1170,8 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, ]) symbolic = symbolic_simple+PatternMatcher([ - # ** COMMUTATIVE flipping ** - (UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), + # ** COMMUTATIVE flipping (only for ints) ** + (UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), # ** boolean algebra ** (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x # ** combine terms ** From 910ae260cd1d45a1326299081e6cc70832cfd21f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 10 Feb 2025 12:14:32 +0800 Subject: [PATCH 4/4] dsp float4 fold + revectorize [pr] (#8995) * dsp float4 fold [pr] * revectorize * fix reg issue * no bool vectorize * cleanups * no need for that --- tinygrad/codegen/rewriter.py | 11 +++++++++-- tinygrad/renderer/cstyle.py | 8 +++++--- tinygrad/runtime/ops_dsp.py | 27 ++++++++++++++++++++++++--- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index 7023fd0fe9b3f..31558df9ca1ab 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -1,6 +1,7 @@ from typing import Optional, Any, Callable import functools, itertools, operator from collections import defaultdict +from tinygrad.device import Device from tinygrad.dtype import dtypes, ImageDType, PtrDType from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp @@ -11,11 +12,18 @@ # ***** float4/image store handling ***** def fold_expanded(ex, buf): - if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None new_srcs = dedup(list(ex.src)) old_new_srcs = new_srcs[:] is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType) + # TODO: get the device from the buffer somehow + if Device.DEFAULT == "DSP": + if buf.dtype.base == dtypes.bool: return None + lengths = [128,4] + else: + if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None + lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])) + # first, extract all the relevant offsets offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict) for i,s in enumerate(new_srcs): @@ -30,7 +38,6 @@ def fold_expanded(ex, buf): offsets_rootsrc[root_src][arg] = i # then rewrite everything we can - lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])) used: set[tuple[UOp, UOp]] = set() for rootsrc, offsets in offsets_rootsrc.items(): for o in offsets: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 8c7131d0760c1..fb0dd69e9f459 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -17,7 +17,7 @@ lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"), (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \ - (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")), + (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CLANG', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")), (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"), @@ -49,7 +49,8 @@ (UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op]( *([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)), (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \ - (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")), + (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CLANG', 'DSP'} else \ + f".{'xyzwabcd'[x.arg[0]]}")), ]) extra_pm = PatternMatcher([ @@ -104,7 +105,8 @@ def render_dtype(self, dt:DType, mutable=True) -> str: if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t" if isinstance(dt, PtrDType): return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*" - return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "") + if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count) + return self.type_map.get(scalar:=dt.scalar(), scalar.name) def __getitem__(self, key): return self.r[key] # hacky helper def render(self, name:str, uops:list[UOp]) -> str: diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index e15f9b12699d4..e323012935377 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -9,9 +9,30 @@ from tinygrad.runtime.autogen import libc, qcom_dsp if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import +from tinygrad.helpers import all_same +from tinygrad.ops import PatternMatcher, UPat, GroupOp + +def revectorize(v:UOp): + if not all_same([x.op for x in v.src]) or any(dtypes.is_bool(x.dtype) for x in v.src[0].src): return None + new_srcs = [UOp(Ops.VECTORIZE, v.src[0].src[i].dtype.vec(v.dtype.count), tuple(x.src[i] for x in v.src)) for i in range(len(v.src[0].src))] + return UOp(v.src[0].op, v.dtype, tuple(new_srcs), v.src[0].arg) + +revectorize_pm = PatternMatcher([ + (UPat(Ops.VECTORIZE, src=UPat((*GroupOp.ALU, Ops.ASSIGN, Ops.CAST)), name="v"), revectorize), + # vectorize DEFINE_ACC (similar to expander) + (UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC), name="v"), + lambda v: UOp(Ops.DEFINE_ACC, v.dtype, + (UOp.broadcast(UOp.const(v.dtype.scalar(), v.src[0].src[0].arg), v.dtype.count),)+v.src[0].src[1:], v.src[0].arg)), + # vectorize increasing GEPs = nothing (wrong if dtypes don't match!) + (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP), name="v"), + lambda v: v.src[0].src[0] if all_same([x.src for x in v.src]) and \ + [x.arg[0] if len(x.arg) == 1 else None for x in v.src] == list(range(v.dtype.count)) else None), +]) + class DSPRenderer(ClangRenderer): device = "DSP" - supports_float4 = False + supports_float4 = True + extra_matcher = revectorize_pm+ClangRenderer.extra_matcher buffer_suffix = " restrict __attribute__((align_value(128)))" kernel_prefix = "__attribute__((noinline)) " type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" } @@ -233,8 +254,8 @@ def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str # https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html # control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it msrc = ['''static long syscall(long r0, long r1, long r2, long r3, long r4, long r5, long r6) { - long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = #%7; trap0(#1); %0 = r0" : "=r" (retval) - : "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "i" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; } + long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = %7; trap0(#1); %0 = r0" : "=r" (retval) + : "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "r" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; } static int read(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 63); }} static int write(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 64); }} static int exit(int ret) {{ return syscall(ret, 0, 0, 0, 0, 0, 93); }}