-
Notifications
You must be signed in to change notification settings - Fork 564
XLA duplicates input arguments memory because of wrapped_transpose #26987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
This may be somewhat related to |
I cannot come up with a simple way to fix it (rather than pre-transpose weights), but here's what happens: Some ops, like dynamic slices, require tensor to be transposed a certain way, for example this one:
It's located inside a loop (actually, two nested loops, right?), so optimizer decides to hoist that transposition out of the loop. As a result, the transposed copy of weights have to be alive through the entire duration of the loop. And that's for all weights. One thing that sometimes helps is AUTO layout (e.g. https://github.com/jax-ml/jax/blob/main/tests/layout_test.py#L36) which would do transposes on the host, but I'm not sure how easy it is to make it work. |
Yes, this is indeed a double while loop (gradient accumulation and scan over layers). |
Here's also an HLO with --xla_dump_hlo_pass_re=.* |
In the "with_re" archive, I don't see any wrapped transposes, is it the same model (also the file itself is smaller)? Could you try Other options to try meanwhile are |
Yeah, the issue with previous HLO was in compilation cache. Here's one with it. Couldn't attach it to github because of file size restrictions. |
I've tried every flag but neither of them had any effect on memory usage
|
This appears to be caused by layout assignment, which introduces a transposing copy in the top level. Not sure yet why that's happening, we're having some trouble with VLOGing. |
Ah yeah, that makes sense, we've seen that quite a few times. Layout assignment goes from inner computations to outer, and then in the entry computation when there's a clash between parameter layout and nested computation parameter layout, there's a transpose. There's no easy way to fix it correctly, as all computations are processed in isolation (and sometimes it's the model parameter that can "adsorb" the transpose by being auto layout, sometimes it's an insturction in the nested computation that doesn't care), but I'll think about workarounds. For the OOM issue in particular, I believe we should have a pass quite late in the pipeline that would sink operations into the loop body to reduce the lifetime of buffers (certainly that doesn't go as a quick workaround). |
AFAICT, the reason are all these all-gathers with dimension=1, like this:
Because the all-gather dimension has to be major, gpu_layout_assignment thinks that they need to get a layout of @mooskagh I see a couple of ways to fix this, e.g. by getting rid of these degenerate dims or by attempting to make gpu_layout_assignment smarter. The former seems easier to me, WDYT? |
4c817c5 This gets rid of most of the copies, a handful are still there. Looking into the remaining ones now. |
…hers. Imported from GitHub PR #27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See #26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c5 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042a by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb4 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=#27096 from jreiffers:all-gather-layout 6c5dbb4 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR openxla/xla#27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See openxla/xla#26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c55329fc4139451db14aea2d66bab737496 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042ab9f90bdb0d85f3f9207d9b35b8e9a1c6b by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e21233cfc6771044cdebb5b45fad747ba0 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#27096 from jreiffers:all-gather-layout 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR #27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See #26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c5 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042a by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb4 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=#27096 from jreiffers:all-gather-layout 6c5dbb4 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR openxla/xla#27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See openxla/xla#26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c55329fc4139451db14aea2d66bab737496 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042ab9f90bdb0d85f3f9207d9b35b8e9a1c6b by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e21233cfc6771044cdebb5b45fad747ba0 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#27096 from jreiffers:all-gather-layout 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR #27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See #26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c5 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042a by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb4 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=#27096 from jreiffers:all-gather-layout 6c5dbb4 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR openxla/xla#27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See openxla/xla#26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c55329fc4139451db14aea2d66bab737496 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042ab9f90bdb0d85f3f9207d9b35b8e9a1c6b by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e21233cfc6771044cdebb5b45fad747ba0 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#27096 from jreiffers:all-gather-layout 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR #27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See #26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c5 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042a by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb4 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=#27096 from jreiffers:all-gather-layout 6c5dbb4 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR openxla/xla#27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See openxla/xla#26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c55329fc4139451db14aea2d66bab737496 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042ab9f90bdb0d85f3f9207d9b35b8e9a1c6b by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e21233cfc6771044cdebb5b45fad747ba0 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#27096 from jreiffers:all-gather-layout 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR #27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See #26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c5 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042a by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb4 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=#27096 from jreiffers:all-gather-layout 6c5dbb4 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR #27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See #26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c5 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042a by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb4 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=#27096 from jreiffers:all-gather-layout 6c5dbb4 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR openxla/xla#27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See openxla/xla#26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c55329fc4139451db14aea2d66bab737496 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042ab9f90bdb0d85f3f9207d9b35b8e9a1c6b by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e21233cfc6771044cdebb5b45fad747ba0 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#27096 from jreiffers:all-gather-layout 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR #27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See #26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c5 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042a by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb4 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=#27096 from jreiffers:all-gather-layout 6c5dbb4 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR openxla/xla#27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See openxla/xla#26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c55329fc4139451db14aea2d66bab737496 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042ab9f90bdb0d85f3f9207d9b35b8e9a1c6b by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e21233cfc6771044cdebb5b45fad747ba0 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#27096 from jreiffers:all-gather-layout 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR #27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See #26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c5 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042a by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb4 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=#27096 from jreiffers:all-gather-layout 6c5dbb4 PiperOrigin-RevId: 764137188
…hers. Imported from GitHub PR openxla/xla#27096 `GpuLayoutAssignment` requires all-gathers to gather on the major dimension (i.e., the gathered tensor is contiguous). When we have a pattern like `all-gather(dynamic-slice)` and the slice's major dimension is degenerate, the all-gather gets a funky non-standard layout, which is then propagated through the slice, ultimately all the way to the parameter, where we eventually introduce a transposing copy. See openxla/xla#26987 for details. This should fix about 2/3 of that bug. The remaining transposing copies are caused by all-gathers that actually need to be transposes, but the fix for those is more complicated. Copybara import of the project: -- 4c817c55329fc4139451db14aea2d66bab737496 by Johannes Reifferscheid <jreiffers@nvidia.com>: [WIP] Add a pass that removes degenerate dimensions from all-gathers. Not really tested yet. -- a3f042ab9f90bdb0d85f3f9207d9b35b8e9a1c6b by Johannes Reifferscheid <jreiffers@nvidia.com>: Add tests. -- 66a336e21233cfc6771044cdebb5b45fad747ba0 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix unused status. -- 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 by Johannes Reifferscheid <jreiffers@nvidia.com>: Fix missing newline at end of file. Merging this change closes #27096 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#27096 from jreiffers:all-gather-layout 6c5dbb46d660cf24723c422cdaecb8dbb774d5e5 PiperOrigin-RevId: 764137188
Uh oh!
There was an error while loading. Please reload this page.
Hi, I have simple LLM training step that computes forward+backward pass and updates weights. It is LoRA training, so only very small subset of parameters is trainable (70B total parameters, ~600M trainable). I've checked memory analysis and found that actual memory consumption is way bigger than expected and figured out that this is because XLA is doing some very weird transposition of input weights, resulting in effectively doubling of memory usage for 70B model parameters.
Can I somehow disabled this optimization or overcome this issue in some other way?
I've tested it on both JAX 0.4.34 and latest version (JAX 0.6.1)
Here's attached HLO dump:
hlo.tgz
Some of the transposed buffer names:
The text was updated successfully, but these errors were encountered: