diff --git a/HISTORY.rst b/HISTORY.rst index 68026ec3..5c22615c 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Add a ``dask-ms katdal import`` application for exporting SARAO archive data directly to zarr (:pr:`315`) * Define dask-ms command line applications with click (:pr:`317`) * Make poetry dev and docs groups optional (:pr:`316`) * Only test Github Action Push events on master (:pr:`313`) diff --git a/daskms/apps/entrypoint.py b/daskms/apps/entrypoint.py index c910851a..79713865 100644 --- a/daskms/apps/entrypoint.py +++ b/daskms/apps/entrypoint.py @@ -3,9 +3,10 @@ import click from daskms.apps.convert import convert +from daskms.apps.katdal_import import katdal -@click.group() +@click.group(name="dask-ms") @click.pass_context @click.option("--debug/--no-debug", default=False) def main(ctx, debug): @@ -15,3 +16,4 @@ def main(ctx, debug): main.add_command(convert) +main.add_command(katdal) diff --git a/daskms/apps/katdal_import.py b/daskms/apps/katdal_import.py new file mode 100644 index 00000000..e5c629a5 --- /dev/null +++ b/daskms/apps/katdal_import.py @@ -0,0 +1,67 @@ +import click + + +@click.group() +@click.pass_context +def katdal(ctx): + """subgroup for katdal commands""" + pass + + +class PolarisationListType(click.ParamType): + name = "polarisation list" + VALID = {"HH", "HV", "VH", "VV"} + + def convert(self, value, param, ctx): + if isinstance(value, str): + value = [p.strip() for p in value.split(",")] + else: + raise TypeError( + f"{value} should be a comma separated string of polarisations" + ) + + if not set(value).issubset(self.VALID): + raise ValueError(f"{set(value)} is not a subset of {self.VALID}") + + return value + + +@katdal.command(name="import") +@click.pass_context +@click.argument("rdb_url", required=True) +@click.option( + "-a", + "--no-auto", + flag_value=True, + default=False, + help="Exclude auto-correlation data", +) +@click.option( + "-o", + "--output-store", + help="Output store name. Will be derived from the rdb url if not provided.", + default=None, +) +@click.option( + "-p", + "--pols-to-use", + default="HH,HV,VH,VV", + help="Select polarisation products to include in MS as " + "a comma-separated list, containing values from [HH, HV, VH, VV].", + type=PolarisationListType(), +) +@click.option( + "--applycal", + default="", + help="List of calibration solutions to apply to data as " + "a string of comma-separated names, e.g. 'l1' or " + "'K,B,G'. Use 'default' for L1 + L2 and 'all' for " + "all available products.", +) +def _import(ctx, rdb_url, no_auto, pols_to_use, applycal, output_store): + """Export an observation in the SARAO archive to zarr formation + + RDB_URL is the SARAO archive link""" + from daskms.experimental.katdal import katdal_import + + katdal_import(rdb_url, output_store, no_auto, applycal) diff --git a/daskms/experimental/katdal/__init__.py b/daskms/experimental/katdal/__init__.py new file mode 100644 index 00000000..8ece4e0f --- /dev/null +++ b/daskms/experimental/katdal/__init__.py @@ -0,0 +1 @@ +from daskms.experimental.katdal.katdal_import import katdal_import diff --git a/daskms/experimental/katdal/corr_products.py b/daskms/experimental/katdal/corr_products.py new file mode 100644 index 00000000..ea4674a6 --- /dev/null +++ b/daskms/experimental/katdal/corr_products.py @@ -0,0 +1,68 @@ +# Creation of the correlation product index is derived from +# https://github.com/ska-sa/katdal/blob/v0.22/scripts/mvftoms.py +# under the following license +# +# ################################################################################ +# Copyright (c) 2011-2023, National Research Foundation (SARAO) +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use +# this file except in compliance with the License. You may obtain a copy +# of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + + +from collections import namedtuple + +import numpy as np + +CPInfo = namedtuple("CPInfo", "ant1_index ant2_index ant1 ant2 cp_index") + + +def corrprod_index(dataset, pols_to_use, include_auto_corrs=False): + """The correlator product index (with -1 representing missing indices).""" + corrprod_to_index = {tuple(cp): n for n, cp in enumerate(dataset.corr_products)} + + # ========================================== + # Generate per-baseline antenna pairs and + # correlator product indices + # ========================================== + + def _cp_index(a1, a2, pol): + """Create correlator product index from antenna pair and pol.""" + a1 = a1.name + pol[0].lower() + a2 = a2.name + pol[1].lower() + return corrprod_to_index.get((a1, a2), -1) + + # Generate baseline antenna pairs + auto_corrs = 0 if include_auto_corrs else 1 + ant1_index, ant2_index = np.triu_indices(len(dataset.ants), auto_corrs) + ant1_index, ant2_index = (a.astype(np.int32) for a in (ant1_index, ant2_index)) + + # Order as similarly to the input as possible, which gives better performance + # in permute_baselines. + bl_indices = list(zip(ant1_index, ant2_index)) + bl_indices.sort( + key=lambda ants: _cp_index( + dataset.ants[ants[0]], dataset.ants[ants[1]], pols_to_use[0] + ) + ) + # Undo the zip + ant1_index[:] = [bl[0] for bl in bl_indices] + ant2_index[:] = [bl[1] for bl in bl_indices] + ant1 = [dataset.ants[a1] for a1 in ant1_index] + ant2 = [dataset.ants[a2] for a2 in ant2_index] + + # Create actual correlator product index + cp_index = [_cp_index(a1, a2, p) for a1, a2 in zip(ant1, ant2) for p in pols_to_use] + cp_index = np.array(cp_index, dtype=np.int32) + cp_index = cp_index.reshape(-1, len(pols_to_use)) + + return CPInfo(ant1_index, ant2_index, ant1, ant2, cp_index) diff --git a/daskms/experimental/katdal/katdal_import.py b/daskms/experimental/katdal/katdal_import.py new file mode 100644 index 00000000..85898d94 --- /dev/null +++ b/daskms/experimental/katdal/katdal_import.py @@ -0,0 +1,53 @@ +import os +import urllib + +import dask + +from daskms.utils import requires + +try: + import katdal + from katdal.dataset import DataSet + + from daskms.experimental.katdal.msv2_facade import XarrayMSV2Facade + from daskms.experimental.zarr import xds_to_zarr +except ImportError as e: + import_error = e +else: + import_error = None + + +def default_output_name(url): + url_parts = urllib.parse.urlparse(url, scheme="file") + # Create zarr dataset in current working directory (strip off directories) + dataset_filename = os.path.basename(url_parts.path) + # Get rid of the ".full" bit on RDB files (it's the same dataset) + full_rdb_ext = ".full.rdb" + if dataset_filename.endswith(full_rdb_ext): + dataset_basename = dataset_filename[: -len(full_rdb_ext)] + else: + dataset_basename = os.path.splitext(dataset_filename)[0] + return f"{dataset_basename}.zarr" + + +@requires("pip install dask-ms[katdal]", import_error) +def katdal_import(url: str, out_store: str, no_auto: bool, applycal: str): + if isinstance(url, str): + dataset = katdal.open(url, appycal=applycal) + elif isinstance(url, DataSet): + dataset = url + else: + raise TypeError(f"{url} must be a string or a katdal DataSet") + + facade = XarrayMSV2Facade(dataset, no_auto=no_auto) + main_xds, subtable_xds = facade.xarray_datasets() + + if not out_store: + out_store = default_output_name(url) + + writes = [ + xds_to_zarr(main_xds, out_store), + *(xds_to_zarr(ds, f"{out_store}::{k}") for k, ds in subtable_xds.items()), + ] + + dask.compute(writes) diff --git a/daskms/experimental/katdal/meerkat_antennas.py b/daskms/experimental/katdal/meerkat_antennas.py new file mode 100644 index 00000000..eb457bb8 --- /dev/null +++ b/daskms/experimental/katdal/meerkat_antennas.py @@ -0,0 +1,70 @@ +MEERKAT_ANTENNA_DESCRIPTIONS = [ + "m000, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -8.264 -207.290 8.597", + "m001, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 1.121 -171.762 8.471", + "m002, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -32.113 -224.236 8.645", + "m003, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -66.518 -202.276 8.285", + "m004, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -123.624 -252.946 8.513", + "m005, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -102.088 -283.120 8.875", + "m006, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -18.232 -295.428 9.188", + "m007, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -89.592 -402.732 9.769", + "m008, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -93.527 -535.026 10.445", + "m009, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 32.357 -371.056 10.140", + "m010, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 88.095 -511.872 11.186", + "m011, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 84.012 -352.078 10.151", + "m012, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 140.019 -368.267 10.449", + "m013, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 236.792 -393.460 11.124", + "m014, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 280.669 -285.792 10.547", + "m015, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 210.644 -219.142 9.738", + "m016, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 288.159 -185.873 9.795", + "m017, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 199.624 -112.263 8.955", + "m018, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 105.727 -245.870 9.529", + "m019, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 170.787 -285.223 10.071", + "m020, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 97.016 -299.638 9.877", + "m021, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -295.966 -327.241 8.117", + "m022, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -373.002 0.544 5.649", + "m023, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -322.306 -142.185 6.825", + "m024, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -351.046 150.088 4.845", + "m025, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -181.978 225.617 5.068", + "m026, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -99.004 17.045 6.811", + "m027, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 40.475 -23.112 7.694", + "m028, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -51.179 -87.170 7.636", + "m029, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -88.762 -124.111 7.700", + "m030, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 171.281 113.949 7.278", + "m031, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 246.567 93.756 7.469", + "m032, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 461.275 175.505 7.367", + "m033, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 580.678 863.959 3.600", + "m034, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 357.811 -28.308 8.972", + "m035, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 386.152 -180.894 10.290", + "m036, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 388.257 -290.759 10.812", + "m037, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 380.286 -459.309 12.172", + "m038, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 213.308 -569.080 11.946", + "m039, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 253.748 -592.147 12.441", + "m040, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -26.858 -712.219 11.833", + "m041, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -287.545 -661.678 9.949", + "m042, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -361.714 -460.318 8.497", + "m043, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -629.853 -128.326 5.264", + "m044, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -896.164 600.497 -0.640", + "m045, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -1832.860 266.750 0.108", + "m046, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -1467.341 1751.923 -7.078", + "m047, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -578.296 -517.297 7.615", + "m048, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -2805.653 2686.863 -9.755", + "m049, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -3605.957 436.462 2.696", + "m050, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -2052.336 -843.715 5.338", + "m051, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -850.255 -769.359 7.614", + "m052, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -593.192 -1148.652 10.550", + "m053, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 9.365 -1304.462 15.032", + "m054, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 871.980 -499.812 13.364", + "m055, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 1201.780 96.492 10.023", + "m056, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 1598.403 466.668 6.990", + "m057, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 294.645 3259.915 -10.637", + "m058, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 2805.764 2686.873 -3.660", + "m059, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 3686.427 758.895 11.822", + "m060, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 3419.683 -1840.478 23.697", + "m061, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -16.409 -2323.779 21.304", + "m062, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -1440.632 -2503.773 21.683", + "m063, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -3419.585 -1840.480 16.383", +] + +assert all( + int(a.split(",")[0][1:]) == i for i, a in enumerate(MEERKAT_ANTENNA_DESCRIPTIONS) +) diff --git a/daskms/experimental/katdal/mock_dataset.py b/daskms/experimental/katdal/mock_dataset.py new file mode 100644 index 00000000..3b163500 --- /dev/null +++ b/daskms/experimental/katdal/mock_dataset.py @@ -0,0 +1,88 @@ +from katdal.lazy_indexer import DaskLazyIndexer +from katdal.chunkstore_npy import NpyFileChunkStore +from katdal.dataset import Subarray +from katdal.spectral_window import SpectralWindow +from katdal.vis_flags_weights import ChunkStoreVisFlagsWeights +from katdal.test.test_vis_flags_weights import put_fake_dataset +from katdal.test.test_dataset import MinimalDataSet +from katpoint import Antenna + + +from daskms.experimental.katdal.meerkat_antennas import MEERKAT_ANTENNA_DESCRIPTIONS + +SPW = SpectralWindow( + centre_freq=1284e6, channel_width=0, num_chans=16, sideband=1, bandwidth=856e6 +) + + +class MockDataset(MinimalDataSet): + def __init__( + self, + path, + targets, + timestamps, + antennas=MEERKAT_ANTENNA_DESCRIPTIONS, + spw=SPW, + ): + antennas = list(map(Antenna, antennas)) + corr_products = [ + (a1.name + c1, a2.name + c2) + for i, a1 in enumerate(antennas) + for a2 in antennas[i:] + for c1 in ("h", "v") + for c2 in ("h", "v") + ] + + subarray = Subarray(antennas, corr_products) + assert len(subarray.ants) > 0 + + store = NpyFileChunkStore(str(path)) + shape = (len(timestamps), spw.num_chans, len(corr_products)) + self._test_data, chunk_info = put_fake_dataset( + store, + "cb1", + shape, + chunk_overrides={ + "correlator_data": (1, spw.num_chans, len(corr_products)), + "flags": (1, spw.num_chans, len(corr_products)), + "weights": (1, spw.num_chans, len(corr_products)), + }, + ) + self._vfw = ChunkStoreVisFlagsWeights(store, chunk_info) + self._vis = None + self._weights = None + self._flags = None + super().__init__(targets, timestamps, subarray, spw) + + def _set_keep( + self, + time_keep=None, + freq_keep=None, + corrprod_keep=None, + weights_keep=None, + flags_keep=None, + ): + super()._set_keep(time_keep, freq_keep, corrprod_keep, weights_keep, flags_keep) + stage1 = (time_keep, freq_keep, corrprod_keep) + self._vis = DaskLazyIndexer(self._vfw.vis, stage1) + self._weights = DaskLazyIndexer(self._vfw.weights, stage1) + self._flags = DaskLazyIndexer(self._vfw.flags, stage1) + + @property + def vis(self): + if self._vis is None: + raise ValueError("Selection has not yet been performed") + return self._vis + + @property + def flags(self): + if self._flags is None: + raise ValueError("Selection has not yet been performed") + return self._flags + + @property + def weights(self): + if self._weights is None: + raise ValueError("Selection has not yet been performed") + + return self._weights diff --git a/daskms/experimental/katdal/msv2_facade.py b/daskms/experimental/katdal/msv2_facade.py new file mode 100644 index 00000000..9c47fa0b --- /dev/null +++ b/daskms/experimental/katdal/msv2_facade.py @@ -0,0 +1,482 @@ +# Much of the subtable generation code is derived from +# https://github.com/ska-sa/katdal/blob/v0.22/katdal/ms_extra.py +# under the following license +# +# ################################################################################ +# Copyright (c) 2011-2023, National Research Foundation (SARAO) +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use +# this file except in compliance with the License. You may obtain a copy +# of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from functools import partial + +import dask.array as da +import numpy as np + +from katdal.dataset import DataSet +from katdal.lazy_indexer import DaskLazyIndexer +from katpoint import Timestamp +import numba +import xarray + +from daskms.constants import DASKMS_PARTITION_KEY +from daskms.experimental.katdal.corr_products import corrprod_index +from daskms.experimental.katdal.transpose import transpose +from daskms.experimental.katdal.uvw import uvw_coords + +TAG_TO_INTENT = { + "gaincal": "CALIBRATE_PHASE,CALIBRATE_AMPLI", + "bpcal": "CALIBRATE_BANDPASS,CALIBRATE_FLUX", + "target": "TARGET", +} + + +# Partitioning columns +GROUP_COLS = ["FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"] + +# No partitioning, applies to many subtables +EMPTY_PARTITION_SCHEMA = {DASKMS_PARTITION_KEY: ()} + +# katdal datasets only have one spectral window +# and one polarisation. Thus, there +# is only one DATA_DESC_ID and it is zero +DATA_DESC_ID = 0 + + +def to_mjds(timestamp: Timestamp): + """Converts a katpoint Timestamp to Modified Julian Date Seconds""" + return timestamp.to_mjd() * 24 * 60 * 60 + + +class XarrayMSV2Facade: + """Provides a simplified xarray Dataset view over a katdal dataset""" + + def __init__(self, dataset: DataSet, no_auto: bool = True, row_view: bool = True): + self._dataset = dataset + self._no_auto = no_auto + self._row_view = row_view + self._pols_to_use = ["HH", "HV", "VH", "VV"] + # Reset the dataset selection + self._dataset.select(reset="") + self._cp_info = corrprod_index(dataset, self._pols_to_use, not no_auto) + + @property + def cp_info(self): + return self._cp_info + + def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, target): + # Extract numpy and dask products + dataset = self._dataset + cp_info = self._cp_info + time_utc = dataset.timestamps + t_chunks, chan_chunks, cp_chunks = dataset.vis.dataset.chunks + + # Modified Julian Date in Seconds + time_mjds = np.asarray([to_mjds(t) for t in map(Timestamp, time_utc)]) + + # Create a dask chunking transform + rechunk = partial(da.rechunk, chunks=(t_chunks, chan_chunks, cp_chunks)) + + # Transpose from (time, chan, corrprod) to (time, bl, chan, corr) + cpi = cp_info.cp_index + flag_transpose = partial( + transpose, + cp_index=cpi, + data_type=numba.literally("flags"), + row=self._row_view, + ) + weight_transpose = partial( + transpose, + cp_index=cpi, + data_type=numba.literally("weights"), + row=self._row_view, + ) + vis_transpose = partial( + transpose, + cp_index=cpi, + data_type=numba.literally("vis"), + row=self._row_view, + ) + + flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose)) + weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose)) + vis = DaskLazyIndexer(dataset.vis, (), transforms=(vis_transpose,)) + + time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1)) + ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0])) + ant2 = da.from_array(cp_info.ant2_index[None, :], chunks=(1, cpi.shape[0])) + + uvw = uvw_coords( + target, + da.from_array(time_utc, chunks=t_chunks), + dataset.ants, + cp_info, + row=self._row_view, + ) + + time, ant1, ant2 = da.broadcast_arrays(time, ant1, ant2) + + if self._row_view: + primary_dims = ("row",) + time = time.ravel().rechunk({0: vis.dataset.chunks[0]}) + ant1 = ant1.ravel().rechunk({0: vis.dataset.chunks[0]}) + ant2 = ant2.ravel().rechunk({0: vis.dataset.chunks[0]}) + else: + primary_dims = ("time", "baseline") + + data_vars = { + # Primary indexing columns + "TIME": (primary_dims, time), + "ANTENNA1": (primary_dims, ant1), + "ANTENNA2": (primary_dims, ant2), + "FEED1": (primary_dims, da.zeros_like(ant1)), + "FEED2": (primary_dims, da.zeros_like(ant1)), + "DATA_DESC_ID": (primary_dims, da.full_like(ant1, DATA_DESC_ID)), + "FIELD_ID": (primary_dims, da.full_like(ant1, field_id)), + "STATE_ID": (primary_dims, da.full_like(ant1, state_id)), + "ARRAY_ID": (primary_dims, da.zeros_like(ant1)), + "OBSERVATION_ID": (primary_dims, da.zeros_like(ant1)), + "PROCESSOR_ID": (primary_dims, da.ones_like(ant1)), + "SCAN_NUMBER": (primary_dims, da.full_like(ant1, scan_index)), + "TIME_CENTROID": (primary_dims, time), + "INTERVAL": (primary_dims, da.full_like(time, dataset.dump_period)), + "EXPOSURE": (primary_dims, da.full_like(time, dataset.dump_period)), + "UVW": (primary_dims + ("uvw",), uvw), + "DATA": (primary_dims + ("chan", "corr"), vis.dataset), + "FLAG": (primary_dims + ("chan", "corr"), flags.dataset), + "WEIGHT_SPECTRUM": ( + primary_dims + ("chan", "corr"), + weights.dataset, + ), + # Estimated RMS noise per frequency channel + # note this column is used when computing calibration weights + # in CASA - WEIGHT_SPECTRUM may be modified based on the + # values in this column. See + # https://casadocs.readthedocs.io/en/stable/notebooks/data_weights.html + # for further details + "SIGMA_SPECTRUM": ( + primary_dims + ("chan", "corr"), + weights.dataset**-0.5, + ), + } + + attrs = { + DASKMS_PARTITION_KEY: tuple( + (c, data_vars[c][-1].dtype.name) for c in GROUP_COLS + ), + "FIELD_ID": field_id, + "DATA_DESC_ID": DATA_DESC_ID, + "SCAN_NUMBER": scan_index, + } + + assert (set(GROUP_COLS) & set(attrs)) == set(GROUP_COLS) + + return xarray.Dataset(data_vars, attrs=attrs) + + def _antenna_xarray_factory(self): + antennas = self._dataset.ants + nant = len(antennas) + return xarray.Dataset( + { + "NAME": ("row", np.asarray([a.name for a in antennas], dtype=object)), + "STATION": ( + "row", + np.asarray([a.name for a in antennas], dtype=object), + ), + "POSITION": ( + ("row", "xyz"), + np.asarray([a.position_ecef for a in antennas]), + ), + "OFFSET": (("row", "xyz"), np.zeros((nant, 3))), + "DISH_DIAMETER": ("row", np.asarray([a.diameter for a in antennas])), + "MOUNT": ("row", np.array(["ALT-AZ"] * nant, dtype=object)), + "TYPE": ("row", np.array(["GROUND-BASED"] * nant, dtype=object)), + "FLAG_ROW": ("row", np.zeros(nant, dtype=np.int32)), + } + ) + + def _spw_xarray_factory(self): + def ref_freq(chan_freqs): + return chan_freqs[len(chan_freqs) // 2].astype(np.float64) + + return [ + xarray.Dataset( + { + "NUM_CHAN": (("row",), np.array([spw.num_chans], dtype=np.int32)), + "CHAN_FREQ": (("row", "chan"), spw.channel_freqs[np.newaxis, :]), + "RESOLUTION": (("row", "chan"), spw.channel_freqs[np.newaxis, :]), + "CHAN_WIDTH": ( + ("row", "chan"), + np.full_like( + spw.channel_freqs[np.newaxis, :], spw.channel_width + ), + ), + "EFFECTIVE_BW": ( + ("row", "chan"), + np.full_like( + spw.channel_freqs[np.newaxis, :], spw.channel_width + ), + ), + "MEAS_FREQ_REF": ("row", np.array([5], dtype=np.int32)), + "REF_FREQUENCY": ("row", [ref_freq(spw.channel_freqs)]), + "NAME": ("row", np.asarray([f"{spw.band}-band"], dtype=object)), + "FREQ_GROUP_NAME": ( + "row", + np.asarray([f"{spw.band}-band"], dtype=object), + ), + "FREQ_GROUP": ("row", np.zeros(1, dtype=np.int32)), + "IF_CONV_CHAN": ("row", np.zeros(1, dtype=np.int32)), + "NET_SIDEBAND": ("row", np.ones(1, dtype=np.int32)), + "TOTAL_BANDWIDTH": ("row", np.asarray([spw.channel_freqs.sum()])), + "FLAG_ROW": ("row", np.zeros(1, dtype=np.int32)), + } + ) + for spw in self._dataset.spectral_windows + ] + + def _pol_xarray_factory(self): + pol_num = {"H": 0, "V": 1} + # MeerKAT only has linear feeds, these map to + # CASA ["XX", "XY", "YX", "YY"] + pol_types = {"HH": 9, "HV": 10, "VH": 11, "VV": 12} + return xarray.Dataset( + { + "NUM_CORR": ("row", np.array([len(self._pols_to_use)], dtype=np.int32)), + "CORR_PRODUCT": ( + ("row", "corr", "corrprod_idx"), + np.array( + [[[pol_num[p[0]], pol_num[p[1]]] for p in self._pols_to_use]], + dtype=np.int32, + ), + ), + "CORR_TYPE": ( + ("row", "corr"), + np.asarray( + [[pol_types[p] for p in self._pols_to_use]], dtype=np.int32 + ), + ), + "FLAG_ROW": ("row", np.zeros(1, dtype=np.int32)), + } + ) + + def _ddid_xarray_factory(self): + return xarray.Dataset( + { + "SPECTRAL_WINDOW_ID": ("row", np.zeros(1, dtype=np.int32)), + "POLARIZATION_ID": ("row", np.zeros(1, dtype=np.int32)), + "FLAG_ROW": ("row", np.zeros(1, dtype=np.int32)), + } + ) + + def _feed_xarray_factory(self): + nfeeds = len(self._dataset.ants) + NRECEPTORS = 2 + + return xarray.Dataset( + { + # ID of antenna in this array (integer) + "ANTENNA_ID": ("row", np.arange(nfeeds, dtype=np.int32)), + # Id for BEAM model (integer) + "BEAM_ID": ("row", np.ones(nfeeds, dtype=np.int32)), + # Beam position offset (on sky but in antenna reference frame): (double, 2-dim) + "BEAM_OFFSET": ( + ("row", "receptors", "radec"), + np.zeros((nfeeds, 2, 2), dtype=np.float64), + ), + # Feed id (integer) + "FEED_ID": ("row", np.zeros(nfeeds, dtype=np.int32)), + # Interval for which this set of parameters is accurate (double) + "INTERVAL": ("row", np.zeros(nfeeds, dtype=np.float64)), + # Number of receptors on this feed (probably 1 or 2) (integer) + "NUM_RECEPTORS": ("row", np.full(nfeeds, NRECEPTORS, dtype=np.int32)), + # Type of polarisation to which a given RECEPTOR responds (string, 1-dim) + "POLARIZATION_TYPE": ( + ("row", "receptors"), + np.array([["X", "Y"]] * nfeeds, dtype=object), + ), + # D-matrix i.e. leakage between two receptors (complex, 2-dim) + "POL_RESPONSE": ( + ("row", "receptors", "receptors-2"), + np.array([np.eye(2, dtype=np.complex64) for _ in range(nfeeds)]), + ), + # Position of feed relative to feed reference position (double, 1-dim, shape=(3,)) + "POSITION": (("row", "xyz"), np.zeros((nfeeds, 3), np.float64)), + # The reference angle for polarisation (double, 1-dim). A parallactic angle of + # 0 means that V is aligned to x (celestial North), but we are mapping H to x + # so we have to correct with a -90 degree rotation. + "RECEPTOR_ANGLE": ( + ("row", "receptors"), + np.full((nfeeds, NRECEPTORS), -np.pi / 2, dtype=np.float64), + ), + # ID for this spectral window setup (integer) + "SPECTRAL_WINDOW_ID": ("row", np.full(nfeeds, -1, dtype=np.int32)), + # Midpoint of time for which this set of parameters is accurate (double) + "TIME": ("row", np.zeros(nfeeds, dtype=np.float64)), + } + ) + + def _field_xarray_factory(self, field_data): + fields = [ + xarray.Dataset( + { + "NAME": ("row", np.array([target.name], object)), + "CODE": ("row", np.array(["T"], object)), + "SOURCE_ID": ("row", np.array([field_id], dtype=np.int32)), + "NUM_POLY": ("row", np.zeros(1, dtype=np.int32)), + "TIME": ("row", np.array([time])), + "DELAY_DIR": ( + ("row", "field-poly", "field-dir"), + np.array([[radec]], dtype=np.float64), + ), + "PHASE_DIR": ( + ("row", "field-poly", "field-dir"), + np.array([[radec]], dtype=np.float64), + ), + "REFERENCE_DIR": ( + ("row", "field-poly", "field-dir"), + np.array([[radec]], dtype=np.float64), + ), + "FLAG_ROW": ("row", np.zeros(1, dtype=np.int32)), + } + ) + for field_id, time, target, radec in field_data.values() + ] + + return xarray.concat(fields, dim="row") + + def _source_xarray_factory(self, field_data): + field_ids, times, targets, radecs = zip(*(field_data.values())) + times = np.array(times, dtype=np.float64) + nfields = len(field_ids) + return xarray.Dataset( + { + "NAME": ("row", np.array([t.name for t in targets], dtype=object)), + "SOURCE_ID": ("row", np.array(field_ids, dtype=np.int32)), + "PROPER_MOTION": ( + ("row", "radec-per-sec"), + np.zeros((nfields, 2), dtype=np.float32), + ), + "CALIBRATION_GROUP": ("row", np.full(nfields, -1, dtype=np.int32)), + "DIRECTION": (("row", "radec"), np.array(radecs)), + "TIME": ("row", times), + "NUM_LINES": ("row", np.ones(nfields, dtype=np.int32)), + "REST_FREQUENCY": ( + ("row", "lines"), + np.zeros((nfields, 1), dtype=np.float64), + ), + } + ) + + def _state_xarray_factory(self, state_modes): + state_ids, modes = zip(*sorted((i, m) for m, i in state_modes.items())) + nstates = len(state_ids) + return xarray.Dataset( + { + "SIG": np.ones(nstates, dtype=np.uint8), + "REF": np.zeros(nstates, dtype=np.uint8), + "CAL": np.zeros(nstates, dtype=np.float64), + "LOAD": np.zeros(nstates, dtype=np.float64), + "SUB_SCAN": np.zeros(nstates, dtype=np.int32), + "OBS_MODE": np.array(modes, dtype=object), + "FLAG_ROW": np.zeros(nstates, dtype=np.int32), + } + ) + + def _observation_xarray_factory(self): + ds = self._dataset + start, end = [to_mjds(t) for t in [ds.start_time, ds.end_time]] + return xarray.Dataset( + { + "OBSERVER": ("row", np.array([ds.observer], dtype=object)), + "PROJECT": ("row", np.array([ds.experiment_id], dtype=object)), + "LOG": (("row", "extra"), np.array([["unavailable"]], dtype=object)), + "SCHEDULE": ( + ("row", "extra"), + np.array([["unavailable"]], dtype=object), + ), + "SCHEDULE_TYPE": ("row", np.array(["unknown"], dtype=object)), + "TELESCOPE": ("row", np.array(["MeerKAT"], dtype=object)), + "TIME_RANGE": (("row", "extent"), np.array([[start, end]])), + "FLAG_ROW": ("row", np.zeros(1, np.uint8)), + } + ) + + def xarray_datasets(self): + """Generates partitions of the main MSv2 table, as well as the subtables. + + Returns + ------- + main_xds: list of :code:`xarray.Dataset` + A list of xarray datasets corresponding to Measurement Set 2 + partitions + subtable_xds: dict of :code:`xarray.Dataset` + A dictionary of datasets keyed on subtable names + """ + main_xds = [] + field_data = [] + field_data = {} + UNKNOWN_STATE_ID = 0 + state_modes = {"UNKNOWN": UNKNOWN_STATE_ID} + + # Generate MAIN table xarray partition datasets + for scan_index, scan_state, target in self._dataset.scans(): + # Retrieve existing field data, or create + try: + field_id, _, _, _ = field_data[target.name] + except KeyError: + field_id = len(field_data) + time_origin = Timestamp(self._dataset.timestamps[0]) + field_data[target.name] = ( + field_id, + to_mjds(time_origin), + target, + target.radec(time_origin), + ) + + # Create or retrieve the state_id associated + # with the tags of the current source + state_tag = ",".join( + TAG_TO_INTENT[tag] for tag in target.tags if tag in TAG_TO_INTENT + ) + if state_tag and state_tag not in state_modes: + state_modes[state_tag] = len(state_modes) + state_id = state_modes.get(state_tag, UNKNOWN_STATE_ID) + + main_xds.append( + self._main_xarray_factory( + field_id, state_id, scan_index, scan_state, target + ) + ) + + # Generate subtable xarray datasets + subtables = { + "ANTENNA": self._antenna_xarray_factory(), + "DATA_DESCRIPTION": self._ddid_xarray_factory(), + "SPECTRAL_WINDOW": self._spw_xarray_factory(), + "POLARIZATION": self._pol_xarray_factory(), + "FEED": self._feed_xarray_factory(), + "FIELD": self._field_xarray_factory(field_data), + "SOURCE": self._source_xarray_factory(field_data), + "OBSERVATION": self._observation_xarray_factory(), + "STATE": self._state_xarray_factory(state_modes), + } + + # Assign empty partition schemas to subtables + subtables = { + n: dss.assign_attrs(EMPTY_PARTITION_SCHEMA) + if isinstance(dss, xarray.Dataset) + else [ds.assign_attrs(EMPTY_PARTITION_SCHEMA) for ds in dss] + for n, dss in subtables.items() + } + + return main_xds, subtables diff --git a/daskms/experimental/katdal/tests/conftest.py b/daskms/experimental/katdal/tests/conftest.py new file mode 100644 index 00000000..c6de9ff9 --- /dev/null +++ b/daskms/experimental/katdal/tests/conftest.py @@ -0,0 +1,57 @@ +import pytest + +from daskms.experimental.katdal.meerkat_antennas import MEERKAT_ANTENNA_DESCRIPTIONS +import numpy as np + +NTIME = 20 +NCHAN = 16 +NANT = 4 +DUMP_RATE = 8.0 + +DEFAULT_PARAM = {"ntime": NTIME, "nchan": NCHAN, "nant": NANT, "dump_rate": DUMP_RATE} + + +@pytest.fixture(scope="session", params=[DEFAULT_PARAM]) +def dataset(request, tmp_path_factory): + MockDataset = pytest.importorskip( + "daskms.experimental.katdal.mock_dataset" + ).MockDataset + SpectralWindow = pytest.importorskip("katdal.spectral_window").SpectralWindow + Target = pytest.importorskip("katpoint").Target + + DEFAULT_TARGETS = [ + # It would have been nice to have radec = 19:39, -63:42 but then + # selection by description string does not work because the catalogue's + # description string pads it out to radec = 19:39:00.00, -63:42:00.0. + # (XXX Maybe fix Target comparison in katpoint to support this?) + Target("J1939-6342 | PKS1934-638, radec bpcal, 19:39:25.03, -63:42:45.6"), + Target("J1939-6342, radec gaincal, 19:39:25.03, -63:42:45.6"), + Target("J0408-6545 | PKS 0408-65, radec bpcal, 4:08:20.38, -65:45:09.1"), + Target("J1346-6024 | Cen B, radec, 13:46:49.04, -60:24:29.4"), + ] + targets = request.param.get("targets", DEFAULT_TARGETS) + ntime = request.param.get("ntime", NTIME) + nchan = request.param.get("nchan", NCHAN) + nant = request.param.get("nant", NANT) + dump_rate = request.param.get("dump_rate", DUMP_RATE) + + # Ensure that len(timestamps) is an integer multiple of len(targets) + timestamps = 1234667890.0 + dump_rate * np.arange(ntime) + assert ntime > len(targets) + assert ntime % len(targets) == 0 + + spw = SpectralWindow( + centre_freq=1284e6, + channel_width=0, + num_chans=nchan, + sideband=1, + bandwidth=856e6, + ) + + return MockDataset( + tmp_path_factory.mktemp("chunks"), + targets, + timestamps, + antennas=MEERKAT_ANTENNA_DESCRIPTIONS[:nant], + spw=spw, + ) diff --git a/daskms/experimental/katdal/tests/test_chunkstore.py b/daskms/experimental/katdal/tests/test_chunkstore.py new file mode 100644 index 00000000..f906a833 --- /dev/null +++ b/daskms/experimental/katdal/tests/test_chunkstore.py @@ -0,0 +1,61 @@ +import pytest + +xarray = pytest.importorskip("xarray") +katdal = pytest.importorskip("katdal") + +import dask +import numpy as np +from numpy.testing import assert_array_equal + +from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr +from daskms.experimental.katdal.msv2_facade import XarrayMSV2Facade + + +@pytest.mark.parametrize( + "dataset", [{"ntime": 20, "nchan": 16, "nant": 4}], indirect=True +) +@pytest.mark.parametrize("auto_corrs", [True]) +@pytest.mark.parametrize("row_dim", [True, False]) +@pytest.mark.parametrize("out_store", ["output.zarr"]) +def test_chunkstore(tmp_path_factory, dataset, auto_corrs, row_dim, out_store): + facade = XarrayMSV2Facade(dataset, not auto_corrs, row_dim) + xds, sub_xds = facade.xarray_datasets() + + # Reintroduce the shutil.rmtree and remote the tmp_path_factory + # to test in the local directory + # shutil.rmtree(out_store, ignore_errors=True) + out_store = tmp_path_factory.mktemp("output") / out_store + + writes = [ + xds_to_zarr(xds, out_store), + *(xds_to_zarr(ds, f"{out_store}::{k}") for k, ds in sub_xds.items()), + ] + dask.compute(writes) + + # Compare visibilities, weights and flags + (read_xds,) = dask.compute(xds_from_zarr(out_store)) + read_xds = xarray.concat(read_xds, dim="row" if row_dim else "time") + + test_data = dataset._test_data["correlator_data"] + # Defer to ChunkStoreVisWeights application of weight scaling + test_weights = dataset._vfw.weights + assert test_weights.shape == test_data.shape + # Clamp test data to [0, 1] + test_flags = np.where(dataset._test_data["flags"] != 0, 1, 0) + ntime, nchan, _ = test_data.shape + (nbl,) = facade.cp_info.ant1_index.shape + ncorr = read_xds.sizes["corr"] + + # This must hold for test_tranpose to work + assert_array_equal(facade.cp_info.cp_index.ravel(), np.arange(nbl * ncorr)) + + def assert_transposed_equal(a, e): + """Simple transpose of katdal (time, chan, corrprod) to + (time, bl, chan, corr).""" + t = a.reshape(ntime, nchan, nbl, ncorr).transpose(0, 2, 1, 3) + t = t.reshape(-1, nchan, ncorr) if row_dim else t + return assert_array_equal(t, e) + + assert_transposed_equal(test_data, read_xds.DATA.values) + assert_transposed_equal(test_weights, read_xds.WEIGHT_SPECTRUM.values) + assert_transposed_equal(test_flags, read_xds.FLAG.values) diff --git a/daskms/experimental/katdal/transpose.py b/daskms/experimental/katdal/transpose.py new file mode 100644 index 00000000..068a13c4 --- /dev/null +++ b/daskms/experimental/katdal/transpose.py @@ -0,0 +1,138 @@ +# The numba transposition code is derived from +# https://github.com/ska-sa/katdal/blob/v0.22/scripts/mvftoms.py +# under the following license +# +# ################################################################################ +# Copyright (c) 2011-2023, National Research Foundation (SARAO) +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use +# this file except in compliance with the License. You may obtain a copy +# of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + + +import dask.array as da +import numpy as np + +from numba import njit, literally +from numba.extending import overload, SentryLiteralArgs, register_jitable +from numba.core.errors import TypingError + + +JIT_OPTIONS = {"nogil": True, "cache": True} + + +@njit(**JIT_OPTIONS) +def transpose_core(in_data, cp_index, data_type, row): + return transpose_impl(in_data, cp_index, data_type, row) + + +def transpose_impl(in_data, cp_index, data_type, row): + raise NotImplementedError + + +@overload(transpose_impl, jit_options=JIT_OPTIONS, prefer_literal=True) +def nb_transpose(in_data, cp_index, data_type, row): + SentryLiteralArgs(["data_type", "row"]).for_function(nb_transpose).bind( + in_data, cp_index, data_type, row + ) + + try: + data_type = data_type.literal_value + except AttributeError as e: + raise TypingError(f"data_type {data_type} is not a literal_value") from e + else: + if not isinstance(data_type, str): + raise TypeError(f"data_type {data_type} is not a string: {type(data_type)}") + + try: + row_dim = row.literal_value + except AttributeError as e: + raise TypingError(f"row {row} is not a literal_value") from e + else: + if not isinstance(row_dim, bool): + raise TypingError(f"row_dim {row_dim} is not a boolean {type(row_dim)}") + + if data_type == "flags": + get_value = lambda v: v != 0 + default_value = np.bool_(True) + elif data_type == "vis": + get_value = lambda v: v + default_value = in_data.dtype(0 + 0j) + elif data_type == "weights": + get_value = lambda v: v + default_value = in_data.dtype(0) + else: + raise TypingError(f"data_type {data_type} is not supported") + + get_value = register_jitable(get_value) + + def impl(in_data, cp_index, data_type, row): + n_time, n_chans, _ = in_data.shape + n_bls, n_pol = cp_index.shape + out_data = np.empty((n_time, n_bls, n_chans, n_pol), in_data.dtype) + + bstep = 128 + bblocks = (n_bls + bstep - 1) // bstep + for t in range(n_time): + for bblock in range(bblocks): # numba.prange + bstart = bblock * bstep + bstop = min(n_bls, bstart + bstep) + for c in range(n_chans): + for b in range(bstart, bstop): + for p in range(out_data.shape[3]): + idx = cp_index[b, p] + data = ( + get_value(in_data[t, c, idx]) + if idx >= 0 + else default_value + ) + out_data[t, b, c, p] = data + + if row_dim: + return out_data.reshape(n_time * n_bls, n_chans, n_pol) + + return out_data + + return impl + + +def transpose(data, cp_index, data_type, row=False): + ntime, _, _ = data.shape + nbl, ncorr = cp_index.shape + + if row: + out_dims = ("row", "chan", "corr") + new_axes = {"row": ntime * nbl, "corr": ncorr} + else: + out_dims = ("time", "bl", "chan", "corr") + new_axes = {"bl": nbl, "corr": ncorr} + + output = da.blockwise( + transpose_core, + out_dims, + data, + ("time", "chan", "corrprod"), + cp_index, + None, + literally(data_type), + None, + row, + None, + concatenate=True, + new_axes=new_axes, + dtype=data.dtype, + ) + + if row: + return output.rechunk({0: ntime * (nbl,)}) + + return output diff --git a/daskms/experimental/katdal/uvw.py b/daskms/experimental/katdal/uvw.py new file mode 100644 index 00000000..21e8e4ea --- /dev/null +++ b/daskms/experimental/katdal/uvw.py @@ -0,0 +1,76 @@ +# The uvw calculation code is derived from +# https://github.com/ska-sa/katdal/blob/v0.22/katdal/ms_async.py +# under the following license +# +# ################################################################################ +# Copyright (c) 2011-2023, National Research Foundation (SARAO) +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use +# this file except in compliance with the License. You may obtain a copy +# of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + + +import dask.array as da +from katpoint import Target +import numpy as np + + +def _uvw(target_description, time_utc, antennas, ant1, ant2, row): + """Calculate UVW coordinates""" + array_centre = antennas[0].array_reference_antenna() + target = Target(target_description) + uvw_ant = target.uvw(antennas, time_utc, array_centre) + uvw_ant = np.transpose(uvw_ant, (1, 2, 0)) + # Compute baseline UVW coordinates from per-antenna coordinates. + # The sign convention matches `CASA`_, rather than the + # Measurement Set `definition`_. + # .. _CASA: https://casa.nrao.edu/Memos/CoordConvention.pdf + # .. _definition: https://casa.nrao.edu/Memos/229.html#SECTION00064000000000000000 + uvw_bl = np.take(uvw_ant, ant1, axis=1) - np.take(uvw_ant, ant2, axis=1) + return uvw_bl.reshape(-1, 3) if row else uvw_bl + + +def uvw_coords(target, time_utc, antennas, cp_info, row=True): + (ntime,) = time_utc.shape + (nbl,) = cp_info.ant1_index.shape + + if row: + out_dims = ("row", "uvw") + new_axes = {"row": ntime * nbl, "uvw": 3} + else: + out_dims = ("time", "bl", "uvw") + new_axes = {"uvw": 3} + + out = da.blockwise( + _uvw, + out_dims, + target.description, + None, + time_utc, + ("time",), + antennas, + ("ant",), + cp_info.ant1_index, + ("bl",), + cp_info.ant2_index, + ("bl",), + row, + None, + concatenate=True, + new_axes=new_axes, + meta=np.empty((0,) * len(out_dims), np.float64), + ) + + if row: + return out.rechunk({0: ntime * (nbl,)}) + + return out diff --git a/daskms/table_schemas.py b/daskms/table_schemas.py index 086955a0..4981f454 100644 --- a/daskms/table_schemas.py +++ b/daskms/table_schemas.py @@ -100,7 +100,7 @@ SOURCE_SCHEMA = { "DIRECTION": {"dims": ("radec",)}, "POSITION": {"dims": ("position",)}, - "PROPER_MOTION": {"dims": ("radec_per_sec",)}, + "PROPER_MOTION": {"dims": ("radec-per-sec",)}, "REST_FREQUENCY": {"dims": ("lines",)}, "SYSVEL": {"dims": ("lines",)}, "TRANSITION": {"dims": ("lines",)}, diff --git a/daskms/utils.py b/daskms/utils.py index 16b79b35..843b33b0 100644 --- a/daskms/utils.py +++ b/daskms/utils.py @@ -192,7 +192,7 @@ def requires(*args): def decorator(fn): lines = [ f"Optional extras required by " - f"{funcname(fn)} are missing due to " + f"{fn.__name__} are missing due to " f"the following ImportErrors:" ] diff --git a/pyproject.toml b/pyproject.toml index c941629d..71674195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ s3fs = {version = ">= 2023.1.0", optional=true} minio = {version = "^7.2.0", optional = true} pytest = {version = "^7.1.3", optional=true} pandas = {version = "^2.1.2", optional = true} +katdal = {version = "^0.22", optional = true} [tool.poetry.scripts] dask-ms = "daskms.apps.entrypoint:main" @@ -27,10 +28,11 @@ fragments = "daskms.apps.fragments:main" [tool.poetry.extras] arrow = ["pandas", "pyarrow"] +katdal = ["katdal", "xarray", "zarr"] xarray = ["xarray"] zarr = ["zarr"] s3 = ["s3fs"] -complete = ["s3fs", "pandas", "pyarrow", "xarray", "zarr"] +complete = ["s3fs", "pandas", "pyarrow", "xarray", "zarr", "katdal"] testing = ["minio", "pytest"] [tool.poetry.group.dev]