Skip to content

Commit

Permalink
feat: vjp utility
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Feb 16, 2025
1 parent 4a5316e commit 6bf6299
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 0 deletions.
73 changes: 73 additions & 0 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1160,3 +1160,76 @@ grad
return nothing
end


"""
seeded_autodiff_thunk(
rmode::ReverseModeSplit,
dresult,
f,
ReturnActivity,
annotated_args...
)
Call [`autodiff_thunk`](@ref), execute the forward pass, increment output tangent with `dresult`, then execute the reverse pass.
Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
"""
function seeded_autodiff_thunk(
rmode::ReverseModeSplit{ReturnPrimal},
dresult,
f::FA,
::Type{RA},
args::Vararg{Annotation,N},
) where {ReturnPrimal,FA<:Annotation,RA<:Annotation,N}
forward, reverse = autodiff_thunk(rmode, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresult, tape))
else
shadow_result .+= dresult # TODO: generalize beyond arrays
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end

"""
batch_seeded_autodiff_thunk(
rmode::ReverseModeSplit,
dresults::NTuple,
f,
ReturnActivity,
annotated_args...
)
Call [`autodiff_thunk`](@ref), execute the forward pass, increment each output tangent with the corresponding element from `dresults`, then execute the reverse pass.
Useful for computing pullbacks / VJPs for functions whose output is not a scalar.
"""
function batch_seeded_autodiff_thunk(
rmode::ReverseModeSplit{ReturnPrimal},
dresults::NTuple{B},
f::FA,
::Type{RA},
args::Vararg{Annotation,N},
) where {ReturnPrimal,B,FA<:Annotation,RA<:Annotation,N}
rmode_rightwidth = ReverseSplitWidth(rmode, Val(B))
forward, reverse = autodiff_thunk(rmode_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresults, tape))
else
foreach(shadow_results, dresults) do d0, d
d0 .+= d # TODO: generalize beyond arrays
end
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end
72 changes: 72 additions & 0 deletions test/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,75 @@ end
# @show J_r_3(u, A, x)
# @show J_f_3(u, A, x)
end

using Enzyme: seeded_autodiff_thunk, batch_seeded_autodiff_thunk

@testset "seeded_autodiff_thunk" begin

f(x::Vector{Float64}, y::Float64) = sum(abs2, x) * y
g(x::Vector{Float64}, y::Float64) = [f(x, y)]

x = [1.0, 2.0, 3.0]
y = 4.0
dx = similar(x)
dresult = 5.0
dxs = (similar(x), similar(x))
dresults = (5.0, 7.0)

@testset "simple" begin
for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = seeded_autodiff_thunk(mode, dresult, Const(f), Active, Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = seeded_autodiff_thunk(mode, [dresult], Const(g), Duplicated, Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

@testset "batch" begin
for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = batch_seeded_autodiff_thunk(mode, dresults, Const(f), Active, BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = batch_seeded_autodiff_thunk(mode, ([dresults[1]], [dresults[2]]), Const(g), BatchDuplicated, BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if mode == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

end

0 comments on commit 6bf6299

Please sign in to comment.