Skip to content

Detection trainer fails with In_channels>3 #2749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
robmarkcole opened this issue Apr 22, 2025 · 3 comments · May be fixed by #2752
Open

Detection trainer fails with In_channels>3 #2749

robmarkcole opened this issue Apr 22, 2025 · 3 comments · May be fixed by #2752
Labels
trainers PyTorch Lightning trainers

Comments

@robmarkcole
Copy link
Contributor

Description

Eg if using a 4 channel dataset, the error will be raised:

  File "/usr/local/lib/python3.11/site-packages/torchvision/models/detection/transform.py", line 141, in forward
    image = self.normalize(image)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/site-packages/torchvision/models/detection/transform.py", line 169, in normalize
    return (image - mean[:, None, None]) / std[:, None, None]
            ~~~~~~^~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

This is because GeneralizedRCNNTransform only supports 3 channels. We want to use Kornia for norm instead or disable this behaviour from torchvision

Steps to reproduce

I'm using a proprietary 4 channel dataset

model:
  class_path: ObjectDetectionTask
  init_args:
    model: faster-rcnn
    backbone: resnet18
    weights: True
    lr: 5e-4
    in_channels: 4

Version

torchgeo==0.7.0

@robmarkcole
Copy link
Contributor Author

As a workaround in my script:

def dummy_normalize(self, image):
    # Simply return the image as-is; assumes it's already normalized
    return image

# Patch the normalize method of GeneralizedRCNNTransform
detection_transform.GeneralizedRCNNTransform.normalize = dummy_normalize

@adamjstewart adamjstewart added this to the 0.7.1 milestone Apr 22, 2025
@adamjstewart adamjstewart added the trainers PyTorch Lightning trainers label Apr 22, 2025
@adamjstewart
Copy link
Collaborator

Glad someone actually tested this. Likely affects instance segmentation too. Wish we had some non-RGB datasets in TorchGeo to test this properly.

@isaaccorley
Copy link
Collaborator

isaaccorley commented Apr 22, 2025

Are we okay just converting this transform to nn.Identity by default? This is one thing I dislike about the torchvision RCNN models baking in the transform into the model. I've been screwed over by this in the past as well.

@adamjstewart adamjstewart removed this from the 0.7.1 milestone May 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants