Skip to content

Commit

Permalink
DI: Reverse dependency structure
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 8, 2025
1 parent 4a5316e commit 00803ee
Show file tree
Hide file tree
Showing 7 changed files with 1,012 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand All @@ -29,6 +31,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[extensions]
EnzymeBFloat16sExt = "BFloat16s"
EnzymeChainRulesCoreExt = "ChainRulesCore"
EnzymeDIExt = ["DifferentiationInterface", "ADTypes", "EnzymeCore"]
EnzymeGPUArraysCoreExt = "GPUArraysCore"
EnzymeLogExpFunctionsExt = "LogExpFunctions"
EnzymeSpecialFunctionsExt = "SpecialFunctions"
Expand All @@ -53,8 +56,10 @@ StaticArrays = "1"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand Down
54 changes: 54 additions & 0 deletions ext/EnzymeDIExt/EnzymeDIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
module EnzymeDIExt

using ADTypes: ADTypes, AutoEnzyme
using Base: Fix1
import DifferentiationInterface as DI
using EnzymeCore:
Active,
Annotation,
BatchDuplicated,
BatchDuplicatedNoNeed,
BatchMixedDuplicated,
Combined,
Const,
Duplicated,
DuplicatedNoNeed,
EnzymeCore,
Forward,
ForwardMode,
ForwardWithPrimal,
MixedDuplicated,
Mode,
NoPrimal,
Reverse,
ReverseMode,
ReverseModeSplit,
ReverseSplitNoPrimal,
ReverseSplitWidth,
ReverseSplitWithPrimal,
ReverseWithPrimal,
Split,
WithPrimal
using Enzyme:
autodiff,
autodiff_thunk,
create_shadows,
gradient,
gradient!,
guess_activity,
hvp,
hvp!,
jacobian,
make_zero,
make_zero!,
onehot

include("utils.jl")

include("forward_onearg.jl")
include("forward_twoarg.jl")

include("reverse_onearg.jl")
include("reverse_twoarg.jl")

end # module
267 changes: 267 additions & 0 deletions ext/EnzymeDIExt/forward_onearg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
## Pushforward

function DI.prepare_pushforward(
f::F,
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
return DI.NoPushforwardPrep()
end

function DI.value_and_pushforward(
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
dx_sametype = convert(typeof(x), only(tx))
x_and_dx = Duplicated(x, dx_sametype)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)
return y, (dy,)
end

function DI.value_and_pushforward(
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
x_and_tx = BatchDuplicated(x, tx_sametype)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
return y, values(ty)
end

function DI.pushforward(
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{DI.Context,C},
) where {F,C}
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
dx_sametype = convert(typeof(x), only(tx))
x_and_dx = Duplicated(x, dx_sametype)
annotated_contexts = translate(backend, mode, Val(1), contexts...)
dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...))
return (dy,)
end

function DI.pushforward(
f::F,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
x_and_tx = BatchDuplicated(x, tx_sametype)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
return values(ty)
end

function DI.value_and_pushforward!(
f::F,
ty::NTuple,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
# dy cannot be passed anyway
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
return y, ty
end

function DI.pushforward!(
f::F,
ty::NTuple,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{DI.Context,C},
) where {F,C}
# dy cannot be passed anyway
new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
return ty
end

## Gradient

struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
shadows::O
end

function EnzymeForwardGradientPrep(::Val{B}, shadows::O) where {B,O}
return EnzymeForwardGradientPrep{B,O}(shadows)
end

function DI.prepare_gradient(
f::F,
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,C}
valB = to_val(DI.pick_batchsize(backend, x))
shadows = create_shadows(valB, x)
return EnzymeForwardGradientPrep(valB, shadows)
end

function DI.gradient(
f::F,
prep::EnzymeForwardGradientPrep{B},
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
derivs = gradient(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
)
return first(derivs)
end

function DI.value_and_gradient(
f::F,
prep::EnzymeForwardGradientPrep{B},
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
(; derivs, val) = gradient(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
)
return val, first(derivs)
end

function DI.gradient!(
f::F,
grad,
prep::EnzymeForwardGradientPrep{B},
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...))
end

function DI.value_and_gradient!(
f::F,
grad,
prep::EnzymeForwardGradientPrep{B},
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
return y, copyto!(grad, new_grad)
end

## Jacobian

struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
shadows::O
output_length::Int
end

function EnzymeForwardOneArgJacobianPrep(
::Val{B}, shadows::O, output_length::Integer
) where {B,O}
return EnzymeForwardOneArgJacobianPrep{B,O}(shadows, output_length)
end

function DI.prepare_jacobian(
f::F,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,C}
y = f(x, map(DI.unwrap, contexts)...)
valB = to_val(DI.pick_batchsize(backend, x))
shadows = create_shadows(valB, x)
return EnzymeForwardOneArgJacobianPrep(valB, shadows, length(y))
end

function DI.jacobian(
f::F,
prep::EnzymeForwardOneArgJacobianPrep{B},
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
mode = forward_noprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
derivs = jacobian(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
)
jac_tensor = first(derivs)
return maybe_reshape(jac_tensor, prep.output_length, length(x))
end

function DI.value_and_jacobian(
f::F,
prep::EnzymeForwardOneArgJacobianPrep{B},
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,B,C}
mode = forward_withprimal(backend)
f_and_df = get_f_and_df(f, backend, mode)
annotated_contexts = translate(backend, mode, Val(B), contexts...)
(; derivs, val) = jacobian(
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
)
jac_tensor = first(derivs)
return val, maybe_reshape(jac_tensor, prep.output_length, length(x))
end

function DI.jacobian!(
f::F,
jac,
prep::EnzymeForwardOneArgJacobianPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,C}
return copyto!(jac, DI.jacobian(f, prep, backend, x, contexts...))
end

function DI.value_and_jacobian!(
f::F,
jac,
prep::EnzymeForwardOneArgJacobianPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{DI.Constant,C},
) where {F,C}
y, new_jac = DI.value_and_jacobian(f, prep, backend, x, contexts...)
return y, copyto!(jac, new_jac)
end
Loading

0 comments on commit 00803ee

Please sign in to comment.