Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in exception handling with custom reverse rule #2290

Closed
SouthEndMusic opened this issue Feb 1, 2025 · 1 comment
Closed

Error in exception handling with custom reverse rule #2290

SouthEndMusic opened this issue Feb 1, 2025 · 1 comment

Comments

@SouthEndMusic
Copy link

MWE:

using SplineGrids
using Enzyme
using .EnzymeRules
using Random

function SplineGrids.evaluate!(
    spline_grid::SplineGrid,
    control_points::SplineGrids.AbstractControlPointArray;
    kwargs...
)::Nothing
    evaluate!(
        spline_grid;
        control_points,
        kwargs...
    )
    return nothing
end

function EnzymeRules.augmented_primal(
    config::RevConfigWidth{1},
    ::Const{typeof(evaluate!)},
    ::Type{RT},
    spline_grid::Duplicated{<:SplineGrid},
    control_points::Duplicated{<:SplineGrids.AbstractControlPointArray};
    kwargs...
) where {RT}
    evaluate!(spline_grid.val, control_points.val; kwargs...)
    primal = if needs_primal(config)
        spline_grid.val
    else
        nothing
    end
    shadow = if needs_shadow(config)
        spline_grid.dval
    else
        nothing
    end
    EnzymeRules.AugmentedReturn(primal, shadow, kwargs)
end

function EnzymeRules.reverse(
    ::RevConfigWidth{1},
    ::Const{typeof(evaluate!)},
    ::Type{RT},
    kwargs,
    spline_grid::Duplicated{<:SplineGrid},
    control_points::Duplicated{<:SplineGrids.AbstractControlPointArray}
) where {RT}
    evaluate_adjoint!(spline_grid.dval; control_points=control_points.dval, kwargs...)
    (nothing, nothing)
end

Random.seed!(1)

n_control_points = (10, 10)
degree = (2, 2)
n_sample_points = (50, 50)
Nout = 2

spline_dimensions = SplineDimension.(n_control_points, degree, n_sample_points)
spline_grid = SplineGrid(spline_dimensions, Nout)

function loss(control_points_flat, spline_grid)
    evaluate!(
        spline_grid,
        reshape(control_points_flat, size(spline_grid.control_points))
    )
    return sum(spline_grid.eval)
end

control_points_flat = rand(
    Float32,
    length(spline_grid.control_points)
)

dcontrol_points_flat = Duplicated(control_points_flat, make_zero(control_points_flat))
dspline_grid = Duplicated(spline_grid, make_zero(spline_grid))

autodiff(
    Reverse,
    loss,
    Active,
    dcontrol_points_flat,
    dspline_grid
)

env:

  [7da242da] Enzyme v0.13.28
  [59c446ea] SplineGrids v0.1.1

Error + Stacktrace:

ERROR: MethodError: no method matching my_methodinstance(::Type{typeof(Enzyme.Compiler.custom_rule_method_error)}, ::Type{Tuple{…}}, ::UInt64)

Closest candidates are:
  my_methodinstance(::Core.Compiler.AbstractInterpreter, ::Type, ::Type)
   @ Enzyme C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\utils.jl:266
  my_methodinstance(::Core.Compiler.AbstractInterpreter, ::Type, ::Type, ::Union{Nothing, Base.RefValue{UInt64}})
   @ Enzyme C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\utils.jl:266
  my_methodinstance(::Core.Compiler.AbstractInterpreter, ::Type, ::Type, ::Union{Nothing, Base.RefValue{UInt64}}, ::Union{Nothing, Base.RefValue{UInt64}})
   @ Enzyme C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\utils.jl:266
  ...

Stacktrace:
  [1] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{…}, shadowR::Ptr{…}, tape::LLVM.UndefValue)
    @ Enzyme.Compiler C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\rules\customrules.jl:992
  [2] enzyme_custom_rev(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, tape::Union{…})
    @ Enzyme.Compiler C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\rules\customrules.jl:1516
  [3] enzyme_custom_rev_cfunc(B::Ptr{…}, OrigCI::Ptr{…}, gutils::Ptr{…}, tape::Ptr{…})
    @ Enzyme.Compiler C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\rules\llvmrules.jl:48
  [4] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\api.jl:268
  [5] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…} where N, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\compiler.jl:1706
  [6] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\compiler.jl:4550
  [7] codegen
    @ C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\compiler.jl:3353 [inlined]
  [8] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)     
    @ Enzyme.Compiler C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\compiler.jl:5410
  [9] _thunk
    @ C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\compiler.jl:5410 [inlined]
 [10] cached_compilation
    @ C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\compiler.jl:5462 [inlined]
 [11] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, edges::Vector{…})
    @ Enzyme.Compiler C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\compiler.jl:5573
 [12] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::Tuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\compiler.jl:5758
 [13] autodiff
    @ C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\Enzyme.jl:485 [inlined]
 [14] autodiff(::ReverseMode{…}, ::typeof(loss), ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
    @ Enzyme C:\Users\konin_bt\.julia\packages\Enzyme\R6sE8\src\Enzyme.jl:524
@wsmoses
Copy link
Member

wsmoses commented Feb 2, 2025

Okay error message fixed by #2291

now shows

       )
ERROR: MethodError: no method matching augmented_primal(::RevConfigWidth{1, false, false, (false, false, false), false}, ::Const{typeof(evaluate!)}, ::Type{Const{Nothing}}, ::MixedDuplicated{SplineGrid{2, 2, Float32, Int32, SplineDimension{Float32, Int32, KnotVector{}, Vector{}, Vector{}, Array{}}, DefaultControlPoints{2, 2, Float32, Array{}}, Nothing, Array{Float32, 3}, false}}, ::Duplicated{Array{Float32, 3}})

Closest candidates are:
  augmented_primal(::RevConfig, ::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, ::BodyTy, ::Any, Annotation...) where {BodyTy, N}
   @ Enzyme ~/git/Enzyme.jl/src/internal_rules.jl:337
  augmented_primal(::RevConfigWidth{1}, ::Const{typeof(evaluate!)}, ::Type{RT}, ::Duplicated{<:SplineGrid}, ::Duplicated{<:Union{SplineGrids.AbstractControlPoints{Nin, Nout, Tv}, AbstractArray{Tv}} where {Nin, Nout, Tv}}; kwargs...) where RT
   @ Main REPL[7]:1
  augmented_primal(::RevConfig, ::Const{<:KernelAbstractions.Kernel}, ::Type{Const{Nothing}}, ::Any...; ndrange, workgroupsize) where N
   @ EnzymeExt ~/.julia/packages/KernelAbstractions/mD0Rj/ext/EnzymeCore08Ext.jl:214
  ...

Stacktrace:
  [1] custom_rule_method_error
    @ ~/git/Enzyme.jl/src/rules/customrules.jl:445 [inlined]
  [2] loss
    @ ./REPL[16]:2 [inlined]
  [3] loss
    @ ./REPL[16]:0 [inlined]
  [4] diffejulia_loss_1454_inner_1wrap
    @ ./REPL[16]:0
  [5] macro expansion
    @ ~/git/Enzyme.jl/src/compiler.jl:5340 [inlined]
  [6] enzyme_call
    @ ~/git/Enzyme.jl/src/compiler.jl:4878 [inlined]
  [7] CombinedAdjointThunk
    @ ~/git/Enzyme.jl/src/compiler.jl:4750 [inlined]
  [8] autodiff
    @ ~/git/Enzyme.jl/src/Enzyme.jl:503 [inlined]
  [9] autodiff(::ReverseMode{false, false, FFIABI, false, false}, ::typeof(loss), ::Type{Active}, ::Duplicated{Vector{Float32}}, ::Duplicated{SplineGrid{2, 2, Float32, Int32, SplineDimension{Float32, Int32, KnotVector{Float32, Int32, Vector{Float32}, Vector{Int32}}, Vector{Float32}, Vector{Int32}, Array{Float32, 3}}, DefaultControlPoints{2, 2, Float32, Array{Float32, 3}}, Nothing, Array{Float32, 3}, false}})
    @ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:524
 [10] top-level scope
    @ REPL[20]:1
Some type information was truncated. Use `show(err)` to see complete types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants