Skip to content

Commit a3c78d4

Browse files
authored
speed docs + upgrades [pr] (tinygrad#8964)
* add some docs about speed [pr] * better torch gemm * enable locals on llvm/clang * disable locals for beam speed on LLVM/CLANG * 0x20 alignment in llvm allows ymm use
1 parent 5bdd6a1 commit a3c78d4

File tree

8 files changed

+103
-6
lines changed

8 files changed

+103
-6
lines changed

docs/developer/developer.md

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ The tinygrad framework has four pieces
77

88
There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-notes/) by Di Zhu that go over tinygrad internals.
99

10+
There's also a [doc describing speed](../developer/speed.md)
11+
1012
## Frontend
1113

1214
Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md).

docs/developer/speed.md

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# speed in tinygrad
2+
3+
## Overview
4+
5+
Speed refers to many different things. To break it down to four, there's:
6+
7+
- Compile Speed (Python)
8+
- Execution Speed (driver)
9+
- Model Speed (scheduler)
10+
- Kernel Speed (codegen)
11+
12+
## Compile Speed (Python)
13+
14+
This is how long the first run of your model takes. It's limited largely by the runtime of the Python doing UOp rewrites. Currently it's a bit slow, but on par with torch.compile. It gets even slower if you are using BEAM, since that's compiling many variants of each kernel.
15+
16+
This will be improved by writing faster graph_rewrite, doing less graph_rewrite, and better parallelization.
17+
18+
## Execution Speed (driver)
19+
20+
After your model is compiled, you are often using the `TinyJIT`. tinygrad has the best execution speed of any framework because it usually bypasses the GPU driver and prebuilds the command queue. It's tons faster than normal CUDA, and often even faster than CUDA Graph.
21+
22+
There's very little to improve here, as this is almost never the bottleneck.
23+
24+
## Model Speed (scheduler)
25+
26+
The scheduler determines how operations are grouped into kernels and which Tensors are written to memory. This is currently a big bottleneck of training speed.
27+
28+
The decisions are often not obvious. For example, when is it worth recomputing an arithmetic operation instead of storing and loading from memory? Example:
29+
30+
```python
31+
from tinygrad import Tensor
32+
a = Tensor.rand(100)
33+
b = Tensor.rand(100)
34+
c = Tensor.rand(100)
35+
d = Tensor.rand(100)
36+
out1 = a+b+c
37+
out2 = a+b+d
38+
Tensor.realize(out1, out2)
39+
```
40+
41+
The real answer is obvious, compute both `out1` and `out2` in the same kernel. But you can't always do that. If you can't, should `a+b` first be saved to a subbuffer? Or should both the `out1` and `out2` kernels recompute `a+b`?
42+
43+
In this case: with recompute (6 reads + 2 writes), no recompute (6 reads + 3 writes), so we should probably recompute. However, once you add movement ops and casts this is even harder to figure out. tinygrad doesn't yet have a systematic way to do it.
44+
45+
## Kernel Speed (codegen)
46+
47+
Given that you have decided how the model ops will be grouped and what will be written to memory, kernel speed determines how fast that operation is done. This is what BEAM changes, it searches over a set of equivalent kernels which all perform the same operation and finds the one which performs the task the fastest.
48+
49+
In `kernel.py` we have a set of `OptOps`, these control the parameters of the speed optimizations applied to the kernel.
50+
51+
### Memory
52+
53+
The main bottleneck in most kernels is accessing memory. In a freshman algorithms class, you'll learn about cache aware matrix multiplication, and this is all forms of that. While the same math is run, the order in which you run it can have large impacts on the speed depending on if the data you are loading. OptOps will change this order.
54+
55+
Memory, even cache, is often much slower than accessing the register file. The amount of times data is used in math is called the "arithmetic intensity". For operations like BS=1 GEMV, the arithmetic intensity is 1, but for GEMMs and convs it can be much higher. OptOps like UPCAST and UNROLL can increase this, but be careful of making them too large, as if there's too much register pressure on the GPU the warp scheduler may not be able to fit many warps, or even worse, it could be spilling to local memory.
56+
57+
4090s have 1 TB/s of ram bandwidth and ~160 TFLOPS of compute, so you need to use each loaded value ~100 times. The L1 cache has around 40 TB/s of bandwidth, so in order to get full compute utilization you need to use each value ~4 times.
58+
59+
A lot of work can still be done here. For example, we never copy the inputs to on chip SRAM, but this is often quite helpful for kernel speed. Also, we aren't doing a good job with L2 cache awareness (the locals handle L1 quite well)
60+
61+
### Tensor Cores
62+
63+
Many accelerators have Tensor Cores / MAC arrays / systolic arrays. The main value of these is that, since they are 2-D, they create an n^2 ratio between the compute and the input data.
64+
65+
GPUs use Tensor Cores instead of MAC arrays to fit better in the GPU warp paradigm. This is because the output of Tensor Cores is O(n) wrt the input, while the output of MAC arrays like the AMX is O(n^2)
66+
67+
We have a simple framework in tinygrad for adding these ALU blocks and achieving good performance from them.
68+
69+
### Indexing
70+
71+
Indexing determines the address of the memory we need to load. GPUs often have less integer math resources than floating point math, so this can sometimes be the bottleneck. We have a symbolic math engine in our rewrite rules to simplifiy indexing before it's emitted to the kernel. Newer NVIDIA GPUs have a "Tensor Memory Accelerator" to assist with fast indexing, however, this is not supported in tinygrad yet.

extra/gemm/torch_gemm.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
1+
import os
2+
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
3+
os.environ["MKL_NUM_THREADS"] = "1"
4+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
5+
os.environ["OMP_NUM_THREADS"] = "1"
16
import time
27
import torch
8+
torch.set_num_threads(1)
9+
from tinygrad.helpers import getenv
10+
CUDA = getenv("CUDA", 1)
311

4-
for dtype in [torch.float16, torch.float32]:
12+
for dtype in [torch.float32, torch.float16]:
513
for N in [256, 512, 1024, 2048, 4096]:
614
FLOPS = N*N*N*2
715

8-
b = torch.rand((N,N), dtype=dtype).cuda()
9-
c = torch.rand((N,N), dtype=dtype).cuda()
16+
b = torch.rand((N,N), dtype=dtype)
17+
c = torch.rand((N,N), dtype=dtype)
18+
if CUDA: b,c = b.cuda(),c.cuda()
1019

1120
def torch_prog(b, c):
1221
st = time.perf_counter()
1322
a = b@c
14-
torch.cuda.synchronize()
23+
if CUDA: torch.cuda.synchronize()
1524
return time.perf_counter() - st
1625
tm = min([torch_prog(b, c) for _ in range(20)])
1726
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS {N:4d}x{N:4d}x{N:4d} matmul in {dtype}")

mkdocs.yml

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ nav:
2222
- Runtime: runtime.md
2323
- Developer:
2424
- Intro: developer/developer.md
25+
- Speed: developer/speed.md
2526
- UOp: developer/uop.md
2627
- Runtime:
2728
- developer/runtime.md

test/test_linearizer.py

+10
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,16 @@ def test_reduce_upcast(self):
981981
assert len(stores) == 1
982982
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
983983

984+
# NOTE: can reenable, it does work. it just makes BEAM slow
985+
@unittest.expectedFailure
986+
@unittest.skipUnless(Device.DEFAULT == "CLANG", "test only for CLANG")
987+
def test_upcast_with_locals_clang(self):
988+
out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous()
989+
k = Kernel(out.schedule()[-1].ast)
990+
k.apply_opt(Opt(OptOps.LOCAL, axis=0, arg=4))
991+
prg = k.to_program()
992+
self.assertEqual(len(prg.src.split("for")), 5)
993+
984994
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
985995
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
986996
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")

tinygrad/codegen/kernel.py

+2
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,8 @@ def apply_opt(self, opt:Opt, append_opt:bool=True):
385385
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
386386

387387
if opt.op is OptOps.LOCAL: # cyan
388+
# NOTE: LLVM/CLANG can use locals too, but they are treated the same as globals (still helpful for L1 cache)
389+
# it's disabled for now since it makes BEAM slow for little gain
388390
check(self.opts.has_local, "target does not support local")
389391
check(axis < self.global_dims, "local is for globals")
390392
self.shift_to(axis, amt, insert_before=self.first_reduce)

tinygrad/device.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None):
207207

208208
class _MallocAllocator(LRUAllocator):
209209
def _alloc(self, size:int, options:BufferSpec):
210-
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, 16)
210+
# must be aligned to 0x20 for 256-bit ymm registers
211+
return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, 0x20)
211212
def _alloc_aligned(self, size:int, alignment:int):
212213
buffer = (ctypes.c_uint8 * (size + alignment))()
213214
offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer)

tinygrad/renderer/llvmir.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def render(self, name: str, uops: list[UOp]) -> str:
133133

134134
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
135135
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
136-
args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
136+
# NOTE: MallocAllocator promises 0x20 alignment
137+
args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
137138
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
138139
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
139140
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)

0 commit comments

Comments
 (0)