Skip to content

Commit a2faa5e

Browse files
authored
am: fix pt free (tinygrad#8810)
1 parent 9df8e34 commit a2faa5e

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

test/external/external_test_am.py

+17
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,22 @@ def test_try_bad_unmap(self):
139139
with self.assertRaises(AssertionError):
140140
mm0.unmap_range(0x10000, 0x3000)
141141

142+
def test_free_pt(self):
143+
mm0 = self.d[0].mm
144+
145+
# offset from start
146+
for off in [0, 0x3000, 0x10000]:
147+
mm0.map_range(0x1000000 + off, (2 << 20) - off, paddrs=[(0x10000, 0x1000)] * (512 - off // 0x1000))
148+
mm0.unmap_range(0x1000000 + off, (2 << 20) - off)
149+
mm0.map_range(0x1000000, 2 << 20, paddrs=[(0x10000, 2 << 20)])
150+
mm0.unmap_range(0x1000000, 2 << 20)
151+
152+
# offset from end
153+
for off in [0x1000, 0x20000]:
154+
mm0.map_range(0x1000000, (2 << 20) - off, paddrs=[(0x10000, 0x1000)] * (512 - off // 0x1000))
155+
mm0.unmap_range(0x1000000, (2 << 20) - off)
156+
mm0.map_range(0x1000000, 2 << 20, paddrs=[(0x10000, 2 << 20)])
157+
mm0.unmap_range(0x1000000, 2 << 20)
158+
142159
if __name__ == "__main__":
143160
unittest.main()

tinygrad/runtime/support/am/amdev.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _try_free_pt(self) -> bool:
144144
return False
145145

146146
def level_up(self):
147-
while self.pt_stack[-1][1] == 512 or self._try_free_pt():
147+
while self._try_free_pt() or self.pt_stack[-1][1] == 512:
148148
_, pt_cnt, _ = self.pt_stack.pop()
149149
if pt_cnt == 512: self.pt_stack[-1] = (self.pt_stack[-1][0], self.pt_stack[-1][1] + 1, self.pt_stack[-1][2])
150150

@@ -174,6 +174,8 @@ def __init__(self, adev:AMDev, vram_size:int):
174174
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
175175

176176
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
177+
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
178+
177179
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
178180

179181
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True)
@@ -190,7 +192,7 @@ def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=
190192
return AMMapping(vaddr, size, paddrs, uncached=uncached, system=system, snooped=snooped)
191193

192194
def unmap_range(self, vaddr:int, size:int):
193-
if AM_DEBUG >= 2: print(f"Unmapping {vaddr=:#x} ({size=:#x})")
195+
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: unmapping {vaddr=:#x} ({size=:#x})")
194196

195197
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True)
196198
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):

0 commit comments

Comments
 (0)