Skip to content

Commit b022bea

Browse files
Merge pull request #16 from mirsazzathossain/dev
feat(utils): add bulk image download from catalog
2 parents ab48994 + cb436ba commit b022bea

File tree

3 files changed

+178
-0
lines changed

3 files changed

+178
-0
lines changed

rgc/utils/data.py

+102
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
from astropy import units as u
18+
from astropy.coordinates import SkyCoord
1719
from astropy.io import fits
1820
from astroquery.skyview import SkyView
1921
from astroquery.vizier import Vizier
@@ -203,6 +205,8 @@ def mask_image(image: Image.Image, mask: Image.Image) -> Image.Image:
203205
204206
:return: A PIL Image object containing the masked image.
205207
:rtype: Image.Image
208+
209+
:raises _ImageMaskDimensionError: If the dimensions of the image and mask do not match.
206210
"""
207211
image_array = np.array(image)
208212
mask_array = np.array(mask)
@@ -235,6 +239,21 @@ def __init__(self, message: str = "Number of images and masks must match and be
235239

236240

237241
def mask_image_bulk(image_dir: str, mask_dir: str, masked_dir: str) -> None:
242+
"""
243+
Mask a directory of images with a directory of mask images.
244+
245+
:param image_dir: The path to the directory containing the images.
246+
:type image_dir: str
247+
248+
:param mask_dir: The path to the directory containing the mask images.
249+
:type mask_dir: str
250+
251+
:param masked_dir: The path to the directory to save the masked images.
252+
:type masked_dir: str
253+
254+
:raises _FileNotFoundError: If no images or masks are found in the directories.
255+
:raises _ImageMaskCountMismatchError: If the number of images and masks do not match.
256+
"""
238257
image_paths = sorted(Path(image_dir).glob("*.png"))
239258
mask_paths = sorted(Path(mask_dir).glob("*.png"))
240259

@@ -263,3 +282,86 @@ def mask_image_bulk(image_dir: str, mask_dir: str, masked_dir: str) -> None:
263282
masked_image = mask_image(image, mask)
264283

265284
masked_image.save(Path(masked_dir) / image_path.name)
285+
286+
287+
class _ColumnNotFoundError(Exception):
288+
"""
289+
An exception to be raised when a specified column is not found in the catalog.
290+
"""
291+
292+
def __init__(self, column: str) -> None:
293+
super().__init__(f"Column {column} not found in the catalog.")
294+
295+
296+
def _get_class_labels(catalog: pd.Series, classes: dict, cls_col: str) -> str:
297+
"""
298+
Get the class labels for the celestial objects in the catalog.
299+
300+
:param catalog: A pandas Series representing a row in the catalog of celestial objects.
301+
:type catalog: pd.Series
302+
303+
:param classes: A dictionary containing the classes of the celestial objects.
304+
:type classes: dict
305+
306+
:param cls_col: The name of the column containing the class labels.
307+
:type cls_col: str
308+
309+
:return: Class labels for the celestial objects in the catalog.
310+
:rtype: str
311+
312+
:raises _ColumnNotFoundError: If the specified column is not found in the catalog.
313+
"""
314+
if cls_col not in catalog.index:
315+
raise _ColumnNotFoundError(cls_col)
316+
317+
value = catalog[cls_col]
318+
for key, label in classes.items():
319+
if key in value:
320+
return str(label)
321+
322+
return ""
323+
324+
325+
def celestial_capture_bulk(
326+
catalog: pd.DataFrame, survey: str, img_dir: str, classes: Optional[dict] = None, cls_col: Optional[str] = None
327+
) -> None:
328+
"""
329+
Capture celestial images for a catalog of celestial objects.
330+
331+
:param catalog: A pandas DataFrame containing the catalog of celestial objects.
332+
:type catalog: pd.DataFrame
333+
334+
:param survey: The name of the survey to be used e.g. 'VLA FIRST (1.4 GHz)'.
335+
:type survey: str
336+
337+
:param img_dir: The path to the directory to save the images.
338+
:type img_dir: str
339+
340+
:param classes: A dictionary containing the classes of the celestial objects.
341+
:type classes: dict
342+
343+
:param cls_col: The name of the column containing the class labels.
344+
345+
:raises _InvalidCoordinatesError: If coordinates are invalid.
346+
"""
347+
failed = pd.DataFrame(columns=catalog.columns)
348+
for _, entry in catalog.iterrows():
349+
try:
350+
tag = celestial_tag(entry)
351+
coordinate = SkyCoord(tag, unit=(u.hourangle, u.deg))
352+
353+
right_ascension = coordinate.ra.deg
354+
declination = coordinate.dec.deg
355+
356+
label = _get_class_labels(entry, classes, cls_col) if classes is not None and cls_col is not None else ""
357+
358+
if "filename" in catalog.columns:
359+
filename = f'{img_dir}/{label}_{entry["filename"]}.fits'
360+
else:
361+
filename = f"{img_dir}/{label}_{tag}.fits"
362+
363+
celestial_capture(survey, right_ascension, declination, filename)
364+
except Exception as err:
365+
series = entry.to_frame().T
366+
failed = pd.concat([failed, series], ignore_index=True)
367+
print(f"Failed to capture image. {err}")

tests/test_celestial_capture_bulk.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from unittest.mock import MagicMock, patch
2+
3+
import pandas as pd
4+
5+
from rgc.utils.data import celestial_capture_bulk
6+
7+
8+
@patch("rgc.utils.data.celestial_tag")
9+
@patch("rgc.utils.data.SkyCoord")
10+
@patch("rgc.utils.data.celestial_capture")
11+
def test_celestial_capture_bulk(mock_celestial_capture, mock_SkyCoord, mock_celestial_tag):
12+
# Mock data
13+
mock_celestial_tag.return_value = "10h00m00s +10d00m00s"
14+
mock_SkyCoord.return_value = MagicMock(ra=MagicMock(deg=10), dec=MagicMock(deg=20))
15+
16+
catalog = pd.DataFrame({"label": ["WAT"], "object_name": ["test"]})
17+
classes = {"WAT": 100, "NAT": 200}
18+
img_dir = "/path/to/images"
19+
20+
# Run the function
21+
celestial_capture_bulk(catalog, "VLA FIRST (1.4 GHz)", img_dir, classes, "label")
22+
23+
# Check that celestial_capture was called with the expected arguments
24+
mock_celestial_capture.assert_called_once_with(
25+
"VLA FIRST (1.4 GHz)", 10, 20, "/path/to/images/100_10h00m00s +10d00m00s.fits"
26+
)
27+
28+
# Test failure handling
29+
mock_celestial_capture.reset_mock()
30+
mock_celestial_tag.side_effect = Exception("Test exception")
31+
32+
with patch("builtins.print") as mock_print:
33+
celestial_capture_bulk(catalog, "VLA FIRST (1.4 GHz)", img_dir, classes, "object_name")
34+
mock_print.assert_called_once_with("Failed to capture image. Test exception")
35+
36+
37+
@patch("rgc.utils.data.celestial_tag")
38+
@patch("rgc.utils.data.SkyCoord")
39+
@patch("rgc.utils.data.celestial_capture")
40+
def test_celestial_capture_bulk_with_filename(mock_celestial_capture, mock_SkyCoord, mock_celestial_tag):
41+
# Mock data
42+
mock_celestial_tag.return_value = "10h00m00s +10d00m00s"
43+
mock_SkyCoord.return_value = MagicMock(ra=MagicMock(deg=10), dec=MagicMock(deg=20))
44+
45+
# Catalog with filename column
46+
catalog = pd.DataFrame({"label": ["WAT"], "filename": ["image1"], "object_name": ["test"]})
47+
classes = {"WAT": 100, "NAT": 200}
48+
img_dir = "/path/to/images"
49+
50+
# Run the function
51+
celestial_capture_bulk(catalog, "VLA FIRST (1.4 GHz)", img_dir, classes, "label")
52+
53+
# Check that celestial_capture was called with the expected filename
54+
mock_celestial_capture.assert_called_once_with("VLA FIRST (1.4 GHz)", 10, 20, "/path/to/images/100_image1.fits")

tests/test_get_class_label.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pandas as pd
2+
import pytest
3+
4+
from rgc.utils.data import _ColumnNotFoundError, _get_class_labels
5+
6+
7+
def test_get_class_labels():
8+
# Sample data
9+
catalog = pd.Series({"object_name": "Object1", "class_col": "Galaxy"})
10+
classes = {"Galaxy": "Galactic", "Star": "Stellar"}
11+
12+
# Test with valid column and key
13+
result = _get_class_labels(catalog, classes, "class_col")
14+
assert result == "Galactic", "Should return 'Galactic' for 'Galaxy'"
15+
16+
# Test with invalid column
17+
with pytest.raises(_ColumnNotFoundError):
18+
_get_class_labels(catalog, classes, "invalid_col")
19+
20+
# Test with no matching key
21+
result = _get_class_labels(catalog, classes, "object_name")
22+
assert result == "", "Should return '' if no matching key is found"

0 commit comments

Comments
 (0)