Skip to content

ObjectDetection/InstanceSegmentationTask: fix support for non-RGB images #2752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

isaaccorley
Copy link
Collaborator

@isaaccorley isaaccorley commented Apr 23, 2025

One of the wonderful surprises of torchvision's detector models is that a GeneralizedRCNNTransform gets added under the hood which defaults to ImageNet RGB mean/std normalize + dynamic resizing in the range of (800, 1333).

This PR fixes this by loading pretrained weights but overriding this transform to simply subtract 0 and divide by 1 which is a no-op and changes the dynamic resize to allow for a min/max input shape in the range of (1, 4096).

Alternatives considered:

I attempted to simply replace model.transform with nn.Identity() but this doesn't work because the detection models pass multiple args to the transform which will throw an error.

Fixes #2749

@isaaccorley isaaccorley self-assigned this Apr 23, 2025
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes multispectral support in the ObjectDetectionTask by overriding the default transform parameters in the detector models. The changes update the initialization of FasterRCNN, FCOS, and RetinaNet with custom parameters (min_size, max_size, image_mean, and image_std) that enable multispectral inputs, and a new test is added to validate this functionality.

  • Updated transform parameters for multispectral support in three detection model constructors.
  • Added a new test in tests/trainers/test_detection.py to check multispectral behavior.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
torchgeo/trainers/detection.py Updated model constructors to override transform parameters for multispectral data.
tests/trainers/test_detection.py Added a test case to validate multispectral support with a non-RGB input channel.

@github-actions github-actions bot added testing Continuous integration testing trainers PyTorch Lightning trainers labels Apr 23, 2025
@adamjstewart
Copy link
Collaborator

Can you also check instance segmentation?

@robmarkcole
Copy link
Contributor

FYI I confirmed no issues with OD on my 4 channel dataset

@robmarkcole
Copy link
Contributor

robmarkcole commented Apr 23, 2025

OK, whilst the error is resolved, the loss for OD models I train is always zero. Is there a validated results I can reproduce? Note this might just be my datasets which I recently updated for the new format

@isaaccorley
Copy link
Collaborator Author

OK, whilst the error is resolved, the loss for OD models I train is always zero. Is there a validated results I can reproduce? Note this might just be my datasets which I recently updated for the new format

This by default was resizing all imagery to a min of 800. What transforms are you using to preprocess your imagery?

@isaaccorley
Copy link
Collaborator Author

isaaccorley commented Apr 24, 2025

@adamjstewart fixed the instance segmentation task. It had the same issue.

@robmarkcole
Copy link
Contributor

@isaaccorley can you elaborate on This by default was resizing all imagery to a min of 800?
Example transforms:

        self.train_aug = Ka.AugmentationSequential(    
                Ka.Normalize(mean=self.mean, std=self.std),
                Ka.Resize(self.chip_size),
                Ka.RandomHorizontalFlip(p=0.5),
                Ka.RandomVerticalFlip(p=0.5),
                Ka.RandomRotation(degrees=(90.0, 90.0), p=0.25),
                data_keys=None,
                keepdim=True,
        )

        test_transforms: List[Ka.AugmentationBase2D] = [
            Ka.Normalize(mean=self.mean, std=self.std),
            Ka.Resize(self.chip_size),
        ]

where typically chip_size = 224

@adamjstewart adamjstewart changed the title Fix Multispectral Support in ObjectDetectionTask ObjectDetection/InstanceSegmentationTask: fix support for non-RGB images Apr 24, 2025
@isaaccorley
Copy link
Collaborator Author

@isaaccorley can you elaborate on This by default was resizing all imagery to a min of 800?

Example transforms:


        self.train_aug = Ka.AugmentationSequential(    

                Ka.Normalize(mean=self.mean, std=self.std),

                Ka.Resize(self.chip_size),

                Ka.RandomHorizontalFlip(p=0.5),

                Ka.RandomVerticalFlip(p=0.5),

                Ka.RandomRotation(degrees=(90.0, 90.0), p=0.25),

                data_keys=None,

                keepdim=True,

        )



        test_transforms: List[Ka.AugmentationBase2D] = [

            Ka.Normalize(mean=self.mean, std=self.std),

            Ka.Resize(self.chip_size),

        ]

where typically chip_size = 224

Torchvision Faster-RCNN and MaskRCNN has a GeneralizedRCNNTransform transform module inside the model itself that will perform normalize + resizing to a minimum image size of 800. So any image you pass in <800 will be resized to 800. See the code here.

One trick that works well for object detection in remote sensing is to simply resize your small patches to be larger. This may be why you're getting poor performance.

@robmarkcole
Copy link
Contributor

@isaaccorley good to know! Perhaps we should document this?

@isaaccorley
Copy link
Collaborator Author

@isaaccorley good to know! Perhaps we should document this?

This PR basically removes this transform, so a user can decide which normalize and resize Kornia transform they want to do themselves.

@robmarkcole
Copy link
Contributor

I've ruled out issues with my dataset and the remaining differences I see beween my legacy implementation and this implementation are details such as the anchor sizes I've utilised. I suggest we merge this approach and then as a follow up (and pending a suitable test dataset) work on further optimisations in another PR

@adamjstewart adamjstewart added this to the 0.7.1 milestone Apr 28, 2025
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

@adamjstewart
Copy link
Collaborator

Except for the tests...

@isaaccorley isaaccorley force-pushed the trainers/multispectral-object-detection branch from 7c10564 to e4c6832 Compare May 1, 2025 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Detection trainer fails with In_channels>3
3 participants