@@ -42,6 +42,7 @@ def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id)
42
42
sender .restype = restype
43
43
return sender (ptr , sel (selector ), * args )
44
44
45
+ @functools .lru_cache (None )
45
46
def to_ns_str (s : str ): return msg (libobjc .objc_getClass (b"NSString" ), "stringWithUTF8String:" , s .encode (), restype = objc_instance )
46
47
def from_ns_str (s ): return bytes (msg (s , "UTF8String" , restype = ctypes .c_char_p )).decode ()
47
48
@@ -146,21 +147,22 @@ def __init__(self, dev:MetalDevice, name:str, lib:bytes):
146
147
self .pipeline_state = msg (self .dev .sysdevice , "newComputePipelineStateWithDescriptor:options:reflection:error:" ,
147
148
descriptor , MTLPipelineOption .MTLPipelineOptionNone , None , ctypes .byref (error_pipeline_creation := objc_instance ()), restype = objc_instance )
148
149
error_check (error_pipeline_creation )
150
+ # cache these msg calls
151
+ self .max_total_threads : int = cast (int , msg (self .pipeline_state , "maxTotalThreadsPerThreadgroup" , restype = ctypes .c_ulong ))
149
152
150
153
def __call__ (self , * bufs , global_size :tuple [int ,int ,int ]= (1 ,1 ,1 ), local_size :tuple [int ,int ,int ]= (1 ,1 ,1 ), vals :tuple [int , ...]= (), wait = False ):
151
- max_total_threads = msg (self .pipeline_state , "maxTotalThreadsPerThreadgroup" , restype = ctypes .c_ulong )
152
- if prod (local_size ) > cast (int , max_total_threads ):
154
+ if prod (local_size ) > self .max_total_threads :
153
155
exec_width = msg (self .pipeline_state , "threadExecutionWidth" , restype = ctypes .c_ulong )
154
156
memory_length = msg (self .pipeline_state , "staticThreadgroupMemoryLength" , restype = ctypes .c_ulong )
155
- raise RuntimeError (f"local size { local_size } bigger than { max_total_threads } with exec width { exec_width } memory length { memory_length } " )
157
+ raise RuntimeError (f"local size { local_size } bigger than { self . max_total_threads } with exec width { exec_width } memory length { memory_length } " )
156
158
command_buffer = msg (self .dev .mtl_queue , "commandBuffer" , restype = objc_instance )
157
159
encoder = msg (command_buffer , "computeCommandEncoder" , restype = objc_instance )
158
160
msg (encoder , "setComputePipelineState:" , self .pipeline_state )
159
161
for i ,a in enumerate (bufs ): msg (encoder , "setBuffer:offset:atIndex:" , a .buf , a .offset , i )
160
162
for i ,a in enumerate (vals , start = len (bufs )): msg (encoder , "setBytes:length:atIndex:" , bytes (ctypes .c_int (a )), 4 , i )
161
163
msg (encoder , "dispatchThreadgroups:threadsPerThreadgroup:" , to_struct (* global_size ), to_struct (* local_size ))
162
164
msg (encoder , "endEncoding" )
163
- msg (command_buffer , "setLabel:" , to_ns_str (self .name ))
165
+ msg (command_buffer , "setLabel:" , to_ns_str (self .name )) # TODO: is this always needed?
164
166
msg (command_buffer , "commit" )
165
167
self .dev .mtl_buffers_in_flight .append (command_buffer )
166
168
if wait :
0 commit comments