1
1
from __future__ import annotations
2
- from typing import cast , Type , TypeVar , Generic , Any
2
+ from typing import cast , Type , TypeVar , Generic , Any , ClassVar
3
3
import contextlib , decimal , statistics , time , ctypes , array , os , fcntl
4
4
from tinygrad .helpers import PROFILE , from_mv , getenv , to_mv , round_up
5
5
from tinygrad .renderer import Renderer
@@ -203,15 +203,20 @@ def submit(self, dev:DeviceType, var_vals:dict[Variable, int]|None=None):
203
203
def _submit (self , dev :DeviceType ): raise NotImplementedError ("need _submit" )
204
204
205
205
class HCQSignal (Generic [DeviceType ]):
206
- def __init__ (self , base_addr :sint = 0 , value :int = 0 , timeline_for_device :DeviceType | None = None , timestamp_divider = 1 , value_off = 0 , timestamp_off = 8 ):
207
- self .base_addr , self .value_addr , self .timestamp_addr = base_addr , base_addr + value_off , base_addr + timestamp_off
206
+ def __init__ (self , base_addr :sint | None = None , value :int = 0 , dev_t :Type [DeviceType ]| None = None , timeline_for_device :DeviceType | None = None ,
207
+ timestamp_divider = 1 , value_off = 0 , timestamp_off = 8 ):
208
+ self .base_addr = dev_t ._alloc_signal_addr () if dev_t is not None and base_addr is None else base_addr
209
+ self .value_addr , self .timestamp_addr , self .dev_t = self .base_addr + value_off , self .base_addr + timestamp_off , dev_t
208
210
self .timestamp_divider :decimal .Decimal = decimal .Decimal (timestamp_divider )
209
211
self .timeline_for_device :DeviceType | None = timeline_for_device
210
212
211
- if isinstance (base_addr , int ):
213
+ if isinstance (self . base_addr , int ):
212
214
self .value_mv , self .timestamp_mv = to_mv (self .value_addr , 8 ).cast ('Q' ), to_mv (self .timestamp_addr , 8 ).cast ('Q' )
213
215
self .value_mv [0 ] = value
214
216
217
+ def __del__ (self ):
218
+ if isinstance (self .base_addr , int ) and self .dev_t is not None : self .dev_t .signal_pool .append (self .base_addr )
219
+
215
220
@property
216
221
def value (self ) -> int : return self .value_mv [0 ]
217
222
@@ -332,23 +337,29 @@ class HCQCompiled(Compiled, Generic[SignalType]):
332
337
"""
333
338
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
334
339
"""
335
- devices : list [HCQCompiled ] = []
340
+ devices : ClassVar [list [HCQCompiled ]] = []
341
+ signal_pages : ClassVar [list [Any ]] = []
342
+ signal_pool : ClassVar [list [int ]] = []
336
343
337
344
def __init__ (self , device :str , allocator :HCQAllocatorBase , renderer :Renderer , compiler :Compiler , runtime , signal_t :Type [SignalType ],
338
345
comp_queue_t :Type [HWQueue ], copy_queue_t :Type [HWQueue ]| None ):
339
346
self .device_id :int = int (device .split (":" )[1 ]) if ":" in device else 0
347
+
348
+ from tinygrad .runtime .graph .hcq import HCQGraph
349
+ super ().__init__ (device , allocator , renderer , compiler , runtime , HCQGraph )
350
+
351
+ # Map signals if any
352
+ for sig_page in self .signal_pages : cast (HCQAllocator , self .allocator ).map (sig_page )
353
+ self .devices .append (self )
354
+
340
355
self .signal_t , self .hw_compute_queue_t , self .hw_copy_queue_t = signal_t , comp_queue_t , copy_queue_t
341
356
self .timeline_value :int = 1
342
357
self .timeline_signal :SignalType = self .signal_t (value = 0 , timeline_for_device = self )
343
358
self ._shadow_timeline_signal :SignalType = self .signal_t (value = 0 , timeline_for_device = self )
344
359
self .sig_prof_records :list [tuple [HCQSignal , HCQSignal , str , bool ]] = []
345
360
346
- from tinygrad .runtime .graph .hcq import HCQGraph
347
- super ().__init__ (device , allocator , renderer , compiler , runtime , HCQGraph )
348
-
349
361
self .kernargs_page :HCQBuffer = self .allocator .alloc (16 << 20 , BufferSpec (cpu_access = True ))
350
362
self .kernargs_allocator :BumpAllocator = BumpAllocator (self .kernargs_page .size , base = cast (int , self .kernargs_page .va_addr ), wrap = True )
351
- self .devices .append (self )
352
363
353
364
def synchronize (self ):
354
365
try : self .timeline_signal .wait (self .timeline_value - 1 )
@@ -361,6 +372,14 @@ def synchronize(self):
361
372
Compiled .profile_events += [ProfileRangeEvent (self .device , name , st .timestamp , en .timestamp , cp ) for st ,en ,name ,cp in self .sig_prof_records ]
362
373
self .sig_prof_records = []
363
374
375
+ @classmethod
376
+ def _alloc_signal_addr (cls ) -> int :
377
+ if not cls .signal_pool :
378
+ cls .signal_pages .append (alc := cls .devices [0 ].allocator .alloc (0x1000 , BufferSpec (host = True , uncached = True , cpu_access = True )))
379
+ cls .signal_pool += [alc .va_addr + off for off in range (0 , alc .size , 16 )]
380
+ for dev in cls .devices : cast (HCQAllocator , dev .allocator ).map (alc )
381
+ return cls .signal_pool .pop ()
382
+
364
383
def _at_profile_finalize (self ):
365
384
def _sync (d :HCQCompiled , q_t :Type [HWQueue ]):
366
385
q_t ().timestamp (d .timeline_signal ).signal (d .timeline_signal , d .timeline_value ).submit (d )
0 commit comments