Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use click to define command line applications #317

Merged
merged 10 commits into from
Mar 28, 2024
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Define dask-ms command line applications with click (:pr:`317`)
* Make poetry dev and docs groups optional (:pr:`316`)
* Only test Github Action Push events on master (:pr:`313`)
* Move consolidated metadata into partition subdirectories (:pr:`312`)
Expand Down
285 changes: 159 additions & 126 deletions daskms/apps/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from collections import defaultdict
import logging

import click
import dask.array as da

from daskms.apps.application import Application
from daskms.apps.formats import TableFormat, CasaFormat
from daskms.fsspec_store import DaskMSStore

Expand Down Expand Up @@ -42,26 +42,26 @@ def visit_Constant(self, node):
NONUNIFORM_SUBTABLES = ["SPECTRAL_WINDOW", "POLARIZATION", "FEED", "SOURCE"]


def _check_input_path(input: str):
input_path = DaskMSStore(input)
def _check_input_path(ctx, param, value):
input_path = DaskMSStore(value)

if not input_path.exists():
raise ArgumentTypeError(f"{input} is an invalid path.")
raise ArgumentTypeError(f"{value} is an invalid path.")

return input_path


def _check_output_path(output: str):
return DaskMSStore(output)
def _check_output_path(ctx, param, value):
return DaskMSStore(value)


def _check_exclude_columns(columns: str):
if not columns:
def _check_exclude_columns(ctx, param, value):
if not value:
return {}

outputs = defaultdict(set)

for column in (c.strip() for c in columns.split(",")):
for column in (c.strip() for c in value.split(",")):
bits = column.split("::")

if len(bits) == 2:
Expand All @@ -87,113 +87,146 @@ def _check_exclude_columns(columns: str):
return outputs


def parse_chunks(chunks: str):
return ChunkTransformer().visit(ast.parse(chunks))


class Convert(Application):
def __init__(self, args, log):
self.log = log
self.args = args

@staticmethod
def col_converter(columns):
if not columns:
return None

return [c.strip() for c in columns.split(",")]

@classmethod
def setup_parser(cls, parser):
parser.add_argument("input", type=_check_input_path)
parser.add_argument("-o", "--output", type=_check_output_path, required=True)
parser.add_argument(
"-x",
"--exclude",
type=_check_exclude_columns,
default="",
help="Comma-separated list of columns to exclude. "
"For example 'CORRECTED_DATA,"
"SPECTRAL_WINDOW::EFFECTIVE_BW' "
"will exclude CORRECTED_DATA "
"from the main table and "
"EFFECTIVE_BW from the SPECTRAL_WINDOW "
"subtable. SPECTRAL_WINDOW::* will exclude "
"the entire SPECTRAL_WINDOW subtable",
)
parser.add_argument(
"-g",
"--group-columns",
type=Convert.col_converter,
default="",
help="Comma-separatred list of columns to group "
"or partition the input dataset by. "
"This defaults to the default "
"for the underlying storage mechanism."
"This is only supported when converting "
"from casa format.",
)
parser.add_argument(
"-i",
"--index-columns",
type=Convert.col_converter,
default="",
help="Columns to sort "
"the input dataset by. "
"This defaults to the default "
"for the underlying storage mechanism."
"This is only supported when converting "
"from casa format.",
)
parser.add_argument(
"--taql-where",
default="",
help="TAQL where clause. "
"Only useable with CASA inputs. "
"For example, to exclude auto-correlations "
'"ANTENNA1 != ANTENNA2"',
)
parser.add_argument(
"-f",
"--format",
choices=["ms", "casa", "zarr", "parquet"],
default="zarr",
help="Output format",
)
parser.add_argument(
"--force",
action="store_true",
default=False,
help="Force overwrite of output",
)
parser.add_argument(
"-c",
"--chunks",
default="{row: 10000}",
help=(
"chunking schema applied to each dataset "
"e.g. {row: 1000, chan: 16, corr: 1}"
),
type=parse_chunks,
)
def parse_chunks(ctx, param, value):
return ChunkTransformer().visit(ast.parse(value))


def col_converter(ctx, param, value):
if not value:
return None

return [c.strip() for c in value.split(",")]


@click.command
@click.pass_context
@click.argument("input", required=True, callback=_check_input_path)
@click.option("-o", "--output", callback=_check_output_path, required=True)
@click.option(
"-x",
"--exclude",
default="",
callback=_check_exclude_columns,
help="Comma-separated list of columns to exclude. "
"For example 'CORRECTED_DATA,"
"SPECTRAL_WINDOW::EFFECTIVE_BW' "
"will exclude CORRECTED_DATA "
"from the main table and "
"EFFECTIVE_BW from the SPECTRAL_WINDOW "
"subtable. SPECTRAL_WINDOW::* will exclude "
"the entire SPECTRAL_WINDOW subtable",
)
@click.option(
"-g",
"--group-columns",
default="",
callback=col_converter,
help="Comma-separatred list of columns to group "
"or partition the input dataset by. "
"This defaults to the default "
"for the underlying storage mechanism."
"This is only supported when converting "
"from casa format.",
)
@click.option(
"-i",
"--index-columns",
default="",
callback=col_converter,
help="Columns to sort "
"the input dataset by. "
"This defaults to the default "
"for the underlying storage mechanism."
"This is only supported when converting "
"from casa format.",
)
@click.option(
"--taql-where",
default="",
help="TAQL where clause. "
"Only useable with CASA inputs. "
"For example, to exclude auto-correlations "
'"ANTENNA1 != ANTENNA2"',
)
@click.option(
"-f",
"--format",
type=click.Choice(["ms", "casa", "zarr", "parquet"]),
default="zarr",
)
@click.option("--force", is_flag=True)
@click.option(
"-c",
"--chunks",
default="{row: 10000}",
callback=parse_chunks,
help="chunking schema applied to each dataset "
"e.g. {row: 1000, chan: 16, corr: 1}",
)
def convert(
ctx,
input,
output,
exclude,
group_columns,
index_columns,
taql_where,
format,
force,
chunks,
):
converter = Convert(
input,
output,
exclude,
group_columns,
index_columns,
taql_where,
format,
force,
chunks,
)
converter.execute()


class Convert:
def __init__(
self,
input,
output,
exclude,
group_columns,
index_columns,
taql_where,
format,
force,
chunks,
):
self.input = input
self.output = output
self.exclude = exclude
self.group_columns = group_columns
self.index_columns = index_columns
self.taql_where = taql_where
self.format = format
self.force = force
self.chunks = chunks

def execute(self):
import dask

if self.args.output.exists():
if self.args.force:
self.args.output.rm(recursive=True)
if self.output.exists():
if self.force:
self.output.rm(recursive=True)
else:
raise ValueError(
f"{self.args.output} exists. " f"Use --force to overwrite."
)
raise ValueError(f"{self.output} exists. " f"Use --force to overwrite.")

writes = self.convert_table(self.args)
writes = self.convert_table()

dask.compute(writes)

def _expand_group_columns(self, datasets, args):
if not args.group_columns:
def _expand_group_columns(self, datasets):
if not self.group_columns:
return datasets

new_datasets = []
Expand All @@ -202,10 +235,10 @@ def _expand_group_columns(self, datasets, args):
# Remove grouping attribute and recreate grouping columns
new_group_vars = {}
row_chunks = ds.chunks["row"]
row_dims = ds.dims["row"]
row_dims = ds.sizes["row"]
attrs = ds.attrs

for column in args.group_columns:
for column in self.group_columns:
value = attrs.pop(column)
group_column = da.full(row_dims, value, chunks=row_chunks)
new_group_vars[column] = (("row",), group_column)
Expand All @@ -215,46 +248,46 @@ def _expand_group_columns(self, datasets, args):

return new_datasets

def convert_table(self, args):
in_fmt = TableFormat.from_store(args.input)
out_fmt = TableFormat.from_type(args.format)
def convert_table(self):
in_fmt = TableFormat.from_store(self.input)
out_fmt = TableFormat.from_type(self.format)

reader = in_fmt.reader(
group_columns=args.group_columns,
index_columns=args.index_columns,
taql_where=args.taql_where,
group_columns=self.group_columns,
index_columns=self.index_columns,
taql_where=self.taql_where,
)
writer = out_fmt.writer()

datasets = reader(args.input, chunks=args.chunks)
datasets = reader(self.input, chunks=self.chunks)

if exclude_columns := args.exclude.get("MAIN", False):
if exclude_columns := self.exclude.get("MAIN", False):
datasets = [
ds.drop_vars(exclude_columns, errors="ignore") for ds in datasets
]

if isinstance(out_fmt, CasaFormat):
# Reintroduce any grouping columns
datasets = self._expand_group_columns(datasets, args)
datasets = self._expand_group_columns(datasets)

log.info("Input: '%s' %s", in_fmt, str(args.input))
log.info("Output: '%s' %s", out_fmt, str(args.output))
log.info("Input: '%s' %s", in_fmt, str(self.input))
log.info("Output: '%s' %s", out_fmt, str(self.output))

writes = [writer(datasets, args.output)]
writes = [writer(datasets, self.output)]

# Now do the subtables
for table in list(in_fmt.subtables):
if (
table in {"SORTED_TABLE", "SOURCE"}
or args.exclude.get(table, "") == "*"
or self.exclude.get(table, "") == "*"
):
log.warning(f"Ignoring {table}")
continue

in_store = args.input.subtable_store(table)
in_store = self.input.subtable_store(table)
in_fmt = TableFormat.from_store(in_store)
out_store = args.output.subtable_store(table)
out_fmt = TableFormat.from_type(args.format, subtable=table)
out_store = self.output.subtable_store(table)
out_fmt = TableFormat.from_type(self.format, subtable=table)

reader = in_fmt.reader()
writer = out_fmt.writer()
Expand All @@ -264,7 +297,7 @@ def convert_table(self, args):
else:
datasets = reader(in_store)

if exclude_columns := args.exclude.get(table, False):
if exclude_columns := self.exclude.get(table, False):
datasets = [
ds.drop_vars(exclude_columns, errors="ignore") for ds in datasets
]
Expand Down
Loading