Skip to content

Commit f90001e

Browse files
authored
amd llvm render (no_comgr prereq) (tinygrad#9543)
* amd llvm render * skip test_div_rounding_mode --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com>
1 parent 4f5e03b commit f90001e

File tree

5 files changed

+45
-12
lines changed

5 files changed

+45
-12
lines changed

.github/workflows/test.yml

+4-2
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ jobs:
574574
run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/test_hcq.py test/external/external_test_am.py --durations=20
575575
- name: Run pytest (amd with llvm backend)
576576
if: matrix.backend=='amd'
577-
run: python -m pytest -n=auto test/test_amd_llvm.py --durations=20
577+
run: AMD_LLVM=1 python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/test_hcq.py test/external/external_test_am.py test/test_amd_llvm.py --durations=20
578578
- name: Run TRANSCENDENTAL math
579579
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
580580
- name: Run process replay tests
@@ -630,8 +630,10 @@ jobs:
630630
env:
631631
MOCKGPU: 1
632632
AMD: 1
633+
AMD_LLVM: 1
634+
FORWARD_ONLY: 1
633635
run: |
634-
python -m pytest -n=auto test/test_amd_llvm.py --durations=20
636+
python -m pytest -n=auto test/test_hcq.py test/test_tiny.py test/test_amd_llvm.py --durations=20
635637
- name: Run pytest (ptx)
636638
env:
637639
MOCKGPU: 1

test/test_dtype_alu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
binary_operations = [operator.add, operator.sub, operator.mul, operator.lt, operator.eq]
2424

2525
# TODO: LLVM comparing with nan is incorrect
26-
if Device.DEFAULT == "LLVM":
26+
if Device.DEFAULT == "LLVM" or getenv("AMD_LLVM", 0):
2727
binary_operations.remove(operator.lt)
2828

2929
integer_binary_operations = binary_operations + [(Tensor.bitwise_xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),

test/test_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def test_div(self):
534534
helper_test_op([(45,65), (45,65)], lambda x,y: x/y)
535535
helper_test_op([(), ()], lambda x,y: x/y)
536536

537+
@unittest.skipIf(getenv("AMD_LLVM", 0), "AMD with LLVM backend generate rcp in FP division causes trunc/floor errors")
537538
def test_div_rounding_mode(self):
538539
for denominator in [-10, -5, -3, -2, -1, 1, 2, 3, 5, 10]:
539540
# int numerator

tinygrad/renderer/llvmir.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from typing import cast
22
import math, struct, sys
33
from tinygrad.renderer import Renderer
4-
from tinygrad.renderer.cstyle import ClangRenderer
4+
from tinygrad.renderer.cstyle import ClangRenderer, AMDRenderer
55
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
66
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
77
from tinygrad.helpers import prod, AMX
88

99
def ldt(dt:DType):
1010
if dt.vcount > 1: return f"<{dt.vcount} x {ldt(dt.scalar())}>"
11-
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
12-
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
11+
if isinstance(dt, PtrDType): return ldt(dt.base) + (" addrspace(3)*" if dt.local else "*")
12+
return {dtypes.void: "void", dtypes.bool: "i1", dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
1313
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
14-
dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
14+
dtypes.float16: "half", dtypes.bfloat16: "bfloat", dtypes.float32: "float", dtypes.float64: "double"}[dt]
1515

1616
def lconst(x, dtype:DType):
1717
if dtype in dtypes.floats:
@@ -63,7 +63,8 @@ def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1
6363
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
6464
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
6565
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
66-
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
66+
(UPat(Ops.LOAD, src=(UPat.var('idx'),), allow_any_len=True, name="x"),
67+
lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
6768
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
6869

6970
# GEP/VECTORIZE/CAST for float4 support
@@ -113,7 +114,7 @@ class LLVMRenderer(Renderer):
113114
supports_float4 = True
114115
has_local = False
115116
has_shared = False
116-
global_max = None
117+
global_max: tuple[int, ...] | None = None
117118
string_rewrite = base_rewrite
118119
if AMX: tensor_cores = ClangRenderer.amx_tc
119120

@@ -126,6 +127,12 @@ class LLVMRenderer(Renderer):
126127
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
127128
# rewrite bf16 CAST(LOAD) to CAST(BITCAST)
128129
(UPat(Ops.CAST, name="root", src=(UPat.load(UPat.index(UPat.var("buf"), UPat.var("idx")), dtype=dtypes.bfloat16),)), llvm_bf16_cast),
130+
# copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
131+
(UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"),
132+
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))),
133+
# copied from cstyle.py, add float intermediate casting
134+
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
135+
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
129136
])
130137

131138
def render(self, uops: list[UOp]) -> str:
@@ -135,6 +142,7 @@ def render(self, uops: list[UOp]) -> str:
135142
end_lines: dict[str, None] = {}
136143
vc = -1
137144

145+
local_args: list[str] = []
138146
acc_to_assign: dict[UOp, UOp] = {}
139147
for u in uops:
140148
if u.op is Ops.ASSIGN: # prealloc all assigns
@@ -158,6 +166,10 @@ def render(self, uops: list[UOp]) -> str:
158166
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
159167
# NOTE: MallocAllocator promises 0x20 alignment
160168
args.append(f"{ldt(u.dtype)}{' noalias align 32' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
169+
elif u.op == Ops.DEFINE_LOCAL:
170+
r[u] = f"@local_{u.arg}"
171+
assert isinstance(u.dtype, PtrDType)
172+
local_args.append(f"{r[u]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
161173
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
162174
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
163175
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
@@ -182,11 +194,27 @@ def render(self, uops: list[UOp]) -> str:
182194
r[x] = f"%acc{vc}"
183195

184196
# output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings)
185-
return f'''\
197+
prg = f'''\
186198
define{(' '+self.abi) if self.abi is not None else ''} void @{name}({','.join(args)}) #0 {{
187199
{chr(10).join(kernel)}
188200
ret void
189201
}}
190202
{chr(10).join(end_lines.keys())}
191203
attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
192204
'''
205+
return prg if len(local_args) == 0 else "\n".join(local_args)+f"\n{prg}"
206+
207+
barrier = 'fence syncscope("workgroup") release\ntail call void @llvm.amdgcn.s.barrier()\nfence syncscope("workgroup") acquire\n'
208+
code_for_workitem = {"g": lambda x: f"tail call i32 @llvm.amdgcn.workgroup.id.{chr(120+int(x))}()",
209+
"l": lambda x: f"tail call i32 @llvm.amdgcn.workitem.id.{chr(120+int(x))}()"}
210+
class AMDLLVMRenderer(LLVMRenderer):
211+
device = "AMD"
212+
has_local = True
213+
has_shared = True
214+
shared_max = AMDRenderer.shared_max
215+
global_max = AMDRenderer.global_max
216+
abi = "amdgpu_kernel"
217+
string_rewrite = base_rewrite + PatternMatcher([
218+
(UPat(Ops.SPECIAL, name="x"), lambda ctx, x: f" {ctx[x]} = " + f"{ code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; "),
219+
(UPat(Ops.BARRIER), lambda ctx: barrier),
220+
])

tinygrad/runtime/ops_amd.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from tinygrad.device import Compiled, ProfileEvent, BufferSpec, CPUProgram, PROFILE
99
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, mv_address, DEBUG, OSX
1010
from tinygrad.renderer.cstyle import AMDRenderer
11+
from tinygrad.renderer.llvmir import AMDLLVMRenderer
1112
from tinygrad.runtime.autogen import kfd, hsa, amd_gpu, libc, pci, vfio, sqtt
1213
from tinygrad.runtime.autogen.am import am, gc_11_0_0
13-
from tinygrad.runtime.support.compiler_amd import HIPCompiler
14+
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
1415
from tinygrad.runtime.support.elf import elf_loader
1516
from tinygrad.runtime.support.am.amdev import AMDev, AMMapping
1617
if getenv("IOCTL"): import extra.hip_gpu_driver.hip_ioctl # noqa: F401 # pylint: disable=unused-import
@@ -706,7 +707,8 @@ def __init__(self, device:str=""):
706707

707708
self.sdma_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x800000)
708709

709-
super().__init__(device, AMDAllocator(self), AMDRenderer(self.arch), HIPCompiler(self.arch), functools.partial(AMDProgram, self),
710+
super().__init__(device, AMDAllocator(self), AMDLLVMRenderer() if getenv("AMD_LLVM", 0) else AMDRenderer(self.arch),
711+
AMDLLVMCompiler(self.arch) if getenv("AMD_LLVM", 0) else HIPCompiler(self.arch), functools.partial(AMDProgram, self),
710712
AMDSignal, AMDComputeQueue, AMDCopyQueue)
711713

712714
# Scratch setup

0 commit comments

Comments
 (0)