1
1
from typing import cast
2
2
import math , struct , sys
3
3
from tinygrad .renderer import Renderer
4
- from tinygrad .renderer .cstyle import ClangRenderer
4
+ from tinygrad .renderer .cstyle import ClangRenderer , AMDRenderer
5
5
from tinygrad .ops import UOp , PatternMatcher , UPat , Ops , GroupOp
6
6
from tinygrad .dtype import dtypes , DType , PtrDType , truncate
7
7
from tinygrad .helpers import prod , AMX
8
8
9
9
def ldt (dt :DType ):
10
10
if dt .vcount > 1 : return f"<{ dt .vcount } x { ldt (dt .scalar ())} >"
11
- if isinstance (dt , PtrDType ): return ldt (dt .base ) + "*"
12
- return {dtypes .int8 : "i8" , dtypes .int16 : "i16" , dtypes .int32 : "i32" , dtypes .int64 : "i64" ,
11
+ if isinstance (dt , PtrDType ): return ldt (dt .base ) + ( " addrspace(3)*" if dt . local else "*" )
12
+ return {dtypes .void : "void" , dtypes . bool : "i1" , dtypes . int8 : "i8" , dtypes .int16 : "i16" , dtypes .int32 : "i32" , dtypes .int64 : "i64" ,
13
13
dtypes .uint8 : "i8" , dtypes .uint16 : "i16" , dtypes .uint32 : "i32" , dtypes .uint64 : "i64" ,
14
- dtypes .float16 : "half" , dtypes .float32 : "float " , dtypes .float64 : "double " , dtypes .bool : "i1" , dtypes . void : "void " }[dt ]
14
+ dtypes .float16 : "half" , dtypes .bfloat16 : "bfloat " , dtypes .float32 : "float " , dtypes .float64 : "double " }[dt ]
15
15
16
16
def lconst (x , dtype :DType ):
17
17
if dtype in dtypes .floats :
@@ -63,7 +63,8 @@ def AMX(op, gpr): return f'call void asm sideeffect ".word (0x201000+($0<<5)+0$1
63
63
f" { ctx [x ]} _yes = load { ldt (x .dtype )} , { ldt (idx .dtype )} { ctx [idx ]} \n "
64
64
f" br label { ctx [x ]} _exit\n { ctx [x ][1 :]} _exit:\n "
65
65
f" { ctx [x ]} = phi { ldt (x .dtype )} [{ ctx [x ]} _yes, { ctx [x ]} _load], [{ ctx [alt ]} , { ctx [x ]} _entry]" ),
66
- (UPat (Ops .LOAD , src = (UPat .var ('idx' ),), name = "x" ), lambda ctx ,x ,idx : f" { ctx [x ]} = load { ldt (x .dtype )} , { ldt (idx .dtype )} { ctx [idx ]} " ),
66
+ (UPat (Ops .LOAD , src = (UPat .var ('idx' ),), allow_any_len = True , name = "x" ),
67
+ lambda ctx ,x ,idx : f" { ctx [x ]} = load { ldt (x .dtype )} , { ldt (idx .dtype )} { ctx [idx ]} " ),
67
68
(UPat (Ops .STORE , name = "x" ), lambda ctx ,x : f" store { ldt (x .src [1 ].dtype )} { ctx [x .src [1 ]]} , { ldt (x .src [0 ].dtype )} { ctx [x .src [0 ]]} " ),
68
69
69
70
# GEP/VECTORIZE/CAST for float4 support
@@ -113,7 +114,7 @@ class LLVMRenderer(Renderer):
113
114
supports_float4 = True
114
115
has_local = False
115
116
has_shared = False
116
- global_max = None
117
+ global_max : tuple [ int , ...] | None = None
117
118
string_rewrite = base_rewrite
118
119
if AMX : tensor_cores = ClangRenderer .amx_tc
119
120
@@ -126,6 +127,12 @@ class LLVMRenderer(Renderer):
126
127
(UPat (Ops .MAX , name = "m" ), lambda m : (m .src [0 ] < m .src [1 ]).where (m .src [1 ], m .src [0 ])),
127
128
# rewrite bf16 CAST(LOAD) to CAST(BITCAST)
128
129
(UPat (Ops .CAST , name = "root" , src = (UPat .load (UPat .index (UPat .var ("buf" ), UPat .var ("idx" )), dtype = dtypes .bfloat16 ),)), llvm_bf16_cast ),
130
+ # copied from cstyle.py, upcast to float32 all the ops that don't support bfloat16
131
+ (UPat ((Ops .SQRT , Ops .EXP2 , Ops .LOG2 , Ops .SIN ), dtype = dtypes .bfloat16 , name = "x" ),
132
+ lambda x : (UOp (x .op , dtypes .float , tuple (vv .cast (dtypes .float ) for vv in x .src ), x .arg ).cast (dtypes .bfloat16 ))),
133
+ # copied from cstyle.py, add float intermediate casting
134
+ (UPat (Ops .CAST , name = "x" , src = UPat .var ("y" , dtypes .bfloat16 )),lambda x ,y : y .cast (dtypes .float ).cast (x .dtype ) if x .dtype != dtypes .float else None ),
135
+ (UPat (Ops .CAST , dtypes .bfloat16 , UPat .var ("x" )),lambda x : x .cast (dtypes .float ).cast (dtypes .bfloat16 ) if x .dtype != dtypes .float else None ),
129
136
])
130
137
131
138
def render (self , uops : list [UOp ]) -> str :
@@ -135,6 +142,7 @@ def render(self, uops: list[UOp]) -> str:
135
142
end_lines : dict [str , None ] = {}
136
143
vc = - 1
137
144
145
+ local_args : list [str ] = []
138
146
acc_to_assign : dict [UOp , UOp ] = {}
139
147
for u in uops :
140
148
if u .op is Ops .ASSIGN : # prealloc all assigns
@@ -158,6 +166,10 @@ def render(self, uops: list[UOp]) -> str:
158
166
r [u ] = f"%data{ u .arg } " if u .op is Ops .DEFINE_GLOBAL else f"%{ u .arg [0 ]} "
159
167
# NOTE: MallocAllocator promises 0x20 alignment
160
168
args .append (f"{ ldt (u .dtype )} { ' noalias align 32' if isinstance (u .dtype , PtrDType ) else '' } { r [u ]} " )
169
+ elif u .op == Ops .DEFINE_LOCAL :
170
+ r [u ] = f"@local_{ u .arg } "
171
+ assert isinstance (u .dtype , PtrDType )
172
+ local_args .append (f"{ r [u ]} = internal unnamed_addr addrspace(3) global [{ u .dtype .size } x { ldt (u .dtype )} ] undef, align 16" )
161
173
elif u .op is Ops .ASSIGN : pass # assign is already handled by the first pass
162
174
elif u .op is Ops .DEFINE_ACC : r [u ] = r [u .src [0 ]] # a define acc can be used and never be assigned to
163
175
elif u .op is Ops .CONST : r [u ] = lconst (u .arg , u .dtype )
@@ -182,11 +194,27 @@ def render(self, uops: list[UOp]) -> str:
182
194
r [x ] = f"%acc{ vc } "
183
195
184
196
# output the function. chr(10) is '\n' (python < 3.12 doesn't support backslashes in f-strings)
185
- return f'''\
197
+ prg = f'''\
186
198
define{ (' ' + self .abi ) if self .abi is not None else '' } void @{ name } ({ ',' .join (args )} ) #0 {{
187
199
{ chr (10 ).join (kernel )}
188
200
ret void
189
201
}}
190
202
{ chr (10 ).join (end_lines .keys ())}
191
203
attributes #0 = {{ nounwind "no-builtins" "no-trapping-math"="true" }}
192
204
'''
205
+ return prg if len (local_args ) == 0 else "\n " .join (local_args )+ f"\n { prg } "
206
+
207
+ barrier = 'fence syncscope("workgroup") release\n tail call void @llvm.amdgcn.s.barrier()\n fence syncscope("workgroup") acquire\n '
208
+ code_for_workitem = {"g" : lambda x : f"tail call i32 @llvm.amdgcn.workgroup.id.{ chr (120 + int (x ))} ()" ,
209
+ "l" : lambda x : f"tail call i32 @llvm.amdgcn.workitem.id.{ chr (120 + int (x ))} ()" }
210
+ class AMDLLVMRenderer (LLVMRenderer ):
211
+ device = "AMD"
212
+ has_local = True
213
+ has_shared = True
214
+ shared_max = AMDRenderer .shared_max
215
+ global_max = AMDRenderer .global_max
216
+ abi = "amdgpu_kernel"
217
+ string_rewrite = base_rewrite + PatternMatcher ([
218
+ (UPat (Ops .SPECIAL , name = "x" ), lambda ctx , x : f" { ctx [x ]} = " + f"{ code_for_workitem [x .arg [0 ][0 ]](x .arg [0 ][- 1 ])} ; " ),
219
+ (UPat (Ops .BARRIER ), lambda ctx : barrier ),
220
+ ])
0 commit comments