Skip to content

Autodiff Deferred Thunk is broken #1417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wsmoses opened this issue May 7, 2024 · 3 comments
Closed

Autodiff Deferred Thunk is broken #1417

wsmoses opened this issue May 7, 2024 · 3 comments
Assignees

Comments

@wsmoses
Copy link
Member

wsmoses commented May 7, 2024

@vchuravy @michel2323 I'm not sure what's happening here. The non-deferred version succeeeds. This is a minimization of the segfault from JuliaGPU/KernelAbstractions.jl#476

using KernelAbstractions
using Test
using Enzyme
using EnzymeCore
# using KernelAbstractions.EnzymeExt
Enzyme.API.printall!(true)

@kernel function square!(A)
    I = @index(Global, Linear)
    @inbounds A[I] *= A[I]
end


    A = Array{Float64}(undef, 64)
    dA = Array{Float64}(undef, 64)

    A .= (1:1:64)
    dA .= 1

    import KernelAbstractions: Kernel, StaticSize, launch_config, __groupsize, __groupindex, blocks, mkcontext, CompilerMetadata, CPU

    function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT}
        TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
        @show TapeType
        forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
        
        # Non deferred works
        # forward, reverse = EnzymeCore.autodiff_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
        forward(Const(f), Const(ctx), args...)[1]
        return nothing
    end

    ndrange = (1,)
    workgroupsize = nothing
    func = Enzyme.Const(square!(CPU()))
        kernel = func.val
        f = kernel.f

        ndrange, workgroupsize, iterspace, dynamic = launch_config(kernel, ndrange, workgroupsize)


        # TODO autodiff_deferred on the func.val
        ModifiedBetween = Val((false, false, true))

        aug_kernel = similar(kernel, aug_fwd)

        aug_kernel(f, ModifiedBetween, Duplicated(A, dA); ndrange, workgroupsize)
@vchuravy
Copy link
Member

vchuravy commented May 7, 2024

Do you have a backtrace for the fault?

@wsmoses
Copy link
Member Author

wsmoses commented May 7, 2024



[307161] signal (11.1): Segmentation fault
in expression starting at /home/wmoses/git/KernelAbstractions.jl/test.jl:47
cpu_square! at /home/wmoses/git/KernelAbstractions.jl/src/macros.jl:285 [inlined]
cpu_square! at ./none:0 [inlined]
cpu_square! at ./none:0 [inlined]
augmented_julia_cpu_square__3572_inner_1wrap at ./none:0
macro expansion at /home/wmoses/git/Enzyme.jl/src/compiler.jl:5656 [inlined]
enzyme_call at /home/wmoses/git/Enzyme.jl/src/compiler.jl:5334 [inlined]
AugmentedForwardThunk at /home/wmoses/git/Enzyme.jl/src/compiler.jl:5223 [inlined]
aug_fwd at /home/wmoses/git/KernelAbstractions.jl/test.jl:26 [inlined]
__thread_run at /home/wmoses/git/KernelAbstractions.jl/src/cpu.jl:115
unknown function (ip: 0x7d8f18dc929e)
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
__run at /home/wmoses/git/KernelAbstractions.jl/src/cpu.jl:82
unknown function (ip: 0x7d903bc8900d)
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
#_#16 at /home/wmoses/git/KernelAbstractions.jl/src/cpu.jl:44
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
do_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/builtins.c:768
Kernel at /home/wmoses/git/KernelAbstractions.jl/src/cpu.jl:37
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
do_call at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:126
eval_value at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:617
jl_interpret_toplevel_thunk at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/interpreter.c:775
jl_toplevel_eval_flex at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:877
ijl_toplevel_eval_in at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/toplevel.c:985
eval at ./boot.jl:385 [inlined]
include_string at ./loading.jl:2076
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
_include at ./loading.jl:2136
include at ./Base.jl:495
jfptr_include_46403.1 at /home/wmoses/git/Enzyme.jl/julia-1.10.2/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
exec_options at ./client.jl:318
_start at ./client.jl:552
jfptr__start_82738.1 at /home/wmoses/git/Enzyme.jl/julia-1.10.2/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
true_main at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jlapi.c:582
jl_repl_entrypoint at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/src/jlapi.c:731
main at /cache/build/builder-amdci5-1/julialang/julia-release-1-dot-10/cli/loader_exe.c:58
unknown function (ip: 0x7d903d029d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 24285604 (Pool: 24248754; Big: 36850); GC: 36
Segmentation fault (core dumped)

@vchuravy
Copy link
Member

KA free reproducer

function kernel(len, A)
    for i in 1:len
        A[i] *= A[i]
    end 
end

using Enzyme, EnzymeCore

A = Array{Float64}(undef, 64)
dA = Array{Float64}(undef, 64)

A .= (1:1:64)
dA .= 1

function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, args...) where {ModifiedBetween, FT}
    TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
    forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
    
    # Non deferred works
    # forward, reverse = EnzymeCore.autodiff_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
    forward(Const(f), Const(ctx), args...)[1]
    return nothing
end

ModifiedBetween = Val((false, false, true))

aug_fwd(64, kernel, ModifiedBetween, Duplicated(A, dA))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants