Skip to content

Commit a832e62

Browse files
feat(utils): add functions to detect sources and generate mask
This commit adds new functions to the `rgc.utils.data` module. The `generate_mask` function is added to detect sources in an image and generate a mask. The `generate_mask_bulk` function is added to generate masks for a catalog of celestial objects. These functions are useful for image processing and analysis tasks. The `generate_mask` function takes various parameters such as the image path, mask directory, frequency, beam size, dilation factor, and threshold values for pixel and island detection. It uses the `bdsf` library to process the image and export the generated mask. The `generate_mask_bulk` function takes a pandas DataFrame containing the catalog of celestial objects, the image directory, mask directory, frequency, and beam size. It iterates over the catalog entries and calls the `generate_mask` function for each entry to generate masks for all the images in the catalog. These new functions enhance the functionality of the `rgc` package and provide convenient tools for working with astronomical images.
1 parent 258e559 commit a832e62

File tree

6 files changed

+498
-94
lines changed

6 files changed

+498
-94
lines changed

CHANGELOG.md

+93-93
Large diffs are not rendered by default.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ classifiers=[
1414
dependencies = [
1515
"astropy>=5.2.2",
1616
"astroquery>=0.4.7",
17+
"bdsf>=1.11.1",
1718
"numpy>=1.24.4",
1819
"pandas>=2.0.3",
1920
"pillow>=10.4.0",

rgc/utils/data.py

+102
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pathlib import Path
1313
from typing import Optional, cast
1414

15+
import bdsf
1516
import numpy as np
1617
import pandas as pd
1718
import torch
@@ -415,3 +416,104 @@ def remove_artifacts(folder: str, extension: list[str]) -> None:
415416
os.remove(os.path.join(folder, file))
416417

417418
print(f"Artifacts removed from {folder} with extensions {', '.join(extension)}")
419+
420+
421+
def generate_mask(
422+
image_path: str,
423+
mask_dir: str,
424+
freq: float,
425+
beam: tuple[float, float, float],
426+
dilation: int,
427+
threshold_pixel: float = 5.0,
428+
threshold_island: float = 3.0,
429+
) -> None:
430+
"""
431+
Detect sources in the image and generate a mask.
432+
433+
:param image_path: Path to the image file
434+
:type image_path: str
435+
436+
:param mask_dir: Path to the directory to save the mask
437+
:type mask_dir: str
438+
439+
:param freq: Frequency of the image in MHz
440+
:type freq: float
441+
442+
:param beam: Beam size of the image in arcsec
443+
:type beam: tuple
444+
445+
:param dilation: Dilation factor for the mask
446+
:type dilation: int
447+
448+
:param threshold_pixel: Threshold for island peak in number of sigma above the mean
449+
:type threshold_pixel: float
450+
451+
:param threshold_island: Threshold for island detection in number of sigma above the mean
452+
:type threshold_island: float
453+
"""
454+
try:
455+
image = bdsf.process_image(
456+
image_path,
457+
beam=beam,
458+
thresh_isl=threshold_island,
459+
thresh_pix=threshold_pixel,
460+
frequency=freq,
461+
)
462+
463+
mask_file = Path(mask_dir) / Path(image_path).name
464+
Path(mask_file).parent.mkdir(parents=True, exist_ok=True)
465+
466+
image.export_image(
467+
img_type="island_mask",
468+
outfile=mask_file,
469+
clobber=True,
470+
mask_dilation=dilation,
471+
)
472+
473+
except Exception:
474+
print("Failed to generate mask.")
475+
return None
476+
477+
478+
def generate_mask_bulk(
479+
catalog: pd.DataFrame, img_dir: str, mask_dir: str, freq: float, beam: tuple[float, float, float]
480+
) -> None:
481+
"""
482+
Generate masks for a catalog of celestial objects.
483+
484+
:param catalog: A pandas DataFrame containing the catalog of celestial objects.
485+
:type catalog: pd.DataFrame
486+
487+
:param img_dir: The path to the directory containing the images.
488+
:type img_dir: str
489+
490+
:param mask_dir: The path to the directory to save the masks.
491+
:type mask_dir: str
492+
493+
:param freq: Frequency of the image in MHz
494+
:type freq: float
495+
496+
:param beam: Beam size of the image in arcsec
497+
:type beam: tuple
498+
"""
499+
for _, entry in catalog.iterrows():
500+
try:
501+
filename = entry["filename"]
502+
image_path = os.path.join(img_dir, f"{filename}.fits")
503+
dilation = entry["dilation"]
504+
threshold_pixel = entry["background sigma"]
505+
threshold_island = entry["foreground sigma"]
506+
507+
generate_mask(
508+
image_path,
509+
mask_dir,
510+
freq,
511+
beam,
512+
dilation,
513+
threshold_pixel,
514+
threshold_island,
515+
)
516+
517+
except Exception as err:
518+
print(f"Failed to generate mask. {err}")
519+
return None

tests/test_generate_mask.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import unittest
2+
from pathlib import Path
3+
from unittest.mock import MagicMock, patch
4+
5+
from rgc.utils.data import generate_mask
6+
7+
8+
class TestGenerateMask(unittest.TestCase):
9+
@patch("rgc.utils.data.bdsf.process_image", return_value=MagicMock())
10+
@patch("rgc.utils.data.Path.mkdir")
11+
@patch("rgc.utils.data.Path.name", new_callable=MagicMock(return_value="image.fits"))
12+
@patch("rgc.utils.data.Path", autospec=True)
13+
@patch("rgc.utils.data.print")
14+
def test_generate_mask_success(self, mock_print, mock_path, mock_name, mock_mkdir, mock_process_image):
15+
# Setup mocks
16+
mock_path_instance = mock_path.return_value
17+
# Ensure that the truediv operator returns the correct Path
18+
mock_path_instance.__truediv__.return_value = Path("path/to/mask_dir/image.fits")
19+
mock_path_instance.parent.mkdir.return_value = None
20+
mock_path_instance.name = "image.fits" # Correct the name attribute
21+
22+
# Setup test parameters
23+
image_path = "path/to/image.fits"
24+
mask_dir = "path/to/mask_dir"
25+
freq = 1400.0
26+
beam = (5.0, 5.0, 5.0)
27+
dilation = 2
28+
threshold_pixel = 5.0
29+
threshold_island = 3.0
30+
31+
# Run the function
32+
generate_mask(
33+
image_path=image_path,
34+
mask_dir=mask_dir,
35+
freq=freq,
36+
beam=beam,
37+
dilation=dilation,
38+
threshold_pixel=threshold_pixel,
39+
threshold_island=threshold_island,
40+
)
41+
42+
# Verify that print was not called
43+
mock_print.assert_not_called()
44+
45+
# Verify that the mask file path was generated correctly
46+
expected_mask_file = Path(mask_dir) / "image.fits"
47+
mock_path_instance.__truediv__.assert_called_once_with("image.fits")
48+
49+
# Verify that the parent directory was created
50+
mock_path_instance.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True)
51+
52+
# Verify that process_image was called correctly
53+
mock_process_image.assert_called_once_with(
54+
image_path, beam=beam, thresh_isl=threshold_island, thresh_pix=threshold_pixel, frequency=freq
55+
)
56+
57+
# Verify that export_image is called on the mock returned by process_image
58+
mock_process_image.return_value.export_image.assert_called_once_with(
59+
img_type="island_mask",
60+
outfile=expected_mask_file,
61+
clobber=True,
62+
mask_dilation=dilation,
63+
)
64+
65+
@patch("rgc.utils.data.bdsf.process_image", return_value=MagicMock())
66+
@patch("rgc.utils.data.Path.mkdir", side_effect=PermissionError("Permission denied"))
67+
@patch("rgc.utils.data.Path.parent", new_callable=MagicMock)
68+
@patch("rgc.utils.data.Path.name", new_callable=MagicMock(return_value="image.fits"))
69+
@patch("rgc.utils.data.Path", autospec=True)
70+
@patch("rgc.utils.data.print")
71+
def test_generate_mask_permission_error(
72+
self, mock_print, mock_path, mock_name, mock_parent, mock_mkdir, mock_process_image
73+
):
74+
# Setup mocks
75+
mock_path.return_value = Path("/mock/path")
76+
mock_parent.return_value = mock_path
77+
mock_process_image.return_value = MagicMock()
78+
79+
# Setup test parameters
80+
image_path = "path/to/image.fits"
81+
mask_dir = "path/to/mask_dir"
82+
freq = 1400.0
83+
beam = (5.0, 5.0, 5.0)
84+
dilation = 2
85+
threshold_pixel = 5.0
86+
threshold_island = 3.0
87+
88+
# Run the function
89+
generate_mask(
90+
image_path=image_path,
91+
mask_dir=mask_dir,
92+
freq=freq,
93+
beam=beam,
94+
dilation=dilation,
95+
threshold_pixel=threshold_pixel,
96+
threshold_island=threshold_island,
97+
)
98+
99+
# Verify that print was called with the correct error message
100+
mock_print.assert_called_once_with("Failed to generate mask.")
101+
102+
@patch("rgc.utils.data.bdsf.process_image", side_effect=Exception("Process failed"))
103+
@patch("rgc.utils.data.print")
104+
def test_generate_mask_failure(self, mock_print, mock_process_image):
105+
# Setup test parameters
106+
image_path = "path/to/image.fits"
107+
mask_dir = "path/to/mask_dir"
108+
freq = 1400.0
109+
beam = (5.0, 5.0, 5.0)
110+
dilation = 2
111+
threshold_pixel = 5.0
112+
threshold_island = 3.0
113+
114+
# Run the function
115+
generate_mask(
116+
image_path=image_path,
117+
mask_dir=mask_dir,
118+
freq=freq,
119+
beam=beam,
120+
dilation=dilation,
121+
threshold_pixel=threshold_pixel,
122+
threshold_island=threshold_island,
123+
)
124+
125+
# Verify that print was called with the correct error message
126+
mock_print.assert_called_once_with("Failed to generate mask.")

tests/test_generate_mask_bulk.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import unittest
2+
from unittest.mock import call, patch
3+
4+
import pandas as pd
5+
6+
from rgc.utils.data import generate_mask_bulk
7+
8+
9+
class TestGenerateMaskBulk(unittest.TestCase):
10+
@patch("rgc.utils.data.generate_mask")
11+
@patch("rgc.utils.data.os.path.join", return_value="mock_image_path.fits")
12+
@patch("rgc.utils.data.print") # Mock print to suppress output during the test
13+
def test_generate_mask_bulk_success(self, mock_print, mock_join, mock_generate_mask):
14+
# Create a mock catalog DataFrame
15+
data = {
16+
"filename": ["image1", "image2"],
17+
"dilation": [2, 3],
18+
"background sigma": [5.0, 6.0],
19+
"foreground sigma": [3.0, 4.0],
20+
}
21+
catalog = pd.DataFrame(data)
22+
23+
# Test parameters
24+
img_dir = "mock/img/dir"
25+
mask_dir = "mock/mask/dir"
26+
freq = 1400.0
27+
beam = (5.0, 5.0, 5.0)
28+
29+
# Run the function
30+
generate_mask_bulk(
31+
catalog=catalog,
32+
img_dir=img_dir,
33+
mask_dir=mask_dir,
34+
freq=freq,
35+
beam=beam,
36+
)
37+
38+
# Check that os.path.join was called with the correct parameters
39+
expected_calls = [call(img_dir, "image1.fits"), call(img_dir, "image2.fits")]
40+
mock_join.assert_has_calls(expected_calls, any_order=False)
41+
42+
# Check that generate_mask was called twice with the correct parameters
43+
expected_mask_calls = [
44+
call(
45+
"mock_image_path.fits",
46+
mask_dir,
47+
freq,
48+
beam,
49+
2, # dilation for image1
50+
5.0, # background sigma for image1
51+
3.0, # foreground sigma for image1
52+
),
53+
call(
54+
"mock_image_path.fits",
55+
mask_dir,
56+
freq,
57+
beam,
58+
3, # dilation for image2
59+
6.0, # background sigma for image2
60+
4.0, # foreground sigma for image2
61+
),
62+
]
63+
mock_generate_mask.assert_has_calls(expected_mask_calls, any_order=False)
64+
65+
# Ensure no errors were printed
66+
mock_print.assert_not_called()
67+
68+
@patch("rgc.utils.data.generate_mask", side_effect=Exception("Mocked error"))
69+
@patch("rgc.utils.data.os.path.join", return_value="mock_image_path.fits")
70+
@patch("rgc.utils.data.print")
71+
def test_generate_mask_bulk_failure(self, mock_print, mock_join, mock_generate_mask):
72+
# Create a mock catalog DataFrame
73+
data = {
74+
"filename": ["image1"],
75+
"dilation": [2],
76+
"background sigma": [5.0],
77+
"foreground sigma": [3.0],
78+
}
79+
catalog = pd.DataFrame(data)
80+
81+
# Test parameters
82+
img_dir = "mock/img/dir"
83+
mask_dir = "mock/mask/dir"
84+
freq = 1400.0
85+
beam = (5.0, 5.0, 5.0)
86+
87+
# Run the function
88+
generate_mask_bulk(
89+
catalog=catalog,
90+
img_dir=img_dir,
91+
mask_dir=mask_dir,
92+
freq=freq,
93+
beam=beam,
94+
)
95+
96+
# Verify that the error message was printed
97+
mock_print.assert_called_once_with("Failed to generate mask. Mocked error")
98+
99+
# Check that generate_mask was called once before the exception
100+
mock_generate_mask.assert_called_once()
101+
102+
103+
if __name__ == "__main__":
104+
unittest.main()

0 commit comments

Comments
 (0)