Skip to content

Commit b69e0ef

Browse files
Merge pull request #17 from mirsazzathossain/dev
feat(utils): add functions to compute dataset statistics and save HTM…
2 parents c8ac95b + 168dc31 commit b69e0ef

6 files changed

+361
-0
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies = [
1717
"numpy>=1.24.4",
1818
"pandas>=2.0.3",
1919
"pillow>=10.4.0",
20+
"torch>=2.4.1",
2021
]
2122

2223
[project.urls]

rgc/utils/data.py

+50
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
import torch
1718
from astropy import units as u
1819
from astropy.coordinates import SkyCoord
1920
from astropy.io import fits
@@ -365,3 +366,52 @@ def celestial_capture_bulk(
365366
series = entry.to_frame().T
366367
failed = pd.concat([failed, series], ignore_index=True)
367368
print(f"Failed to capture image. {err}")
369+
370+
371+
def dataframe_to_html(catalog: pd.DataFrame, save_dir: str) -> None:
372+
"""
373+
Save the catalog as an HTML file.
374+
375+
:param catalog: Catalog of the astronomical objects
376+
:type catalog: pd.DataFrame
377+
:param save_dir: Path to the directory to save the HTML file
378+
:type save_dir: str
379+
"""
380+
Path(save_dir).mkdir(parents=True, exist_ok=True)
381+
catalog.to_html(os.path.join(save_dir, "catalog.html"))
382+
383+
384+
def compute_mean_std(dataloader: torch.utils.data.DataLoader) -> tuple[torch.Tensor, torch.Tensor]:
385+
"""
386+
Compute the mean and standard deviation of the dataset.
387+
388+
:param dataloader: The dataloader for the dataset.
389+
:type dataloader: torch.utils.data.DataLoader
390+
391+
:return: The mean and standard deviation of the dataset.
392+
:rtype: tuple[torch.Tensor, torch.Tensor]
393+
"""
394+
data = torch.tensor([])
395+
for batch in dataloader:
396+
data = torch.cat((data, batch[0]), 0)
397+
398+
mean = torch.mean(data, dim=(0, 2, 3))
399+
std = torch.std(data, dim=(0, 2, 3))
400+
401+
return mean, std
402+
403+
404+
def remove_artifacts(folder: str, extension: list[str]) -> None:
405+
"""
406+
Remove files with the given extensions from a folder.
407+
408+
:param folder: Path to the folder to clear
409+
:type folder: str
410+
:param extension: List of file with the given extensions to keep
411+
:type extension: list
412+
"""
413+
for file in os.listdir(folder):
414+
if not file.endswith(tuple(extension)):
415+
os.remove(os.path.join(folder, file))
416+
417+
print(f"Artifacts removed from {folder} with extensions {', '.join(extension)}")

tests/test_compute_mean_std.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
from torch.utils.data import DataLoader, TensorDataset
3+
4+
from rgc.utils.data import compute_mean_std
5+
6+
7+
def test_compute_mean_std():
8+
# Create a mock dataset with 3 channels
9+
data = torch.tensor([
10+
[[[1.0, 2.0], [3.0, 4.0]], [[2.0, 4.0], [6.0, 8.0]], [[0.5, 1.0], [1.5, 2.0]]], # Batch 1, 3 channels
11+
[[[5.0, 6.0], [7.0, 8.0]], [[10.0, 12.0], [14.0, 16.0]], [[2.5, 3.0], [3.5, 4.0]]], # Batch 2, 3 channels
12+
[[[9.0, 10.0], [11.0, 12.0]], [[18.0, 20.0], [22.0, 24.0]], [[4.5, 5.0], [5.5, 6.0]]], # Batch 3, 3 channels
13+
])
14+
15+
targets = torch.tensor([0, 1, 2]) # Dummy target labels
16+
dataset = TensorDataset(data, targets)
17+
dataloader = DataLoader(dataset, batch_size=2)
18+
19+
# Run the function
20+
mean, std = compute_mean_std(dataloader)
21+
22+
# Expected mean and std for each channel based on the dataset
23+
expected_mean = torch.tensor([6.5000, 13.0000, 3.2500]) # Mean across all batches for each channel
24+
expected_std = torch.tensor([3.6056, 7.2111, 1.8028]) # Standard deviation across all batches for each channel
25+
26+
# Check the mean and std are as expected
27+
assert torch.allclose(mean, expected_mean, atol=1e-4), f"Expected mean {expected_mean}, but got {mean}"
28+
assert torch.allclose(std, expected_std, atol=1e-4), f"Expected std {expected_std}, but got {std}"

tests/test_dataframe_to_html.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from unittest.mock import patch
2+
3+
import pandas as pd
4+
5+
from rgc.utils.data import dataframe_to_html
6+
7+
8+
@patch("rgc.utils.data.Path.mkdir")
9+
@patch("rgc.utils.data.pd.DataFrame.to_html")
10+
@patch("rgc.utils.data.os.path.join", return_value="/mocked/path/catalog.html")
11+
def test_dataframe_to_html(mock_join, mock_to_html, mock_mkdir):
12+
# Sample catalog
13+
catalog = pd.DataFrame({"object_name": ["Object1", "Object2"], "ra": [10.5, 20.3], "dec": [-30.1, 45.2]})
14+
15+
save_dir = "/mocked/directory"
16+
17+
# Run the function
18+
dataframe_to_html(catalog, save_dir)
19+
20+
# Check that the directory was created
21+
mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
22+
23+
# Check that the catalog was saved as HTML
24+
mock_to_html.assert_called_once_with("/mocked/path/catalog.html")
25+
26+
# Check that os.path.join was called with the correct parameters
27+
mock_join.assert_called_once_with(save_dir, "catalog.html")

tests/test_remove_artifacts.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
import tempfile
3+
import unittest
4+
from pathlib import Path
5+
6+
from rgc.utils.data import remove_artifacts
7+
8+
9+
class TestRemoveArtifacts(unittest.TestCase):
10+
def setUp(self):
11+
# Create a temporary directory
12+
self.test_dir = tempfile.TemporaryDirectory()
13+
14+
# Create test files in the temporary directory
15+
self.test_files = ["file1.txt", "file2.jpg", "file3.txt", "file4.png", "file5.csv"]
16+
for file_name in self.test_files:
17+
Path(os.path.join(self.test_dir.name, file_name)).touch()
18+
19+
def tearDown(self):
20+
# Clean up the temporary directory
21+
self.test_dir.cleanup()
22+
23+
def test_remove_artifacts(self):
24+
# Define extensions to keep
25+
extensions_to_keep = [".txt", ".jpg"]
26+
27+
# Run the function
28+
remove_artifacts(self.test_dir.name, extensions_to_keep)
29+
30+
# List remaining files
31+
remaining_files = os.listdir(self.test_dir.name)
32+
33+
# Check that only .txt and .jpg files are kept
34+
expected_remaining_files = ["file1.txt", "file2.jpg", "file3.txt"]
35+
self.assertEqual(sorted(remaining_files), sorted(expected_remaining_files))
36+
37+
# Check that other files are removed
38+
self.assertNotIn("file4.png", remaining_files)
39+
self.assertNotIn("file5.csv", remaining_files)
40+
41+
42+
if __name__ == "__main__":
43+
unittest.main()

0 commit comments

Comments
 (0)