-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
546 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
try: | ||
import katdal | ||
except ImportError as e: | ||
raise ImportError("katdal is not installed\n" "pip install dask-ms[katdal]") from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from collections import namedtuple | ||
|
||
import numpy as np | ||
|
||
CPInfo = namedtuple("CPInfo", "ant1_index ant2_index ant1 ant2 cp_index") | ||
|
||
|
||
def corrprod_index(dataset, pols_to_use, include_auto_corrs=False): | ||
"""The correlator product index (with -1 representing missing indices).""" | ||
corrprod_to_index = {tuple(cp): n for n, cp in enumerate(dataset.corr_products)} | ||
|
||
# ========================================== | ||
# Generate per-baseline antenna pairs and | ||
# correlator product indices | ||
# ========================================== | ||
|
||
def _cp_index(a1, a2, pol): | ||
"""Create correlator product index from antenna pair and pol.""" | ||
a1 = a1.name + pol[0].lower() | ||
a2 = a2.name + pol[1].lower() | ||
return corrprod_to_index.get((a1, a2), -1) | ||
|
||
# Generate baseline antenna pairs | ||
auto_corrs = 0 if include_auto_corrs else 1 | ||
ant1_index, ant2_index = np.triu_indices(len(dataset.ants), auto_corrs) | ||
ant1_index, ant2_index = (a.astype(np.int32) for a in (ant1_index, ant2_index)) | ||
|
||
# Order as similarly to the input as possible, which gives better performance | ||
# in permute_baselines. | ||
bl_indices = list(zip(ant1_index, ant2_index)) | ||
bl_indices.sort( | ||
key=lambda ants: _cp_index( | ||
dataset.ants[ants[0]], dataset.ants[ants[1]], pols_to_use[0] | ||
) | ||
) | ||
# Undo the zip | ||
ant1_index[:] = [bl[0] for bl in bl_indices] | ||
ant2_index[:] = [bl[1] for bl in bl_indices] | ||
ant1 = [dataset.ants[a1] for a1 in ant1_index] | ||
ant2 = [dataset.ants[a2] for a2 in ant2_index] | ||
|
||
# Create actual correlator product index | ||
cp_index = [_cp_index(a1, a2, p) for a1, a2 in zip(ant1, ant2) for p in pols_to_use] | ||
cp_index = np.array(cp_index, dtype=np.int32) | ||
cp_index = cp_index.reshape(-1, len(pols_to_use)) | ||
|
||
return CPInfo(ant1_index, ant2_index, ant1, ant2, cp_index) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
MEERKAT_ANTENNA_DESCRIPTIONS = [ | ||
"m000, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -8.264 -207.290 8.597", | ||
"m001, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 1.121 -171.762 8.471", | ||
"m002, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -32.113 -224.236 8.645", | ||
"m003, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -66.518 -202.276 8.285", | ||
"m004, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -123.624 -252.946 8.513", | ||
"m005, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -102.088 -283.120 8.875", | ||
"m006, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -18.232 -295.428 9.188", | ||
"m007, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -89.592 -402.732 9.769", | ||
"m008, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -93.527 -535.026 10.445", | ||
"m009, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 32.357 -371.056 10.140", | ||
"m010, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 88.095 -511.872 11.186", | ||
"m011, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 84.012 -352.078 10.151", | ||
"m012, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 140.019 -368.267 10.449", | ||
"m013, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 236.792 -393.460 11.124", | ||
"m014, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 280.669 -285.792 10.547", | ||
"m015, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 210.644 -219.142 9.738", | ||
"m016, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 288.159 -185.873 9.795", | ||
"m017, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 199.624 -112.263 8.955", | ||
"m018, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 105.727 -245.870 9.529", | ||
"m019, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 170.787 -285.223 10.071", | ||
"m020, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 97.016 -299.638 9.877", | ||
"m021, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -295.966 -327.241 8.117", | ||
"m022, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -373.002 0.544 5.649", | ||
"m023, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -322.306 -142.185 6.825", | ||
"m024, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -351.046 150.088 4.845", | ||
"m025, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -181.978 225.617 5.068", | ||
"m026, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -99.004 17.045 6.811", | ||
"m027, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 40.475 -23.112 7.694", | ||
"m028, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -51.179 -87.170 7.636", | ||
"m029, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -88.762 -124.111 7.700", | ||
"m030, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 171.281 113.949 7.278", | ||
"m031, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 246.567 93.756 7.469", | ||
"m032, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 461.275 175.505 7.367", | ||
"m033, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 580.678 863.959 3.600", | ||
"m034, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 357.811 -28.308 8.972", | ||
"m035, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 386.152 -180.894 10.290", | ||
"m036, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 388.257 -290.759 10.812", | ||
"m037, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 380.286 -459.309 12.172", | ||
"m038, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 213.308 -569.080 11.946", | ||
"m039, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 253.748 -592.147 12.441", | ||
"m040, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -26.858 -712.219 11.833", | ||
"m041, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -287.545 -661.678 9.949", | ||
"m042, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -361.714 -460.318 8.497", | ||
"m043, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -629.853 -128.326 5.264", | ||
"m044, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -896.164 600.497 -0.640", | ||
"m045, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -1832.860 266.750 0.108", | ||
"m046, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -1467.341 1751.923 -7.078", | ||
"m047, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -578.296 -517.297 7.615", | ||
"m048, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -2805.653 2686.863 -9.755", | ||
"m049, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -3605.957 436.462 2.696", | ||
"m050, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -2052.336 -843.715 5.338", | ||
"m051, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -850.255 -769.359 7.614", | ||
"m052, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -593.192 -1148.652 10.550", | ||
"m053, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 9.365 -1304.462 15.032", | ||
"m054, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 871.980 -499.812 13.364", | ||
"m055, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 1201.780 96.492 10.023", | ||
"m056, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 1598.403 466.668 6.990", | ||
"m057, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 294.645 3259.915 -10.637", | ||
"m058, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 2805.764 2686.873 -3.660", | ||
"m059, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 3686.427 758.895 11.822", | ||
"m060, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 3419.683 -1840.478 23.697", | ||
"m061, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -16.409 -2323.779 21.304", | ||
"m062, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -1440.632 -2503.773 21.683", | ||
"m063, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -3419.585 -1840.480 16.383", | ||
] | ||
|
||
assert all( | ||
int(a.split(",")[0][1:]) == i for i, a in enumerate(MEERKAT_ANTENNA_DESCRIPTIONS) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
from functools import partial | ||
|
||
import dask.array as da | ||
|
||
from daskms.experimental.zarr import xds_to_zarr | ||
|
||
from katdal.chunkstore_npy import NpyFileChunkStore | ||
from katdal.dataset import Subarray | ||
from katdal.lazy_indexer import DaskLazyIndexer | ||
from katdal.spectral_window import SpectralWindow | ||
from katdal.vis_flags_weights import ChunkStoreVisFlagsWeights | ||
from katdal.test.test_vis_flags_weights import put_fake_dataset | ||
from katdal.test.test_dataset import MinimalDataSet | ||
from katpoint import Antenna, Target, Timestamp | ||
|
||
import numba | ||
import numpy as np | ||
import pytest | ||
import xarray | ||
|
||
from daskms.experimental.katdal.meerkat_antennas import MEERKAT_ANTENNA_DESCRIPTIONS | ||
from daskms.experimental.katdal.transpose import transpose | ||
from daskms.experimental.katdal.corr_products import corrprod_index | ||
from daskms.experimental.katdal.uvw import uvw_coords | ||
|
||
SPW = SpectralWindow( | ||
centre_freq=1284e6, channel_width=0, num_chans=16, sideband=1, bandwidth=856e6 | ||
) | ||
|
||
|
||
class FakeDataset(MinimalDataSet): | ||
def __init__( | ||
self, | ||
path, | ||
targets, | ||
timestamps, | ||
antennas=MEERKAT_ANTENNA_DESCRIPTIONS, | ||
spw=SPW, | ||
): | ||
antennas = list(map(Antenna, antennas)) | ||
corr_products = [ | ||
(a1.name + c1, a2.name + c2) | ||
for i, a1 in enumerate(antennas) | ||
for a2 in antennas[i:] | ||
for c1 in ("h", "v") | ||
for c2 in ("h", "v") | ||
] | ||
|
||
subarray = Subarray(antennas, corr_products) | ||
assert len(subarray.ants) > 0 | ||
|
||
store = NpyFileChunkStore(str(path)) | ||
shape = (len(timestamps), spw.num_chans, len(corr_products)) | ||
# data, chunk_info = put_fake_dataset(store, "cb1", shape) | ||
data, chunk_info = put_fake_dataset( | ||
store, | ||
"cb1", | ||
shape, | ||
chunk_overrides={ | ||
"correlator_data": (1, spw.num_chans, len(corr_products)), | ||
"flags": (1, spw.num_chans, len(corr_products)), | ||
"weights": (1, spw.num_chans, len(corr_products)), | ||
}, | ||
) | ||
self._vfw = ChunkStoreVisFlagsWeights(store, chunk_info) | ||
self._vis = None | ||
self._weights = None | ||
self._flags = None | ||
super().__init__(targets, timestamps, subarray, spw) | ||
|
||
def _set_keep( | ||
self, | ||
time_keep=None, | ||
freq_keep=None, | ||
corrprod_keep=None, | ||
weights_keep=None, | ||
flags_keep=None, | ||
): | ||
super()._set_keep(time_keep, freq_keep, corrprod_keep, weights_keep, flags_keep) | ||
stage1 = (time_keep, freq_keep, corrprod_keep) | ||
self._vis = DaskLazyIndexer(self._vfw.vis, stage1) | ||
self._weights = DaskLazyIndexer(self._vfw.weights, stage1) | ||
self._flags = DaskLazyIndexer(self._vfw.flags, stage1) | ||
|
||
@property | ||
def vis(self): | ||
if self._vis is None: | ||
raise ValueError("Selection has not yet been performed") | ||
return self._vis | ||
|
||
@property | ||
def flags(self): | ||
if self._flags is None: | ||
raise ValueError("Selection has not yet been performed") | ||
return self._flags | ||
|
||
@property | ||
def weights(self): | ||
if self._weights is None: | ||
raise ValueError("Selection has not yet been performed") | ||
|
||
return self._weights | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def dataset(request, tmp_path_factory): | ||
path = tmp_path_factory.mktemp("chunks") | ||
targets = [ | ||
# It would have been nice to have radec = 19:39, -63:42 but then | ||
# selection by description string does not work because the catalogue's | ||
# description string pads it out to radec = 19:39:00.00, -63:42:00.0. | ||
# (XXX Maybe fix Target comparison in katpoint to support this?) | ||
Target("J1939-6342 | PKS1934-638, radec bpcal, 19:39:25.03, -63:42:45.6"), | ||
Target("J1939-6342, radec gaincal, 19:39:25.03, -63:42:45.6"), | ||
Target("J0408-6545 | PKS 0408-65, radec bpcal, 4:08:20.38, -65:45:09.1"), | ||
Target("J1346-6024 | Cen B, radec, 13:46:49.04, -60:24:29.4"), | ||
] | ||
# Ensure that len(timestamps) is an integer multiple of len(targets) | ||
timestamps = 1234667890.0 + 8.0 * np.arange(20) | ||
|
||
spw = SpectralWindow( | ||
centre_freq=1284e6, channel_width=0, num_chans=4096, sideband=1, bandwidth=856e6 | ||
) | ||
|
||
return FakeDataset( | ||
path, targets, timestamps, antennas=MEERKAT_ANTENNA_DESCRIPTIONS[:16], spw=spw | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("include_auto_corrs", [False]) | ||
@pytest.mark.parametrize("row_dim", [True, False]) | ||
@pytest.mark.parametrize("out_store", ["output.zarr"]) | ||
def test_chunkstore(tmp_path_factory, dataset, include_auto_corrs, row_dim, out_store): | ||
cp_info = corrprod_index(dataset, ["HH", "HV", "VH", "VV"], include_auto_corrs) | ||
all_antennas = dataset.ants | ||
|
||
xds = [] | ||
|
||
for scan_index, scan_state, target in dataset.scans(): | ||
# Extract numpy and dask products | ||
time_utc = dataset.timestamps | ||
t_chunks, chan_chunks, cp_chunks = dataset.vis.dataset.chunks | ||
|
||
# Modified Julian Date in Seconds | ||
time_mjds = np.asarray( | ||
[t.to_mjd() * 24 * 60 * 60 for t in map(Timestamp, time_utc)] | ||
) | ||
|
||
# Create a dask chunking transform | ||
rechunk = partial(da.rechunk, chunks=(t_chunks, chan_chunks, cp_chunks)) | ||
|
||
# Transpose from (time, chan, corrprod) to (time, bl, chan, corr) | ||
cpi = cp_info.cp_index | ||
flag_transpose = partial( | ||
transpose, | ||
cp_index=cpi, | ||
data_type=numba.literally("flags"), | ||
row=row_dim, | ||
) | ||
weight_transpose = partial( | ||
transpose, | ||
cp_index=cpi, | ||
data_type=numba.literally("weights"), | ||
row=row_dim, | ||
) | ||
vis_transpose = partial( | ||
transpose, cp_index=cpi, data_type=numba.literally("vis"), row=row_dim | ||
) | ||
|
||
flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose)) | ||
weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose)) | ||
vis = DaskLazyIndexer(dataset.vis, (), transforms=(vis_transpose,)) | ||
|
||
time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1)) | ||
ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0])) | ||
ant2 = da.from_array(cp_info.ant2_index[None, :], chunks=(1, cpi.shape[0])) | ||
|
||
uvw = uvw_coords( | ||
target, | ||
da.from_array(time_utc, chunks=t_chunks), | ||
all_antennas, | ||
cp_info, | ||
row=row_dim, | ||
) | ||
|
||
time, ant1, ant2 = da.broadcast_arrays(time, ant1, ant2) | ||
|
||
if row_dim: | ||
primary_dims = ("row",) | ||
time = time.ravel().rechunk({0: vis.dataset.chunks[0]}) | ||
ant1 = ant1.ravel().rechunk({0: vis.dataset.chunks[0]}) | ||
ant2 = ant2.ravel().rechunk({0: vis.dataset.chunks[0]}) | ||
else: | ||
primary_dims = ("time", "baseline") | ||
|
||
xds.append( | ||
xarray.Dataset( | ||
{ | ||
# Primary indexing columns | ||
"TIME": (primary_dims, time), | ||
"ANTENNA1": (primary_dims, ant1), | ||
"ANTENNA2": (primary_dims, ant2), | ||
"FEED1": (primary_dims, da.zeros_like(ant1)), | ||
"FEED2": (primary_dims, da.zeros_like(ant1)), | ||
# TODO(sjperkins) | ||
# Fill these in with real values | ||
"DATA_DESC_ID": (primary_dims, da.zeros_like(ant1)), | ||
"FIELD_ID": (primary_dims, da.zeros_like(ant1)), | ||
"STATE_ID": (primary_dims, da.zeros_like(ant1)), | ||
"ARRAY_ID": (primary_dims, da.zeros_like(ant1)), | ||
"OBSERVATION_ID": (primary_dims, da.zeros_like(ant1)), | ||
"PROCESSOR_ID": (primary_dims, da.ones_like(ant1)), | ||
"SCAN_NUMBER": (primary_dims, da.full_like(ant1, scan_index)), | ||
"TIME_CENTROID": (primary_dims, time), | ||
"INTERVAL": (primary_dims, da.full_like(time, dataset.dump_period)), | ||
"EXPOSURE": (primary_dims, da.full_like(time, dataset.dump_period)), | ||
"UVW": (primary_dims + ("uvw",), uvw), | ||
"DATA": (primary_dims + ("chan", "corr"), vis.dataset), | ||
"FLAG": (primary_dims + ("chan", "corr"), flags.dataset), | ||
"WEIGHT_SPECTRUM": ( | ||
primary_dims + ("chan", "corr"), | ||
weights.dataset, | ||
), | ||
# Estimated RMS noise per frequency channel | ||
# note this column is used when computing calibration weights | ||
# in CASA - WEIGHT_SPECTRUM may be modified based on the | ||
# values in this column. See | ||
# https://casadocs.readthedocs.io/en/stable/notebooks/data_weights.html | ||
# for further details | ||
"SIGMA_SPECTRUM": ( | ||
primary_dims + ("chan", "corr"), | ||
weights.dataset**-0.5, | ||
), | ||
} | ||
) | ||
) | ||
|
||
xds = xarray.concat(xds, dim=primary_dims[0]) | ||
|
||
import dask | ||
import shutil | ||
|
||
# out_store = tmp_path_factory.mktemp(out_store) | ||
|
||
print(xds) | ||
|
||
shutil.rmtree(out_store, ignore_errors=True) | ||
dask.compute(xds_to_zarr(xds, out_store)) | ||
|
||
# auto_corrs = 0 if include_auto_corrs else 1 | ||
# ant1, ant2 = np.triu_indices(len(dataset.ants), auto_corrs) | ||
# ant1, ant2 = (a.astype(np.int32) for a in (ant1, ant2)) |
Oops, something went wrong.