From 8c75d2f623c9a463fcda4dacc8b4136a9b7a3a50 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 25 Mar 2024 16:25:05 +0200 Subject: [PATCH] Sanity check written visibilities, weights and flags --- .../katdal/tests/test_chunkstore.py | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/daskms/experimental/katdal/tests/test_chunkstore.py b/daskms/experimental/katdal/tests/test_chunkstore.py index d61acc9c..cfd7c21e 100644 --- a/daskms/experimental/katdal/tests/test_chunkstore.py +++ b/daskms/experimental/katdal/tests/test_chunkstore.py @@ -10,9 +10,10 @@ import dask.array as da import numba import numpy as np +from numpy.testing import assert_array_equal import xarray -from daskms.experimental.zarr import xds_to_zarr +from daskms.experimental.zarr import xds_from_zarr, xds_to_zarr from katdal.chunkstore_npy import NpyFileChunkStore from katdal.dataset import Subarray @@ -57,7 +58,6 @@ def __init__( store = NpyFileChunkStore(str(path)) shape = (len(timestamps), spw.num_chans, len(corr_products)) - # data, chunk_info = put_fake_dataset(store, "cb1", shape) self._test_data, chunk_info = put_fake_dataset( store, "cb1", @@ -266,8 +266,6 @@ def test_chunkstore(tmp_path_factory, dataset, include_auto_corrs, row_dim, out_ } ) ] - print(ant_xds) - print(xds) # Reintroduce the shutil.rmtree and remote the tmp_path_factory # to test in the local directory @@ -276,3 +274,29 @@ def test_chunkstore(tmp_path_factory, dataset, include_auto_corrs, row_dim, out_ dask.compute(xds_to_zarr(xds, out_store)) dask.compute(xds_to_zarr(ant_xds, f"{out_store}::ANTENNA")) + + # Compare visibilities, weights and flags + (read_xds,) = dask.compute(xds_from_zarr(out_store)) + (read_xds,) = read_xds + + test_data = dataset._test_data["correlator_data"] + # Defer to ChunkStoreVisWeights application of weight scaling + test_weights = dataset._vfw.weights + assert test_weights.shape == test_data.shape + # Clamp test data to [0, 1] + test_flags = np.where(dataset._test_data["flags"] != 0, 1, 0) + ntime, nchan, _ = test_data.shape + (nbl,) = cp_info.ant1_index.shape + ncorr = read_xds.sizes["corr"] + + def test_transpose(a): + """Simple transpose of katdal (time, chan, corrprod) to + (time, bl, chan, corr).""" + # This must hold for this simple tranpose to work + assert_array_equal(cp_info.cp_index.ravel(), np.arange(nbl * ncorr)) + o = a.reshape(ntime, nchan, nbl, ncorr).transpose(0, 2, 1, 3) + return o.reshape(-1, nchan, ncorr) if row_dim else o + + assert_array_equal(test_transpose(test_data), read_xds.DATA.values) + assert_array_equal(test_transpose(test_weights), read_xds.WEIGHT_SPECTRUM.values) + assert_array_equal(test_transpose(test_flags), read_xds.FLAG.values)