Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunk specification to the katdal export #318

Merged
merged 5 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Add chunk specification to ``dask-ms katdal import`` (:pr:`318`)
* Add a ``dask-ms katdal import`` application for exporting SARAO archive data directly to zarr (:pr:`315`)
* Define dask-ms command line applications with click (:pr:`317`)
* Make poetry dev and docs groups optional (:pr:`316`)
Expand Down
31 changes: 2 additions & 29 deletions daskms/apps/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast
from argparse import ArgumentTypeError
from collections import defaultdict
import logging
Expand All @@ -8,37 +7,11 @@

from daskms.apps.formats import TableFormat, CasaFormat
from daskms.fsspec_store import DaskMSStore
from daskms.utils import parse_chunks_dict

log = logging.getLogger(__name__)


class ChunkTransformer(ast.NodeTransformer):
def visit_Module(self, node):
if len(node.body) != 1 or not isinstance(node.body[0], ast.Expr):
raise ValueError("Module must contain a single expression")

expr = node.body[0]

if not isinstance(expr.value, ast.Dict):
raise ValueError("Expression must contain a dictionary")

return self.visit(expr).value

def visit_Dict(self, node):
keys = [self.visit(k) for k in node.keys]
values = [self.visit(v) for v in node.values]
return {k: v for k, v in zip(keys, values)}

def visit_Name(self, node):
return node.id

def visit_Tuple(self, node):
return tuple(self.visit(v) for v in node.elts)

def visit_Constant(self, node):
return node.n


NONUNIFORM_SUBTABLES = ["SPECTRAL_WINDOW", "POLARIZATION", "FEED", "SOURCE"]


Expand Down Expand Up @@ -88,7 +61,7 @@ def _check_exclude_columns(ctx, param, value):


def parse_chunks(ctx, param, value):
return ChunkTransformer().visit(ast.parse(value))
return parse_chunks_dict(value)


def col_converter(ctx, param, value):
Expand Down
12 changes: 10 additions & 2 deletions daskms/apps/katdal_import.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import click

from daskms.utils import parse_chunks_dict


@click.group()
@click.pass_context
Expand Down Expand Up @@ -58,10 +60,16 @@ def convert(self, value, param, ctx):
"'K,B,G'. Use 'default' for L1 + L2 and 'all' for "
"all available products.",
)
def _import(ctx, rdb_url, no_auto, pols_to_use, applycal, output_store):
@click.option(
"--chunks",
callback=lambda c, p, v: parse_chunks_dict(v),
default="{time: 10}",
help="Chunking values to apply to each dimension",
)
def _import(ctx, rdb_url, output_store, no_auto, pols_to_use, applycal, chunks):
"""Export an observation in the SARAO archive to zarr formation

RDB_URL is the SARAO archive link"""
from daskms.experimental.katdal import katdal_import

katdal_import(rdb_url, output_store, no_auto, applycal)
katdal_import(rdb_url, output_store, no_auto, applycal, chunks)
13 changes: 11 additions & 2 deletions daskms/experimental/katdal/katdal_import.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import os
import urllib

import dask

from daskms.fsspec_store import DaskMSStore
from daskms.utils import requires

log = logging.getLogger(__file__)

try:
import katdal
from katdal.dataset import DataSet
Expand All @@ -31,20 +35,25 @@ def default_output_name(url):


@requires("pip install dask-ms[katdal]", import_error)
def katdal_import(url: str, out_store: str, no_auto: bool, applycal: str):
def katdal_import(url: str, out_store: str, no_auto: bool, applycal: str, chunks: dict):
if isinstance(url, str):
dataset = katdal.open(url, appycal=applycal)
elif isinstance(url, DataSet):
dataset = url
else:
raise TypeError(f"{url} must be a string or a katdal DataSet")

facade = XarrayMSV2Facade(dataset, no_auto=no_auto)
facade = XarrayMSV2Facade(dataset, no_auto=no_auto, chunks=chunks)
main_xds, subtable_xds = facade.xarray_datasets()

if not out_store:
out_store = default_output_name(url)

out_store = DaskMSStore(out_store)
if out_store.exists():
log.warn("Removing previously existing %s", out_store)
out_store.rm("", recursive=True)

writes = [
xds_to_zarr(main_xds, out_store),
*(xds_to_zarr(ds, f"{out_store}::{k}") for k, ds in subtable_xds.items()),
Expand Down
27 changes: 25 additions & 2 deletions daskms/experimental/katdal/msv2_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,22 @@ def to_mjds(timestamp: Timestamp):
return timestamp.to_mjd() * 24 * 60 * 60


DEFAULT_TIME_CHUNKS = 100
DEFAULT_CHAN_CHUNKS = 4096
DEFAULT_CHUNKS = {"time": DEFAULT_TIME_CHUNKS, "chan": DEFAULT_CHAN_CHUNKS}


class XarrayMSV2Facade:
"""Provides a simplified xarray Dataset view over a katdal dataset"""

def __init__(self, dataset: DataSet, no_auto: bool = True, row_view: bool = True):
def __init__(
self,
dataset: DataSet,
no_auto: bool = True,
row_view: bool = True,
chunks: dict = None,
):
self._chunks = chunks or DEFAULT_CHUNKS
self._dataset = dataset
self._no_auto = no_auto
self._row_view = row_view
Expand All @@ -81,6 +93,10 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe
time_utc = dataset.timestamps
t_chunks, chan_chunks, cp_chunks = dataset.vis.dataset.chunks

# Override time and channel chunking
t_chunks = self._chunks.get("time", t_chunks)
chan_chunks = self._chunks.get("chan", chan_chunks)

# Modified Julian Date in Seconds
time_mjds = np.asarray([to_mjds(t) for t in map(Timestamp, time_utc)])

Expand Down Expand Up @@ -110,7 +126,14 @@ def _main_xarray_factory(self, field_id, state_id, scan_index, scan_state, targe

flags = DaskLazyIndexer(dataset.flags, (), (rechunk, flag_transpose))
weights = DaskLazyIndexer(dataset.weights, (), (rechunk, weight_transpose))
vis = DaskLazyIndexer(dataset.vis, (), transforms=(vis_transpose,))
vis = DaskLazyIndexer(
dataset.vis,
(),
transforms=(
rechunk,
vis_transpose,
),
)

time = da.from_array(time_mjds[:, None], chunks=(t_chunks, 1))
ant1 = da.from_array(cp_info.ant1_index[None, :], chunks=(1, cpi.shape[0]))
Expand Down
13 changes: 13 additions & 0 deletions daskms/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from daskms.utils import (
parse_chunks_dict,
promote_columns,
natural_order,
table_path_split,
Expand All @@ -15,6 +16,18 @@
)


def test_parse_chunks_dict():
assert parse_chunks_dict("{row: 1000}") == {"row": 1000}
assert parse_chunks_dict("{row: 1000, chan: 64}") == {"row": 1000, "chan": 64}
assert parse_chunks_dict("{row: (10, 10), chan: (4, 4)}") == {
"row": (10, 10),
"chan": (4, 4),
}

with pytest.raises(SyntaxError):
parse_chunks_dict("row:1000}")


def test_natural_order():
data = [f"{i}.parquet" for i in reversed(range(20))]
expected = [f"{i}.parquet" for i in range(20)]
Expand Down
36 changes: 32 additions & 4 deletions daskms/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
# -*- coding: utf-8 -*-

import ast
from collections import OrderedDict
import importlib.util
import logging
from pathlib import PurePath, Path
import re
import sys
import time
import inspect
import warnings

from dask.utils import funcname

# The numpy module may disappear during interpreter shutdown
# so explicitly import ndarray
from numpy import ndarray
Expand All @@ -21,6 +18,37 @@
log = logging.getLogger(__name__)


class ChunkTransformer(ast.NodeTransformer):
def visit_Module(self, node):
if len(node.body) != 1 or not isinstance(node.body[0], ast.Expr):
raise ValueError("Module must contain a single expression")

expr = node.body[0]

if not isinstance(expr.value, ast.Dict):
raise ValueError("Expression must contain a dictionary")

return self.visit(expr).value

def visit_Dict(self, node):
keys = [self.visit(k) for k in node.keys]
values = [self.visit(v) for v in node.values]
return {k: v for k, v in zip(keys, values)}

def visit_Name(self, node):
return node.id

def visit_Tuple(self, node):
return tuple(self.visit(v) for v in node.elts)

def visit_Constant(self, node):
return node.n


def parse_chunks_dict(chunks_str):
return ChunkTransformer().visit(ast.parse(chunks_str))


def natural_order(key):
return tuple(
int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", str(key))
Expand Down
Loading