Skip to content

Commit b1ddb2a

Browse files
authored
fix win32 CPUProgram missing cache flush (tinygrad#9171)
* win32: fix missing inst cache flush, rename ptr->self.mem for consistency with posix code * fix types, remove assert * fix memory leak * rm whitespace
1 parent 1bb9d78 commit b1ddb2a

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tinygrad/device.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,13 @@ def __init__(self, name:str, lib:bytes):
236236
PAGE_EXECUTE_READWRITE = 0x40
237237
MEM_COMMIT = 0x1000
238238
MEM_RESERVE = 0x2000
239-
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_uint64
240-
ptr = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_int(0), ctypes.c_int(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
241-
ctypes.memmove(ptr, lib, len(lib))
242-
self.fxn = ctypes.CFUNCTYPE(None)(ptr)
239+
ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
240+
self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
241+
ctypes.memmove(self.mem, lib, len(lib))
242+
ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
243+
proc = ctypes.windll.kernel32.GetCurrentProcess()
244+
ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
245+
self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
243246
else:
244247
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
245248
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
@@ -268,6 +271,9 @@ def __call__(self, *bufs, vals=(), wait=False):
268271
if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
269272
return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
270273

274+
def __del__(self):
275+
if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
276+
271277
# **************** for Compiled Devices ****************
272278

273279
class CompileError(Exception): pass

0 commit comments

Comments
 (0)