Skip to content

Commit 43e6091

Browse files
authored
init torch hooking (tinygrad#9284)
* smth * mv * prof wk * revert and move * fix * nvprof * fix and no print much
1 parent 387ea41 commit 43e6091

File tree

2 files changed

+348
-0
lines changed

2 files changed

+348
-0
lines changed

extra/torch_hook/hook_cuda.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import ctypes, struct, platform, pathlib, os, binascii, itertools
2+
from hexdump import hexdump
3+
from tinygrad.helpers import to_mv, DEBUG, getenv, colored, time_to_str
4+
from tinygrad.runtime.autogen import libc, cuda
5+
from tinygrad.device import CPUProgram, Device
6+
from tinygrad.runtime.support.elf import elf_loader
7+
from tinygrad.runtime.ops_cuda import cu_time_execution
8+
9+
print(f"hooking CUDA runtime, running with {Device.DEFAULT}")
10+
11+
# TODO: regen and make cuda 12 default?
12+
cuda.cuFuncGetParamInfo = cuda._libraries['libcuda.so'].cuFuncGetParamInfo
13+
cuda.cuFuncGetParamInfo.restype = cuda.CUresult
14+
cuda.cuFuncGetParamInfo.argtypes = [cuda.CUfunction, cuda.size_t, ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_uint64)]
15+
16+
ignore_dispatch = [False] # default valus is False
17+
def push_ignore_dispatch(val):
18+
global ignore_dispatch
19+
ignore_dispatch.append(val)
20+
21+
def pop_ignore_dispatch():
22+
global ignore_dispatch
23+
ignore_dispatch.pop()
24+
25+
hooked = {}
26+
def _hook(fxn_address_value, tramp):
27+
page_address = (fxn_address_value//0x1000)*0x1000
28+
ret = libc.mprotect(page_address, 0x2000, 7)
29+
assert ret == 0
30+
libc.memcpy(fxn_address_value, tramp, len(tramp))
31+
ret = libc.mprotect(page_address, 0x2000, 5)
32+
assert ret == 0
33+
CPUProgram.rt_lib["__clear_cache"](fxn_address_value, fxn_address_value + len(tramp))
34+
35+
def install_hook(c_function, python_function):
36+
python_function_addr = ctypes.cast(ctypes.byref(python_function), ctypes.POINTER(ctypes.c_ulong)).contents.value
37+
# AARCH64 trampoline to ioctl
38+
if (processor:=platform.processor()) == "aarch64":
39+
# 0x0000000000000000: 70 00 00 10 adr x16, #0xc
40+
# 0x0000000000000004: 10 02 40 F9 ldr x16, [x16]
41+
# 0x0000000000000008: 00 02 1F D6 br x16
42+
tramp = b"\x70\x00\x00\x10\x10\x02\x40\xf9\x00\x02\x1f\xd6"
43+
tramp += struct.pack("Q", python_function_addr)
44+
elif processor == "x86_64":
45+
# 0x0000000000000000: 49 BB aa aa aa aa aa aa aa aa movabs r11, <address>
46+
# 0x000000000000000a: 41 FF E3 jmp r11
47+
tramp = b"\x49\xBB" + struct.pack("Q", python_function_addr) + b"\x41\xFF\xE3"
48+
else:
49+
raise Exception(f"processor {processor} not supported")
50+
tramp = ctypes.create_string_buffer(tramp)
51+
52+
# get real function address
53+
fxn_address = ctypes.cast(ctypes.byref(c_function), ctypes.POINTER(ctypes.c_ulong))
54+
fxn_address_value = fxn_address.contents.value
55+
#print(f"** hooking function at 0x{fxn_address_value}")
56+
57+
orig_save = (ctypes.c_char*len(tramp))()
58+
libc.memcpy(orig_save, fxn_address_value, len(tramp))
59+
_hook(fxn_address_value, tramp)
60+
61+
def original(*args):
62+
_hook(fxn_address_value, orig_save)
63+
ret = c_function(*args)
64+
_hook(fxn_address_value, tramp)
65+
return ret
66+
return original
67+
68+
allocated_memory_enum = 0
69+
allocated_memory = {}
70+
function_names = {}
71+
tiny_devs = {}
72+
73+
seen_modules = set()
74+
75+
global_events = []
76+
class HookEvent: pass
77+
class HookMemAllocEvent(HookEvent):
78+
def __init__(self, cuda_address, bytesize, enum): self.cuda_address, self.bytesize, self.enum = cuda_address, bytesize, enum
79+
def __repr__(self): return f"tensor alloc: {self.enum}: {self.cuda_address:#x} - {self.bytesize:#x} bytes"
80+
class HookConstParamEvent(HookEvent):
81+
def __init__(self, value): self.value = value
82+
def __repr__(self): return f"const({self.value:#x})"
83+
class HookTensorParamEvent(HookEvent):
84+
def __init__(self, cuda_address, offset, enum): self.cuda_address, self.offset, self.enum = cuda_address, offset, enum
85+
def __repr__(self): return f"tensor{self.enum}({self.cuda_address:#x}, {self.offset=:#x})"
86+
class HookKernelCallEvent(HookEvent):
87+
def __init__(self, grid, block, tm, ptm, name, params): self.grid, self.block, self.tm, self.ptm, self.name, self.params = grid, block, tm, ptm, name, params
88+
def __repr__(self): return f"kernel call <<{self.grid}>> <<{self.block}>> {self.ptm}\n | {self.params}\n | {self.name}"
89+
90+
def collect_events(clear=False):
91+
global global_events
92+
x = global_events
93+
if clear: global_events = []
94+
return x
95+
96+
@ctypes.CFUNCTYPE(*([cuda.cuDeviceGet.restype] + cuda.cuDeviceGet.argtypes))
97+
def cuDeviceGet(device, ordinal):
98+
tiny_devs[ordinal] = Device[f"{Device.DEFAULT}:{ordinal}"]
99+
device.contents.value = ordinal
100+
return cuda.CUDA_SUCCESS
101+
102+
@ctypes.CFUNCTYPE(*([cuda.cuMemHostAlloc.restype] + cuda.cuMemHostAlloc.argtypes))
103+
def cuMemHostAlloc(pp, bytesize, flags):
104+
print(f"cuMemHostAlloc {bytesize}")
105+
return hooked["cuMemHostAlloc"](pp, bytesize, flags)
106+
107+
@ctypes.CFUNCTYPE(*([cuda.cuModuleLoadData.restype] + cuda.cuModuleLoadData.argtypes))
108+
def cuModuleLoadData(module, image):
109+
ret = hooked["cuModuleLoadData"](module, image)
110+
module_address = ctypes.addressof(module.contents.contents)
111+
seen_modules.add(module_address)
112+
return ret
113+
114+
@ctypes.CFUNCTYPE(*([cuda.cuModuleGetFunction.restype] + cuda.cuModuleGetFunction.argtypes))
115+
def cuModuleGetFunction(hfunc, hmod, name):
116+
ret = hooked["cuModuleGetFunction"](hfunc, hmod, name)
117+
python_name = ctypes.string_at(name).decode()
118+
119+
# pip install git+https://github.com/wbenny/pydemangler.git
120+
import pydemangler
121+
demangled_name = pydemangler.demangle(python_name)
122+
if demangled_name is not None: python_name = demangled_name
123+
124+
# print(f"called cuModuleGetFunction 0x{ctypes.addressof(hmod.contents):X} {python_name}")
125+
function_names[ctypes.addressof(hfunc.contents.contents)] = python_name
126+
return ret
127+
128+
@ctypes.CFUNCTYPE(*([cuda.cuMemAlloc_v2.restype] + cuda.cuMemAlloc_v2.argtypes))
129+
def cuMemAlloc_v2(dptr, bytesize):
130+
global allocated_memory_enum, text_prefix
131+
132+
ret = hooked["cuMemAlloc_v2"](dptr, bytesize)
133+
cuda_address = dptr.contents.value
134+
allocated_memory[cuda_address] = (bytesize, allocated_memory_enum)
135+
136+
global_events.append(HookMemAllocEvent(cuda_address, bytesize, allocated_memory_enum))
137+
if DEBUG >= 3: print(global_events[-1])
138+
139+
allocated_memory_enum += 1
140+
return ret
141+
142+
@ctypes.CFUNCTYPE(*([cuda.cuLaunchKernel.restype] + cuda.cuLaunchKernel.argtypes))
143+
def cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra):
144+
global ignore_dispatch
145+
146+
name = function_names[ctypes.addressof(f.contents)]
147+
if ignore_dispatch[-1]:
148+
if DEBUG >= 4: print(f"ignoring dispatch {name}")
149+
return 0
150+
151+
tm = cu_time_execution(lambda:
152+
hooked["cuLaunchKernel"](f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra), True)
153+
154+
ptm = colored(time_to_str(tm, w=9), "yellow" if tm > 0.01 else "green")
155+
156+
params = []
157+
while True:
158+
ret = cuda.cuFuncGetParamInfo(f, len(params), ctypes.byref(paramOffset:=ctypes.c_size_t()), ctypes.byref(paramSize:=ctypes.c_size_t()))
159+
if ret != 0: break
160+
params.append((paramOffset.value, paramSize.value))
161+
162+
ev_params = []
163+
if extra: params_ptr = to_mv(extra, 5*8).cast("Q")
164+
else: params_ptr = to_mv(kernelParams, len(params)*8).cast("Q")
165+
166+
for i,(off,sz) in enumerate(params):
167+
sz_to_let = {1: 'B', 2: 'H', 4: 'I', 8: 'Q'}
168+
if sz >= 8:
169+
for j in range(sz//8):
170+
if extra: value = to_mv(params_ptr[1] + off, sz).cast("Q")[0]
171+
else: value = to_mv(params_ptr[i] + j*8, 8).cast('Q')[0]
172+
173+
has_in_allocated_mem, lcoff, alnum = False, 0, -1
174+
for taddr, (tsz, talnum) in allocated_memory.items():
175+
if taddr <= value < taddr + tsz:
176+
has_in_allocated_mem = True
177+
lcoff = value - taddr
178+
alnum = talnum
179+
break
180+
181+
if has_in_allocated_mem: ev_params.append(HookTensorParamEvent(value, lcoff, alnum))
182+
else: ev_params.append(HookConstParamEvent(value))
183+
else:
184+
if extra: value = to_mv(params_ptr[1] + off, sz).cast(sz_to_let[sz])[0]
185+
else: value = to_mv(params_ptr[i], sz).cast(sz_to_let[sz])[0]
186+
ev_params.append(HookConstParamEvent(value))
187+
188+
global_events.append(HookKernelCallEvent((gridDimX, gridDimY, gridDimZ), (blockDimX, blockDimY, blockDimZ), tm, ptm, name, ev_params))
189+
if DEBUG >= 3: print(global_events[-1])
190+
191+
return 0
192+
193+
def create_hook(func_name, restype, argtypes):
194+
def hook_template(*args):
195+
# print(func_name, flush=True)
196+
return hooked[func_name](*args)
197+
return ctypes.CFUNCTYPE(restype, *argtypes)(hook_template)
198+
199+
def install_hooks():
200+
hooked['cuModuleGetFunction'] = install_hook(cuda.cuModuleGetFunction, cuModuleGetFunction)
201+
hooked['cuLaunchKernel'] = install_hook(cuda.cuLaunchKernel, cuLaunchKernel)
202+
203+
# memory stuff
204+
hooked['cuMemAlloc_v2'] = install_hook(cuda.cuMemAlloc_v2, cuMemAlloc_v2)
205+
hooked['cuMemHostAlloc'] = install_hook(cuda.cuMemHostAlloc, cuMemHostAlloc)
206+
207+
# module loading + not used module loading
208+
hooked['cuModuleLoadData'] = install_hook(cuda.cuModuleLoadData, cuModuleLoadData)
209+
210+
NVPROFILER = os.environ.get("NV_COMPUTE_PROFILER_PERFWORKS_DIR", None) # realize and wait each aten call
211+
if NVPROFILER is None: install_hooks()
212+
else:
213+
print("Detected NSIGHT Profiled, hooking not avail.")
214+
cuda._libraries['libcuda.so'] = ctypes.CDLL(NVPROFILER + "/libcuda-injection.so")

extra/torch_hook/hook_torch.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import ctypes, struct, platform, pathlib, os, binascii, itertools
2+
from hexdump import hexdump
3+
from tinygrad.device import Device
4+
from tinygrad import Tensor
5+
from tinygrad.helpers import to_mv, DEBUG, getenv, colored, time_to_str
6+
7+
import extra.torch_hook.hook_cuda as hook_cuda
8+
9+
# settings to profile gemm in the __main__ example: TINY_MIRROR=1;CUDA=1;RUN_ONLY=9
10+
# nvprof sample command (this will sample all kernels):
11+
# ncu --export ~/nvprof_data --force-overwrite --rule AchievedOccupancy --rule Compute --rule LaunchConfiguration --rule Memory --rule PMSamplingData --rule SOLBottleneck --rule TheoreticalOccupancy --rule WorkloadImbalance python3 extra/torch_hook/hook_torch.py
12+
# or just run nsight compute from the host to the machine.
13+
14+
TINY_MIRROR = getenv("TINY_MIRROR", 1) # should mirror aten ops to tiny backend
15+
RUN_ONLY = getenv("RUN_ONLY", -1) # run only a specific aten call
16+
REALIZE = getenv("REALIZE", 1) # realize and wait each aten call
17+
FULL_KERN_NAME = getenv("FULL_KERN_NAME", 0) # print full kernel name
18+
19+
print("importing torch...")
20+
import torch
21+
print("importing torch done:", torch.__version__, torch.__file__)
22+
23+
if TINY_MIRROR:
24+
print("importing tiny torch")
25+
import extra.torch_backend.backend as tiny_torch
26+
print("importing tiny torch done")
27+
28+
torch.set_default_device("cuda")
29+
30+
cuda_to_tiny_mappings = {}
31+
32+
enumerator_aten_calls = itertools.count(0)
33+
from torch.utils._python_dispatch import TorchDispatchMode
34+
class DispatchLog(TorchDispatchMode):
35+
def __torch_dispatch__(self, func, types, args, kwargs=None):
36+
txt_args = []
37+
should_call_tiny = kwargs.get('device') is not None and kwargs['device'].type == "cuda"
38+
39+
def can_print_arg(arg):
40+
return args is None or isinstance(arg, str) or isinstance(arg, int) or isinstance(arg, float) or isinstance(arg, bool)
41+
42+
for i,arg in enumerate(args):
43+
if torch.is_tensor(arg):
44+
if arg.device.type == "cuda": should_call_tiny = True
45+
txt_args.append(f"tensor({arg.shape} {arg.device} {arg.dtype})")
46+
elif can_print_arg(arg): txt_args.append(f'{arg}')
47+
else: txt_args.append(f"{type(arg)}")
48+
for k,v in (kwargs or {}).items():
49+
if torch.is_tensor(v):
50+
if arg.device.type == "cuda": should_call_tiny = True
51+
txt_args.append(f"{k}:tensor({v.shape} {v.device} {v.dtype})")
52+
elif can_print_arg(arg): txt_args.append(f'{k}:{arg}"')
53+
else: txt_args.append(f"{type(arg)}")
54+
55+
# magenta-colored kerenls mirrored to tiny backend.
56+
aten_id = next(enumerator_aten_calls)
57+
should_call_tiny = TINY_MIRROR and should_call_tiny
58+
print(colored(f"#{aten_id} {func}", "magenta" if should_call_tiny else "cyan") + "("+", ".join(txt_args)+")", flush=True)
59+
60+
# ignore dispatches if needed
61+
hook_cuda.push_ignore_dispatch(RUN_ONLY >= 0 and RUN_ONLY != aten_id)
62+
orig_x = func(*args, **(kwargs or {}))
63+
64+
def print_events(evs, name, out_addr):
65+
for ev in evs:
66+
if isinstance(ev, hook_cuda.HookKernelCallEvent):
67+
txt_params = []
68+
for param in ev.params:
69+
if isinstance(param, hook_cuda.HookTensorParamEvent):
70+
is_out = param.cuda_address == out_addr
71+
txt_params += [f"{'out' if is_out else 'in'} tensor{param.enum}({param.cuda_address:#x}, off={param.offset:#x})"]
72+
73+
just_kern_name = ev.name
74+
if not FULL_KERN_NAME:
75+
just_kern_name = ev.name.replace("(anonymous namespace)", "").replace("void ", "").split("<")[0].split("(")[0].split("::")[-1]
76+
print(f"\t {name} kernel {just_kern_name} {ev.grid} {ev.block} {ev.ptm}\n\t\t({', '.join(txt_params)})")
77+
else: print("\t", name, ev)
78+
79+
if REALIZE:
80+
torch.cuda.synchronize()
81+
cuda_events = hook_cuda.collect_events(clear=True)
82+
print_events(cuda_events, colored("cuda", "cyan"), orig_x.data_ptr() if torch.is_tensor(orig_x) else 0x0)
83+
84+
if should_call_tiny:
85+
# replace with tiny tensor
86+
tiny_args, tiny_kwargs = [], {}
87+
for arg in args:
88+
if torch.is_tensor(arg): tiny_args.append(cuda_to_tiny_mappings[arg])
89+
else: tiny_args.append(arg)
90+
91+
for k,v in (kwargs or {}).items():
92+
if torch.is_tensor(v): tiny_kwargs[k] = cuda_to_tiny_mappings[v]
93+
else: tiny_kwargs[k] = v
94+
if 'device' in tiny_kwargs and kwargs['device'].type == "cuda":
95+
tiny_kwargs['device'] = torch.device("tiny")
96+
97+
tiny_x = func(*tiny_args, **tiny_kwargs)
98+
99+
# TODO: this is a hack, any way to do this better?
100+
if REALIZE:
101+
tiny_x.cpu()
102+
tiny_events = hook_cuda.collect_events(clear=True)
103+
print_events(tiny_events, colored("tiny", "magenta"), 0x0)
104+
105+
cuda_to_tiny_mappings[orig_x] = tiny_x
106+
107+
hook_cuda.pop_ignore_dispatch()
108+
return orig_x
109+
DispatchLog().__enter__()
110+
111+
if __name__ == "__main__":
112+
if getenv("RESNET"):
113+
import torchvision.models as models
114+
model = models.resnet18(pretrained=True)
115+
model = model.cuda()
116+
model.eval()
117+
118+
if getenv("COMPILE"): model = torch.compile(model)
119+
120+
X = torch.rand(getenv("BS", 1), 3, 288, 288, device='cuda')
121+
model(X)
122+
123+
print("\n\n\n****** second run ******\n")
124+
model(X)
125+
else:
126+
a = torch.randn(64, 64)
127+
b = torch.randn(64, 64)
128+
a += 1
129+
b += 2
130+
a = a.exp2()
131+
b = b.exp2()
132+
a += b
133+
c = a @ b
134+
print("tensor math done", c.cpu().numpy())

0 commit comments

Comments
 (0)