Skip to content

Commit 747408f

Browse files
authored
Identify channel and correlation-like dimensions in non-standard MS columns (#329)
1 parent b436c0e commit 747408f

7 files changed

+135
-16
lines changed

HISTORY.rst

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ History
44

55
X.Y.Z (YYYY-MM-DD)
66
------------------
7+
* Improve table schema handling (:pr:`329`)
8+
* Identify channel and correlation-like dimensions in non-standard MS columns (:pr:`329`)
79
* DaskMSStore depends on ``fsspec >= 2022.7.0`` (:pr:`328`)
810
* Optimise `broadcast_arrays` in katdal import (:pr:`326`)
911
* Change `dask-ms katdal import` to `dask-ms import katdal` (:pr:`325`)

daskms/columns.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,23 @@ def column_metadata(column, table_proxy, table_schema, chunks, exemplar_row=0):
222222
"match shape of exemplar=%s" % (ndim, shape)
223223
)
224224

225-
# Extract dimension schema
225+
# Get the column schema, or create a default
226226
try:
227-
dims = table_schema[column]["dims"]
227+
column_schema = table_schema[column]
228228
except KeyError:
229-
dims = tuple("%s-%d" % (column, i) for i in range(1, len(shape) + 1))
229+
column_schema = {
230+
"dims": tuple("%s-%d" % (column, i) for i in range(1, len(shape) + 1))
231+
}
232+
233+
try:
234+
dims = column_schema["dims"]
235+
except KeyError:
236+
raise ColumnMetadataError(
237+
f"Column schema {column_schema} does not contain required 'dims' attribute"
238+
)
239+
240+
if not isinstance(dims, tuple) or not all(isinstance(d, str) for d in dims):
241+
raise ColumnMetadataError(f"Dimensions {dims} is not a tuple of strings")
230242

231243
dim_chunks = []
232244

daskms/dask_ms.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,12 @@ def xds_from_ms(ms, columns=None, index_cols=None, group_cols=None, **kwargs):
328328
kwargs.setdefault("table_schema", "MS")
329329

330330
return xds_from_table(
331-
ms, columns=columns, index_cols=index_cols, group_cols=group_cols, **kwargs
331+
ms,
332+
columns=columns,
333+
index_cols=index_cols,
334+
group_cols=group_cols,
335+
context="ms",
336+
**kwargs,
332337
)
333338

334339

daskms/reads.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def __init__(self, table, select_cols, group_cols, index_cols, **kwargs):
318318
self.table_keywords = kwargs.pop("table_keywords", False)
319319
self.column_keywords = kwargs.pop("column_keywords", False)
320320
self.table_proxy = kwargs.pop("table_proxy", False)
321+
self.context = kwargs.pop("context", None)
321322

322323
if len(kwargs) > 0:
323324
raise ValueError(f"Unhandled kwargs: {kwargs}")
@@ -359,8 +360,10 @@ def _single_dataset(self, table_proxy, orders, exemplar_row=0):
359360
coords = {"ROWID": rowid}
360361

361362
attrs = {DASKMS_PARTITION_KEY: ()}
362-
363-
return Dataset(variables, coords=coords, attrs=attrs)
363+
dataset = Dataset(variables, coords=coords, attrs=attrs)
364+
return self.postprocess_dataset(
365+
dataset, table_proxy, exemplar_row, orders, self.chunks[0], short_table_name
366+
)
364367

365368
def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
366369
_, t, s = table_path_split(self.canonical_name)
@@ -420,10 +423,64 @@ def _group_datasets(self, table_proxy, groups, exemplar_rows, orders):
420423
group_id = [gid.item() for gid in group_id]
421424
attrs.update(zip(self.group_cols, group_id))
422425

423-
datasets.append(Dataset(group_var_dims, attrs=attrs, coords=coords))
426+
dataset = Dataset(group_var_dims, attrs=attrs, coords=coords)
427+
dataset = self.postprocess_dataset(
428+
dataset, table_proxy, exemplar_row, order, group_chunks, array_suffix
429+
)
430+
datasets.append(dataset)
424431

425432
return datasets
426433

434+
def postprocess_dataset(
435+
self, dataset, table_proxy, exemplar_row, order, chunks, array_suffix
436+
):
437+
if not self.context or self.context != "ms":
438+
return dataset
439+
440+
# Fixup any non-standard columns
441+
# with dimensions like chan and corr
442+
try:
443+
chan = dataset.sizes["chan"]
444+
corr = dataset.sizes["corr"]
445+
except KeyError:
446+
return dataset
447+
448+
schema_updates = {}
449+
450+
for name, var in dataset.data_vars.items():
451+
new_dims = list(var.dims[1:])
452+
453+
unassigned = {"chan", "corr"}
454+
455+
for dim, dim_name in enumerate(var.dims[1:]):
456+
# An automicatically assigned dimension name
457+
if dim_name == f"{name}-{dim + 1}":
458+
if dataset.sizes[dim_name] == chan and "chan" in unassigned:
459+
new_dims[dim] = "chan"
460+
unassigned.discard("chan")
461+
elif dataset.sizes[dim_name] == corr and "corr" in unassigned:
462+
new_dims[dim] = "corr"
463+
unassigned.discard("corr")
464+
465+
new_dims = tuple(new_dims)
466+
if var.dims[1:] != new_dims:
467+
schema_updates[name] = {"dims": new_dims}
468+
469+
if not schema_updates:
470+
return dataset
471+
472+
return dataset.assign(
473+
**_dataset_variable_factory(
474+
table_proxy,
475+
schema_updates,
476+
list(schema_updates.keys()),
477+
exemplar_row,
478+
order,
479+
chunks,
480+
array_suffix,
481+
)
482+
)
483+
427484
def datasets(self):
428485
table_proxy = self._table_proxy_factory()
429486

daskms/table_schemas.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# -*- coding: utf-8 -*-
22

3+
from copy import deepcopy
4+
35
try:
46
from collections.abc import Mapping
57
except ImportError:
@@ -158,19 +160,20 @@ def lookup_table_schema(table_name, lookup_str):
158160
A dictionary of the form
159161
:code:`{column: {'dims': (...)}}`.
160162
"""
161-
if lookup_str is None:
162-
table_type = infer_table_type(table_name)
163+
table_type = infer_table_type(table_name)
163164

164-
try:
165-
return _ALL_SCHEMAS[table_type]
166-
except KeyError:
167-
raise ValueError(f"No schema registered " f"for table type '{table_type}'")
165+
# Infer a base schema from the inferred table, if available
166+
try:
167+
table_schema = deepcopy(_ALL_SCHEMAS[table_type])
168+
except KeyError:
169+
table_schema = {}
168170

169-
if not isinstance(lookup_str, (tuple, list)):
171+
if lookup_str is None:
172+
lookup_str = []
173+
elif not isinstance(lookup_str, (tuple, list)):
170174
lookup_str = [lookup_str]
171175

172-
table_schema = {}
173-
176+
# Add extra schema information to the table
174177
for ls in lookup_str:
175178
if isinstance(ls, Mapping):
176179
table_schema.update(ls)

daskms/tests/test_dataset.py

+5
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,11 @@ def test_dataset_table_schemas(ms):
272272
table_schema = ["MS", {"DATA": {"dims": data_dims}}]
273273
datasets = read_datasets(ms, [], [], [], table_schema=table_schema)
274274
assert datasets[0].data_vars["DATA"].dims == ("row",) + data_dims
275+
assert datasets[0].data_vars["UVW"].dims == ("row", "uvw")
276+
277+
datasets = read_datasets(ms, [], [], [], table_schema={"DATA": {"dims": data_dims}})
278+
assert datasets[0].data_vars["DATA"].dims == ("row",) + data_dims
279+
assert datasets[0].data_vars["UVW"].dims == ("row", "uvw")
275280

276281

277282
@pytest.mark.parametrize(

daskms/tests/test_ms_read_and_update.py

+35
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,38 @@ def test_mismatched_rowid(ms):
365365

366366
def test_request_rowid(ms):
367367
xdsl = xds_from_ms(ms, columns=["TIME", "ROWID"]) # noqa
368+
369+
370+
def test_postprocess_ms(ms):
371+
"""Test that postprocessing of MS variables identifies chan/corr like data"""
372+
xdsl = xds_from_ms(ms)
373+
374+
def _array(ds, dims):
375+
shape = tuple(ds.sizes[d] for d in dims)
376+
chunks = tuple(ds.chunks[d] for d in dims)
377+
return (dims, da.random.random(size=shape, chunks=chunks))
378+
379+
# Write some non-standard columns back to the MS
380+
for i, ds in enumerate(xdsl):
381+
xdsl[i] = ds.assign(
382+
**{
383+
"BITFLAG": _array(ds, ("row", "chan", "corr")),
384+
"HAS_CORRS": _array(ds, ("row", "corr")),
385+
"HAS_CHANS": _array(ds, ("row", "chan")),
386+
}
387+
)
388+
389+
dask.compute(xds_to_table(xdsl, ms))
390+
391+
for ds in xds_from_ms(ms, chunks={"row": 1, "chan": 1, "corr": 1}):
392+
assert ds.BITFLAG.dims == ("row", "chan", "corr")
393+
394+
assert ds.HAS_CORRS.dims == ("row", "corr")
395+
assert ds.HAS_CHANS.dims == ("row", "chan")
396+
397+
assert dict(ds.chunks) == {
398+
"uvw": (3,),
399+
"row": (1,) * ds.sizes["row"],
400+
"chan": (1,) * ds.sizes["chan"],
401+
"corr": (1,) * ds.sizes["corr"],
402+
}

0 commit comments

Comments
 (0)