-
-
Notifications
You must be signed in to change notification settings - Fork 68
GPU acceleration of the reproject package #489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I don't know anything about using GPUs (and I wish I did), but I'm very interested in this, as are a few of my colleagues. I'd be really interested in seeing how you're doing this, and I'm happy to help in whatever (perhaps limited) ways I can! |
So I rewrote the pixel_to_pixel algorithm using cupy and got a 33% reduction in time. I also implemented it in torch and got a 45% reduction. This was the most time-consuming function that didn't involve changing WCS. The next step for me is to try and tackle some of the WCS computations using torch. Here is the implementation in torch ef pixel_to_pixel_gpu(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs):
"""
GPU version: Transform pixel coordinates using PyTorch, optimized to reduce transfer overhead.
"""
# Automatically select device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if np.isscalar(inputs[0]):
world_outputs = wcs_in.pixel_to_world(*inputs)
if not isinstance(world_outputs, (tuple, list)):
world_outputs = (world_outputs,)
return wcs_out.world_to_pixel(*world_outputs)
original_shape = inputs[0].shape
outputs = [None] * wcs_out.pixel_n_dim
# Ensure inputs are torch tensors and move to selected device
pixel_inputs = [
torch.tensor(arr).to(device) if not isinstance(arr, torch.Tensor) else arr.to(device) for arr in inputs
]
pixel_inputs = torch.broadcast_tensors(*pixel_inputs)
# Compute world outputs on the CPU using the WCS functions
world_outputs_cpu = wcs_in.pixel_to_world(*[arr.cpu() for arr in pixel_inputs])
if not isinstance(world_outputs_cpu, (tuple, list)):
world_outputs_cpu = (world_outputs_cpu,)
pixel_outputs_cpu = wcs_out.world_to_pixel(*world_outputs_cpu)
if wcs_out.pixel_n_dim == 1:
pixel_outputs_cpu = (pixel_outputs_cpu,)
for i in range(wcs_out.pixel_n_dim):
outputs[i] = pixel_outputs_cpu[i]
# Convert torch tensors back to NumPy arrays
outputs = [output for output in outputs]
return outputs[0] if wcs_out.pixel_n_dim == 1 else outputs |
Thanks for posting your code! I tried out your function on my computer and I'm afraid I'm not seeing any speedup. Here's how I explored this, on a computer with an RTX 3070. The GPU version was faster for small inputs of just a thousand coordinates, but the difference could come down to just having fewer lines of Python code---I saw a similar speedup from a stripped-down CPU-only version. For a larger case of 4k x 4k input coordinates, where any speedup could be really valuable, I have the GPU version taking longer (I'm guessing from the overhead of moving data to the GPU and back). Does it work differently on your computer? Or do you have different inputs that show a difference? I don't know anything about pyTorch, but it looks like your function only uses the GPU to broadcast the input arrays, and then still uses the CPU for the coordinate conversions ( Let me know if I'm missing anything! |
Hello, You are absolutely correct in your assessment of what pyTorch is doing! My goal with this little test was to begin looking at where functions could be rewritten to speed things up. I also profiled the functionality and found that the |
I'm also getting a big speedup with your code:
I tried filling in the details of the WCSes, to make sure we're getting a representative workload for the coordinate conversions: wcs_in = WCS(naxis=2)
wcs_out = WCS(naxis=2)
wcs_in.wcs.crpix = 500, 500
wcs_in.wcs.crval = 0, 10
wcs_in.wcs.ctype = 'RA---CAR', 'DEC--CAR'
wcs_in.wcs.cdelt = 0.05, 0.05
wcs_out.wcs.crpix = 500, 500
wcs_out.wcs.crval = 30, 12
wcs_out.wcs.ctype = 'RA---AZP', 'DEC--AZP'
wcs_out.wcs.cdelt = 0.05, 0.05 That didn't really change the speedup factor though. I explored some and found that the execution time drops a lot if I use
I think the problem is that in the cut-down CPU version in your file, this loop for i in range(wcs_out.pixel_n_dim):
pixel_inputs = np.broadcast_arrays(*inputs)
world_outputs = wcs_in.pixel_to_world(*pixel_inputs)
... has it run the whole thing twice. I think that came from the astropy version, where if the coordinates are independent (e.g. longitude depends only on x and latitude depends only on y), it tries to transform the x and y coordinates separately without broadcasting them together (and unbroadcasting them if necessary!), potentially saving a lot of compute time. The cut down version removes that independence check but keeps a modified loop, creating this bug that expands the execution time. I really hope you're successful getting |
Thanks for explaining that! Instead of trying to wrangle astropy (which is obviously an amazing package), I wrote a standalone package written using torch that will compute reprojections. If it is acceptable, I'll post the link to the GitHub when it becomes public. |
If anyone has tested
|
Hello, |
I also tried
|
array_utils.txt |
This is a copy of a question I posed on the slack channel opened up to all users
Hello! I am working on speeding up the reproject package using the GPU. I've already updated the pixel to pixel functionality for a 30% reduction in computation time for this algorithm. I'm going to be working on updating the other functions in the package to run on the GPU. I wanted to know if anyone has already done this or if someone is currently working on something similar. According to the Roadmap (or at least this is how I understood it), there is a need for someone to work on this type of implementation. Would it be possible to talk to anyone on the dev team about this? I'm going to be working on this acceleration either way, so I'd like to be able to contribute if the community thinks it would be helpful
The text was updated successfully, but these errors were encountered: