-
Notifications
You must be signed in to change notification settings - Fork 432
ObjectDetection/InstanceSegmentationTask: fix support for non-RGB images #2752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
ObjectDetection/InstanceSegmentationTask: fix support for non-RGB images #2752
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes multispectral support in the ObjectDetectionTask by overriding the default transform parameters in the detector models. The changes update the initialization of FasterRCNN, FCOS, and RetinaNet with custom parameters (min_size, max_size, image_mean, and image_std) that enable multispectral inputs, and a new test is added to validate this functionality.
- Updated transform parameters for multispectral support in three detection model constructors.
- Added a new test in tests/trainers/test_detection.py to check multispectral behavior.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
File | Description |
---|---|
torchgeo/trainers/detection.py | Updated model constructors to override transform parameters for multispectral data. |
tests/trainers/test_detection.py | Added a test case to validate multispectral support with a non-RGB input channel. |
Can you also check instance segmentation? |
FYI I confirmed no issues with OD on my 4 channel dataset |
OK, whilst the error is resolved, the loss for OD models I train is always zero. Is there a validated results I can reproduce? Note this might just be my datasets which I recently updated for the new format |
This by default was resizing all imagery to a min of 800. What transforms are you using to preprocess your imagery? |
@adamjstewart fixed the instance segmentation task. It had the same issue. |
@isaaccorley can you elaborate on
where typically chip_size = 224 |
Torchvision Faster-RCNN and MaskRCNN has a One trick that works well for object detection in remote sensing is to simply resize your small patches to be larger. This may be why you're getting poor performance. |
@isaaccorley good to know! Perhaps we should document this? |
This PR basically removes this transform, so a user can decide which normalize and resize Kornia transform they want to do themselves. |
I've ruled out issues with my dataset and the remaining differences I see beween my legacy implementation and this implementation are details such as the anchor sizes I've utilised. I suggest we merge this approach and then as a follow up (and pending a suitable test dataset) work on further optimisations in another PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fix!
Except for the tests... |
7c10564
to
e4c6832
Compare
One of the wonderful surprises of torchvision's detector models is that a
GeneralizedRCNNTransform
gets added under the hood which defaults to ImageNet RGB mean/std normalize + dynamic resizing in the range of (800, 1333).This PR fixes this by loading pretrained weights but overriding this transform to simply subtract 0 and divide by 1 which is a no-op and changes the dynamic resize to allow for a min/max input shape in the range of (1, 4096).
Alternatives considered:
I attempted to simply replace
model.transform
withnn.Identity()
but this doesn't work because the detection models pass multiple args to the transform which will throw an error.Fixes #2749