|
14 | 14 |
|
15 | 15 | import numpy as np
|
16 | 16 | import pandas as pd
|
| 17 | +from astropy import units as u |
| 18 | +from astropy.coordinates import SkyCoord |
17 | 19 | from astropy.io import fits
|
18 | 20 | from astroquery.skyview import SkyView
|
19 | 21 | from astroquery.vizier import Vizier
|
@@ -203,6 +205,8 @@ def mask_image(image: Image.Image, mask: Image.Image) -> Image.Image:
|
203 | 205 |
|
204 | 206 | :return: A PIL Image object containing the masked image.
|
205 | 207 | :rtype: Image.Image
|
| 208 | +
|
| 209 | + :raises _ImageMaskDimensionError: If the dimensions of the image and mask do not match. |
206 | 210 | """
|
207 | 211 | image_array = np.array(image)
|
208 | 212 | mask_array = np.array(mask)
|
@@ -235,6 +239,21 @@ def __init__(self, message: str = "Number of images and masks must match and be
|
235 | 239 |
|
236 | 240 |
|
237 | 241 | def mask_image_bulk(image_dir: str, mask_dir: str, masked_dir: str) -> None:
|
| 242 | + """ |
| 243 | + Mask a directory of images with a directory of mask images. |
| 244 | +
|
| 245 | + :param image_dir: The path to the directory containing the images. |
| 246 | + :type image_dir: str |
| 247 | +
|
| 248 | + :param mask_dir: The path to the directory containing the mask images. |
| 249 | + :type mask_dir: str |
| 250 | +
|
| 251 | + :param masked_dir: The path to the directory to save the masked images. |
| 252 | + :type masked_dir: str |
| 253 | +
|
| 254 | + :raises _FileNotFoundError: If no images or masks are found in the directories. |
| 255 | + :raises _ImageMaskCountMismatchError: If the number of images and masks do not match. |
| 256 | + """ |
238 | 257 | image_paths = sorted(Path(image_dir).glob("*.png"))
|
239 | 258 | mask_paths = sorted(Path(mask_dir).glob("*.png"))
|
240 | 259 |
|
@@ -263,3 +282,86 @@ def mask_image_bulk(image_dir: str, mask_dir: str, masked_dir: str) -> None:
|
263 | 282 | masked_image = mask_image(image, mask)
|
264 | 283 |
|
265 | 284 | masked_image.save(Path(masked_dir) / image_path.name)
|
| 285 | + |
| 286 | + |
| 287 | +class _ColumnNotFoundError(Exception): |
| 288 | + """ |
| 289 | + An exception to be raised when a specified column is not found in the catalog. |
| 290 | + """ |
| 291 | + |
| 292 | + def __init__(self, column: str) -> None: |
| 293 | + super().__init__(f"Column {column} not found in the catalog.") |
| 294 | + |
| 295 | + |
| 296 | +def _get_class_labels(catalog: pd.Series, classes: dict, cls_col: str) -> str: |
| 297 | + """ |
| 298 | + Get the class labels for the celestial objects in the catalog. |
| 299 | +
|
| 300 | + :param catalog: A pandas Series representing a row in the catalog of celestial objects. |
| 301 | + :type catalog: pd.Series |
| 302 | +
|
| 303 | + :param classes: A dictionary containing the classes of the celestial objects. |
| 304 | + :type classes: dict |
| 305 | +
|
| 306 | + :param cls_col: The name of the column containing the class labels. |
| 307 | + :type cls_col: str |
| 308 | +
|
| 309 | + :return: Class labels for the celestial objects in the catalog. |
| 310 | + :rtype: str |
| 311 | +
|
| 312 | + :raises _ColumnNotFoundError: If the specified column is not found in the catalog. |
| 313 | + """ |
| 314 | + if cls_col not in catalog.index: |
| 315 | + raise _ColumnNotFoundError(cls_col) |
| 316 | + |
| 317 | + value = catalog[cls_col] |
| 318 | + for key, label in classes.items(): |
| 319 | + if key in value: |
| 320 | + return str(label) |
| 321 | + |
| 322 | + return "" |
| 323 | + |
| 324 | + |
| 325 | +def celestial_capture_bulk( |
| 326 | + catalog: pd.DataFrame, survey: str, img_dir: str, classes: Optional[dict] = None, cls_col: Optional[str] = None |
| 327 | +) -> None: |
| 328 | + """ |
| 329 | + Capture celestial images for a catalog of celestial objects. |
| 330 | +
|
| 331 | + :param catalog: A pandas DataFrame containing the catalog of celestial objects. |
| 332 | + :type catalog: pd.DataFrame |
| 333 | +
|
| 334 | + :param survey: The name of the survey to be used e.g. 'VLA FIRST (1.4 GHz)'. |
| 335 | + :type survey: str |
| 336 | +
|
| 337 | + :param img_dir: The path to the directory to save the images. |
| 338 | + :type img_dir: str |
| 339 | +
|
| 340 | + :param classes: A dictionary containing the classes of the celestial objects. |
| 341 | + :type classes: dict |
| 342 | +
|
| 343 | + :param cls_col: The name of the column containing the class labels. |
| 344 | +
|
| 345 | + :raises _InvalidCoordinatesError: If coordinates are invalid. |
| 346 | + """ |
| 347 | + failed = pd.DataFrame(columns=catalog.columns) |
| 348 | + for _, entry in catalog.iterrows(): |
| 349 | + try: |
| 350 | + tag = celestial_tag(entry) |
| 351 | + coordinate = SkyCoord(tag, unit=(u.hourangle, u.deg)) |
| 352 | + |
| 353 | + right_ascension = coordinate.ra.deg |
| 354 | + declination = coordinate.dec.deg |
| 355 | + |
| 356 | + label = _get_class_labels(entry, classes, cls_col) if classes is not None and cls_col is not None else "" |
| 357 | + |
| 358 | + if "filename" in catalog.columns: |
| 359 | + filename = f'{img_dir}/{label}_{entry["filename"]}.fits' |
| 360 | + else: |
| 361 | + filename = f"{img_dir}/{label}_{tag}.fits" |
| 362 | + |
| 363 | + celestial_capture(survey, right_ascension, declination, filename) |
| 364 | + except Exception as err: |
| 365 | + series = entry.to_frame().T |
| 366 | + failed = pd.concat([failed, series], ignore_index=True) |
| 367 | + print(f"Failed to capture image. {err}") |
0 commit comments