Skip to content

Commit b0cf376

Browse files
abdulrahman-riyadpre-commit-ci[bot]
andauthoredMar 27, 2025
Splitting the tests in test_augmentation_3d.py and test_container.py into individual test files. (kornia#3143)
* Split augmentation_3d and container tests into individual files * Fix linting error: replace assert False with raise AssertionError * removved the non-modular test files test_container.py and test_augmentation_3d.py * fixing naming conflicts * fixing naming conflicts and imports * Organizing augmentation tests into the subdirectories _3d and container and moving utility function (which got renamed from reproducability_test.py to utils.py) to testing/augmentation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4a5202d commit b0cf376

16 files changed

+1595
-1448
lines changed
 

‎testing/augmentation/utils.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# LICENSE HEADER MANAGED BY add-license-header
2+
#
3+
# Copyright 2018 Kornia Team
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import torch
19+
20+
from testing.base import assert_close
21+
22+
23+
def reproducibility_test(input, seq):
24+
"""Any tests failed here indicate the output cannot be reproduced by the same params."""
25+
if isinstance(input, (tuple, list)):
26+
output_1 = seq(*input)
27+
output_2 = seq(*input, params=seq._params)
28+
else:
29+
output_1 = seq(input)
30+
output_2 = seq(input, params=seq._params)
31+
32+
if isinstance(output_1, (tuple, list)) and isinstance(output_2, (tuple, list)):
33+
[
34+
assert_close(o1, o2)
35+
for o1, o2 in zip(output_1, output_2)
36+
if isinstance(o1, (torch.Tensor,)) and isinstance(o2, (torch.Tensor,))
37+
]
38+
elif isinstance(output_1, (tuple, list)) and isinstance(output_2, (torch.Tensor,)):
39+
assert_close(output_1[0], output_2)
40+
elif isinstance(output_2, (tuple, list)) and isinstance(output_1, (torch.Tensor,)):
41+
assert_close(output_1, output_2[0])
42+
elif isinstance(output_2, (torch.Tensor,)) and isinstance(output_1, (torch.Tensor,)):
43+
assert_close(output_1, output_2, msg=f"{seq._params}")
44+
else:
45+
raise AssertionError(f"cannot compare {type(output_1)} and {type(output_2)}")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# LICENSE HEADER MANAGED BY add-license-header
2+
#
3+
# Copyright 2018 Kornia Team
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import torch
19+
20+
from kornia.augmentation import CenterCrop3D
21+
22+
from testing.base import BaseTester
23+
24+
25+
class TestCenterCrop3D(BaseTester):
26+
def test_no_transform(self, device, dtype):
27+
inp = torch.rand(1, 2, 4, 4, 4, device=device, dtype=dtype)
28+
out = CenterCrop3D(2)(inp)
29+
assert out.shape == (1, 2, 2, 2, 2)
30+
31+
def test_transform(self, device, dtype):
32+
inp = torch.rand(1, 2, 5, 4, 8, device=device, dtype=dtype)
33+
aug = CenterCrop3D(2)
34+
out = aug(inp)
35+
assert out.shape == (1, 2, 2, 2, 2)
36+
assert aug.transform_matrix.shape == (1, 4, 4)
37+
38+
def test_no_transform_tuple(self, device, dtype):
39+
inp = torch.rand(1, 2, 5, 4, 8, device=device, dtype=dtype)
40+
out = CenterCrop3D((3, 4, 5))(inp)
41+
assert out.shape == (1, 2, 3, 4, 5)
42+
43+
def test_gradcheck(self, device):
44+
input_tensor = torch.rand(1, 2, 3, 4, 5, device=device, dtype=torch.float64)
45+
self.gradcheck(CenterCrop3D(3), (input_tensor,))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# LICENSE HEADER MANAGED BY add-license-header
2+
#
3+
# Copyright 2018 Kornia Team
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
import torch
20+
21+
from kornia.augmentation import RandomAffine3D
22+
23+
from testing.base import BaseTester
24+
25+
26+
class TestRandomAffine3D(BaseTester):
27+
def test_batch_random_affine_3d(self, device, dtype):
28+
# TODO(jian): cuda and fp64
29+
if "cuda" in str(device) and dtype == torch.float64:
30+
pytest.skip("AssertionError: assert tensor(False, device='cuda:0')")
31+
32+
f = RandomAffine3D((0, 0, 0), p=1.0) # No rotation
33+
tensor = torch.tensor(
34+
[[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]], device=device, dtype=dtype
35+
) # 1 x 1 x 1 x 3 x 3
36+
37+
expected = torch.tensor(
38+
[[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]], device=device, dtype=dtype
39+
) # 1 x 1 x 1 x 3 x 3
40+
41+
expected_transform = torch.tensor(
42+
[[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]],
43+
device=device,
44+
dtype=dtype,
45+
) # 1 x 4 x 4
46+
47+
tensor = tensor.repeat(5, 3, 1, 1, 1) # 5 x 3 x 3 x 3 x 3
48+
expected = expected.repeat(5, 3, 1, 1, 1) # 5 x 3 x 3 x 3 x 3
49+
expected_transform = expected_transform.repeat(5, 1, 1) # 5 x 4 x 4
50+
51+
self.assert_close(f(tensor), expected)
52+
self.assert_close(f.transform_matrix, expected_transform)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# LICENSE HEADER MANAGED BY add-license-header
2+
#
3+
# Copyright 2018 Kornia Team
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import pytest
19+
import torch
20+
21+
import kornia
22+
from kornia.augmentation import RandomCrop, RandomCrop3D
23+
24+
from testing.base import BaseTester
25+
26+
27+
class TestRandomCrop3D(BaseTester):
28+
# TODO: improve and implement more meaningful smoke tests e.g check for a consistent
29+
# return values such a torch.Tensor variable.
30+
@pytest.mark.xfail(reason="might fail under windows OS due to printing preicision.")
31+
def test_smoke(self):
32+
f = RandomCrop3D(size=(2, 3, 4), padding=(0, 1, 2), fill=10, pad_if_needed=False, p=1.0)
33+
repr = (
34+
"RandomCrop3D(crop_size=(2, 3, 4), padding=(0, 1, 2), fill=10, pad_if_needed=False, "
35+
"padding_mode=constant, resample=BILINEAR, p=1.0, p_batch=1.0, same_on_batch=False, "
36+
"return_transform=None)"
37+
)
38+
assert str(f) == repr
39+
40+
@pytest.mark.parametrize("batch_size", [1, 2])
41+
def test_no_padding(self, batch_size, device, dtype):
42+
torch.manual_seed(42)
43+
input_tensor = torch.tensor(
44+
[
45+
[
46+
[
47+
[
48+
[0.0, 1.0, 2.0, 3.0, 4.0],
49+
[5.0, 6.0, 7.0, 8.0, 9.0],
50+
[10, 11, 12, 13, 14],
51+
[15, 16, 17, 18, 19],
52+
[20, 21, 22, 23, 24],
53+
]
54+
]
55+
]
56+
],
57+
device=device,
58+
dtype=dtype,
59+
).repeat(batch_size, 1, 5, 1, 1)
60+
f = RandomCrop3D(size=(2, 3, 4), padding=None, align_corners=True, p=1.0)
61+
out = f(input_tensor)
62+
if batch_size == 1:
63+
expected = torch.tensor(
64+
[[[[[11, 12, 13, 14], [16, 17, 18, 19], [21, 22, 23, 24]]]]], device=device, dtype=dtype
65+
).repeat(batch_size, 1, 2, 1, 1)
66+
if batch_size == 2:
67+
expected = torch.tensor(
68+
[
69+
[
70+
[
71+
[
72+
[6.0000, 7.0000, 8.0000, 9.0000],
73+
[11.0000, 12.0000, 13.0000, 14.0000],
74+
[16.0000, 17.0000, 18.0000, 19.0000],
75+
],
76+
[
77+
[6.0000, 7.0000, 8.0000, 9.0000],
78+
[11.0000, 12.0000, 13.0000, 14.0000],
79+
[16.0000, 17.0000, 18.0000, 19.0000],
80+
],
81+
]
82+
],
83+
[
84+
[
85+
[
86+
[11.0000, 12.0000, 13.0000, 14.0000],
87+
[16.0000, 17.0000, 18.0000, 19.0000],
88+
[21.0000, 22.0000, 23.0000, 24.0000],
89+
],
90+
[
91+
[11.0000, 12.0000, 13.0000, 14.0000],
92+
[16.0000, 17.0000, 18.0000, 19.0000],
93+
[21.0000, 22.0000, 23.0000, 24.0000],
94+
],
95+
]
96+
],
97+
],
98+
device=device,
99+
dtype=dtype,
100+
)
101+
102+
self.assert_close(out, expected, atol=1e-4, rtol=1e-4)
103+
104+
def test_same_on_batch(self, device, dtype):
105+
f = RandomCrop3D(size=(2, 3, 4), padding=None, align_corners=True, p=1.0, same_on_batch=True)
106+
input_tensor = (
107+
torch.eye(6, device=device, dtype=dtype)
108+
.unsqueeze(dim=0)
109+
.unsqueeze(dim=0)
110+
.unsqueeze(dim=0)
111+
.repeat(2, 3, 5, 1, 1)
112+
)
113+
res = f(input_tensor)
114+
self.assert_close(res[0], res[1])
115+
116+
@pytest.mark.parametrize("padding", [1, (1, 1, 1), (1, 1, 1, 1, 1, 1)])
117+
def test_padding_batch(self, padding, device, dtype):
118+
torch.manual_seed(42)
119+
batch_size = 2
120+
input_tensor = torch.tensor(
121+
[[[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]]], device=device, dtype=dtype
122+
).repeat(batch_size, 1, 3, 1, 1)
123+
expected = torch.tensor(
124+
[
125+
[
126+
[
127+
[[0.0, 1.0, 2.0, 10.0], [3.0, 4.0, 5.0, 10.0], [6.0, 7.0, 8.0, 10.0]],
128+
[[0.0, 1.0, 2.0, 10.0], [3.0, 4.0, 5.0, 10.0], [6.0, 7.0, 8.0, 10.0]],
129+
]
130+
],
131+
[
132+
[
133+
[[3.0, 4.0, 5.0, 10.0], [6.0, 7.0, 8.0, 10.0], [10, 10, 10, 10.0]],
134+
[[3.0, 4.0, 5.0, 10.0], [6.0, 7.0, 8.0, 10.0], [10, 10, 10, 10.0]],
135+
]
136+
],
137+
],
138+
device=device,
139+
dtype=dtype,
140+
)
141+
f = RandomCrop3D(size=(2, 3, 4), fill=10.0, padding=padding, align_corners=True, p=1.0)
142+
out = f(input_tensor)
143+
144+
self.assert_close(out, expected, atol=1e-4, rtol=1e-4)
145+
146+
def test_pad_if_needed(self, device, dtype):
147+
torch.manual_seed(42)
148+
input_tensor = torch.tensor([[[0.0, 1.0, 2.0]]], device=device, dtype=dtype)
149+
expected = torch.tensor(
150+
[
151+
[
152+
[
153+
[[9.0, 9.0, 9.0, 9.0], [9.0, 9.0, 9.0, 9.0], [9.0, 9.0, 9.0, 9.0]],
154+
[[0.0, 1.0, 2.0, 9.0], [9.0, 9.0, 9.0, 9.0], [9.0, 9.0, 9.0, 9.0]],
155+
]
156+
]
157+
],
158+
device=device,
159+
dtype=dtype,
160+
)
161+
rc = RandomCrop3D(size=(2, 3, 4), pad_if_needed=True, fill=9, align_corners=True, p=1.0)
162+
out = rc(input_tensor)
163+
164+
self.assert_close(out, expected, atol=1e-4, rtol=1e-4)
165+
166+
def test_gradcheck(self, device):
167+
torch.manual_seed(0) # for random reproductibility
168+
input_tensor = torch.rand((3, 3, 3), device=device, dtype=torch.float64) # 3 x 3
169+
self.gradcheck(RandomCrop3D(size=(3, 3, 3), p=1.0), (input_tensor,))
170+
171+
@pytest.mark.skip("Need to fix Union type")
172+
def test_jit(self, device, dtype):
173+
# Define script
174+
op = RandomCrop(size=(3, 3), p=1.0).forward
175+
op_script = torch.jit.script(op)
176+
img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype)
177+
178+
actual = op_script(img)
179+
expected = kornia.geometry.transform.center_crop3d(img)
180+
self.assert_close(actual, expected)
181+
182+
@pytest.mark.skip("Need to fix Union type")
183+
def test_jit_trace(self, device, dtype):
184+
# Define script
185+
op = RandomCrop(size=(3, 3), p=1.0).forward
186+
op_script = torch.jit.script(op)
187+
# 1. Trace op
188+
img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype)
189+
190+
op_trace = torch.jit.trace(op_script, (img,))
191+
192+
# 2. Generate new input
193+
img = torch.ones(1, 1, 5, 6, device=device, dtype=dtype)
194+
195+
# 3. Evaluate
196+
actual = op_trace(img)
197+
expected = op(img)
198+
self.assert_close(actual, expected)

0 commit comments

Comments
 (0)