Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] master from tinygrad:master #192

Merged
merged 4 commits into from
Mar 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/actions/setup-tinygrad/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ runs:
if: inputs.webgpu == 'true' && runner.os == 'macOS'
shell: bash
run: |
sudo mkdir -p /usr/local/lib
sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.dylib -o /usr/local/lib/libwebgpu_dawn.dylib
brew tap wpmed92/dawn
brew install dawn

# **** LLVM ****

Expand Down
7 changes: 4 additions & 3 deletions autogen_stubs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,11 @@ generate_sqtt() {
}

generate_webgpu() {
clang2py -l /usr/local/lib/libwebgpu_dawn.so extra/webgpu/webgpu.h -o $BASE/webgpu.py
clang2py extra/webgpu/webgpu.h -o $BASE/webgpu.py
fixup $BASE/webgpu.py
sed -i 's/import ctypes/import ctypes, ctypes.util/g' $BASE/webgpu.py
sed -i "s|ctypes.CDLL('/usr/local/lib/libwebgpu_dawn.so')|ctypes.CDLL(ctypes.util.find_library('webgpu_dawn'))|g" $BASE/webgpu.py
sed -i "s/FIXME_STUB/webgpu/g" "$BASE/webgpu.py"
sed -i "s/FunctionFactoryStub()/ctypes.CDLL(webgpu_support.WEBGPU_PATH)/g" "$BASE/webgpu.py"
sed -i "s/import ctypes/import ctypes, tinygrad.runtime.support.webgpu as webgpu_support/g" "$BASE/webgpu.py"
python3 -c "import tinygrad.runtime.autogen.webgpu"
}

Expand Down
3 changes: 1 addition & 2 deletions docs/developer/hcq.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,8 @@ HCQ-compatible devices use a global timeline signal for synchronizing all operat
```python
HWQueue().wait(your_device.timeline_signal, your_device.timeline_value - 1) \
.exec(...)
.signal(your_device.timeline_signal, your_device.timeline_value) \
.signal(your_device.timeline_signal, your_device.next_timeline()) \
.submit(your_device)
your_device.timeline_value += 1

# Optionally wait for execution
your_device.timeline_signal.wait(your_device.timeline_value - 1)
Expand Down
4 changes: 3 additions & 1 deletion test/external/external_test_am.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import unittest
from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraverseContext
from tinygrad.runtime.support.am.ip import AM_GMC
from tinygrad.helpers import mv_address

class FakeGMC:
class FakeGMC(AM_GMC):
def __init__(self):
self.vm_base = 0x0
self.address_space_mask = (1 << 44) - 1
def init_hw(self): pass
def flush_tlb(self, *args, **kwargs): pass

class FakePCIDev:
Expand Down
537 changes: 273 additions & 264 deletions tinygrad/runtime/autogen/webgpu.py

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions tinygrad/runtime/graph/hcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ def __call__(self, input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]
self.comp_queues[dev].submit(dev, hcq_var_vals)
if (copy_queue:=self.copy_queues.get(dev, None)) is not None: copy_queue.submit(dev, hcq_var_vals)

self.last_timeline[dev] = (dev.timeline_signal, dev.timeline_value)
dev.timeline_value += 1
self.last_timeline[dev] = (dev.timeline_signal, dev.next_timeline())

if wait:
st = time.perf_counter()
Expand Down
6 changes: 2 additions & 4 deletions tinygrad/runtime/ops_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,7 @@ def _ensure_has_local_memory(self, required):
self.max_private_segment_size = required

def invalidate_caches(self):
AMDComputeQueue().memory_barrier().signal(self.timeline_signal, self.timeline_value).submit(self)
self.timeline_value += 1
AMDComputeQueue().memory_barrier().signal(self.timeline_signal, self.next_timeline()).submit(self)
self.synchronize()

def on_device_hang(self): self.dev_iface.on_device_hang()
Expand All @@ -761,8 +760,7 @@ def _at_profile_finalize(self):
if self.sqtt_enabled:
wptrs_buf = self.allocator.alloc(round_up(len(self.sqtt_buffers), 0x1000), BufferSpec(cpu_access=True, nolru=True))
wptrs = to_mv(wptrs_buf.va_addr, wptrs_buf.size)
AMDComputeQueue().stop_trace(len(self.sqtt_buffers), wptrs_buf).signal(self.timeline_signal, self.timeline_value).submit(self)
self.timeline_value += 1
AMDComputeQueue().stop_trace(len(self.sqtt_buffers), wptrs_buf).signal(self.timeline_signal, self.next_timeline()).submit(self)
self.synchronize()
if DEBUG>=2: print('Saving SQTT in profile...')
for i,buf0 in enumerate(self.sqtt_buffers):
Expand Down
3 changes: 1 addition & 2 deletions tinygrad/runtime/ops_nv.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,7 @@ def _ensure_has_local_memory(self, required):

cast(NVComputeQueue, NVComputeQueue().wait(self.timeline_signal, self.timeline_value - 1)) \
.setup(local_mem=self.shader_local_mem.va_addr, local_mem_tpc_bytes=bytes_per_tpc) \
.signal(self.timeline_signal, self.timeline_value).submit(self)
self.timeline_value += 1
.signal(self.timeline_signal, self.next_timeline()).submit(self)

def invalidate_caches(self):
rmctrl.fb_flush_gpu_cache(self.fd_ctl, self.root, self.subdevice,
Expand Down
8 changes: 2 additions & 6 deletions tinygrad/runtime/ops_webgpu.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import functools, struct
from tinygrad.device import Compiled, Allocator, Compiler
from tinygrad.renderer.wgsl import WGSLRenderer
from tinygrad.helpers import round_up, OSX
from tinygrad.helpers import round_up
from tinygrad.runtime.autogen import webgpu
from typing import List, Any
import ctypes
import os

backend_types = {v: k for k, v in webgpu.WGPUBackendType__enumvalues.items() }

try:
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
except AttributeError:
raise RuntimeError("Cannot find dawn library. Install it with: " + ("brew tap wpmed92/dawn && brew install dawn" if OSX else
"sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so"))
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))

def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8'))

Expand Down
11 changes: 3 additions & 8 deletions tinygrad/runtime/support/am/amdev.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,7 @@ def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.entries, self.l

def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"

f = (am.AMDGPU_PTE_VALID if valid else 0) | ((am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE) if not table else 0) \
| am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if not table and self.lv != am.AMDGPU_VM_PTB else 0) \
| ((am.AMDGPU_PTE_SYSTEM) if system else 0) | ((am.AMDGPU_PTE_SNOOPED) if snooped else 0) \
| (am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC) if uncached else 0)
self.entries[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f
self.entries[entry_id] = self.adev.gmc.get_pte_flags(self.lv, table, frag, uncached, system, snooped, valid, extra=(paddr & 0x0000FFFFFFFFF000))

class AMPageTableTraverseContext:
def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False):
Expand All @@ -130,7 +125,7 @@ def level_down(self):
pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True), table=True, valid=True)
entry = pt.entries[pte_idx]

assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
assert not self.adev.gmc.is_pte_huge_page(entry), f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)

self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
Expand All @@ -156,7 +151,7 @@ def next(self, size:int, off=0):
if self.create_pts:
while pte_covers > size: pt, pte_idx, pte_covers = self.level_down()
else:
while pt.lv!=am.AMDGPU_VM_PTB and (pt.entries[pte_idx] & am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down()
while pt.lv!=am.AMDGPU_VM_PTB and not self.adev.gmc.is_pte_huge_page(pt.entries[pte_idx]): pt, pte_idx, pte_covers = self.level_down()

entries = min(size // pte_covers, 512 - pte_idx)
assert entries > 0, "Invalid entries"
Expand Down
7 changes: 7 additions & 0 deletions tinygrad/runtime/support/am/ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def init_hub(self, ip:Literal["MM", "GC"]):
for eng_i in range(18): self.adev.wreg_pair(f"reg{ip}VM_INVALIDATE_ENG{eng_i}_ADDR_RANGE", "_LO32", "_HI32", 0x1fffffffff)
self.hub_initted[ip] = True

def get_pte_flags(self, pte_lv, is_table, frag, uncached, system, snooped, valid, extra=0):
extra |= (am.AMDGPU_PTE_SYSTEM * system) | (am.AMDGPU_PTE_SNOOPED * snooped) | (am.AMDGPU_PTE_VALID * valid) | am.AMDGPU_PTE_FRAG(frag)
extra |= am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC if uncached else 0)
if not is_table: extra |= (am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE)
return extra | (am.AMDGPU_PDE_PTE if not is_table and pte_lv != am.AMDGPU_VM_PTB else 0)
def is_pte_huge_page(self, pte): return pte & am.AMDGPU_PDE_PTE

def on_interrupt(self):
for ip in ["MM", "GC"]:
st, va = self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_STATUS').read(), self.adev.reg(f'reg{ip}VM_L2_PROTECTION_FAULT_ADDR_LO32').read()
Expand Down
37 changes: 16 additions & 21 deletions tinygrad/runtime/support/hcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,14 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Type[HWQueue]|None=No
if enabled and queue is not None: queue.timestamp(st)
elif enabled:
assert queue_type is not None
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
dev.timeline_value += 1
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.next_timeline()).submit(dev)

try: yield (st, en)
finally:
if enabled and queue is not None: queue.timestamp(en)
elif enabled:
assert queue_type is not None
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
dev.timeline_value += 1
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.next_timeline()).submit(dev)

if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))

Expand Down Expand Up @@ -329,8 +327,7 @@ def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), loca
with hcq_profile(self.dev, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
q.exec(self, kernargs, global_size, local_size)

q.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
self.dev.timeline_value += 1
q.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)

if wait: self.dev.synchronize()
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
Expand Down Expand Up @@ -374,6 +371,10 @@ def synchronize(self):
Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
self.sig_prof_records = []

def next_timeline(self):
self.timeline_value += 1
return self.timeline_value - 1

@classmethod
def _alloc_signal_addr(cls) -> int:
if not cls.signal_pool:
Expand All @@ -384,8 +385,7 @@ def _alloc_signal_addr(cls) -> int:

def _at_profile_finalize(self):
def _sync(d:HCQCompiled, q_t:Type[HWQueue]):
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
d.timeline_value += 1
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.next_timeline()).submit(d)
st = time.perf_counter_ns()
d.timeline_signal.wait(d.timeline_value - 1) # average of the two
et = time.perf_counter_ns()
Expand Down Expand Up @@ -439,9 +439,8 @@ def _copyin(self, dest:HCQBuffer, src:memoryview):
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
self.b_timeline[self.b_next] = self.dev.timeline_value
self.dev.timeline_value += 1
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
self.b_timeline[self.b_next] = self.dev.timeline_value - 1

def copy_from_disk(self, dest:HCQBuffer, src, size):
def _get_temp_buf():
Expand All @@ -456,9 +455,8 @@ def _get_temp_buf():
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
self.b_timeline[batch_info[1]] = self.dev.timeline_value
self.dev.timeline_value += 1
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
self.b_timeline[batch_info[1]] = self.dev.timeline_value - 1

def _copyout(self, dest:memoryview, src:HCQBuffer):
self.dev.synchronize()
Expand All @@ -468,9 +466,8 @@ def _copyout(self, dest:memoryview, src:HCQBuffer):
for i in range(0, dest.nbytes, self.b[0].size):
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
self.dev.timeline_signal.wait(self.dev.timeline_value)
self.dev.timeline_value += 1
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
self.dev.timeline_signal.wait(self.dev.timeline_value - 1)

ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)

Expand All @@ -482,11 +479,9 @@ def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:DeviceType, d
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
.copy(dest.va_addr, src.va_addr, sz) \
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
src_dev.timeline_value += 1
.signal(src_dev.timeline_signal, src_dev.next_timeline()).submit(src_dev)

if src_dev != dest_dev:
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
dest_dev.timeline_value += 1
.signal(dest_dev.timeline_signal, dest_dev.next_timeline()).submit(dest_dev)
3 changes: 1 addition & 2 deletions tinygrad/runtime/support/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
raise FileNotFoundError('LLVM not found, you can install it with `winget install LLVM.LLVM` or point at a custom dll with LLVM_PATH')
elif OSX:
# Will raise FileNotFoundError if brew is not installed
brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
# `brew --prefix` will return even if formula is not installed
if not os.path.exists(brew_prefix):
if not os.path.exists(brew_prefix:=subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()):
raise FileNotFoundError('LLVM not found, you can install it with `brew install llvm`')
LLVM_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libLLVM.dylib')
else:
Expand Down
11 changes: 11 additions & 0 deletions tinygrad/runtime/support/webgpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import ctypes, ctypes.util, os, subprocess
from tinygrad.helpers import OSX

if OSX:
if not os.path.exists(brew_prefix:=subprocess.check_output(['brew', '--prefix', 'dawn']).decode().strip()):
raise FileNotFoundError('dawn library not found. Install it with `brew tap wpmed92/dawn && brew install dawn`')
WEBGPU_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libwebgpu_dawn.dylib')
else:
if (WEBGPU_PATH:=ctypes.util.find_library('webgpu_dawn')) is None:
raise FileNotFoundError("dawn library not found. " +
"Install it with `sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so`")
Loading