|
| 1 | +import torch |
| 2 | +from torch.utils.data import DataLoader, TensorDataset |
| 3 | + |
| 4 | +from rgc.utils.data import compute_mean_std |
| 5 | + |
| 6 | + |
| 7 | +def test_compute_mean_std(): |
| 8 | + # Create a mock dataset with 3 channels |
| 9 | + data = torch.tensor([ |
| 10 | + [[[1.0, 2.0], [3.0, 4.0]], [[2.0, 4.0], [6.0, 8.0]], [[0.5, 1.0], [1.5, 2.0]]], # Batch 1, 3 channels |
| 11 | + [[[5.0, 6.0], [7.0, 8.0]], [[10.0, 12.0], [14.0, 16.0]], [[2.5, 3.0], [3.5, 4.0]]], # Batch 2, 3 channels |
| 12 | + [[[9.0, 10.0], [11.0, 12.0]], [[18.0, 20.0], [22.0, 24.0]], [[4.5, 5.0], [5.5, 6.0]]], # Batch 3, 3 channels |
| 13 | + ]) |
| 14 | + |
| 15 | + targets = torch.tensor([0, 1, 2]) # Dummy target labels |
| 16 | + dataset = TensorDataset(data, targets) |
| 17 | + dataloader = DataLoader(dataset, batch_size=2) |
| 18 | + |
| 19 | + # Run the function |
| 20 | + mean, std = compute_mean_std(dataloader) |
| 21 | + |
| 22 | + # Expected mean and std for each channel based on the dataset |
| 23 | + expected_mean = torch.tensor([6.5000, 13.0000, 3.2500]) # Mean across all batches for each channel |
| 24 | + expected_std = torch.tensor([3.6056, 7.2111, 1.8028]) # Standard deviation across all batches for each channel |
| 25 | + |
| 26 | + # Check the mean and std are as expected |
| 27 | + assert torch.allclose(mean, expected_mean, atol=1e-4), f"Expected mean {expected_mean}, but got {mean}" |
| 28 | + assert torch.allclose(std, expected_std, atol=1e-4), f"Expected std {expected_std}, but got {std}" |
0 commit comments