From 5eec1d93e90423bdee485c74d3cd7ef42ee0eb36 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 28 Mar 2024 14:08:16 +0200 Subject: [PATCH] Use click to define command line applications (#317) --- HISTORY.rst | 1 + daskms/apps/convert.py | 285 +++++++++++--------- daskms/apps/entrypoint.py | 70 +---- daskms/apps/tests/test_chunk_transformer.py | 9 +- daskms/apps/tests/test_convert.py | 39 +-- daskms/experimental/arrow/reads.py | 8 +- daskms/experimental/zarr/__init__.py | 2 +- 7 files changed, 204 insertions(+), 210 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index e40445a0..68026ec3 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Define dask-ms command line applications with click (:pr:`317`) * Make poetry dev and docs groups optional (:pr:`316`) * Only test Github Action Push events on master (:pr:`313`) * Move consolidated metadata into partition subdirectories (:pr:`312`) diff --git a/daskms/apps/convert.py b/daskms/apps/convert.py index 485c8846..fded293e 100644 --- a/daskms/apps/convert.py +++ b/daskms/apps/convert.py @@ -3,9 +3,9 @@ from collections import defaultdict import logging +import click import dask.array as da -from daskms.apps.application import Application from daskms.apps.formats import TableFormat, CasaFormat from daskms.fsspec_store import DaskMSStore @@ -42,26 +42,26 @@ def visit_Constant(self, node): NONUNIFORM_SUBTABLES = ["SPECTRAL_WINDOW", "POLARIZATION", "FEED", "SOURCE"] -def _check_input_path(input: str): - input_path = DaskMSStore(input) +def _check_input_path(ctx, param, value): + input_path = DaskMSStore(value) if not input_path.exists(): - raise ArgumentTypeError(f"{input} is an invalid path.") + raise ArgumentTypeError(f"{value} is an invalid path.") return input_path -def _check_output_path(output: str): - return DaskMSStore(output) +def _check_output_path(ctx, param, value): + return DaskMSStore(value) -def _check_exclude_columns(columns: str): - if not columns: +def _check_exclude_columns(ctx, param, value): + if not value: return {} outputs = defaultdict(set) - for column in (c.strip() for c in columns.split(",")): + for column in (c.strip() for c in value.split(",")): bits = column.split("::") if len(bits) == 2: @@ -87,113 +87,146 @@ def _check_exclude_columns(columns: str): return outputs -def parse_chunks(chunks: str): - return ChunkTransformer().visit(ast.parse(chunks)) - - -class Convert(Application): - def __init__(self, args, log): - self.log = log - self.args = args - - @staticmethod - def col_converter(columns): - if not columns: - return None - - return [c.strip() for c in columns.split(",")] - - @classmethod - def setup_parser(cls, parser): - parser.add_argument("input", type=_check_input_path) - parser.add_argument("-o", "--output", type=_check_output_path, required=True) - parser.add_argument( - "-x", - "--exclude", - type=_check_exclude_columns, - default="", - help="Comma-separated list of columns to exclude. " - "For example 'CORRECTED_DATA," - "SPECTRAL_WINDOW::EFFECTIVE_BW' " - "will exclude CORRECTED_DATA " - "from the main table and " - "EFFECTIVE_BW from the SPECTRAL_WINDOW " - "subtable. SPECTRAL_WINDOW::* will exclude " - "the entire SPECTRAL_WINDOW subtable", - ) - parser.add_argument( - "-g", - "--group-columns", - type=Convert.col_converter, - default="", - help="Comma-separatred list of columns to group " - "or partition the input dataset by. " - "This defaults to the default " - "for the underlying storage mechanism." - "This is only supported when converting " - "from casa format.", - ) - parser.add_argument( - "-i", - "--index-columns", - type=Convert.col_converter, - default="", - help="Columns to sort " - "the input dataset by. " - "This defaults to the default " - "for the underlying storage mechanism." - "This is only supported when converting " - "from casa format.", - ) - parser.add_argument( - "--taql-where", - default="", - help="TAQL where clause. " - "Only useable with CASA inputs. " - "For example, to exclude auto-correlations " - '"ANTENNA1 != ANTENNA2"', - ) - parser.add_argument( - "-f", - "--format", - choices=["ms", "casa", "zarr", "parquet"], - default="zarr", - help="Output format", - ) - parser.add_argument( - "--force", - action="store_true", - default=False, - help="Force overwrite of output", - ) - parser.add_argument( - "-c", - "--chunks", - default="{row: 10000}", - help=( - "chunking schema applied to each dataset " - "e.g. {row: 1000, chan: 16, corr: 1}" - ), - type=parse_chunks, - ) +def parse_chunks(ctx, param, value): + return ChunkTransformer().visit(ast.parse(value)) + + +def col_converter(ctx, param, value): + if not value: + return None + + return [c.strip() for c in value.split(",")] + + +@click.command +@click.pass_context +@click.argument("input", required=True, callback=_check_input_path) +@click.option("-o", "--output", callback=_check_output_path, required=True) +@click.option( + "-x", + "--exclude", + default="", + callback=_check_exclude_columns, + help="Comma-separated list of columns to exclude. " + "For example 'CORRECTED_DATA," + "SPECTRAL_WINDOW::EFFECTIVE_BW' " + "will exclude CORRECTED_DATA " + "from the main table and " + "EFFECTIVE_BW from the SPECTRAL_WINDOW " + "subtable. SPECTRAL_WINDOW::* will exclude " + "the entire SPECTRAL_WINDOW subtable", +) +@click.option( + "-g", + "--group-columns", + default="", + callback=col_converter, + help="Comma-separatred list of columns to group " + "or partition the input dataset by. " + "This defaults to the default " + "for the underlying storage mechanism." + "This is only supported when converting " + "from casa format.", +) +@click.option( + "-i", + "--index-columns", + default="", + callback=col_converter, + help="Columns to sort " + "the input dataset by. " + "This defaults to the default " + "for the underlying storage mechanism." + "This is only supported when converting " + "from casa format.", +) +@click.option( + "--taql-where", + default="", + help="TAQL where clause. " + "Only useable with CASA inputs. " + "For example, to exclude auto-correlations " + '"ANTENNA1 != ANTENNA2"', +) +@click.option( + "-f", + "--format", + type=click.Choice(["ms", "casa", "zarr", "parquet"]), + default="zarr", +) +@click.option("--force", is_flag=True) +@click.option( + "-c", + "--chunks", + default="{row: 10000}", + callback=parse_chunks, + help="chunking schema applied to each dataset " + "e.g. {row: 1000, chan: 16, corr: 1}", +) +def convert( + ctx, + input, + output, + exclude, + group_columns, + index_columns, + taql_where, + format, + force, + chunks, +): + converter = Convert( + input, + output, + exclude, + group_columns, + index_columns, + taql_where, + format, + force, + chunks, + ) + converter.execute() + + +class Convert: + def __init__( + self, + input, + output, + exclude, + group_columns, + index_columns, + taql_where, + format, + force, + chunks, + ): + self.input = input + self.output = output + self.exclude = exclude + self.group_columns = group_columns + self.index_columns = index_columns + self.taql_where = taql_where + self.format = format + self.force = force + self.chunks = chunks def execute(self): import dask - if self.args.output.exists(): - if self.args.force: - self.args.output.rm(recursive=True) + if self.output.exists(): + if self.force: + self.output.rm(recursive=True) else: - raise ValueError( - f"{self.args.output} exists. " f"Use --force to overwrite." - ) + raise ValueError(f"{self.output} exists. " f"Use --force to overwrite.") - writes = self.convert_table(self.args) + writes = self.convert_table() dask.compute(writes) - def _expand_group_columns(self, datasets, args): - if not args.group_columns: + def _expand_group_columns(self, datasets): + if not self.group_columns: return datasets new_datasets = [] @@ -202,10 +235,10 @@ def _expand_group_columns(self, datasets, args): # Remove grouping attribute and recreate grouping columns new_group_vars = {} row_chunks = ds.chunks["row"] - row_dims = ds.dims["row"] + row_dims = ds.sizes["row"] attrs = ds.attrs - for column in args.group_columns: + for column in self.group_columns: value = attrs.pop(column) group_column = da.full(row_dims, value, chunks=row_chunks) new_group_vars[column] = (("row",), group_column) @@ -215,46 +248,46 @@ def _expand_group_columns(self, datasets, args): return new_datasets - def convert_table(self, args): - in_fmt = TableFormat.from_store(args.input) - out_fmt = TableFormat.from_type(args.format) + def convert_table(self): + in_fmt = TableFormat.from_store(self.input) + out_fmt = TableFormat.from_type(self.format) reader = in_fmt.reader( - group_columns=args.group_columns, - index_columns=args.index_columns, - taql_where=args.taql_where, + group_columns=self.group_columns, + index_columns=self.index_columns, + taql_where=self.taql_where, ) writer = out_fmt.writer() - datasets = reader(args.input, chunks=args.chunks) + datasets = reader(self.input, chunks=self.chunks) - if exclude_columns := args.exclude.get("MAIN", False): + if exclude_columns := self.exclude.get("MAIN", False): datasets = [ ds.drop_vars(exclude_columns, errors="ignore") for ds in datasets ] if isinstance(out_fmt, CasaFormat): # Reintroduce any grouping columns - datasets = self._expand_group_columns(datasets, args) + datasets = self._expand_group_columns(datasets) - log.info("Input: '%s' %s", in_fmt, str(args.input)) - log.info("Output: '%s' %s", out_fmt, str(args.output)) + log.info("Input: '%s' %s", in_fmt, str(self.input)) + log.info("Output: '%s' %s", out_fmt, str(self.output)) - writes = [writer(datasets, args.output)] + writes = [writer(datasets, self.output)] # Now do the subtables for table in list(in_fmt.subtables): if ( table in {"SORTED_TABLE", "SOURCE"} - or args.exclude.get(table, "") == "*" + or self.exclude.get(table, "") == "*" ): log.warning(f"Ignoring {table}") continue - in_store = args.input.subtable_store(table) + in_store = self.input.subtable_store(table) in_fmt = TableFormat.from_store(in_store) - out_store = args.output.subtable_store(table) - out_fmt = TableFormat.from_type(args.format, subtable=table) + out_store = self.output.subtable_store(table) + out_fmt = TableFormat.from_type(self.format, subtable=table) reader = in_fmt.reader() writer = out_fmt.writer() @@ -264,7 +297,7 @@ def convert_table(self, args): else: datasets = reader(in_store) - if exclude_columns := args.exclude.get(table, False): + if exclude_columns := self.exclude.get(table, False): datasets = [ ds.drop_vars(exclude_columns, errors="ignore") for ds in datasets ] diff --git a/daskms/apps/entrypoint.py b/daskms/apps/entrypoint.py index d0469487..c910851a 100644 --- a/daskms/apps/entrypoint.py +++ b/daskms/apps/entrypoint.py @@ -1,67 +1,17 @@ -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import logging -import logging.config -from pathlib import Path -import sys +import click -def main(): - return EntryPoint(sys.argv[1:]).execute() +from daskms.apps.convert import convert -class EntryPoint: - LOGGING_INI = Path(__file__).parents[0] / "conf" / "logging.ini" +@click.group() +@click.pass_context +@click.option("--debug/--no-debug", default=False) +def main(ctx, debug): + logging.basicConfig(format="%(levelname)s - %(message)s", level=logging.INFO) + ctx.ensure_object(dict) + ctx.obj["DEBUG"] = debug - def __init__(self, cmdline_args): - self.cmdline_args = cmdline_args - def execute(self): - log = self._setup_logging() - - app_klasses = self._application_classes() - parser = self._create_parsers(app_klasses) - args = self._parse_args(parser, self.cmdline_args) - - try: - cmd_klass = app_klasses[args.command] - except KeyError: - raise ValueError( - f"No implementation class found " f"for command {args.command}" - ) - - cmd = cmd_klass(args, log) - cmd.execute() - - @classmethod - def _application_classes(cls): - from daskms.apps.convert import Convert - - return {"convert": Convert} - - @classmethod - def _create_parsers(cls, app_klasses): - parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - subparsers = parser.add_subparsers(help="command", dest="command") - - for app_name, klass in app_klasses.items(): - app_parser = subparsers.add_parser(app_name) - klass.setup_parser(app_parser) - - return parser - - @classmethod - def _parse_args(cls, parser, args): - parsed_args = parser.parse_args(args) - - if not parsed_args.command: - parser.print_help() - sys.exit(0) - - return parsed_args - - @classmethod - def _setup_logging(cls): - assert cls.LOGGING_INI.exists() - - logging.config.fileConfig(fname=cls.LOGGING_INI, disable_existing_loggers=False) - return logging.getLogger(__name__) +main.add_command(convert) diff --git a/daskms/apps/tests/test_chunk_transformer.py b/daskms/apps/tests/test_chunk_transformer.py index 477a47e9..97fa5955 100644 --- a/daskms/apps/tests/test_chunk_transformer.py +++ b/daskms/apps/tests/test_chunk_transformer.py @@ -2,5 +2,10 @@ def test_chunk_parsing(): - assert parse_chunks("{row: 1000, chan: 16}") == {"row": 1000, "chan": 16} - assert parse_chunks("{row: (1000, 1000, 10)}") == {"row": (1000, 1000, 10)} + assert parse_chunks(None, None, "{row: 1000, chan: 16}") == { + "row": 1000, + "chan": 16, + } + assert parse_chunks(None, None, "{row: (1000, 1000, 10)}") == { + "row": (1000, 1000, 10) + } diff --git a/daskms/apps/tests/test_convert.py b/daskms/apps/tests/test_convert.py index 933832c7..424585cd 100644 --- a/daskms/apps/tests/test_convert.py +++ b/daskms/apps/tests/test_convert.py @@ -1,7 +1,8 @@ -from argparse import ArgumentParser import logging -from daskms.apps.convert import Convert +from click.testing import CliRunner +from daskms.apps.entrypoint import main + from daskms import xds_from_storage_ms, xds_from_storage_table import pytest @@ -10,9 +11,9 @@ @pytest.mark.applications -@pytest.mark.parametrize("format", ["ms", "zarr", "parquet"]) +@pytest.mark.parametrize("format", ["zarr"]) def test_convert_application(tau_ms, format, tmp_path_factory): - OUTPUT = tmp_path_factory.mktemp(f"convert_{format}") / "output.{format}" + OUTPUT = tmp_path_factory.mktemp(f"convert_{format}") / f"output.{format}" exclude_columns = [ "ASDM_ANTENNA::*", @@ -22,37 +23,39 @@ def test_convert_application(tau_ms, format, tmp_path_factory): "ASDM_SOURCE::*", "ASDM_STATION::*", "POINTING::OVER_THE_TOP", + "SPECTRAL_WINDOW::ASSOC_SPW_ID", + "SPECTRAL_WINDOW::ASSOC_NATURE", "MODEL_DATA", ] args = [ str(tau_ms), - # "-g", - # "FIELD_ID,DATA_DESC_ID,SCAN_NUMBER", + "-g", + "FIELD_ID,DATA_DESC_ID,SCAN_NUMBER", "-x", ",".join(exclude_columns), "-o", str(OUTPUT), "--format", - "zarr", + format, "--force", ] - p = ArgumentParser() - Convert.setup_parser(p) - args = p.parse_args(args) - app = Convert(args, log) - app.execute() - - datasets = xds_from_storage_ms(OUTPUT) + runner = CliRunner() + result = runner.invoke(main, ["convert"] + args) + assert result.exit_code == 0 - for ds in datasets: + for ds in xds_from_storage_ms(OUTPUT): assert "MODEL_DATA" not in ds.data_vars assert "FLAG" in ds.data_vars assert "ROWID" in ds.coords - datasets = xds_from_storage_table(f"{str(OUTPUT)}::POINTING") - - for ds in datasets: + for ds in xds_from_storage_table(f"{OUTPUT}::POINTING"): assert "OVER_THE_TOP" not in ds.data_vars assert "NAME" in ds.data_vars + + for ds in xds_from_storage_table(f"{OUTPUT}::SPECTRAL_WINDOW"): + assert "CHAN_FREQ" in ds.data_vars + assert "CHAN_WIDTH" in ds.data_vars + assert "ASSOC_SPW_ID" not in ds.data_vars + assert "ASSOC_NATURE" not in ds.data_vars diff --git a/daskms/experimental/arrow/reads.py b/daskms/experimental/arrow/reads.py index ae5287d2..bfc4531c 100644 --- a/daskms/experimental/arrow/reads.py +++ b/daskms/experimental/arrow/reads.py @@ -236,8 +236,7 @@ def xds_from_parquet(store, columns=None, chunks=None, **kwargs): else: raise TypeError("chunks must be None or dict or list of dict") - table_path = "" if store.table else "MAIN" - + table_path = Path("" if store.table else "MAIN") fragments = list(map(Path, store.rglob("*.parquet"))) ds_cfg = defaultdict(list) @@ -246,7 +245,10 @@ def xds_from_parquet(store, columns=None, chunks=None, **kwargs): partition_schemas = set() for fragment in fragments: - *partitions, _ = fragment.relative_to(Path(table_path)).parts + if not fragment.is_relative_to(table_path): + continue + + *partitions, _ = fragment.relative_to(table_path).parts fragment = ParquetFileProxy(store, str(fragment)) fragment_meta = fragment.metadata metadata = json.loads(fragment_meta.metadata[DASKMS_METADATA.encode()]) diff --git a/daskms/experimental/zarr/__init__.py b/daskms/experimental/zarr/__init__.py index 0ef46303..6e7fdde2 100644 --- a/daskms/experimental/zarr/__init__.py +++ b/daskms/experimental/zarr/__init__.py @@ -434,7 +434,7 @@ def xds_from_zarr(store, columns=None, chunks=None, consolidated=True, **kwargs) if entry["type"] == "directory": _, dir_name = os.path.split(entry["name"]) if dir_name.startswith(table_name): - _, i = dir_name.split("_") + _, i = dir_name[len(table_name) :].split("_") partition_ids.append(int(i)) for g in sorted(partition_ids):