Skip to content

Commit

Permalink
Checkpoint write of main MSv2 subtables
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Mar 27, 2024
1 parent 55cb265 commit 8dd406e
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 54 deletions.
159 changes: 157 additions & 2 deletions daskms/experimental/katdal/msv2_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def select(self, **kwargs):
)
return result

def _xarray_factory(self, scan_index, scan_state, target):
def _main_xarray_factory(self, scan_index, scan_state, target):
# Extract numpy and dask products
dataset = self._dataset
cp_info = self._cp_info
Expand Down Expand Up @@ -176,5 +176,160 @@ def _xarray_factory(self, scan_index, scan_state, target):

def scans(self):
"""Proxies :meth:`katdal.scans`"""
xds = [self._xarray_factory(*scan_data) for scan_data in self._dataset.scans()]
xds = [
self._main_xarray_factory(*scan_data) for scan_data in self._dataset.scans()
]
yield xarray.concat(xds, dim="row" if self._row_view else "time")

def _antenna_xarray_factory(self):
antennas = self._dataset.ants
nant = len(antennas)
return xarray.Dataset(
{
"NAME": ("row", np.asarray([a.name for a in antennas], dtype=object)),
"STATION": (
"row",
np.asarray([a.name for a in antennas], dtype=object),
),
"POSITION": (
("row", "xyz"),
np.asarray([a.position_ecef for a in antennas]),
),
"OFFSET": (("row", "xyz"), np.zeros((nant, 3))),
"DISH_DIAMETER": ("row", np.asarray([a.diameter for a in antennas])),
"MOUNT": ("row", np.array(["ALT-AZ"] * nant, dtype=object)),
"TYPE": ("row", np.array(["GROUND-BASED"] * nant, dtype=object)),
"FLAG_ROW": ("row", np.zeros(nant, dtype=np.int32)),
}
)

def _spw_xarray_factory(self):
def ref_freq(chan_freqs):
return chan_freqs[len(chan_freqs) // 2].astype(np.float64)

return [
xarray.Dataset(
{
"NUM_CHAN": (("row",), np.array([spw.num_chans], dtype=np.int32)),
"CHAN_FREQ": (("row", "chan"), spw.channel_freqs[np.newaxis, :]),
"RESOLUTION": (("row", "chan"), spw.channel_freqs[np.newaxis, :]),
"CHAN_WIDTH": (
("row", "chan"),
np.full_like(spw.channel_freqs, spw.channel_width)[
np.newaxis, :
],
),
"EFFECTIVE_BW": (
("row", "chan"),
np.full_like(spw.channel_freqs, spw.channel_width)[
np.newaxis, :
],
),
"MEAS_FREQ_REF": ("row", np.array([5], dtype=np.int32)),
"REF_FREQUENCY": ("row", [ref_freq(spw.channel_freqs)]),
"NAME": ("row", np.asarray([f"{spw.band}-band"], dtype=object)),
"FREQ_GROUP_NAME": (
"row",
np.asarray([f"{spw.band}-band"], dtype=object),
),
"FREQ_GROUP": ("row", np.zeros(1, dtype=np.int32)),
"IF_CONV_CHAN": ("row", np.zeros(1, dtype=np.int32)),
"NET_SIDEBAND": ("row", np.ones(1, dtype=np.int32)),
"TOTAL_BANDWIDTH": ("row", np.asarray([spw.channel_freqs.sum()])),
"FLAG_ROW": ("row", np.zeros(1, dtype=np.int32)),
}
)
for spw in self._dataset.spectral_windows
]

def _pol_xarray_factory(self):
pol_num = {"H": 0, "V": 1}
# MeerKAT only has linear feeds, these map to
# CASA ["XX", "XY", "YX", "YY"]
pol_types = {"HH": 9, "HV": 10, "VH": 11, "VV": 12}
return xarray.Dataset(
{
"NUM_CORR": ("row", np.array([len(self._pols_to_use)], dtype=np.int32)),
"CORR_PRODUCT": (
("row", "corr", "corrprod_idx"),
np.array(
[[[pol_num[p[0]], pol_num[p[1]]] for p in self._pols_to_use]],
dtype=np.int32,
),
),
"CORR_TYPE": (
("row", "corr"),
np.asarray(
[[pol_types[p] for p in self._pols_to_use]], dtype=np.int32
),
),
"FLAG_ROW": ("row", np.zeros(1, dtype=np.int32)),
}
)

def _ddid_xarray_factory(self):
return xarray.Dataset(
{
"SPECTRAL_WINDOW_ID": np.zeros(1, dtype=np.int32),
"POLARIZATION_ID": np.zeros(1, dtype=np.int32),
"FLAG_ROW": np.zeros(1, dtype=np.int32),
}
)

def _feed_xarray_factory(self):
nfeeds = len(self._dataset.ants)
NRECEPTORS = 2

return xarray.Dataset(
{
# ID of antenna in this array (integer)
"ANTENNA_ID": ("row", np.arange(nfeeds, dtype=np.int32)),
# Id for BEAM model (integer)
"BEAM_ID": ("row", np.ones(nfeeds, dtype=np.int32)),
# Beam position offset (on sky but in antenna reference frame): (double, 2-dim)
"BEAM_OFFSET": (
("row", "receptors", "radec"),
np.zeros((nfeeds, 2, 2), dtype=np.float64),
),
# Feed id (integer)
"FEED_ID": ("row", np.zeros(nfeeds, dtype=np.int32)),
# Interval for which this set of parameters is accurate (double)
"INTERVAL": ("row", np.zeros(nfeeds, dtype=np.float64)),
# Number of receptors on this feed (probably 1 or 2) (integer)
"NUM_RECEPTORS": ("row", np.full(nfeeds, NRECEPTORS, dtype=np.int32)),
# Type of polarisation to which a given RECEPTOR responds (string, 1-dim)
"POLARIZATION_TYPE": (
("row", "receptors"),
np.array([["X", "Y"]] * nfeeds, dtype=object),
),
# D-matrix i.e. leakage between two receptors (complex, 2-dim)
"POL_RESPONSE": (
("row", "receptors", "receptors-2"),
np.array([np.eye(2, dtype=np.complex64) for _ in range(nfeeds)]),
),
# Position of feed relative to feed reference position (double, 1-dim, shape=(3,))
"POSITION": (("row", "xyz"), np.zeros((nfeeds, 3), np.float64)),
# The reference angle for polarisation (double, 1-dim). A parallactic angle of
# 0 means that V is aligned to x (celestial North), but we are mapping H to x
# so we have to correct with a -90 degree rotation.
"RECEPTOR_ANGLE": (
("row", "receptors"),
np.full((nfeeds, NRECEPTORS), -np.pi / 2, dtype=np.float64),
),
# ID for this spectral window setup (integer)
"SPECTRAL_WINDOW_ID": ("row", np.full(nfeeds, -1, dtype=np.int32)),
# Midpoint of time for which this set of parameters is accurate (double)
"TIME": ("row", np.zeros(nfeeds, dtype=np.float64)),
}
)

def subtables(self):
self.select(reset="")

return {
"ANTENNA": self._antenna_xarray_factory(),
"DATA_DESCRIPTION": self._ddid_xarray_factory(),
"SPECTRAL_WINDOW": self._spw_xarray_factory(),
"POLARIZATION": self._pol_xarray_factory(),
"FEED": self._feed_xarray_factory(),
}
60 changes: 8 additions & 52 deletions daskms/experimental/katdal/tests/test_chunkstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,65 +21,21 @@ def test_chunkstore(tmp_path_factory, dataset, auto_corrs, row_dim, out_store):
proxy = MSv2DatasetProxy(dataset, auto_corrs, row_dim)
all_antennas = proxy.ants
xds = list(proxy.scans())
sub_xds = proxy.subtables()

ant_xds = [
xarray.Dataset(
{
"NAME": (("row",), np.asarray([a.name for a in all_antennas])),
"OFFSET": (
("row", "xyz"),
np.asarray(np.zeros((len(all_antennas), 3))),
),
"POSITION": (
("row", "xyz"),
np.asarray([a.position_ecef for a in all_antennas]),
),
"DISH_DIAMETER": (
("row",),
np.asarray([a.diameter for a in all_antennas]),
),
# "FLAG_ROW": (("row","xyz"),
# np.zeros([a.flags for a in all_antennas],np.uint8)
# )
}
)
]

spw = dataset.spectral_windows[dataset.spw]

spw_xds = [
xarray.Dataset(
{
"CHAN_FREQ": (("row", "chan"), dataset.channel_freqs[np.newaxis, :]),
"CHAN_WIDTH": (
("row", "chan"),
np.full_like(dataset.channel_freqs, dataset.channel_width)[
np.newaxis, :
],
),
"EFFECTIVE_BW": (
("row", "chan"),
np.full_like(dataset.channel_freqs, dataset.channel_width)[
np.newaxis, :
],
),
"FLAG_ROW": (("row",), np.zeros(1, dtype=np.int32)),
"NUM_CHAN": (("row",), np.array([spw.num_chans], dtype=np.int32)),
}
)
]

print(spw_xds)
print(ant_xds)
print(xds)
pprint(xds)
pprint(sub_xds)

# Reintroduce the shutil.rmtree and remote the tmp_path_factory
# to test in the local directory
# shutil.rmtree(out_store, ignore_errors=True)
out_store = tmp_path_factory.mktemp("output") / out_store

dask.compute(xds_to_zarr(xds, out_store))
dask.compute(xds_to_zarr(ant_xds, f"{out_store}::ANTENNA"))
writes = [
xds_to_zarr(xds, out_store),
*(xds_to_zarr(ds, f"{out_store}::{k}") for k, ds in sub_xds.items()),
]
dask.compute(writes)

# Compare visibilities, weights and flags
(read_xds,) = dask.compute(xds_from_zarr(out_store))
Expand Down

0 comments on commit 8dd406e

Please sign in to comment.