You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am getting the following error somewhere deep in a fairly large jit on TPU:
jaxlib._jax.XlaRuntimeError: UNIMPLEMENTED: While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %all-reduce.0 = c64[16,...]{0,...} all-reduce(%transpose.18), channel_id=1, replica_groups={...}, use_global_device_ids=true, to_apply=%region_21.4425, metadata={...}
I am unfortunately not able to create a shareable minimal example - the compiler gets rid of the all-reduce over c64 in all of my smaller attempts, and spilts it into all-reduce over the equivalent fp32. I am not certain if the issue at hand is that it should have also dealt with the c64 in my larger jit but doesn't, or if the c64 is getting misclassified and shouldn't be targeted by the X64 rewriting in the first place because it actually is 2 fp32, or if the rewriting needs to be expanded to support all-reduce.
The text was updated successfully, but these errors were encountered:
Same setup appears to work on GPU, so it appears to be a TPU specific issue. Happy to give extracting a minimal example another shot if I had a bit of guidance on where to look: is this most likely an issue of 1. the rewriting somehow failing to detect and deal with the rewrite of my larger jit, meaning the size and complexity being the issue or 2. the c64 being targeted by rewriting by mistake or 3. re-writing lacking c64 all-reduce support?
If it's 2. a temporary bandaid could be to disable the specific rewriting pass with --xla_disable_hlo_passes. Since libtpu is closed it is not particularly obvious to me how to identify its name - do you happen to have any advice on how to identify it or what the name could be?
I am getting the following error somewhere deep in a fairly large jit on TPU:
I am unfortunately not able to create a shareable minimal example - the compiler gets rid of the all-reduce over c64 in all of my smaller attempts, and spilts it into all-reduce over the equivalent fp32. I am not certain if the issue at hand is that it should have also dealt with the c64 in my larger jit but doesn't, or if the c64 is getting misclassified and shouldn't be targeted by the X64 rewriting in the first place because it actually is 2 fp32, or if the rewriting needs to be expanded to support all-reduce.
The text was updated successfully, but these errors were encountered: