-
-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds full gelu without approximation #629
base: master
Are you sure you want to change the base?
Conversation
https://github.com/FluxML/NNlib.jl/actions/runs/13177183802/job/36779142351?pr=629#step:7:842 is a real test failure. I think Flux's |
Quick look at how different these functions are: julia> using SpecialFunctions, NNlib
julia> oftf(x, y) = oftype(float(x), y);
julia> new_gelu(x) = x/2*(1 + erf(x/sqrt(oftf(x,2))));
julia> rel(x) = (new_gelu(x) - gelu(x)) / new_gelu(x);
julia> rel.(-3:0.2f0:1)
21-element Vector{Float32}:
0.101809666
0.06506126
0.038386323
0.02019246
0.008700018
0.0021531228
-0.0010194147
-0.0021005166
-0.0020575877
-0.001547056
-0.0009626968
-0.00049560843
-0.0002001288
-5.4488235f-5
-6.02081f-6
NaN
4.4374747f-6
2.876006f-5
7.55584f-5
0.00013319723
0.00018150361
julia> rel_eps(x) = (new_gelu(x) - gelu(x)) / eps(new_gelu(x));
julia> Int.(rel_eps.(-3:0.2f0:1))
21-element Vector{Int64}:
-885402
-999594
-499514
-213282
-142868
-26298
8849
24719
31223
14336
10250
5637
2210
504
68
0
69
253
1104
1409
2562 |
I modified it such that |
@ToucheSir Can this be merged? If not, what is required to make it mergeable? |
It's a little sad that NNlib must depend on SpecialFunctions... maybe not so expensive?
Re name bike-shedding, some chance we should use neutral names like |
@mcabbott I modified the code as suggested:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks basically fine to me, thanks.
Maybe avoiding SpecialFunctions is a rabbit-hole, sorry.
One question is: How well do these variants work on the GPU? Presumably ccall((:erff, libopenlibm), Float32, (Float32,), x)
won't work... does SpecialFunctions have code to make erf.(cu(rand(10)))
work by another path?
Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π)) | ||
end | ||
|
||
_erf(x::Number) = _erf(float(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This intends to catch integers but isn't a safe pattern, as there are other weird numbers out there:
StackOverflowError:
Stacktrace:
[1] _erf(x::ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#141#142"{var"#170#178"{typeof(gelu_erf)}}, Float64}, Float64, 1}) (repeats 79984 times)
@ NNlib ~/work/NNlib.jl/NNlib.jl/src/activations.jl:373
This particular case could be allowed via NNlibForwardDiffExt
Yes, CUDA.jl defines its own overloads at https://github.com/JuliaGPU/CUDA.jl/blob/master/ext/SpecialFunctionsExt.jl. If we want to talk load times, Flux has a direct dep on SpecialFunctions already. If import latency is a pressing concern, we could define a stub |
The implementation via OpenLibm_jll was naive, sorry... I added the missing rules for AD locally which made it compatible with ForwardDiff, Zygote and Enzyme. However, compatibility with other AD and the GPU packages would need further modifications. I would prefer an option that includes SpecialFunctions.jl as this seems much cleaner to me (either as direct dependency or extension (is this a problem for Lux.jl where SpecialFun is not a direct dependency?)). What would you prefer? @mcabbott @ToucheSir |
Can you try the stub function + extension approach I suggested above? If that turns out to be a dead end, I'm fine with going back to the original plan and having SpecialFunctions as a direct dep. There's already an outsized chance any user of NNlib will have it in their environment, so we wouldn't lose much by including it. |
Adds the full gelu without approximation as
gelu(x)
and moves the tanh approximation used before togelu_fast
. See #628 for details.PR Checklist