Skip to content

Commit a8e54df

Browse files
authored
benchmark single kernel launch (tinygrad#8921)
* benchmark kernel launch * don't realize unneeded * faster * faster metal * fix mypy * without sync * no div 0 * lru cache that * no sync in the profile
1 parent 3e082d4 commit a8e54df

File tree

6 files changed

+54
-9
lines changed

6 files changed

+54
-9
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import time
2+
from tinygrad import Tensor, TinyJit, Device, Context
3+
from tinygrad.helpers import Profiling, Timing, GlobalCounters
4+
5+
# python3 test/test_speed_v_torch.py TestSpeed.test_add_a
6+
7+
@TinyJit
8+
def plus(a:Tensor, b:Tensor): return a+b
9+
10+
if __name__ == "__main__":
11+
a = Tensor([1]).realize()
12+
b = Tensor([1]).realize()
13+
for i in range(5):
14+
with Timing(prefix=f"{i}:"):
15+
c = plus(a,b)
16+
Device[c.device].synchronize()
17+
assert c.item() == 2
18+
for i in range(5):
19+
st = time.perf_counter()
20+
c = plus(a,b)
21+
et = time.perf_counter() - st
22+
print(f"nosync {i}: {et*1e6:.2f} us")
23+
Device[c.device].synchronize()
24+
for i in range(5):
25+
st = time.perf_counter()
26+
c = plus(a,b)
27+
Device[c.device].synchronize()
28+
et = time.perf_counter() - st
29+
print(f"precise {i}: {et*1e6:.2f} us")
30+
assert GlobalCounters.time_sum_s == 0
31+
with Context(DEBUG=2):
32+
st = time.perf_counter()
33+
c = plus(a,b)
34+
Device[c.device].synchronize()
35+
et = time.perf_counter() - st
36+
print(f"kernel {GlobalCounters.time_sum_s*1e3:.2f} ms / full {et*1e3:.2f} ms -- {et/(GlobalCounters.time_sum_s+1e-12):.2f} x")
37+
with Profiling():
38+
c = plus(a,b)

test/test_speed_v_torch.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,12 @@ def test_mul_sum(self):
202202
def f(a, b): return (a*b).sum()
203203
helper_test_generic_square('mul_sum', 4096, f, f)
204204

205-
def test_add(self):
206-
for N in [1, 1024, 4096]:
205+
def test_add_a(self):
206+
def f(a, b): return a + b
207+
helper_test_generic_square('add', 1, f, f)
208+
209+
def test_add_big(self):
210+
for N in [1024, 4096]:
207211
def f(a, b): return a + b
208212
helper_test_generic_square('add', N, f, f)
209213

tinygrad/engine/jit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def __call__(self, input_buffers:list[Buffer], var_vals:dict[Variable, int]) ->
193193
def _prepare_jit_inputs(args, kwargs):
194194
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
195195
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
196-
if tensors: Tensor.realize(*tensors)
196+
if len(unrealized_tensors := [x for x in tensors if not x.lazydata.is_realized]): Tensor.realize(*unrealized_tensors)
197197
# TODO: should we be unpacking multi here?
198198
lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
199199
input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]

tinygrad/ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,8 @@ def metadata(self): return all_metadata.get(self, None)
512512

513513
@property
514514
def base(self) -> UOp:
515-
if self.op in GroupOp.Movement: return self.src[0].base
516-
return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 else self
515+
if (self.op is Ops.VIEW and len(self.src) == 1) or self.op in GroupOp.Movement: return self.src[0].base
516+
return self
517517
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
518518

519519
def _mop(self, op:Ops, arg):

tinygrad/runtime/ops_metal.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id)
4242
sender.restype = restype
4343
return sender(ptr, sel(selector), *args)
4444

45+
@functools.lru_cache(None)
4546
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
4647
def from_ns_str(s): return bytes(msg(s, "UTF8String", restype=ctypes.c_char_p)).decode()
4748

@@ -146,21 +147,22 @@ def __init__(self, dev:MetalDevice, name:str, lib:bytes):
146147
self.pipeline_state = msg(self.dev.sysdevice, "newComputePipelineStateWithDescriptor:options:reflection:error:",
147148
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()), restype=objc_instance)
148149
error_check(error_pipeline_creation)
150+
# cache these msg calls
151+
self.max_total_threads: int = cast(int, msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong))
149152

150153
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
151-
max_total_threads = msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong)
152-
if prod(local_size) > cast(int, max_total_threads):
154+
if prod(local_size) > self.max_total_threads:
153155
exec_width = msg(self.pipeline_state, "threadExecutionWidth", restype=ctypes.c_ulong)
154156
memory_length = msg(self.pipeline_state, "staticThreadgroupMemoryLength", restype=ctypes.c_ulong)
155-
raise RuntimeError(f"local size {local_size} bigger than {max_total_threads} with exec width {exec_width} memory length {memory_length}")
157+
raise RuntimeError(f"local size {local_size} bigger than {self.max_total_threads} with exec width {exec_width} memory length {memory_length}")
156158
command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
157159
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
158160
msg(encoder, "setComputePipelineState:", self.pipeline_state)
159161
for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
160162
for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
161163
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
162164
msg(encoder, "endEncoding")
163-
msg(command_buffer, "setLabel:", to_ns_str(self.name))
165+
msg(command_buffer, "setLabel:", to_ns_str(self.name)) # TODO: is this always needed?
164166
msg(command_buffer, "commit")
165167
self.dev.mtl_buffers_in_flight.append(command_buffer)
166168
if wait:

tinygrad/shape/shapetracker.py

+1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def var_vals(self) -> dict[Variable, int]: return merge_dicts([dict([v.unbind()]
109109

110110
def unbind(self) -> tuple[ShapeTracker, dict[Variable, int]]:
111111
unbound_views, var_vals = zip(*[v.unbind() for v in self.views])
112+
if all(len(x) == 0 for x in var_vals): return self, {}
112113
return ShapeTracker(tuple(unbound_views)), merge_dicts(var_vals)
113114

114115
def real_strides(self, ignore_valid=False) -> tuple[Optional[sint], ...]: return views_to_real_strides(self.views, ignore_valid)

0 commit comments

Comments
 (0)