2
2
from hexdump import hexdump
3
3
from tinygrad .device import Device
4
4
from tinygrad import Tensor
5
+ from tinygrad .dtype import _from_torch_dtype
5
6
from tinygrad .helpers import to_mv , DEBUG , getenv , colored , time_to_str
6
7
7
8
import extra .torch_hook .hook_cuda as hook_cuda
14
15
TINY_MIRROR = getenv ("TINY_MIRROR" , 1 ) # should mirror aten ops to tiny backend
15
16
RUN_ONLY = getenv ("RUN_ONLY" , - 1 ) # run only a specific aten call
16
17
REALIZE = getenv ("REALIZE" , 1 ) # realize and wait each aten call
18
+ WRAP_TINY = getenv ("WRAP_TINY" , 1 ) # reuse cuda tensors
17
19
FULL_KERN_NAME = getenv ("FULL_KERN_NAME" , 0 ) # print full kernel name
18
20
19
21
print ("importing torch..." )
@@ -39,15 +41,24 @@ def __torch_dispatch__(self, func, types, args, kwargs=None):
39
41
def can_print_arg (arg ):
40
42
return args is None or isinstance (arg , str ) or isinstance (arg , int ) or isinstance (arg , float ) or isinstance (arg , bool )
41
43
44
+ def create_tiny_mapping (arg ):
45
+ if WRAP_TINY :
46
+ tt = Tensor .from_blob (arg .data_ptr (), arg .shape , dtype = _from_torch_dtype (arg .dtype ))
47
+ cuda_to_tiny_mappings [arg ] = tiny_torch .wrap (tt )
48
+
42
49
for i ,arg in enumerate (args ):
43
50
if torch .is_tensor (arg ):
44
- if arg .device .type == "cuda" : should_call_tiny = True
51
+ if arg .device .type == "cuda" :
52
+ should_call_tiny = True
53
+ if WRAP_TINY : create_tiny_mapping (arg )
45
54
txt_args .append (f"tensor({ arg .shape } { arg .device } { arg .dtype } )" )
46
55
elif can_print_arg (arg ): txt_args .append (f'{ arg } ' )
47
56
else : txt_args .append (f"{ type (arg )} " )
48
57
for k ,v in (kwargs or {}).items ():
49
58
if torch .is_tensor (v ):
50
- if arg .device .type == "cuda" : should_call_tiny = True
59
+ if arg .device .type == "cuda" :
60
+ should_call_tiny = True
61
+ if WRAP_TINY : create_tiny_mapping (arg )
51
62
txt_args .append (f"{ k } :tensor({ v .shape } { v .device } { v .dtype } )" )
52
63
elif can_print_arg (arg ): txt_args .append (f'{ k } :{ arg } "' )
53
64
else : txt_args .append (f"{ type (arg )} " )
@@ -68,7 +79,7 @@ def print_events(evs, name, out_addr):
68
79
for param in ev .params :
69
80
if isinstance (param , hook_cuda .HookTensorParamEvent ):
70
81
is_out = param .cuda_address == out_addr
71
- txt_params += [f"{ 'out ' if is_out else 'in' } tensor { param .enum } ({ param .cuda_address :#x } , off= { param . offset :#x} )" ]
82
+ txt_params += [f"{ 'result ' if is_out else '' } Tensor { param .enum } ({ param .cuda_address :#x} )" ]
72
83
73
84
just_kern_name = ev .name
74
85
if not FULL_KERN_NAME :
@@ -98,11 +109,15 @@ def print_events(evs, name, out_addr):
98
109
99
110
# TODO: this is a hack, any way to do this better?
100
111
if REALIZE :
101
- tiny_x .cpu ()
112
+ out_addr = 0x0
113
+ if torch .is_tensor (tiny_x ):
114
+ tt = tiny_torch .unwrap (tiny_x ).realize ()
115
+ try : out_addr = tt .lazydata .buffer ._buf .value
116
+ except Exception : pass
102
117
tiny_events = hook_cuda .collect_events (clear = True )
103
- print_events (tiny_events , colored ("tiny" , "magenta" ), 0x0 )
118
+ print_events (tiny_events , colored ("tiny" , "magenta" ), out_addr )
104
119
105
- cuda_to_tiny_mappings [orig_x ] = tiny_x
120
+ if not WRAP_TINY : cuda_to_tiny_mappings [orig_x ] = tiny_x
106
121
107
122
hook_cuda .pop_ignore_dispatch ()
108
123
return orig_x
0 commit comments