Skip to content

GPU acceleration of the reproject package #489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
crhea93 opened this issue Feb 3, 2025 · 10 comments
Open

GPU acceleration of the reproject package #489

crhea93 opened this issue Feb 3, 2025 · 10 comments

Comments

@crhea93
Copy link

crhea93 commented Feb 3, 2025

This is a copy of a question I posed on the slack channel opened up to all users

Hello! I am working on speeding up the reproject package using the GPU. I've already updated the pixel to pixel functionality for a 30% reduction in computation time for this algorithm. I'm going to be working on updating the other functions in the package to run on the GPU. I wanted to know if anyone has already done this or if someone is currently working on something similar. According to the Roadmap (or at least this is how I understood it), there is a need for someone to work on this type of implementation. Would it be possible to talk to anyone on the dev team about this? I'm going to be working on this acceleration either way, so I'd like to be able to contribute if the community thinks it would be helpful

@svank
Copy link
Contributor

svank commented Feb 7, 2025

I don't know anything about using GPUs (and I wish I did), but I'm very interested in this, as are a few of my colleagues. I'd be really interested in seeing how you're doing this, and I'm happy to help in whatever (perhaps limited) ways I can!

@crhea93
Copy link
Author

crhea93 commented Feb 10, 2025

So I rewrote the pixel_to_pixel algorithm using cupy and got a 33% reduction in time. I also implemented it in torch and got a 45% reduction. This was the most time-consuming function that didn't involve changing WCS. The next step for me is to try and tackle some of the WCS computations using torch.

Here is the implementation in torch

ef pixel_to_pixel_gpu(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs):
    """
    GPU version: Transform pixel coordinates using PyTorch, optimized to reduce transfer overhead.
    """
    # Automatically select device (GPU if available, otherwise CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if np.isscalar(inputs[0]):
        world_outputs = wcs_in.pixel_to_world(*inputs)
        if not isinstance(world_outputs, (tuple, list)):
            world_outputs = (world_outputs,)
        return wcs_out.world_to_pixel(*world_outputs)

    original_shape = inputs[0].shape
    outputs = [None] * wcs_out.pixel_n_dim

    # Ensure inputs are torch tensors and move to selected device
    pixel_inputs = [
        torch.tensor(arr).to(device) if not isinstance(arr, torch.Tensor) else arr.to(device) for arr in inputs
    ]
    pixel_inputs = torch.broadcast_tensors(*pixel_inputs)

    # Compute world outputs on the CPU using the WCS functions
    world_outputs_cpu = wcs_in.pixel_to_world(*[arr.cpu() for arr in pixel_inputs])

    if not isinstance(world_outputs_cpu, (tuple, list)):
        world_outputs_cpu = (world_outputs_cpu,)

    pixel_outputs_cpu = wcs_out.world_to_pixel(*world_outputs_cpu)

    if wcs_out.pixel_n_dim == 1:
        pixel_outputs_cpu = (pixel_outputs_cpu,)

    for i in range(wcs_out.pixel_n_dim):
        outputs[i] = pixel_outputs_cpu[i]

    # Convert torch tensors back to NumPy arrays
    outputs = [output for output in outputs]

    return outputs[0] if wcs_out.pixel_n_dim == 1 else outputs

@svank
Copy link
Contributor

svank commented Feb 19, 2025

Thanks for posting your code! I tried out your function on my computer and I'm afraid I'm not seeing any speedup. Here's how I explored this, on a computer with an RTX 3070. The GPU version was faster for small inputs of just a thousand coordinates, but the difference could come down to just having fewer lines of Python code---I saw a similar speedup from a stripped-down CPU-only version. For a larger case of 4k x 4k input coordinates, where any speedup could be really valuable, I have the GPU version taking longer (I'm guessing from the overhead of moving data to the GPU and back). Does it work differently on your computer? Or do you have different inputs that show a difference?

I don't know anything about pyTorch, but it looks like your function only uses the GPU to broadcast the input arrays, and then still uses the CPU for the coordinate conversions (pixel_to_world and world_to_pixel)---is that right? When I profile pixel_to_pixel (at the bottom of that link), basically all the compute time is spent in those two function, so I suspect big speedups will only come from accelerating those functions (which I'm guessing would be very involved).

Let me know if I'm missing anything!

@crhea93
Copy link
Author

crhea93 commented Feb 19, 2025

Hello,
Thanks for posting this! I reran your code on my machine (RTX 4060), and, unfortunately, I get the same results as you. I'm including my test that showed me the non-negligible (45%) speedup for using the GPU on 4000x6000 images. I'll need to figure out why it is doing so much more poorly on the test you provided. Perhaps it is because of my inputs...

You are absolutely correct in your assessment of what pyTorch is doing! My goal with this little test was to begin looking at where functions could be rewritten to speed things up. I also profiled the functionality and found that the pixel_to_world and world_to_pixel functions are the main bottlenecks. I've started to, slowly, work on updating these to run using pyTorch since we lose a lot of potential speedup having to recast torch objects to numpy objects for these calculations.

@svank
Copy link
Contributor

svank commented Feb 20, 2025

I'm also getting a big speedup with your code:

Running CPU version...
CPU execution time: 4.105450 seconds
Running GPU version...
GPU execution time: 2.498859 seconds
GPU speedup: 39.13%
Results match within tolerance.

I tried filling in the details of the WCSes, to make sure we're getting a representative workload for the coordinate conversions:

wcs_in = WCS(naxis=2)
wcs_out = WCS(naxis=2)
wcs_in.wcs.crpix = 500, 500
wcs_in.wcs.crval = 0, 10
wcs_in.wcs.ctype = 'RA---CAR', 'DEC--CAR'
wcs_in.wcs.cdelt = 0.05, 0.05
wcs_out.wcs.crpix = 500, 500
wcs_out.wcs.crval = 30, 12
wcs_out.wcs.ctype = 'RA---AZP', 'DEC--AZP'
wcs_out.wcs.cdelt = 0.05, 0.05

That didn't really change the speedup factor though. I explored some and found that the execution time drops a lot if I use astropy.wcs.utils.pixel_to_pixel for the CPU version:

Running CPU version...
CPU execution time: 19.076000 seconds
Running astropy version...
astropy execution time: 9.292996 seconds
Running GPU version...
GPU execution time: 9.689164 seconds
GPU speedup: 49.21%
Results match within tolerance.

I think the problem is that in the cut-down CPU version in your file, this loop

    for i in range(wcs_out.pixel_n_dim):
        pixel_inputs = np.broadcast_arrays(*inputs)
        world_outputs = wcs_in.pixel_to_world(*pixel_inputs)
        ...

has it run the whole thing twice. I think that came from the astropy version, where if the coordinates are independent (e.g. longitude depends only on x and latitude depends only on y), it tries to transform the x and y coordinates separately without broadcasting them together (and unbroadcasting them if necessary!), potentially saving a lot of compute time. The cut down version removes that independence check but keeps a modified loop, creating this bug that expands the execution time.

I really hope you're successful getting pixel_to_world and world_to_pixel to run on the GPU! The pipeline for the PUNCH mission reprojects a lot of large files, so we could save a lot of runtime with this sort of optimization!

@crhea93
Copy link
Author

crhea93 commented Feb 20, 2025

Thanks for explaining that!

Instead of trying to wrangle astropy (which is obviously an amazing package), I wrote a standalone package written using torch that will compute reprojections. If it is acceptable, I'll post the link to the GitHub when it becomes public.

@Eririf
Copy link

Eririf commented Mar 19, 2025

If anyone has tested cupyx.scipy.ndimage.map_coordinates?
scipy.ndimage.map_coordinates seems to be much slower than cv2.remap in fits larger than 8k×8k
maybe gpu acceleration is prefered here, but writing cp.array from gpu back to cpu costs much time

#cv2
import cv2
import numpy as np
src = np.random.rand(8000, 8000).astype(np.float32)
x, y = np.meshgrid(np.arange(2000), np.arange(2000))
map_x = (x + 0.5).astype(np.float32)  
map_y = (y + 0.5).astype(np.float32)  
result = cv2.remap(src, map_x, map_y, cv2.INTER_LINEAR)

#scipy
import numpy as np
from scipy.ndimage import map_coordinates
src = np.random.rand(8000, 8000)
coords = np.indices((8000, 8000)).astype(float)
coords += 0.5  
result = map_coordinates(src, coords, order=1, mode='constant', cval=0)

#cuda+scipy
import cupy as cp
from cupyx.scipy.ndimage import map_coordinates
src_gpu = cp.random.rand(8000, 8000).astype(cp.float32)
coords = cp.indices((8000, 8000)).astype(cp.float32)
coords += 0.5  # 示例偏移
result_gpu = map_coordinates(src_gpu, coords, order=1, mode='constant')

@crhea93
Copy link
Author

crhea93 commented Mar 19, 2025

Hello,
I haven't tried that function, but I was able to get an order of magnitude speed up over map_coordinates by using torch.

@Eririf
Copy link

Eririf commented Mar 20, 2025

I also tried pytorch but not using map_coordinates. I used F.grid_sample , found it a bit slower than cupyx.
Well both tensor and cupy.array need to send back to CPU from GPU, which cost some time.

import torch
import torch.nn.functional as F

src_tensor = torch.rand(1, 1, 8000, 8000).cuda()

grid = torch.meshgrid(
    torch.linspace(-1, 1, 8000),
    torch.linspace(-1, 1, 8000)
)
grid = torch.stack(grid, dim=-1).unsqueeze(0).cuda() + 0.5/2000 

result_tensor = F.grid_sample(
    src_tensor, 
    grid, 
    mode='bilinear', 
    padding_mode='zeros', 
    align_corners=False
)

@Eririf
Copy link

Eririf commented Mar 24, 2025

array_utils.txt
common.txt
core.txt
I managed to write the cupy gpu 'acceleration' version of function reproject_interp above. But the speed still needs to carefully modified... Currently 8s for a 8K×8K image when parallel=8, order = 1, block_size = (2048,2048). Sadly my 12G GPU memory limits the maximum core number here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants