From fcc1274171a916c860f2f2abb6c6b1c6960f5a29 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 28 Mar 2024 16:50:46 +0200 Subject: [PATCH] Add chunk specification to the katdal export --- daskms/apps/convert.py | 31 ++------------------- daskms/apps/katdal_import.py | 12 ++++++-- daskms/experimental/katdal/katdal_import.py | 13 +++++++-- daskms/experimental/katdal/msv2_facade.py | 27 ++++++++++++++++-- daskms/tests/test_utils.py | 9 ++++++ daskms/utils.py | 26 ++++++++++++++--- 6 files changed, 79 insertions(+), 39 deletions(-) diff --git a/daskms/apps/convert.py b/daskms/apps/convert.py index fded293e..462b23d1 100644 --- a/daskms/apps/convert.py +++ b/daskms/apps/convert.py @@ -1,4 +1,3 @@ -import ast from argparse import ArgumentTypeError from collections import defaultdict import logging @@ -8,37 +7,11 @@ from daskms.apps.formats import TableFormat, CasaFormat from daskms.fsspec_store import DaskMSStore +from daskms.utils import parse_chunks_dict log = logging.getLogger(__name__) -class ChunkTransformer(ast.NodeTransformer): - def visit_Module(self, node): - if len(node.body) != 1 or not isinstance(node.body[0], ast.Expr): - raise ValueError("Module must contain a single expression") - - expr = node.body[0] - - if not isinstance(expr.value, ast.Dict): - raise ValueError("Expression must contain a dictionary") - - return self.visit(expr).value - - def visit_Dict(self, node): - keys = [self.visit(k) for k in node.keys] - values = [self.visit(v) for v in node.values] - return {k: v for k, v in zip(keys, values)} - - def visit_Name(self, node): - return node.id - - def visit_Tuple(self, node): - return tuple(self.visit(v) for v in node.elts) - - def visit_Constant(self, node): - return node.n - - NONUNIFORM_SUBTABLES = ["SPECTRAL_WINDOW", "POLARIZATION", "FEED", "SOURCE"] @@ -88,7 +61,7 @@ def _check_exclude_columns(ctx, param, value): def parse_chunks(ctx, param, value): - return ChunkTransformer().visit(ast.parse(value)) + return parse_chunks_dict(value) def col_converter(ctx, param, value): diff --git a/daskms/apps/katdal_import.py b/daskms/apps/katdal_import.py index e5c629a5..6f79b271 100644 --- a/daskms/apps/katdal_import.py +++ b/daskms/apps/katdal_import.py @@ -1,5 +1,7 @@ import click +from daskms.utils import parse_chunks_dict + @click.group() @click.pass_context @@ -58,10 +60,16 @@ def convert(self, value, param, ctx): "'K,B,G'. Use 'default' for L1 + L2 and 'all' for " "all available products.", ) -def _import(ctx, rdb_url, no_auto, pols_to_use, applycal, output_store): +@click.option( + "--chunks", + callback=lambda c, p, v: parse_chunks_dict(v), + default="{time: 10}", + help="Chunking values to apply to each dimension", +) +def _import(ctx, rdb_url, output_store, no_auto, pols_to_use, applycal, chunks): """Export an observation in the SARAO archive to zarr formation RDB_URL is the SARAO archive link""" from daskms.experimental.katdal import katdal_import - katdal_import(rdb_url, output_store, no_auto, applycal) + katdal_import(rdb_url, output_store, no_auto, applycal, chunks) diff --git a/daskms/experimental/katdal/katdal_import.py b/daskms/experimental/katdal/katdal_import.py index 85898d94..b6c71500 100644 --- a/daskms/experimental/katdal/katdal_import.py +++ b/daskms/experimental/katdal/katdal_import.py @@ -1,10 +1,14 @@ +import logging import os import urllib import dask +from daskms.fsspec_store import DaskMSStore from daskms.utils import requires +log = logging.getLogger(__file__) + try: import katdal from katdal.dataset import DataSet @@ -31,7 +35,7 @@ def default_output_name(url): @requires("pip install dask-ms[katdal]", import_error) -def katdal_import(url: str, out_store: str, no_auto: bool, applycal: str): +def katdal_import(url: str, out_store: str, no_auto: bool, applycal: str, chunks: dict): if isinstance(url, str): dataset = katdal.open(url, appycal=applycal) elif isinstance(url, DataSet): @@ -39,12 +43,17 @@ def katdal_import(url: str, out_store: str, no_auto: bool, applycal: str): else: raise TypeError(f"{url} must be a string or a katdal DataSet") - facade = XarrayMSV2Facade(dataset, no_auto=no_auto) + facade = XarrayMSV2Facade(dataset, no_auto=no_auto, chunks=chunks) main_xds, subtable_xds = facade.xarray_datasets() if not out_store: out_store = default_output_name(url) + out_store = DaskMSStore(out_store) + if out_store.exists(): + log.warn("Removing previously existing %s", out_store) + out_store.rm("", recursive=True) + writes = [ xds_to_zarr(main_xds, out_store), *(xds_to_zarr(ds, f"{out_store}::{k}") for k, ds in subtable_xds.items()), diff --git a/daskms/experimental/katdal/msv2_facade.py b/daskms/experimental/katdal/msv2_facade.py index 9c47fa0b..b528bce6 100644 --- a/daskms/experimental/katdal/msv2_facade.py +++ b/daskms/experimental/katdal/msv2_facade.py @@ -58,10 +58,22 @@ def to_mjds(timestamp: Timestamp): return timestamp.to_mjd() * 24 * 60 * 60 +DEFAULT_TIME_CHUNKS = 100 +DEFAULT_CHAN_CHUNKS = 4096 +DEFAULT_CHUNKS = {"time": DEFAULT_TIME_CHUNKS, "chan": DEFAULT_CHAN_CHUNKS} + + class XarrayMSV2Facade: """Provides a simplified xarray Dataset view over a katdal dataset""" - def __init__(self, dataset: DataSet, no_auto: bool = True, row_view: bool = True): + def __init__( + self, + dataset: DataSet, + no_auto: bool = True, + row_view: bool = True, + chunks: dict = None, + ): + self._chunks = chunks or DEFAULT_CHUNKS self._dataset = dataset self._no_auto = no_auto self._row_view = row_view @@ -81,6 +93,10 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe time_utc = dataset.timestamps t_chunks, chan_chunks, cp_chunks = dataset.vis.dataset.chunks + # Override time and channel chunking + t_chunks = self._chunks.get("time", t_chunks) + chan_chunks = self._chunks.get("chan", chan_chunks) + # Modified Julian Date in Seconds time_mjds = np.asarray([to_mjds(t) for t in map(Timestamp, time_utc)]) @@ -110,7 +126,14 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose)) weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose)) - vis = DaskLazyIndexer(dataset.vis, (), transforms=(vis_transpose,)) + vis = DaskLazyIndexer( + dataset.vis, + (), + transforms=( + rechunk, + vis_transpose, + ), + ) time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1)) ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0])) diff --git a/daskms/tests/test_utils.py b/daskms/tests/test_utils.py index b3756228..beb89b85 100644 --- a/daskms/tests/test_utils.py +++ b/daskms/tests/test_utils.py @@ -7,6 +7,7 @@ import pytest from daskms.utils import ( + parse_chunks_dict, promote_columns, natural_order, table_path_split, @@ -15,6 +16,14 @@ ) +def test_parse_chunks_dict(): + assert parse_chunks_dict("{row: 1000}") == {"row": 1000} + assert parse_chunks_dict("{row: 1000, chan: 64}") == {"row": 1000, "chan": 64} + + with pytest.raises(ValueError): + parse_chunks_dict("row:1000}") + + def test_natural_order(): data = [f"{i}.parquet" for i in reversed(range(20))] expected = [f"{i}.parquet" for i in range(20)] diff --git a/daskms/utils.py b/daskms/utils.py index 843b33b0..3b6b6ab2 100644 --- a/daskms/utils.py +++ b/daskms/utils.py @@ -1,17 +1,13 @@ # -*- coding: utf-8 -*- from collections import OrderedDict -import importlib.util import logging from pathlib import PurePath, Path import re -import sys import time import inspect import warnings -from dask.utils import funcname - # The numpy module may disappear during interpreter shutdown # so explicitly import ndarray from numpy import ndarray @@ -21,6 +17,28 @@ log = logging.getLogger(__name__) +def parse_chunks_dict(chunks_str): + chunks_str = chunks_str.strip() + e = ValueError( + f"{chunks_str} is not of the form {{dim_1: size_1, ..., dim_n, size_n}}" + ) + if not (chunks_str.startswith("{") and chunks_str.endswith("}")): + raise e + + chunks = {} + + for kvmap in chunks_str[1:-1].split(","): + try: + k, v = (p.strip() for p in kvmap.split(":")) + v = int(v) + except (IndexError, ValueError): + raise e + else: + chunks[k] = v + + return chunks + + def natural_order(key): return tuple( int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(key))