Skip to content

Commit

Permalink
Test case runs through
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Mar 25, 2024
1 parent 66c7944 commit 1534fed
Show file tree
Hide file tree
Showing 7 changed files with 546 additions and 0 deletions.
4 changes: 4 additions & 0 deletions daskms/experimental/katdal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
try:
import katdal
except ImportError as e:
raise ImportError("katdal is not installed\n" "pip install dask-ms[katdal]") from e
47 changes: 47 additions & 0 deletions daskms/experimental/katdal/corr_products.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from collections import namedtuple

import numpy as np

CPInfo = namedtuple("CPInfo", "ant1_index ant2_index ant1 ant2 cp_index")


def corrprod_index(dataset, pols_to_use, include_auto_corrs=False):
"""The correlator product index (with -1 representing missing indices)."""
corrprod_to_index = {tuple(cp): n for n, cp in enumerate(dataset.corr_products)}

# ==========================================
# Generate per-baseline antenna pairs and
# correlator product indices
# ==========================================

def _cp_index(a1, a2, pol):
"""Create correlator product index from antenna pair and pol."""
a1 = a1.name + pol[0].lower()
a2 = a2.name + pol[1].lower()
return corrprod_to_index.get((a1, a2), -1)

# Generate baseline antenna pairs
auto_corrs = 0 if include_auto_corrs else 1
ant1_index, ant2_index = np.triu_indices(len(dataset.ants), auto_corrs)
ant1_index, ant2_index = (a.astype(np.int32) for a in (ant1_index, ant2_index))

# Order as similarly to the input as possible, which gives better performance
# in permute_baselines.
bl_indices = list(zip(ant1_index, ant2_index))
bl_indices.sort(
key=lambda ants: _cp_index(
dataset.ants[ants[0]], dataset.ants[ants[1]], pols_to_use[0]
)
)
# Undo the zip
ant1_index[:] = [bl[0] for bl in bl_indices]
ant2_index[:] = [bl[1] for bl in bl_indices]
ant1 = [dataset.ants[a1] for a1 in ant1_index]
ant2 = [dataset.ants[a2] for a2 in ant2_index]

# Create actual correlator product index
cp_index = [_cp_index(a1, a2, p) for a1, a2 in zip(ant1, ant2) for p in pols_to_use]
cp_index = np.array(cp_index, dtype=np.int32)
cp_index = cp_index.reshape(-1, len(pols_to_use))

return CPInfo(ant1_index, ant2_index, ant1, ant2, cp_index)
70 changes: 70 additions & 0 deletions daskms/experimental/katdal/meerkat_antennas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
MEERKAT_ANTENNA_DESCRIPTIONS = [
"m000, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -8.264 -207.290 8.597",
"m001, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 1.121 -171.762 8.471",
"m002, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -32.113 -224.236 8.645",
"m003, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -66.518 -202.276 8.285",
"m004, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -123.624 -252.946 8.513",
"m005, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -102.088 -283.120 8.875",
"m006, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -18.232 -295.428 9.188",
"m007, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -89.592 -402.732 9.769",
"m008, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -93.527 -535.026 10.445",
"m009, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 32.357 -371.056 10.140",
"m010, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 88.095 -511.872 11.186",
"m011, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 84.012 -352.078 10.151",
"m012, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 140.019 -368.267 10.449",
"m013, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 236.792 -393.460 11.124",
"m014, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 280.669 -285.792 10.547",
"m015, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 210.644 -219.142 9.738",
"m016, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 288.159 -185.873 9.795",
"m017, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 199.624 -112.263 8.955",
"m018, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 105.727 -245.870 9.529",
"m019, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 170.787 -285.223 10.071",
"m020, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 97.016 -299.638 9.877",
"m021, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -295.966 -327.241 8.117",
"m022, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -373.002 0.544 5.649",
"m023, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -322.306 -142.185 6.825",
"m024, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -351.046 150.088 4.845",
"m025, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -181.978 225.617 5.068",
"m026, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -99.004 17.045 6.811",
"m027, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 40.475 -23.112 7.694",
"m028, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -51.179 -87.170 7.636",
"m029, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -88.762 -124.111 7.700",
"m030, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 171.281 113.949 7.278",
"m031, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 246.567 93.756 7.469",
"m032, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 461.275 175.505 7.367",
"m033, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 580.678 863.959 3.600",
"m034, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 357.811 -28.308 8.972",
"m035, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 386.152 -180.894 10.290",
"m036, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 388.257 -290.759 10.812",
"m037, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 380.286 -459.309 12.172",
"m038, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 213.308 -569.080 11.946",
"m039, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 253.748 -592.147 12.441",
"m040, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -26.858 -712.219 11.833",
"m041, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -287.545 -661.678 9.949",
"m042, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -361.714 -460.318 8.497",
"m043, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -629.853 -128.326 5.264",
"m044, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -896.164 600.497 -0.640",
"m045, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -1832.860 266.750 0.108",
"m046, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -1467.341 1751.923 -7.078",
"m047, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -578.296 -517.297 7.615",
"m048, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -2805.653 2686.863 -9.755",
"m049, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -3605.957 436.462 2.696",
"m050, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -2052.336 -843.715 5.338",
"m051, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -850.255 -769.359 7.614",
"m052, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -593.192 -1148.652 10.550",
"m053, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 9.365 -1304.462 15.032",
"m054, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 871.980 -499.812 13.364",
"m055, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 1201.780 96.492 10.023",
"m056, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 1598.403 466.668 6.990",
"m057, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 294.645 3259.915 -10.637",
"m058, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 2805.764 2686.873 -3.660",
"m059, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 3686.427 758.895 11.822",
"m060, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, 3419.683 -1840.478 23.697",
"m061, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -16.409 -2323.779 21.304",
"m062, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -1440.632 -2503.773 21.683",
"m063, -30:42:39.8, 21:26:38.0, 1086.6, 13.5, -3419.585 -1840.480 16.383",
]

assert all(
int(a.split(",")[0][1:]) == i for i, a in enumerate(MEERKAT_ANTENNA_DESCRIPTIONS)
)
252 changes: 252 additions & 0 deletions daskms/experimental/katdal/tests/test_chunkstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
from functools import partial

import dask.array as da

from daskms.experimental.zarr import xds_to_zarr

from katdal.chunkstore_npy import NpyFileChunkStore
from katdal.dataset import Subarray
from katdal.lazy_indexer import DaskLazyIndexer
from katdal.spectral_window import SpectralWindow
from katdal.vis_flags_weights import ChunkStoreVisFlagsWeights
from katdal.test.test_vis_flags_weights import put_fake_dataset
from katdal.test.test_dataset import MinimalDataSet
from katpoint import Antenna, Target, Timestamp

import numba
import numpy as np
import pytest
import xarray

from daskms.experimental.katdal.meerkat_antennas import MEERKAT_ANTENNA_DESCRIPTIONS
from daskms.experimental.katdal.transpose import transpose
from daskms.experimental.katdal.corr_products import corrprod_index
from daskms.experimental.katdal.uvw import uvw_coords

SPW = SpectralWindow(
centre_freq=1284e6, channel_width=0, num_chans=16, sideband=1, bandwidth=856e6
)


class FakeDataset(MinimalDataSet):
def __init__(
self,
path,
targets,
timestamps,
antennas=MEERKAT_ANTENNA_DESCRIPTIONS,
spw=SPW,
):
antennas = list(map(Antenna, antennas))
corr_products = [
(a1.name + c1, a2.name + c2)
for i, a1 in enumerate(antennas)
for a2 in antennas[i:]
for c1 in ("h", "v")
for c2 in ("h", "v")
]

subarray = Subarray(antennas, corr_products)
assert len(subarray.ants) > 0

store = NpyFileChunkStore(str(path))
shape = (len(timestamps), spw.num_chans, len(corr_products))
# data, chunk_info = put_fake_dataset(store, "cb1", shape)
data, chunk_info = put_fake_dataset(
store,
"cb1",
shape,
chunk_overrides={
"correlator_data": (1, spw.num_chans, len(corr_products)),
"flags": (1, spw.num_chans, len(corr_products)),
"weights": (1, spw.num_chans, len(corr_products)),
},
)
self._vfw = ChunkStoreVisFlagsWeights(store, chunk_info)
self._vis = None
self._weights = None
self._flags = None
super().__init__(targets, timestamps, subarray, spw)

def _set_keep(
self,
time_keep=None,
freq_keep=None,
corrprod_keep=None,
weights_keep=None,
flags_keep=None,
):
super()._set_keep(time_keep, freq_keep, corrprod_keep, weights_keep, flags_keep)
stage1 = (time_keep, freq_keep, corrprod_keep)
self._vis = DaskLazyIndexer(self._vfw.vis, stage1)
self._weights = DaskLazyIndexer(self._vfw.weights, stage1)
self._flags = DaskLazyIndexer(self._vfw.flags, stage1)

@property
def vis(self):
if self._vis is None:
raise ValueError("Selection has not yet been performed")
return self._vis

@property
def flags(self):
if self._flags is None:
raise ValueError("Selection has not yet been performed")
return self._flags

@property
def weights(self):
if self._weights is None:
raise ValueError("Selection has not yet been performed")

return self._weights


@pytest.fixture(scope="session")
def dataset(request, tmp_path_factory):
path = tmp_path_factory.mktemp("chunks")
targets = [
# It would have been nice to have radec = 19:39, -63:42 but then
# selection by description string does not work because the catalogue's
# description string pads it out to radec = 19:39:00.00, -63:42:00.0.
# (XXX Maybe fix Target comparison in katpoint to support this?)
Target("J1939-6342 | PKS1934-638, radec bpcal, 19:39:25.03, -63:42:45.6"),
Target("J1939-6342, radec gaincal, 19:39:25.03, -63:42:45.6"),
Target("J0408-6545 | PKS 0408-65, radec bpcal, 4:08:20.38, -65:45:09.1"),
Target("J1346-6024 | Cen B, radec, 13:46:49.04, -60:24:29.4"),
]
# Ensure that len(timestamps) is an integer multiple of len(targets)
timestamps = 1234667890.0 + 8.0 * np.arange(20)

spw = SpectralWindow(
centre_freq=1284e6, channel_width=0, num_chans=4096, sideband=1, bandwidth=856e6
)

return FakeDataset(
path, targets, timestamps, antennas=MEERKAT_ANTENNA_DESCRIPTIONS[:16], spw=spw
)


@pytest.mark.parametrize("include_auto_corrs", [False])
@pytest.mark.parametrize("row_dim", [True, False])
@pytest.mark.parametrize("out_store", ["output.zarr"])
def test_chunkstore(tmp_path_factory, dataset, include_auto_corrs, row_dim, out_store):
cp_info = corrprod_index(dataset, ["HH", "HV", "VH", "VV"], include_auto_corrs)
all_antennas = dataset.ants

xds = []

for scan_index, scan_state, target in dataset.scans():
# Extract numpy and dask products
time_utc = dataset.timestamps
t_chunks, chan_chunks, cp_chunks = dataset.vis.dataset.chunks

# Modified Julian Date in Seconds
time_mjds = np.asarray(
[t.to_mjd() * 24 * 60 * 60 for t in map(Timestamp, time_utc)]
)

# Create a dask chunking transform
rechunk = partial(da.rechunk, chunks=(t_chunks, chan_chunks, cp_chunks))

# Transpose from (time, chan, corrprod) to (time, bl, chan, corr)
cpi = cp_info.cp_index
flag_transpose = partial(
transpose,
cp_index=cpi,
data_type=numba.literally("flags"),
row=row_dim,
)
weight_transpose = partial(
transpose,
cp_index=cpi,
data_type=numba.literally("weights"),
row=row_dim,
)
vis_transpose = partial(
transpose, cp_index=cpi, data_type=numba.literally("vis"), row=row_dim
)

flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose))
weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose))
vis = DaskLazyIndexer(dataset.vis, (), transforms=(vis_transpose,))

time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1))
ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0]))
ant2 = da.from_array(cp_info.ant2_index[None, :], chunks=(1, cpi.shape[0]))

uvw = uvw_coords(
target,
da.from_array(time_utc, chunks=t_chunks),
all_antennas,
cp_info,
row=row_dim,
)

time, ant1, ant2 = da.broadcast_arrays(time, ant1, ant2)

if row_dim:
primary_dims = ("row",)
time = time.ravel().rechunk({0: vis.dataset.chunks[0]})
ant1 = ant1.ravel().rechunk({0: vis.dataset.chunks[0]})
ant2 = ant2.ravel().rechunk({0: vis.dataset.chunks[0]})
else:
primary_dims = ("time", "baseline")

xds.append(
xarray.Dataset(
{
# Primary indexing columns
"TIME": (primary_dims, time),
"ANTENNA1": (primary_dims, ant1),
"ANTENNA2": (primary_dims, ant2),
"FEED1": (primary_dims, da.zeros_like(ant1)),
"FEED2": (primary_dims, da.zeros_like(ant1)),
# TODO(sjperkins)
# Fill these in with real values
"DATA_DESC_ID": (primary_dims, da.zeros_like(ant1)),
"FIELD_ID": (primary_dims, da.zeros_like(ant1)),
"STATE_ID": (primary_dims, da.zeros_like(ant1)),
"ARRAY_ID": (primary_dims, da.zeros_like(ant1)),
"OBSERVATION_ID": (primary_dims, da.zeros_like(ant1)),
"PROCESSOR_ID": (primary_dims, da.ones_like(ant1)),
"SCAN_NUMBER": (primary_dims, da.full_like(ant1, scan_index)),
"TIME_CENTROID": (primary_dims, time),
"INTERVAL": (primary_dims, da.full_like(time, dataset.dump_period)),
"EXPOSURE": (primary_dims, da.full_like(time, dataset.dump_period)),
"UVW": (primary_dims + ("uvw",), uvw),
"DATA": (primary_dims + ("chan", "corr"), vis.dataset),
"FLAG": (primary_dims + ("chan", "corr"), flags.dataset),
"WEIGHT_SPECTRUM": (
primary_dims + ("chan", "corr"),
weights.dataset,
),
# Estimated RMS noise per frequency channel
# note this column is used when computing calibration weights
# in CASA - WEIGHT_SPECTRUM may be modified based on the
# values in this column. See
# https://casadocs.readthedocs.io/en/stable/notebooks/data_weights.html
# for further details
"SIGMA_SPECTRUM": (
primary_dims + ("chan", "corr"),
weights.dataset**-0.5,
),
}
)
)

xds = xarray.concat(xds, dim=primary_dims[0])

import dask
import shutil

# out_store = tmp_path_factory.mktemp(out_store)

print(xds)

shutil.rmtree(out_store, ignore_errors=True)
dask.compute(xds_to_zarr(xds, out_store))

# auto_corrs = 0 if include_auto_corrs else 1
# ant1, ant2 = np.triu_indices(len(dataset.ants), auto_corrs)
# ant1, ant2 = (a.astype(np.int32) for a in (ant1, ant2))
Loading

0 comments on commit 1534fed

Please sign in to comment.