diff --git a/daskms/experimental/katdal/msv2_proxy.py b/daskms/experimental/katdal/msv2_proxy.py index 1aebe076..468a4adb 100644 --- a/daskms/experimental/katdal/msv2_proxy.py +++ b/daskms/experimental/katdal/msv2_proxy.py @@ -72,7 +72,7 @@ def select(self, **kwargs): ) return result - def _xarray_factory(self, scan_index, scan_state, target): + def _main_xarray_factory(self, scan_index, scan_state, target): # Extract numpy and dask products dataset = self._dataset cp_info = self._cp_info @@ -176,5 +176,160 @@ def _xarray_factory(self, scan_index, scan_state, target): def scans(self): """Proxies :meth:`katdal.scans`""" - xds = [self._xarray_factory(*scan_data) for scan_data in self._dataset.scans()] + xds = [ + self._main_xarray_factory(*scan_data) for scan_data in self._dataset.scans() + ] yield xarray.concat(xds, dim="row" if self._row_view else "time") + + def _antenna_xarray_factory(self): + antennas = self._dataset.ants + nant = len(antennas) + return xarray.Dataset( + { + "NAME": ("row", np.asarray([a.name for a in antennas], dtype=object)), + "STATION": ( + "row", + np.asarray([a.name for a in antennas], dtype=object), + ), + "POSITION": ( + ("row", "xyz"), + np.asarray([a.position_ecef for a in antennas]), + ), + "OFFSET": (("row", "xyz"), np.zeros((nant, 3))), + "DISH_DIAMETER": ("row", np.asarray([a.diameter for a in antennas])), + "MOUNT": ("row", np.array(["ALT-AZ"] * nant, dtype=object)), + "TYPE": ("row", np.array(["GROUND-BASED"] * nant, dtype=object)), + "FLAG_ROW": ("row", np.zeros(nant, dtype=np.int32)), + } + ) + + def _spw_xarray_factory(self): + def ref_freq(chan_freqs): + return chan_freqs[len(chan_freqs) // 2].astype(np.float64) + + return [ + xarray.Dataset( + { + "NUM_CHAN": (("row",), np.array([spw.num_chans], dtype=np.int32)), + "CHAN_FREQ": (("row", "chan"), spw.channel_freqs[np.newaxis, :]), + "RESOLUTION": (("row", "chan"), spw.channel_freqs[np.newaxis, :]), + "CHAN_WIDTH": ( + ("row", "chan"), + np.full_like(spw.channel_freqs, spw.channel_width)[ + np.newaxis, : + ], + ), + "EFFECTIVE_BW": ( + ("row", "chan"), + np.full_like(spw.channel_freqs, spw.channel_width)[ + np.newaxis, : + ], + ), + "MEAS_FREQ_REF": ("row", np.array([5], dtype=np.int32)), + "REF_FREQUENCY": ("row", [ref_freq(spw.channel_freqs)]), + "NAME": ("row", np.asarray([f"{spw.band}-band"], dtype=object)), + "FREQ_GROUP_NAME": ( + "row", + np.asarray([f"{spw.band}-band"], dtype=object), + ), + "FREQ_GROUP": ("row", np.zeros(1, dtype=np.int32)), + "IF_CONV_CHAN": ("row", np.zeros(1, dtype=np.int32)), + "NET_SIDEBAND": ("row", np.ones(1, dtype=np.int32)), + "TOTAL_BANDWIDTH": ("row", np.asarray([spw.channel_freqs.sum()])), + "FLAG_ROW": ("row", np.zeros(1, dtype=np.int32)), + } + ) + for spw in self._dataset.spectral_windows + ] + + def _pol_xarray_factory(self): + pol_num = {"H": 0, "V": 1} + # MeerKAT only has linear feeds, these map to + # CASA ["XX", "XY", "YX", "YY"] + pol_types = {"HH": 9, "HV": 10, "VH": 11, "VV": 12} + return xarray.Dataset( + { + "NUM_CORR": ("row", np.array([len(self._pols_to_use)], dtype=np.int32)), + "CORR_PRODUCT": ( + ("row", "corr", "corrprod_idx"), + np.array( + [[[pol_num[p[0]], pol_num[p[1]]] for p in self._pols_to_use]], + dtype=np.int32, + ), + ), + "CORR_TYPE": ( + ("row", "corr"), + np.asarray( + [[pol_types[p] for p in self._pols_to_use]], dtype=np.int32 + ), + ), + "FLAG_ROW": ("row", np.zeros(1, dtype=np.int32)), + } + ) + + def _ddid_xarray_factory(self): + return xarray.Dataset( + { + "SPECTRAL_WINDOW_ID": np.zeros(1, dtype=np.int32), + "POLARIZATION_ID": np.zeros(1, dtype=np.int32), + "FLAG_ROW": np.zeros(1, dtype=np.int32), + } + ) + + def _feed_xarray_factory(self): + nfeeds = len(self._dataset.ants) + NRECEPTORS = 2 + + return xarray.Dataset( + { + # ID of antenna in this array (integer) + "ANTENNA_ID": ("row", np.arange(nfeeds, dtype=np.int32)), + # Id for BEAM model (integer) + "BEAM_ID": ("row", np.ones(nfeeds, dtype=np.int32)), + # Beam position offset (on sky but in antenna reference frame): (double, 2-dim) + "BEAM_OFFSET": ( + ("row", "receptors", "radec"), + np.zeros((nfeeds, 2, 2), dtype=np.float64), + ), + # Feed id (integer) + "FEED_ID": ("row", np.zeros(nfeeds, dtype=np.int32)), + # Interval for which this set of parameters is accurate (double) + "INTERVAL": ("row", np.zeros(nfeeds, dtype=np.float64)), + # Number of receptors on this feed (probably 1 or 2) (integer) + "NUM_RECEPTORS": ("row", np.full(nfeeds, NRECEPTORS, dtype=np.int32)), + # Type of polarisation to which a given RECEPTOR responds (string, 1-dim) + "POLARIZATION_TYPE": ( + ("row", "receptors"), + np.array([["X", "Y"]] * nfeeds, dtype=object), + ), + # D-matrix i.e. leakage between two receptors (complex, 2-dim) + "POL_RESPONSE": ( + ("row", "receptors", "receptors-2"), + np.array([np.eye(2, dtype=np.complex64) for _ in range(nfeeds)]), + ), + # Position of feed relative to feed reference position (double, 1-dim, shape=(3,)) + "POSITION": (("row", "xyz"), np.zeros((nfeeds, 3), np.float64)), + # The reference angle for polarisation (double, 1-dim). A parallactic angle of + # 0 means that V is aligned to x (celestial North), but we are mapping H to x + # so we have to correct with a -90 degree rotation. + "RECEPTOR_ANGLE": ( + ("row", "receptors"), + np.full((nfeeds, NRECEPTORS), -np.pi / 2, dtype=np.float64), + ), + # ID for this spectral window setup (integer) + "SPECTRAL_WINDOW_ID": ("row", np.full(nfeeds, -1, dtype=np.int32)), + # Midpoint of time for which this set of parameters is accurate (double) + "TIME": ("row", np.zeros(nfeeds, dtype=np.float64)), + } + ) + + def subtables(self): + self.select(reset="") + + return { + "ANTENNA": self._antenna_xarray_factory(), + "DATA_DESCRIPTION": self._ddid_xarray_factory(), + "SPECTRAL_WINDOW": self._spw_xarray_factory(), + "POLARIZATION": self._pol_xarray_factory(), + "FEED": self._feed_xarray_factory(), + } diff --git a/daskms/experimental/katdal/tests/test_chunkstore.py b/daskms/experimental/katdal/tests/test_chunkstore.py index 90e5b111..a292426d 100644 --- a/daskms/experimental/katdal/tests/test_chunkstore.py +++ b/daskms/experimental/katdal/tests/test_chunkstore.py @@ -21,65 +21,21 @@ def test_chunkstore(tmp_path_factory, dataset, auto_corrs, row_dim, out_store): proxy = MSv2DatasetProxy(dataset, auto_corrs, row_dim) all_antennas = proxy.ants xds = list(proxy.scans()) + sub_xds = proxy.subtables() - ant_xds = [ - xarray.Dataset( - { - "NAME": (("row",), np.asarray([a.name for a in all_antennas])), - "OFFSET": ( - ("row", "xyz"), - np.asarray(np.zeros((len(all_antennas), 3))), - ), - "POSITION": ( - ("row", "xyz"), - np.asarray([a.position_ecef for a in all_antennas]), - ), - "DISH_DIAMETER": ( - ("row",), - np.asarray([a.diameter for a in all_antennas]), - ), - # "FLAG_ROW": (("row","xyz"), - # np.zeros([a.flags for a in all_antennas],np.uint8) - # ) - } - ) - ] - - spw = dataset.spectral_windows[dataset.spw] - - spw_xds = [ - xarray.Dataset( - { - "CHAN_FREQ": (("row", "chan"), dataset.channel_freqs[np.newaxis, :]), - "CHAN_WIDTH": ( - ("row", "chan"), - np.full_like(dataset.channel_freqs, dataset.channel_width)[ - np.newaxis, : - ], - ), - "EFFECTIVE_BW": ( - ("row", "chan"), - np.full_like(dataset.channel_freqs, dataset.channel_width)[ - np.newaxis, : - ], - ), - "FLAG_ROW": (("row",), np.zeros(1, dtype=np.int32)), - "NUM_CHAN": (("row",), np.array([spw.num_chans], dtype=np.int32)), - } - ) - ] - - print(spw_xds) - print(ant_xds) - print(xds) + pprint(xds) + pprint(sub_xds) # Reintroduce the shutil.rmtree and remote the tmp_path_factory # to test in the local directory # shutil.rmtree(out_store, ignore_errors=True) out_store = tmp_path_factory.mktemp("output") / out_store - dask.compute(xds_to_zarr(xds, out_store)) - dask.compute(xds_to_zarr(ant_xds, f"{out_store}::ANTENNA")) + writes = [ + xds_to_zarr(xds, out_store), + *(xds_to_zarr(ds, f"{out_store}::{k}") for k, ds in sub_xds.items()), + ] + dask.compute(writes) # Compare visibilities, weights and flags (read_xds,) = dask.compute(xds_from_zarr(out_store))