Skip to content

Commit 84dc331

Browse files
authored
Refactor async (tinygrad#9126)
1 parent 6a9e559 commit 84dc331

File tree

1 file changed

+26
-56
lines changed

1 file changed

+26
-56
lines changed

tinygrad/runtime/ops_webgpu.py

Lines changed: 26 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,25 @@ def from_wgpu_str(string_view): return ctypes.string_at(string_view.data, string
1919
def to_wgpu_str(_str):
2020
return webgpu.WGPUStringView(data=ctypes.cast(ctypes.pointer(to_c_string(_str)), ctypes.POINTER(ctypes.c_char)), length=len(_str))
2121

22-
def wgpu_wait(future):
22+
def _wait(future):
2323
assert webgpu.wgpuInstanceWaitAny(instance, 1, webgpu.WGPUFutureWaitInfo(future=future), 2**64-1) == webgpu.WGPUWaitStatus_Success, "Future failed"
2424

25-
def create_cb_info(cb_info_type, cb_type, cb): return cb_info_type(nextInChain=None, mode=webgpu.WGPUCallbackMode_WaitAnyOnly, callback=cb_type(cb))
26-
2725
def write_buffer(device, buf, offset, src):
2826
src = bytearray(src)
2927
webgpu.wgpuQueueWriteBuffer(webgpu.wgpuDeviceGetQueue(device), buf, offset, (ctypes.c_uint8 * len(src)).from_buffer(src), len(src))
3028

31-
def map_buffer(buf, size):
29+
def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx, msg_idx, *params):
3230
result: List[Any] = []
3331

34-
def cb(status, msg, u1, u2): result[:] = status, from_wgpu_str(msg)
32+
def cb(*params):
33+
result[:] = params
34+
if msg_idx: result[msg_idx] = from_wgpu_str(result[msg_idx])
3535

36-
cb_info = create_cb_info(webgpu.WGPUBufferMapCallbackInfo2, webgpu.WGPUBufferMapCallback2, cb)
37-
wgpu_wait(webgpu.wgpuBufferMapAsync2(buf, webgpu.WGPUMapMode_Read, 0, size, cb_info))
36+
cb_info = cb_info_type(nextInChain=None, mode=webgpu.WGPUCallbackMode_WaitAnyOnly, callback=cb_type(cb))
37+
_wait(async_fun(*params, cb_info))
3838

39-
if result[0] != webgpu.WGPUBufferMapAsyncStatus_Success:
40-
raise RuntimeError(f"Failed to map buffer: [{webgpu.WGPUBufferMapAsyncStatus__enumvalues[result[0]]}] {result[1]}")
39+
if result[0] != 1: raise RuntimeError(f"[{status_enum[result[0]] if status_enum else 'ERROR'}]{result[msg_idx] if msg_idx else ''}")
40+
return result[res_idx] if res_idx else None
4141

4242
def copy_buffer_to_buffer(dev, src, src_offset, dst, dst_offset, size):
4343
encoder = webgpu.wgpuDeviceCreateCommandEncoder(dev, webgpu.WGPUCommandEncoderDescriptor())
@@ -52,21 +52,16 @@ def read_buffer(dev, buf):
5252
tmp_buffer = webgpu.wgpuDeviceCreateBuffer(dev, webgpu.WGPUBufferDescriptor(size=size,
5353
usage=webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_MapRead, mappedAtCreation=False))
5454
copy_buffer_to_buffer(dev, buf, 0, tmp_buffer, 0, size)
55-
map_buffer(tmp_buffer, size)
55+
_run(webgpu.wgpuBufferMapAsync2, webgpu.WGPUBufferMapCallbackInfo2, webgpu.WGPUBufferMapCallback2, webgpu.WGPUBufferMapAsyncStatus__enumvalues,
56+
None, 0, tmp_buffer, webgpu.WGPUMapMode_Read, 0, size)
5657
void_ptr = ctypes.cast(webgpu.wgpuBufferGetConstMappedRange(tmp_buffer, 0, size), ctypes.c_void_p)
5758
buf_copy = bytearray((ctypes.c_uint8 * size).from_address(void_ptr.value))
5859
webgpu.wgpuBufferUnmap(tmp_buffer)
5960
webgpu.wgpuBufferDestroy(tmp_buffer)
6061
return memoryview(buf_copy).cast("B")
6162

6263
def pop_error(device):
63-
result: List[Any] = []
64-
65-
def cb(status, err_type, msg, i2): result[:] = [from_wgpu_str(msg)]
66-
67-
cb_info = create_cb_info(webgpu.WGPUPopErrorScopeCallbackInfo, webgpu.WGPUPopErrorScopeCallback, cb)
68-
wgpu_wait(webgpu.wgpuDevicePopErrorScopeF(device, cb_info))
69-
return result[0] if len(result) > 0 else ""
64+
return _run(webgpu.wgpuDevicePopErrorScopeF, webgpu.WGPUPopErrorScopeCallbackInfo, webgpu.WGPUPopErrorScopeCallback, None, 2, 2, device)
7065

7166
def create_uniform(wgpu_device, val):
7267
buf = webgpu.wgpuDeviceCreateBuffer(wgpu_device,
@@ -141,16 +136,8 @@ def __call__(self, *bufs, global_size=(1,1,1), local_size=(1,1,1), vals=(), wait
141136
# Creating compute pipeline
142137
compute_desc = webgpu.WGPUComputePipelineDescriptor(layout=pipeline_layout,
143138
compute=webgpu.WGPUComputeState(module=self.prg, entryPoint=to_wgpu_str(self.name)))
144-
pipeline_result: List[Any] = []
145-
146-
def cb(status, compute_pipeline_impl, msg, u1, u2): pipeline_result[:] = status, compute_pipeline_impl, from_wgpu_str(msg)
147-
148-
cb_info = create_cb_info(webgpu.WGPUCreateComputePipelineAsyncCallbackInfo2, webgpu.WGPUCreateComputePipelineAsyncCallback2, cb)
149-
webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
150-
wgpu_wait(webgpu.wgpuDeviceCreateComputePipelineAsync2(self.dev, compute_desc, cb_info))
151-
152-
if pipeline_result[0] != webgpu.WGPUCreatePipelineAsyncStatus_Success:
153-
raise RuntimeError(f"{webgpu.WGPUCreatePipelineAsyncStatus__enumvalues[pipeline_result[0]]}: {pipeline_result[2]}, {pop_error(self.dev)}")
139+
pipeline_result = _run(webgpu.wgpuDeviceCreateComputePipelineAsync2, webgpu.WGPUCreateComputePipelineAsyncCallbackInfo2,
140+
webgpu.WGPUCreateComputePipelineAsyncCallback2, webgpu.WGPUCreatePipelineAsyncStatus__enumvalues, 1, None, self.dev, compute_desc)
154141

155142
command_encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.dev, webgpu.WGPUCommandEncoderDescriptor())
156143
comp_pass_desc = webgpu.WGPUComputePassDescriptor(nextInChain=None)
@@ -164,7 +151,7 @@ def cb(status, compute_pipeline_impl, msg, u1, u2): pipeline_result[:] = status,
164151

165152
# Begin compute pass
166153
compute_pass = webgpu.wgpuCommandEncoderBeginComputePass(command_encoder, comp_pass_desc)
167-
webgpu.wgpuComputePassEncoderSetPipeline(compute_pass, pipeline_result[1])
154+
webgpu.wgpuComputePassEncoderSetPipeline(compute_pass, pipeline_result)
168155
webgpu.wgpuComputePassEncoderSetBindGroup(compute_pass, 0, bind_group, 0, None)
169156
webgpu.wgpuComputePassEncoderDispatchWorkgroups(compute_pass, *global_size)
170157
webgpu.wgpuComputePassEncoderEnd(compute_pass)
@@ -204,47 +191,30 @@ def _free(self, opaque, options):
204191
class WebGpuDevice(Compiled):
205192
def __init__(self, device:str):
206193
# Requesting an adapter
207-
adapter_result: List[Any] = []
208-
209-
def adapter_cb(status, adapter, msg, _): adapter_result[:] = status, adapter, from_wgpu_str(msg)
210-
211-
cb_info = create_cb_info(webgpu.WGPURequestAdapterCallbackInfo, webgpu.WGPURequestAdapterCallback, adapter_cb)
212-
wgpu_wait(webgpu.wgpuInstanceRequestAdapterF(instance,
213-
webgpu.WGPURequestAdapterOptions(powerPreference=webgpu.WGPUPowerPreference_HighPerformance), cb_info))
214-
215-
if adapter_result[0] != webgpu.WGPURequestAdapterStatus_Success:
216-
raise RuntimeError(f"Error requesting adapter: [{webgpu.WGPURequestAdapterStatus__enumvalues[adapter_result[0]]}] {adapter_result[2]}")
194+
adapter_res = _run(webgpu.wgpuInstanceRequestAdapterF, webgpu.WGPURequestAdapterCallbackInfo, webgpu.WGPURequestAdapterCallback,
195+
webgpu.WGPURequestAdapterStatus__enumvalues, 1, 2, instance,
196+
webgpu.WGPURequestAdapterOptions(powerPreference=webgpu.WGPUPowerPreference_HighPerformance))
217197

218198
# Get supported features
219199
supported_features = webgpu.WGPUSupportedFeatures()
220-
webgpu.wgpuAdapterGetFeatures(adapter_result[1], supported_features)
200+
webgpu.wgpuAdapterGetFeatures(adapter_res, supported_features)
221201
supported = [supported_features.features[i] for i in range(supported_features.featureCount)]
222202
features = [feat for feat in [webgpu.WGPUFeatureName_TimestampQuery, webgpu.WGPUFeatureName_ShaderF16] if feat in supported]
223203
dev_desc = webgpu.WGPUDeviceDescriptor(requiredFeatureCount=len(features),requiredFeatures=(webgpu.WGPUFeatureName * len(features))(*features))
224204

225205
# Limits
226206
supported_limits = webgpu.WGPUSupportedLimits()
227-
webgpu.wgpuAdapterGetLimits(adapter_result[1], ctypes.cast(ctypes.pointer(supported_limits),ctypes.POINTER(webgpu.struct_WGPUSupportedLimits)))
207+
webgpu.wgpuAdapterGetLimits(adapter_res, ctypes.cast(ctypes.pointer(supported_limits),ctypes.POINTER(webgpu.struct_WGPUSupportedLimits)))
228208
limits = webgpu.WGPURequiredLimits(limits=supported_limits.limits)
229209
dev_desc.requiredLimits = ctypes.cast(ctypes.pointer(limits),ctypes.POINTER(webgpu.struct_WGPURequiredLimits))
230210

231211
# Requesting a device
232-
device_result: List[Any] = []
233-
234-
def dev_cb(status, device_impl, msg, _): device_result[:] = status, device_impl, from_wgpu_str(msg)
235-
236-
cb_info = create_cb_info(webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback, dev_cb)
237-
wgpu_wait(webgpu.wgpuAdapterRequestDeviceF(adapter_result[1], dev_desc, cb_info))
238-
239-
if device_result[0] != webgpu.WGPURequestDeviceStatus_Success:
240-
raise RuntimeError(f"Failed to request device: [{webgpu.WGPURequestDeviceStatus__enumvalues[device_result[0]]}] {device_result[2]}")
212+
device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback,
213+
webgpu.WGPURequestDeviceStatus__enumvalues, 1, 2, adapter_res, dev_desc)
241214

242-
super().__init__(device, WebGpuAllocator(device_result[1]), WGSLRenderer(), Compiler(),
243-
functools.partial(WebGPUProgram, (device_result[1], webgpu.WGPUFeatureName_TimestampQuery in supported)))
215+
super().__init__(device, WebGpuAllocator(device_res), WGSLRenderer(), Compiler(),
216+
functools.partial(WebGPUProgram, (device_res, webgpu.WGPUFeatureName_TimestampQuery in supported)))
244217

245218
def synchronize(self):
246-
result: List[Any] = []
247-
def cb(status, u1, u2): result[:] = [status]
248-
cb_info = create_cb_info(webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2, cb)
249-
wgpu_wait(webgpu.wgpuQueueOnSubmittedWorkDone2(webgpu.wgpuDeviceGetQueue(self.runtime.args[0][0]), cb_info))
250-
if result[0] != webgpu.WGPUQueueWorkDoneStatus_Success: raise RuntimeError(webgpu.WGPUQueueWorkDoneStatus__enumvalues[result[0]])
219+
_run(webgpu.wgpuQueueOnSubmittedWorkDone2, webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2,
220+
webgpu.WGPUQueueWorkDoneStatus__enumvalues, None, None, webgpu.wgpuDeviceGetQueue(self.runtime.args[0][0]))

0 commit comments

Comments
 (0)