Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds full gelu without approximation #629

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

se-schmitt
Copy link

Adds the full gelu without approximation as gelu(x) and moves the tanh approximation used before to gelu_fast. See #628 for details.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@ToucheSir ToucheSir linked an issue Feb 6, 2025 that may be closed by this pull request
@ToucheSir
Copy link
Member

ToucheSir commented Feb 6, 2025

https://github.com/FluxML/NNlib.jl/actions/runs/13177183802/job/36779142351?pr=629#step:7:842 is a real test failure. I think Flux's Nil + outputsize machinery needs to be adjusted to understand SpecialFunctions.erf. The question is how, so I've opened FluxML/Flux.jl#2588 to track this.

@mcabbott
Copy link
Member

mcabbott commented Feb 7, 2025

Quick look at how different these functions are:

julia> using SpecialFunctions, NNlib

julia> oftf(x, y) = oftype(float(x), y);

julia> new_gelu(x) = x/2*(1 + erf(x/sqrt(oftf(x,2))));

julia> rel(x) = (new_gelu(x) - gelu(x)) / new_gelu(x);

julia> rel.(-3:0.2f0:1)
21-element Vector{Float32}:
   0.101809666
   0.06506126
   0.038386323
   0.02019246
   0.008700018
   0.0021531228
  -0.0010194147
  -0.0021005166
  -0.0020575877
  -0.001547056
  -0.0009626968
  -0.00049560843
  -0.0002001288
  -5.4488235f-5
  -6.02081f-6
 NaN
   4.4374747f-6
   2.876006f-5
   7.55584f-5
   0.00013319723
   0.00018150361

julia> rel_eps(x) = (new_gelu(x) - gelu(x)) / eps(new_gelu(x));

julia> Int.(rel_eps.(-3:0.2f0:1))
21-element Vector{Int64}:
 -885402
 -999594
 -499514
 -213282
 -142868
  -26298
    8849
   24719
   31223
   14336
   10250
    5637
    2210
     504
      68
       0
      69
     253
    1104
    1409
    2562

@se-schmitt
Copy link
Author

I modified it such that gelu remains the same and added the full gelu as gelu_full as discussed in #628 . This avoids breaking changes and the test failure from above, however, gelu_full is still not compatible with Flux' outputsize function.

@se-schmitt
Copy link
Author

@ToucheSir Can this be merged? If not, what is required to make it mergeable?

@mcabbott
Copy link
Member

It's a little sad that NNlib must depend on SpecialFunctions... maybe not so expensive?

julia> @time_imports using SpecialFunctions
      8.7 ms  IrrationalConstants
               ┌ 0.0 ms DocStringExtensions.__init__() 
     46.6 ms  DocStringExtensions 97.36% compilation time
      0.6 ms  LogExpFunctions
               ┌ 2.5 ms OpenLibm_jll.__init__() 
      4.2 ms  OpenLibm_jll
      0.4 ms  JLLWrappers
               ┌ 9.2 ms CompilerSupportLibraries_jll.__init__() 
     11.1 ms  CompilerSupportLibraries_jll
               ┌ 6.0 ms OpenSpecFun_jll.__init__() 93.49% compilation time
      6.5 ms  OpenSpecFun_jll 86.17% compilation time
      3.2 ms  SpecialFunctions

Re name bike-shedding, some chance we should use neutral names like gelu_erf and gelu_tanh, with both names available immediately but const gelu = gelu_tanh for now to be non-breaking. (I do not think either should be called gelu_fast, as the point of tanh_fast is that we sometimes automatically replace tanh with that, but there is no plan to automatically replace one of these with the other.)

@se-schmitt
Copy link
Author

@mcabbott I modified the code as suggested:

  • Instead of the SpecialFuncitons.jl package, only OpenLibm_jll.jl is used now and the erf function is defined via ccall (as in SpecialFunctions.jl).
  • I also renamed the functions to gelu_tanh and gelu_erf with const gelu = gelu_tanh. I made this also transparent in the documentation.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks basically fine to me, thanks.

Maybe avoiding SpecialFunctions is a rabbit-hole, sorry.

One question is: How well do these variants work on the GPU? Presumably ccall((:erff, libopenlibm), Float32, (Float32,), x) won't work... does SpecialFunctions have code to make erf.(cu(rand(10))) work by another path?

Φ + x/SQRT2*exp(-(x^2)/2)/sqrt(oftf(x,π))
end

_erf(x::Number) = _erf(float(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This intends to catch integers but isn't a safe pattern, as there are other weird numbers out there:

StackOverflowError:
  Stacktrace:
   [1] _erf(x::ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#141#142"{var"#170#178"{typeof(gelu_erf)}}, Float64}, Float64, 1}) (repeats 79984 times)
     @ NNlib ~/work/NNlib.jl/NNlib.jl/src/activations.jl:373

This particular case could be allowed via NNlibForwardDiffExt

@ToucheSir
Copy link
Member

ToucheSir commented Feb 14, 2025

does SpecialFunctions have code to make erf.(cu(rand(10))) work by another path?

Yes, CUDA.jl defines its own overloads at https://github.com/JuliaGPU/CUDA.jl/blob/master/ext/SpecialFunctionsExt.jl.

If we want to talk load times, Flux has a direct dep on SpecialFunctions already. If import latency is a pressing concern, we could define a stub function gelu_erf end in NNlib and the method for that function in a SpecialFunctionsExt.

@se-schmitt
Copy link
Author

The implementation via OpenLibm_jll was naive, sorry... I added the missing rules for AD locally which made it compatible with ForwardDiff, Zygote and Enzyme. However, compatibility with other AD and the GPU packages would need further modifications.

I would prefer an option that includes SpecialFunctions.jl as this seems much cleaner to me (either as direct dependency or extension (is this a problem for Lux.jl where SpecialFun is not a direct dependency?)).

What would you prefer? @mcabbott @ToucheSir

@ToucheSir
Copy link
Member

Can you try the stub function + extension approach I suggested above? If that turns out to be a dead end, I'm fine with going back to the original plan and having SpecialFunctions as a direct dep. There's already an outsized chance any user of NNlib will have it in their environment, so we wouldn't lose much by including it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

"Full" gelu without approximation
3 participants