Skip to content

Commit 823bf18

Browse files
authored
Merge pull request #224 from JuliaGPU/tb/version_bars
Emit CUDA version info in the LLVM module.
2 parents e5658e5 + a6f98a1 commit 823bf18

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

Diff for: src/ptx.jl

+22-6
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,31 @@ runtime_slug(@nospecialize(job::CompilerJob{PTXCompilerTarget})) =
8585
"-exitable=$(job.target.exitable)"
8686

8787
function process_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), mod::LLVM.Module)
88+
ctx = context(mod)
89+
8890
# calling convention
8991
if LLVM.version() >= v"8"
9092
for f in functions(mod)
9193
# JuliaGPU/GPUCompiler.jl#97
9294
#callconv!(f, LLVM.API.LLVMPTXDeviceCallConv)
9395
end
9496
end
97+
98+
# emit the device capability and ptx isa version as constants in the module. this makes
99+
# it possible to 'query' these in device code, relying on LLVM to optimize the checks
100+
# away and generate static code. note that we only do so if there's actual uses of these
101+
# variables; unconditionally creating a gvar would result in duplicate declarations.
102+
for (name, value) in ["sm_major" => job.target.cap.major,
103+
"sm_minor" => job.target.cap.minor,
104+
"ptx_major" => job.target.ptx.major,
105+
"ptx_minor" => job.target.ptx.minor]
106+
if haskey(globals(mod), name)
107+
gv = globals(mod)[name]
108+
initializer!(gv, ConstantInt(LLVM.Int32Type(ctx), value))
109+
# change the linkage so that we can inline the value
110+
linkage!(gv, LLVM.API.LLVMPrivateLinkage)
111+
end
112+
end
95113
end
96114

97115
function process_entry!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
@@ -142,12 +160,6 @@ function process_entry!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
142160
# calling convention
143161
callconv!(entry, LLVM.API.LLVMPTXKernelCallConv)
144162
end
145-
else
146-
# we can't look up device functions using the CUDA APIs, so alias them to a global
147-
gv = GlobalVariable(mod, llvmtype(entry), LLVM.name(entry) * "_slot")
148-
initializer!(gv, entry)
149-
linkage!(gv, LLVM.API.LLVMLinkOnceODRLinkage)
150-
set_used!(mod, gv)
151163
end
152164

153165
return entry
@@ -161,6 +173,10 @@ function add_lowering_passes!(@nospecialize(job::CompilerJob{PTXCompilerTarget})
161173

162174
# even if we support `unreachable`, we still prefer `exit` to `trap`
163175
add!(pm, ModulePass("HideTrap", hide_trap!))
176+
177+
# we emit properties (of the device and ptx isa) as private global constants,
178+
# so run the optimizer so that they are inlined before the rest of the optimizer runs.
179+
global_optimizer!(pm)
164180
end
165181

166182
function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),

0 commit comments

Comments
 (0)