Skip to content

Restructure makes a copy #146

Open
@linusheck

Description

@linusheck

In some situations, you have to restructure a lot if you use Flux, for instance if you want to run your batches as seperate solves in DiffEqFlux using an EnsembleProblem. You have to use something like a ComponentArray to pass the parameters through the solver and to let the adjoint methods do their work in differentiating the solve. But restructuring using a ComponentArray is unreasonably(?) slow in Flux. Switching to Lux eliminates those problems, but it seems like something that could be implemented better in Flux or ComponentArrays.

Example code:

using Flux, Profile, PProf, Random, ComponentArrays
layer_size = 256
layer = Flux.Dense(layer_size => layer_size)
params_, re = Flux.destructure(layer)
params = ComponentArray(params_)
function eval(steps, input)
	vec = input
	for i in 1:steps
		vec = re(params)(vec)
	end
end
Profile.clear()
@profile eval(100000, rand(Float32, layer_size))
pprof(; web=true)

Here, we spend 10% of the time in sgemv matrix multiplication, another 10% in the rest of the Dense call and about 75% in the Restructure. This gets worse if the networks are smaller. As far as I can read the flame graph, the restructure seems to spend a lot of time in the GC:

0d19ec3a732aaf409e85b06ebb4ebe6392198e20_2_1044x1000

Could there be a way to mitigate this specific problem? In particular if you use the same parameters. I think this would make some example code a lot faster too.

I also opened a discourse about this because I'm not sure if it's an issue with Flux specifically: https://discourse.julialang.org/t/flux-restructure-for-componentarrays-jl-unreasonably-slow/97849/3

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions