From a3c78d47b3d8a48384d4604752929ea4ff5683e6 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 8 Feb 2025 17:28:52 +0800 Subject: [PATCH 1/4] speed docs + upgrades [pr] (#8964) * add some docs about speed [pr] * better torch gemm * enable locals on llvm/clang * disable locals for beam speed on LLVM/CLANG * 0x20 alignment in llvm allows ymm use --- docs/developer/developer.md | 2 ++ docs/developer/speed.md | 71 +++++++++++++++++++++++++++++++++++++ extra/gemm/torch_gemm.py | 17 ++++++--- mkdocs.yml | 1 + test/test_linearizer.py | 10 ++++++ tinygrad/codegen/kernel.py | 2 ++ tinygrad/device.py | 3 +- tinygrad/renderer/llvmir.py | 3 +- 8 files changed, 103 insertions(+), 6 deletions(-) create mode 100644 docs/developer/speed.md diff --git a/docs/developer/developer.md b/docs/developer/developer.md index 39e9e0901b593..f932f0a935198 100644 --- a/docs/developer/developer.md +++ b/docs/developer/developer.md @@ -7,6 +7,8 @@ The tinygrad framework has four pieces There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-notes/) by Di Zhu that go over tinygrad internals. +There's also a [doc describing speed](../developer/speed.md) + ## Frontend Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md). diff --git a/docs/developer/speed.md b/docs/developer/speed.md new file mode 100644 index 0000000000000..e4801e6418bed --- /dev/null +++ b/docs/developer/speed.md @@ -0,0 +1,71 @@ +# speed in tinygrad + +## Overview + +Speed refers to many different things. To break it down to four, there's: + +- Compile Speed (Python) +- Execution Speed (driver) +- Model Speed (scheduler) +- Kernel Speed (codegen) + +## Compile Speed (Python) + +This is how long the first run of your model takes. It's limited largely by the runtime of the Python doing UOp rewrites. Currently it's a bit slow, but on par with torch.compile. It gets even slower if you are using BEAM, since that's compiling many variants of each kernel. + +This will be improved by writing faster graph_rewrite, doing less graph_rewrite, and better parallelization. + +## Execution Speed (driver) + +After your model is compiled, you are often using the `TinyJIT`. tinygrad has the best execution speed of any framework because it usually bypasses the GPU driver and prebuilds the command queue. It's tons faster than normal CUDA, and often even faster than CUDA Graph. + +There's very little to improve here, as this is almost never the bottleneck. + +## Model Speed (scheduler) + +The scheduler determines how operations are grouped into kernels and which Tensors are written to memory. This is currently a big bottleneck of training speed. + +The decisions are often not obvious. For example, when is it worth recomputing an arithmetic operation instead of storing and loading from memory? Example: + +```python +from tinygrad import Tensor +a = Tensor.rand(100) +b = Tensor.rand(100) +c = Tensor.rand(100) +d = Tensor.rand(100) +out1 = a+b+c +out2 = a+b+d +Tensor.realize(out1, out2) +``` + +The real answer is obvious, compute both `out1` and `out2` in the same kernel. But you can't always do that. If you can't, should `a+b` first be saved to a subbuffer? Or should both the `out1` and `out2` kernels recompute `a+b`? + +In this case: with recompute (6 reads + 2 writes), no recompute (6 reads + 3 writes), so we should probably recompute. However, once you add movement ops and casts this is even harder to figure out. tinygrad doesn't yet have a systematic way to do it. + +## Kernel Speed (codegen) + +Given that you have decided how the model ops will be grouped and what will be written to memory, kernel speed determines how fast that operation is done. This is what BEAM changes, it searches over a set of equivalent kernels which all perform the same operation and finds the one which performs the task the fastest. + +In `kernel.py` we have a set of `OptOps`, these control the parameters of the speed optimizations applied to the kernel. + +### Memory + +The main bottleneck in most kernels is accessing memory. In a freshman algorithms class, you'll learn about cache aware matrix multiplication, and this is all forms of that. While the same math is run, the order in which you run it can have large impacts on the speed depending on if the data you are loading. OptOps will change this order. + +Memory, even cache, is often much slower than accessing the register file. The amount of times data is used in math is called the "arithmetic intensity". For operations like BS=1 GEMV, the arithmetic intensity is 1, but for GEMMs and convs it can be much higher. OptOps like UPCAST and UNROLL can increase this, but be careful of making them too large, as if there's too much register pressure on the GPU the warp scheduler may not be able to fit many warps, or even worse, it could be spilling to local memory. + +4090s have 1 TB/s of ram bandwidth and ~160 TFLOPS of compute, so you need to use each loaded value ~100 times. The L1 cache has around 40 TB/s of bandwidth, so in order to get full compute utilization you need to use each value ~4 times. + +A lot of work can still be done here. For example, we never copy the inputs to on chip SRAM, but this is often quite helpful for kernel speed. Also, we aren't doing a good job with L2 cache awareness (the locals handle L1 quite well) + +### Tensor Cores + +Many accelerators have Tensor Cores / MAC arrays / systolic arrays. The main value of these is that, since they are 2-D, they create an n^2 ratio between the compute and the input data. + +GPUs use Tensor Cores instead of MAC arrays to fit better in the GPU warp paradigm. This is because the output of Tensor Cores is O(n) wrt the input, while the output of MAC arrays like the AMX is O(n^2) + +We have a simple framework in tinygrad for adding these ALU blocks and achieving good performance from them. + +### Indexing + +Indexing determines the address of the memory we need to load. GPUs often have less integer math resources than floating point math, so this can sometimes be the bottleneck. We have a symbolic math engine in our rewrite rules to simplifiy indexing before it's emitted to the kernel. Newer NVIDIA GPUs have a "Tensor Memory Accelerator" to assist with fast indexing, however, this is not supported in tinygrad yet. diff --git a/extra/gemm/torch_gemm.py b/extra/gemm/torch_gemm.py index a87f2757eca43..6dde87198075a 100644 --- a/extra/gemm/torch_gemm.py +++ b/extra/gemm/torch_gemm.py @@ -1,17 +1,26 @@ +import os +os.environ["NVIDIA_TF32_OVERRIDE"] = "0" +os.environ["MKL_NUM_THREADS"] = "1" +os.environ["NUMEXPR_NUM_THREADS"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" import time import torch +torch.set_num_threads(1) +from tinygrad.helpers import getenv +CUDA = getenv("CUDA", 1) -for dtype in [torch.float16, torch.float32]: +for dtype in [torch.float32, torch.float16]: for N in [256, 512, 1024, 2048, 4096]: FLOPS = N*N*N*2 - b = torch.rand((N,N), dtype=dtype).cuda() - c = torch.rand((N,N), dtype=dtype).cuda() + b = torch.rand((N,N), dtype=dtype) + c = torch.rand((N,N), dtype=dtype) + if CUDA: b,c = b.cuda(),c.cuda() def torch_prog(b, c): st = time.perf_counter() a = b@c - torch.cuda.synchronize() + if CUDA: torch.cuda.synchronize() return time.perf_counter() - st tm = min([torch_prog(b, c) for _ in range(20)]) print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}") diff --git a/mkdocs.yml b/mkdocs.yml index 38419a57081d0..a09a4b47fcd8c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,6 +22,7 @@ nav: - Runtime: runtime.md - Developer: - Intro: developer/developer.md + - Speed: developer/speed.md - UOp: developer/uop.md - Runtime: - developer/runtime.md diff --git a/test/test_linearizer.py b/test/test_linearizer.py index a5b5519c8a7ff..596f502238a2b 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -981,6 +981,16 @@ def test_reduce_upcast(self): assert len(stores) == 1 assert stores[0].src[-1].dtype == dtypes.float.vec(4) + # NOTE: can reenable, it does work. it just makes BEAM slow + @unittest.expectedFailure + @unittest.skipUnless(Device.DEFAULT == "CLANG", "test only for CLANG") + def test_upcast_with_locals_clang(self): + out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous() + k = Kernel(out.schedule()[-1].ast) + k.apply_opt(Opt(OptOps.LOCAL, axis=0, arg=4)) + prg = k.to_program() + self.assertEqual(len(prg.src.split("for")), 5) + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index b137bc387c4c7..a9df66b8101fd 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -385,6 +385,8 @@ def apply_opt(self, opt:Opt, append_opt:bool=True): check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}") if opt.op is OptOps.LOCAL: # cyan + # NOTE: LLVM/CLANG can use locals too, but they are treated the same as globals (still helpful for L1 cache) + # it's disabled for now since it makes BEAM slow for little gain check(self.opts.has_local, "target does not support local") check(axis < self.global_dims, "local is for globals") self.shift_to(axis, amt, insert_before=self.first_reduce) diff --git a/tinygrad/device.py b/tinygrad/device.py index 2fd2607c2c775..68ffa9fdc03bf 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -207,7 +207,8 @@ def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None): class _MallocAllocator(LRUAllocator): def _alloc(self, size:int, options:BufferSpec): - return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, 16) + # must be aligned to 0x20 for 256-bit ymm registers + return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, 0x20) def _alloc_aligned(self, size:int, alignment:int): buffer = (ctypes.c_uint8 * (size + alignment))() offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 6a592458759cd..8ba14a7c37881 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -133,7 +133,8 @@ def render(self, name: str, uops: list[UOp]) -> str: if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}" - args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}") + # NOTE: MallocAllocator promises 0x20 alignment + args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}") elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype) From 9b9c1e14da3dc4b21a660b9deeabc6144d6b5270 Mon Sep 17 00:00:00 2001 From: uuuvn <83587632+uuuvn@users.noreply.github.com> Date: Sat, 8 Feb 2025 14:29:23 +0500 Subject: [PATCH 2/4] Late MTLCompiler load (#8963) Moved loading MTLCompiler (and trying to load normal llvm before it) to MetalCompiler, like in CPUProgram with helper --- tinygrad/runtime/ops_metal.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 3438e40713bc3..6a7e195f74382 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -4,13 +4,6 @@ from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, cpu_profile, ProfileDeviceEvent, ProfileRangeEvent from tinygrad.renderer.cstyle import MetalRenderer -# Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL -# This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL -# library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there -# doesn't seem to be anything we can do. -with contextlib.suppress(FileNotFoundError): - import tinygrad.runtime.autogen.llvm # noqa: F401 - class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup def __hash__(self): return hash(self.value) def __eq__(self, other): return self.value == other.value @@ -34,14 +27,12 @@ class MTLPipelineOption: libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib") libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal") -compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler") # Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics") libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac libobjc.objc_getClass.restype = objc_id libobjc.sel_registerName.restype = objc_id libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance -compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p libdispatch.dispatch_data_create.restype = objc_instance @functools.lru_cache(None) @@ -102,8 +93,17 @@ def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance: return library class MetalCompiler(Compiler): + # Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL + # This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL + # library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there + # doesn't seem to be anything we can do. + with contextlib.suppress(FileNotFoundError): + import tinygrad.runtime.autogen.llvm # noqa: F401 + support = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler") + support.MTLCodeGenServiceCreate.restype = ctypes.c_void_p + def __init__(self): - self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad")) + self.cgs = ctypes.c_void_p(MetalCompiler.support.MTLCodeGenServiceCreate(b"tinygrad")) super().__init__("compile_metal_direct") def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork def compile(self, src:str) -> bytes: @@ -127,7 +127,7 @@ def callback(blockptr, error, dataPtr, dataLen, errorMessage): # See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout. # Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first # argument and pretend it's a normal callback - compiler.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10)) + MetalCompiler.support.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10)) if isinstance(ret, Exception): raise ret assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}" return ret From e7182bbb2cb78d9f834ec60dee56bb90925a9133 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 8 Feb 2025 11:57:38 +0100 Subject: [PATCH 3/4] fix "fatal bad object" log in process replay [pr] (#8966) --- .github/actions/process-replay/action.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/actions/process-replay/action.yml b/.github/actions/process-replay/action.yml index 35eb6c9c85381..2a05afd13ed07 100644 --- a/.github/actions/process-replay/action.yml +++ b/.github/actions/process-replay/action.yml @@ -7,7 +7,9 @@ runs: shell: bash run: | export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH") - export COMMIT_MESSAGE=$(git show -s --format=%B ${{ github.event.pull_request.head.sha }}) + export CURRENT_SHA=${{ github.event.pull_request && github.event.pull_request.head.sha || github.sha }} + git fetch origin $CURRENT_SHA + export COMMIT_MESSAGE=$(git show -s --format=%B "$CURRENT_SHA") export CURRENT_HEAD=$(git rev-parse HEAD) cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py - git checkout $CURRENT_HEAD # restore to branch \ No newline at end of file + git checkout $CURRENT_HEAD # restore to branch From 0cac941af1005fc7ecc1d61ec56656c53be0aa01 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 8 Feb 2025 10:09:24 -0500 Subject: [PATCH 4/4] move xpow to sym instead of late_rewrite (#8968) does not need to be in late_rewrite and can be simplified further --- tinygrad/codegen/rewriter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index 974581f2d9c20..7023fd0fe9b3f 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -123,7 +123,6 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: def get_late_rewrite_patterns(ops, force_transcendental=False): pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \ ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental] - pat.append((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))) # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] @@ -297,6 +296,8 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): # ** where ** # push cast to branches (UPat.var("s").where(UPat.var("a"), UPat.var("b")).cast().named("cast"), lambda s,a,b,cast: s.where(a.cast(cast.dtype), b.cast(cast.dtype))), + # ** pow ** + ((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))), # ** load/store folding ** (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)), (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),