From 450a241c45218904a1fdf01846c6a077b2856a48 Mon Sep 17 00:00:00 2001 From: plamen-alcatraz <56071366+plamen-alcatraz@users.noreply.github.com> Date: Tue, 31 Dec 2024 12:24:14 +0200 Subject: [PATCH] Fix RandomRain crash when no images in the batch are augmented (#3097) (#3103) --- kornia/augmentation/random_generator/_2d/random_rain.py | 2 +- tests/augmentation/test_augmentation.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/kornia/augmentation/random_generator/_2d/random_rain.py b/kornia/augmentation/random_generator/_2d/random_rain.py index 50332f380f..d3be616690 100644 --- a/kornia/augmentation/random_generator/_2d/random_rain.py +++ b/kornia/augmentation/random_generator/_2d/random_rain.py @@ -64,7 +64,7 @@ def forward(self, batch_shape: tuple[int, ...], same_on_batch: bool = False) -> device=_device, dtype=torch.long ) coordinates_factor = _adapted_rsampling( - (batch_size, int(number_of_drops_factor.max().item()), 2), + (batch_size, int(number_of_drops_factor.max().item()) if number_of_drops_factor.numel() > 0 else 0, 2), self.coordinates_sampler, same_on_batch=same_on_batch, ).to(device=_device) diff --git a/tests/augmentation/test_augmentation.py b/tests/augmentation/test_augmentation.py index ca3b862d58..0f859807e7 100644 --- a/tests/augmentation/test_augmentation.py +++ b/tests/augmentation/test_augmentation.py @@ -5150,6 +5150,11 @@ def test_exception(self, device, dtype): assert err_msg in str(errinfo) + def test_zero_probability(self, device): + input_data = torch.rand(10, 3, 8, 8, device=device) + aug = RandomRain(p=0.0, drop_height=(2, 3), drop_width=(2, 3), number_of_drops=(1, 3)) + aug(input_data) + class TestMultiprocessing: torch.manual_seed(0) # for random reproductibility