From cc4b9eb2c4f068a3371c2088036328eafb25a4db Mon Sep 17 00:00:00 2001 From: se-schmitt Date: Thu, 6 Feb 2025 10:10:04 +0100 Subject: [PATCH 1/5] Add SpecialFunctions as dependency --- Project.toml | 2 ++ src/NNlib.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index 7cbd10d2..f07b9dfb 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] @@ -41,6 +42,7 @@ GPUArraysCore = "0.1, 0.2" KernelAbstractions = "0.9.2" LinearAlgebra = "<0.0.1, 1" Random = "<0.0.1, 1" +SpecialFunctions = "2.5.0" Statistics = "1" cuDNN = "1" julia = "1.9" diff --git a/src/NNlib.jl b/src/NNlib.jl index 687206fc..493eee9d 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -15,6 +15,7 @@ using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose using Random using Statistics using Statistics: mean +using SpecialFunctions const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} From 86256b2966226207ead5abb50b05fdc0ec6a8d8d Mon Sep 17 00:00:00 2001 From: se-schmitt Date: Thu, 6 Feb 2025 10:33:33 +0100 Subject: [PATCH 2/5] add full `gelu`, changes old `gelu` -> `gelu_fast` --- src/activations.jl | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/activations.jl b/src/activations.jl index 4ed58622..f8dbd787 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -5,7 +5,7 @@ ACTIVATIONS = [ :σ, :hardσ, :hardtanh, :relu, - :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :hardswish, :selu, + :leakyrelu, :relu6, :rrelu, :elu, :gelu, :gelu_fast, :swish, :hardswish, :selu, :celu, :softplus, :softsign, :logσ, :logcosh, :mish, :tanhshrink, :softshrink, :trelu, :lisht, :tanh_fast, :sigmoid_fast, @@ -301,7 +301,7 @@ elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1)) deriv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + oftype(Ω, α)) """ - gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) + gelu(x) = xΦ(x) = 0.5x * (1 + erf(x/√2)) Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). @@ -335,7 +335,20 @@ julia> lineplot!(ans, swish) ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ -function gelu(x) +gelu(x) = x/2*(1 + erf(x/sqrt(oftf(x,2)))) + +function deriv_gelu(x) + SQRT2 = sqrt(oftf(x,2)) + Φ = (1 + erf(x/SQRT2))/2 + Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π)) +end + +""" + gelu_fast(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) + +Fast approximation of [`gelu`](@ref) activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). +""" +function gelu_fast(x) α = oftf(x, 0.044715) # λ = oftf(x, gelu_λ) # x/2 * (1 + tanh(λ * (x + α * x^3))) # Standard implementation, for reference @@ -346,7 +359,7 @@ end const gelu_λ = √(2 / π) const gelu_2λ = √(8 / π) -function deriv_gelu(x) +function deriv_gelu_fast(x) α = oftf(x, 0.044715) α2 = oftf(x, 0.08943) λλ = oftf(x, gelu_2λ) @@ -875,6 +888,7 @@ UNARY_ACTS = [ # f, dfdx # rrelu is random, can't write a rule. (:elu, :(deriv_elu(Ω))), (:gelu, :(deriv_gelu(x))), + (:gelu_fast, :(deriv_gelu_fast(x))), (:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))), (:hardswish, :(deriv_hardswish(x))), # lisht From 8275b185e90f1f33a00a5854987b0ea795cc52e0 Mon Sep 17 00:00:00 2001 From: se-schmitt Date: Thu, 6 Feb 2025 10:50:21 +0100 Subject: [PATCH 3/5] Add tests and docs --- docs/src/reference.md | 1 + test/activations.jl | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/src/reference.md b/docs/src/reference.md index 5edde719..0304d9ee 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -10,6 +10,7 @@ Non-linearities that go between layers of your model. Note that, unless otherwis celu elu gelu +gelu_fast hardsigmoid sigmoid_fast hardtanh diff --git a/test/activations.jl b/test/activations.jl index 3a14bfde..9b3a5891 100644 --- a/test/activations.jl +++ b/test/activations.jl @@ -12,6 +12,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test rrelu(0.0) == 0.0 @test elu(0.0) == 0.0 @test gelu(0.0) == 0.0 +@test gelu_fast(0.0) == 0.0 @test swish(0.0) == 0.0 @test hardswish(0.0) == 0.0 @test lisht(0.0) == 0.0 @@ -35,7 +36,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test relu6(1.0) == 1.0 @test rrelu(1.0) == 1.0 @test elu(1.0) == 1.0 -@test gelu(1.0) == 0.8411919906082768 +@test gelu(1.0) == 0.8413447460685429 +@test gelu_fast(1.0) == 0.8411919906082768 @test swish(1.0) == sigmoid(1.0) @test hardswish(1.0) == hardsigmoid(1.0) @test lisht(1.0) ≈ 1.0 * tanh(1.0) @@ -57,7 +59,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test relu6(-1.0) == 0.0 @test -1/3.0 <= rrelu(-1.0) <= -1/8.0 @test elu(-1.0) == exp(-1.0) - 1.0 -@test gelu(-1.0) ≈ -0.15880800939172324 +@test gelu(-1.0) == -0.15865525393145707 +@test gelu_fast(-1.0) ≈ -0.15880800939172324 @test swish(-1.0) == -sigmoid(-1.0) @test hardswish(-1.0) == -hardsigmoid(-1.0) @test lisht(-1.0) ≈ -1.0 * tanh(-1.0) @@ -114,7 +117,7 @@ end a == softsign && continue @test !isnan(a(Inf32)) - a in [gelu, swish, hardswish, logcosh, mish] && continue + a in [gelu, gelu_fast, swish, hardswish, logcosh, mish] && continue @test !isnan(a(-Inf32)) end end From bdfa0f046afd45def91f0472d94c3d3fb307d1f6 Mon Sep 17 00:00:00 2001 From: se-schmitt Date: Sun, 9 Feb 2025 16:10:29 +0100 Subject: [PATCH 4/5] Change names: `gelu_fast` -> `gelu`, `gelu` -> `gelu_full` --- docs/src/reference.md | 2 +- src/activations.jl | 38 +++++++++++++++++++------------------- test/activations.jl | 12 ++++++------ 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/docs/src/reference.md b/docs/src/reference.md index 0304d9ee..88375047 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -10,7 +10,7 @@ Non-linearities that go between layers of your model. Note that, unless otherwis celu elu gelu -gelu_fast +gelu_full hardsigmoid sigmoid_fast hardtanh diff --git a/src/activations.jl b/src/activations.jl index f8dbd787..b1981ed3 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -5,7 +5,7 @@ ACTIVATIONS = [ :σ, :hardσ, :hardtanh, :relu, - :leakyrelu, :relu6, :rrelu, :elu, :gelu, :gelu_fast, :swish, :hardswish, :selu, + :leakyrelu, :relu6, :rrelu, :elu, :gelu, :gelu_full, :swish, :hardswish, :selu, :celu, :softplus, :softsign, :logσ, :logcosh, :mish, :tanhshrink, :softshrink, :trelu, :lisht, :tanh_fast, :sigmoid_fast, @@ -301,9 +301,9 @@ elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1)) deriv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + oftype(Ω, α)) """ - gelu(x) = xΦ(x) = 0.5x * (1 + erf(x/√2)) + gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) -Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). +Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) (see also [`gelu_full`](@ref)). ```julia-repl julia> lineplot(gelu, -2, 2, height=7) @@ -335,20 +335,7 @@ julia> lineplot!(ans, swish) ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ -gelu(x) = x/2*(1 + erf(x/sqrt(oftf(x,2)))) - -function deriv_gelu(x) - SQRT2 = sqrt(oftf(x,2)) - Φ = (1 + erf(x/SQRT2))/2 - Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π)) -end - -""" - gelu_fast(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) - -Fast approximation of [`gelu`](@ref) activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). -""" -function gelu_fast(x) +function gelu(x) α = oftf(x, 0.044715) # λ = oftf(x, gelu_λ) # x/2 * (1 + tanh(λ * (x + α * x^3))) # Standard implementation, for reference @@ -359,7 +346,7 @@ end const gelu_λ = √(2 / π) const gelu_2λ = √(8 / π) -function deriv_gelu_fast(x) +function deriv_gelu(x) α = oftf(x, 0.044715) α2 = oftf(x, 0.08943) λλ = oftf(x, gelu_2λ) @@ -370,6 +357,19 @@ function deriv_gelu_fast(x) muladd(dσ * λλ * muladd(x2, α2, t), x, Ω) end +""" + gelu_full(x) = xΦ(x) = 0.5x * (1 + erf(x/√2)) + +Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) without approximation. +""" +gelu_full(x) = x/2*(1 + erf(x/sqrt(oftf(x,2)))) + +function deriv_gelu_full(x) + SQRT2 = sqrt(oftf(x,2)) + Φ = (1 + erf(x/SQRT2))/2 + Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π)) +end + """ swish(x) = x * σ(x) @@ -888,7 +888,7 @@ UNARY_ACTS = [ # f, dfdx # rrelu is random, can't write a rule. (:elu, :(deriv_elu(Ω))), (:gelu, :(deriv_gelu(x))), - (:gelu_fast, :(deriv_gelu_fast(x))), + (:gelu_full, :(deriv_gelu_full(x))), (:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))), (:hardswish, :(deriv_hardswish(x))), # lisht diff --git a/test/activations.jl b/test/activations.jl index 9b3a5891..2d1f33ef 100644 --- a/test/activations.jl +++ b/test/activations.jl @@ -12,7 +12,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test rrelu(0.0) == 0.0 @test elu(0.0) == 0.0 @test gelu(0.0) == 0.0 -@test gelu_fast(0.0) == 0.0 +@test gelu_full(0.0) == 0.0 @test swish(0.0) == 0.0 @test hardswish(0.0) == 0.0 @test lisht(0.0) == 0.0 @@ -36,8 +36,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test relu6(1.0) == 1.0 @test rrelu(1.0) == 1.0 @test elu(1.0) == 1.0 -@test gelu(1.0) == 0.8413447460685429 -@test gelu_fast(1.0) == 0.8411919906082768 +@test gelu(1.0) == 0.8411919906082768 +@test gelu_full(1.0) == 0.8413447460685429 @test swish(1.0) == sigmoid(1.0) @test hardswish(1.0) == hardsigmoid(1.0) @test lisht(1.0) ≈ 1.0 * tanh(1.0) @@ -59,8 +59,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test relu6(-1.0) == 0.0 @test -1/3.0 <= rrelu(-1.0) <= -1/8.0 @test elu(-1.0) == exp(-1.0) - 1.0 -@test gelu(-1.0) == -0.15865525393145707 -@test gelu_fast(-1.0) ≈ -0.15880800939172324 +@test gelu(-1.0) ≈ -0.15880800939172324 +@test gelu_full(-1.0) == -0.15865525393145707 @test swish(-1.0) == -sigmoid(-1.0) @test hardswish(-1.0) == -hardsigmoid(-1.0) @test lisht(-1.0) ≈ -1.0 * tanh(-1.0) @@ -117,7 +117,7 @@ end a == softsign && continue @test !isnan(a(Inf32)) - a in [gelu, gelu_fast, swish, hardswish, logcosh, mish] && continue + a in [gelu, gelu_full, swish, hardswish, logcosh, mish] && continue @test !isnan(a(-Inf32)) end end From 44ee4f545a708fd433144f1f087efb704445a743 Mon Sep 17 00:00:00 2001 From: se-schmitt Date: Thu, 13 Feb 2025 11:10:06 +0100 Subject: [PATCH 5/5] Rename gelus and use OpenLibm_jll instead of SpecialFunctions.jl for erf --- Project.toml | 4 ++-- docs/src/reference.md | 3 ++- src/NNlib.jl | 2 +- src/activations.jl | 50 ++++++++++++++++++++++++++++++------------- test/activations.jl | 11 ++++++---- 5 files changed, 47 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index f07b9dfb..415a75fd 100644 --- a/Project.toml +++ b/Project.toml @@ -9,8 +9,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] @@ -41,8 +41,8 @@ ForwardDiff = "0.10.36" GPUArraysCore = "0.1, 0.2" KernelAbstractions = "0.9.2" LinearAlgebra = "<0.0.1, 1" +OpenLibm_jll = "0.8.1" Random = "<0.0.1, 1" -SpecialFunctions = "2.5.0" Statistics = "1" cuDNN = "1" julia = "1.9" diff --git a/docs/src/reference.md b/docs/src/reference.md index 88375047..1b1f7827 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -10,7 +10,8 @@ Non-linearities that go between layers of your model. Note that, unless otherwis celu elu gelu -gelu_full +gelu_tanh +gelu_erf hardsigmoid sigmoid_fast hardtanh diff --git a/src/NNlib.jl b/src/NNlib.jl index 493eee9d..17e9e9dc 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -15,7 +15,7 @@ using LinearAlgebra: AdjOrTransAbsMat, Adjoint, BlasFloat, Transpose using Random using Statistics using Statistics: mean -using SpecialFunctions +using OpenLibm_jll const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} diff --git a/src/activations.jl b/src/activations.jl index b1981ed3..56adf1d4 100644 --- a/src/activations.jl +++ b/src/activations.jl @@ -5,7 +5,7 @@ ACTIVATIONS = [ :σ, :hardσ, :hardtanh, :relu, - :leakyrelu, :relu6, :rrelu, :elu, :gelu, :gelu_full, :swish, :hardswish, :selu, + :leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh, :gelu_erf, :swish, :hardswish, :selu, :celu, :softplus, :softsign, :logσ, :logcosh, :mish, :tanhshrink, :softshrink, :trelu, :lisht, :tanh_fast, :sigmoid_fast, @@ -301,14 +301,14 @@ elu(x, α=1) = ifelse(x ≥ 0, float(x), @fastmath oftf(x, α) * (exp(x) - 1)) deriv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + oftype(Ω, α)) """ - gelu(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) + gelu_tanh(x) = 0.5x * (1 + tanh(√(2/π) * (x + 0.044715x^3))) -Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) (see also [`gelu_full`](@ref)). +Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) using tanh approximation. ```julia-repl -julia> lineplot(gelu, -2, 2, height=7) +julia> lineplot(gelu_tanh, -2, 2, height=7) ┌────────────────────────────────────────┐ - 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu(x) + 2 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⡠⠔⠊│ gelu_tanh(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⠔⠊⠁⠀⠀⠀│ │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⠀⠀⠀⠀⠀⣀⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⠀⠀⠀⣀⡠⠤⠒⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ @@ -319,11 +319,11 @@ julia> lineplot(gelu, -2, 2, height=7) ⠀-2⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀2⠀ ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ -julia> lineplot(gelu, -5, 0, height=7); +julia> lineplot(gelu_tanh, -5, 0, height=7); julia> lineplot!(ans, swish) ┌────────────────────────────────────────┐ - 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu(x) + 0 │⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠒⠒⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸│ gelu_tanh(x) │⠑⠒⠢⠤⣄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠓⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇│ swish(x) │⠀⠀⠀⠀⠀⠈⠉⠒⠤⣀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢆⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠁│ f(x) │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠉⠒⢄⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠑⢄⠀⠀⠀⠀⠀⠀⠀⠀⢠⡇⠀│ @@ -335,7 +335,7 @@ julia> lineplot!(ans, swish) ⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀x⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀ ``` """ -function gelu(x) +function gelu_tanh(x) α = oftf(x, 0.044715) # λ = oftf(x, gelu_λ) # x/2 * (1 + tanh(λ * (x + α * x^3))) # Standard implementation, for reference @@ -346,7 +346,7 @@ end const gelu_λ = √(2 / π) const gelu_2λ = √(8 / π) -function deriv_gelu(x) +function deriv_gelu_tanh(x) α = oftf(x, 0.044715) α2 = oftf(x, 0.08943) λλ = oftf(x, gelu_2λ) @@ -358,18 +358,38 @@ function deriv_gelu(x) end """ - gelu_full(x) = xΦ(x) = 0.5x * (1 + erf(x/√2)) + gelu_erf(x) = xΦ(x) = 0.5x * (1 + erf(x/√2)) Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) without approximation. """ -gelu_full(x) = x/2*(1 + erf(x/sqrt(oftf(x,2)))) +gelu_erf(x) = x/2*(1 + _erf(x/sqrt(oftf(x,2)))) -function deriv_gelu_full(x) +function deriv_gelu_erf(x) SQRT2 = sqrt(oftf(x,2)) - Φ = (1 + erf(x/SQRT2))/2 + Φ = (1 + _erf(x/SQRT2))/2 Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π)) end +_erf(x::Number) = _erf(float(x)) +_erf(x::Float64) = ccall((:erf, libopenlibm), Float64, (Float64,), x) +_erf(x::Float32) = ccall((:erff, libopenlibm), Float32, (Float32,), x) +_erf(x::Float16) = Float16(_erf(Float32(x))) +_erf(x::BigFloat) = begin + z = BigFloat(x) + ccall((:mpfr_erf, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Int32), z, x, Base.MPFR.ROUNDING_MODE[]) + return z +end + +""" + gelu(x) = gelu_tanh(x) + +Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415). +See [`gelu_tanh`](@ref). +""" +const gelu = gelu_tanh +export gelu +const deriv_gelu = deriv_gelu_tanh + """ swish(x) = x * σ(x) @@ -887,8 +907,8 @@ UNARY_ACTS = [ # f, dfdx (:relu6, :((Ω>0) & (Ω<6))), # rrelu is random, can't write a rule. (:elu, :(deriv_elu(Ω))), - (:gelu, :(deriv_gelu(x))), - (:gelu_full, :(deriv_gelu_full(x))), + (:gelu_tanh, :(deriv_gelu_tanh(x))), + (:gelu_erf, :(deriv_gelu_erf(x))), (:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))), (:hardswish, :(deriv_hardswish(x))), # lisht diff --git a/test/activations.jl b/test/activations.jl index 2d1f33ef..0bb5047c 100644 --- a/test/activations.jl +++ b/test/activations.jl @@ -12,7 +12,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test rrelu(0.0) == 0.0 @test elu(0.0) == 0.0 @test gelu(0.0) == 0.0 -@test gelu_full(0.0) == 0.0 +@test gelu_tanh(0.0) == 0.0 +@test gelu_erf(0.0) == 0.0 @test swish(0.0) == 0.0 @test hardswish(0.0) == 0.0 @test lisht(0.0) == 0.0 @@ -37,7 +38,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test rrelu(1.0) == 1.0 @test elu(1.0) == 1.0 @test gelu(1.0) == 0.8411919906082768 -@test gelu_full(1.0) == 0.8413447460685429 +@test gelu_tanh(1.0) == 0.8411919906082768 +@test gelu_erf(1.0) == 0.8413447460685429 @test swish(1.0) == sigmoid(1.0) @test hardswish(1.0) == hardsigmoid(1.0) @test lisht(1.0) ≈ 1.0 * tanh(1.0) @@ -60,7 +62,8 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI @test -1/3.0 <= rrelu(-1.0) <= -1/8.0 @test elu(-1.0) == exp(-1.0) - 1.0 @test gelu(-1.0) ≈ -0.15880800939172324 -@test gelu_full(-1.0) == -0.15865525393145707 +@test gelu_tanh(-1.0) ≈ -0.15880800939172324 +@test gelu_erf(-1.0) == -0.15865525393145707 @test swish(-1.0) == -sigmoid(-1.0) @test hardswish(-1.0) == -hardsigmoid(-1.0) @test lisht(-1.0) ≈ -1.0 * tanh(-1.0) @@ -117,7 +120,7 @@ end a == softsign && continue @test !isnan(a(Inf32)) - a in [gelu, gelu_full, swish, hardswish, logcosh, mish] && continue + a in [gelu, gelu_tanh, gelu_erf, swish, hardswish, logcosh, mish] && continue @test !isnan(a(-Inf32)) end end