Description
In some situations, you have to restructure a lot if you use Flux, for instance if you want to run your batches as seperate solves in DiffEqFlux using an EnsembleProblem. You have to use something like a ComponentArray to pass the parameters through the solver and to let the adjoint methods do their work in differentiating the solve. But restructuring using a ComponentArray is unreasonably(?) slow in Flux. Switching to Lux eliminates those problems, but it seems like something that could be implemented better in Flux or ComponentArrays.
Example code:
using Flux, Profile, PProf, Random, ComponentArrays
layer_size = 256
layer = Flux.Dense(layer_size => layer_size)
params_, re = Flux.destructure(layer)
params = ComponentArray(params_)
function eval(steps, input)
vec = input
for i in 1:steps
vec = re(params)(vec)
end
end
Profile.clear()
@profile eval(100000, rand(Float32, layer_size))
pprof(; web=true)
Here, we spend 10% of the time in sgemv matrix multiplication, another 10% in the rest of the Dense call and about 75% in the Restructure. This gets worse if the networks are smaller. As far as I can read the flame graph, the restructure seems to spend a lot of time in the GC:
Could there be a way to mitigate this specific problem? In particular if you use the same parameters. I think this would make some example code a lot faster too.
I also opened a discourse about this because I'm not sure if it's an issue with Flux specifically: https://discourse.julialang.org/t/flux-restructure-for-componentarrays-jl-unreasonably-slow/97849/3