Skip to content

Commit

Permalink
Update bias_act.jl (#618)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Jan 4, 2025
1 parent be9c1c8 commit 5732b97
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/bias_act.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const RCR = RuleConfig{>:HasReverseMode}
@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x)))

# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
# is independent of `x`, as `return_type` says `Union{}` when calling is an error.
struct NotaNumber <: Real end

"""
Expand Down Expand Up @@ -57,7 +57,7 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA
end

# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber}))
if isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, NotaNumber}))
Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat}
function bias_act!_fastback(Δ)
# Tempting to overwrite x again, but only safe if you call pullback at most once,
Expand All @@ -70,7 +70,7 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA

# # Slower path: can't overwrite x, but can use derivatives_given_output
# # This case is WRONG and tests fail, but not sure why
# elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
# elseif isconcretetype(Core.Compiler.return_type(only_derivative, Tuple{T, F, T}))
# Ω2 = fast_act(σ, x).(x) .+ b
# @show σ b
# function bias_act!_back2(Δ)
Expand Down

0 comments on commit 5732b97

Please sign in to comment.