Skip to content

Commit e4c6832

Browse files
committed
add grayscale tests to detection
1 parent 0b0ffb9 commit e4c6832

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/trainers/test_detection.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,15 @@ def test_freeze_backbone(self, model_name: str) -> None:
125125
assert not all([param.requires_grad for param in model.model.parameters()])
126126

127127
@pytest.mark.parametrize('model_name', ['faster-rcnn', 'fcos', 'retinanet'])
128-
def test_multispectral_support(self, model_name: str) -> None:
129-
channels = 4
128+
@pytest.mark.parametrize('in_channels', [1, 4])
129+
def test_multispectral_support(self, model_name: str, in_channels: int) -> None:
130130
model = ObjectDetectionTask(
131-
model=model_name, backbone='resnet18', num_classes=2, in_channels=channels
131+
model=model_name,
132+
backbone='resnet18',
133+
num_classes=2,
134+
in_channels=in_channels,
132135
)
133136
model.eval()
134-
sample = [torch.randn(channels, 224, 224)]
137+
sample = [torch.randn(in_channels, 224, 224)]
135138
with torch.inference_mode():
136139
model(sample)

0 commit comments

Comments
 (0)