Skip to content

Commit 499c0da

Browse files
authored
Support the new have_fma intrinsic on Julia 1.8. (#270)
1 parent b02de57 commit 499c0da

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

src/interface.jl

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ function julia_datalayout(@nospecialize(target::AbstractCompilerTarget))
4343
DataLayout(string(dl) * "-ni:10:11:12:13")
4444
end
4545

46+
have_fma(@nospecialize(target::AbstractCompilerTarget), T::Type) = false
47+
4648

4749
## params
4850

src/optim.jl

+52
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ function addOptimizationPasses!(pm, opt_level=2)
2424
constant_merge!(pm)
2525

2626
if opt_level < 2
27+
cpu_features!(pm)
28+
if opt_level == 1
29+
instruction_simplify!(pm)
30+
end
2731
if LLVM.version() >= v"12"
2832
cfgsimplification!(pm; hoist_common_insts=true)
2933
else
@@ -72,6 +76,7 @@ function addOptimizationPasses!(pm, opt_level=2)
7276
else
7377
cfgsimplification!(pm)
7478
end
79+
cpu_features!(pm)
7580
scalar_repl_aggregates!(pm)
7681
instruction_simplify!(pm)
7782
jump_threading!(pm)
@@ -237,6 +242,53 @@ end
237242

238243

239244
## lowering intrinsics
245+
cpu_features!(pm::PassManager) = add!(pm, ModulePass("LowerCPUFeatures", cpu_features!))
246+
function cpu_features!(mod::LLVM.Module)
247+
job = current_job::CompilerJob
248+
ctx = context(mod)
249+
changed = false
250+
251+
argtyps = Dict(
252+
"f32" => Float32,
253+
"f64" => Float64,
254+
)
255+
256+
# have_fma
257+
for f in functions(mod)
258+
ft = eltype(llvmtype(f))
259+
fn = LLVM.name(f)
260+
startswith(fn, "julia.cpu.have_fma.") || continue
261+
typnam = fn[20:end]
262+
263+
# determine whether this back-end supports FMA on this type
264+
has_fma = if haskey(argtyps, typnam)
265+
typ = argtyps[typnam]
266+
have_fma(job.target, typ)
267+
else
268+
# warn?
269+
false
270+
end
271+
has_fma = ConstantInt(return_type(ft), has_fma)
272+
273+
# substitute all uses of the intrinsic with a constant
274+
materialized = LLVM.Value[]
275+
for use in uses(f)
276+
val = user(use)
277+
replace_uses!(val, has_fma)
278+
push!(materialized, val)
279+
end
280+
281+
# remove the intrinsic and its uses
282+
for val in materialized
283+
@assert isempty(uses(val))
284+
unsafe_delete!(LLVM.parent(val), val)
285+
end
286+
@assert isempty(uses(f))
287+
unsafe_delete!(mod, f)
288+
end
289+
290+
return changed
291+
end
240292

241293
# lower object allocations to to PTX malloc
242294
#

src/ptx.jl

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ llvm_datalayout(target::PTXCompilerTarget) = Int===Int64 ?
6161
"e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64"*
6262
"-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
6363

64+
have_fma(@nospecialize(target::PTXCompilerTarget), T::Type) = true
65+
6466

6567
## job
6668

0 commit comments

Comments
 (0)