Skip to content

Commit 052722a

Browse files
authored
torch hook: address comments (tinygrad#9295)
* torch hook: address comments * failed test
1 parent d657d5f commit 052722a

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

extra/torch_backend/test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,12 @@ def test_str(self):
8787
a = torch.ones(4, device=device)
8888
print(str(a))
8989

90+
@unittest.skip("failed")
91+
def test_floor_div(self):
92+
a = torch.tensor([10., 7., 5.], device=device)
93+
b = torch.tensor([3., 2., 2.], device=device)
94+
result = a // b
95+
np.testing.assert_equal(result.cpu().numpy(), [3., 3., 2.])
96+
9097
if __name__ == "__main__":
9198
unittest.main()

extra/torch_hook/hook_torch.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from hexdump import hexdump
33
from tinygrad.device import Device
44
from tinygrad import Tensor
5+
from tinygrad.dtype import _from_torch_dtype
56
from tinygrad.helpers import to_mv, DEBUG, getenv, colored, time_to_str
67

78
import extra.torch_hook.hook_cuda as hook_cuda
@@ -14,6 +15,7 @@
1415
TINY_MIRROR = getenv("TINY_MIRROR", 1) # should mirror aten ops to tiny backend
1516
RUN_ONLY = getenv("RUN_ONLY", -1) # run only a specific aten call
1617
REALIZE = getenv("REALIZE", 1) # realize and wait each aten call
18+
WRAP_TINY = getenv("WRAP_TINY", 1) # reuse cuda tensors
1719
FULL_KERN_NAME = getenv("FULL_KERN_NAME", 0) # print full kernel name
1820

1921
print("importing torch...")
@@ -39,15 +41,24 @@ def __torch_dispatch__(self, func, types, args, kwargs=None):
3941
def can_print_arg(arg):
4042
return args is None or isinstance(arg, str) or isinstance(arg, int) or isinstance(arg, float) or isinstance(arg, bool)
4143

44+
def create_tiny_mapping(arg):
45+
if WRAP_TINY:
46+
tt = Tensor.from_blob(arg.data_ptr(), arg.shape, dtype=_from_torch_dtype(arg.dtype))
47+
cuda_to_tiny_mappings[arg] = tiny_torch.wrap(tt)
48+
4249
for i,arg in enumerate(args):
4350
if torch.is_tensor(arg):
44-
if arg.device.type == "cuda": should_call_tiny = True
51+
if arg.device.type == "cuda":
52+
should_call_tiny = True
53+
if WRAP_TINY: create_tiny_mapping(arg)
4554
txt_args.append(f"tensor({arg.shape} {arg.device} {arg.dtype})")
4655
elif can_print_arg(arg): txt_args.append(f'{arg}')
4756
else: txt_args.append(f"{type(arg)}")
4857
for k,v in (kwargs or {}).items():
4958
if torch.is_tensor(v):
50-
if arg.device.type == "cuda": should_call_tiny = True
59+
if arg.device.type == "cuda":
60+
should_call_tiny = True
61+
if WRAP_TINY: create_tiny_mapping(arg)
5162
txt_args.append(f"{k}:tensor({v.shape} {v.device} {v.dtype})")
5263
elif can_print_arg(arg): txt_args.append(f'{k}:{arg}"')
5364
else: txt_args.append(f"{type(arg)}")
@@ -68,7 +79,7 @@ def print_events(evs, name, out_addr):
6879
for param in ev.params:
6980
if isinstance(param, hook_cuda.HookTensorParamEvent):
7081
is_out = param.cuda_address == out_addr
71-
txt_params += [f"{'out' if is_out else 'in'} tensor{param.enum}({param.cuda_address:#x}, off={param.offset:#x})"]
82+
txt_params += [f"{'result ' if is_out else ''}Tensor{param.enum}({param.cuda_address:#x})"]
7283

7384
just_kern_name = ev.name
7485
if not FULL_KERN_NAME:
@@ -98,11 +109,15 @@ def print_events(evs, name, out_addr):
98109

99110
# TODO: this is a hack, any way to do this better?
100111
if REALIZE:
101-
tiny_x.cpu()
112+
out_addr = 0x0
113+
if torch.is_tensor(tiny_x):
114+
tt = tiny_torch.unwrap(tiny_x).realize()
115+
try: out_addr = tt.lazydata.buffer._buf.value
116+
except Exception: pass
102117
tiny_events = hook_cuda.collect_events(clear=True)
103-
print_events(tiny_events, colored("tiny", "magenta"), 0x0)
118+
print_events(tiny_events, colored("tiny", "magenta"), out_addr)
104119

105-
cuda_to_tiny_mappings[orig_x] = tiny_x
120+
if not WRAP_TINY: cuda_to_tiny_mappings[orig_x] = tiny_x
106121

107122
hook_cuda.pop_ignore_dispatch()
108123
return orig_x

0 commit comments

Comments
 (0)