Skip to content

Commit

Permalink
Use click to define command line applications (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins authored Mar 28, 2024
1 parent 8c52d71 commit 5eec1d9
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 210 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Define dask-ms command line applications with click (:pr:`317`)
* Make poetry dev and docs groups optional (:pr:`316`)
* Only test Github Action Push events on master (:pr:`313`)
* Move consolidated metadata into partition subdirectories (:pr:`312`)
Expand Down
285 changes: 159 additions & 126 deletions daskms/apps/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from collections import defaultdict
import logging

import click
import dask.array as da

from daskms.apps.application import Application
from daskms.apps.formats import TableFormat, CasaFormat
from daskms.fsspec_store import DaskMSStore

Expand Down Expand Up @@ -42,26 +42,26 @@ def visit_Constant(self, node):
NONUNIFORM_SUBTABLES = ["SPECTRAL_WINDOW", "POLARIZATION", "FEED", "SOURCE"]


def _check_input_path(input: str):
input_path = DaskMSStore(input)
def _check_input_path(ctx, param, value):
input_path = DaskMSStore(value)

if not input_path.exists():
raise ArgumentTypeError(f"{input} is an invalid path.")
raise ArgumentTypeError(f"{value} is an invalid path.")

return input_path


def _check_output_path(output: str):
return DaskMSStore(output)
def _check_output_path(ctx, param, value):
return DaskMSStore(value)


def _check_exclude_columns(columns: str):
if not columns:
def _check_exclude_columns(ctx, param, value):
if not value:
return {}

outputs = defaultdict(set)

for column in (c.strip() for c in columns.split(",")):
for column in (c.strip() for c in value.split(",")):
bits = column.split("::")

if len(bits) == 2:
Expand All @@ -87,113 +87,146 @@ def _check_exclude_columns(columns: str):
return outputs


def parse_chunks(chunks: str):
return ChunkTransformer().visit(ast.parse(chunks))


class Convert(Application):
def __init__(self, args, log):
self.log = log
self.args = args

@staticmethod
def col_converter(columns):
if not columns:
return None

return [c.strip() for c in columns.split(",")]

@classmethod
def setup_parser(cls, parser):
parser.add_argument("input", type=_check_input_path)
parser.add_argument("-o", "--output", type=_check_output_path, required=True)
parser.add_argument(
"-x",
"--exclude",
type=_check_exclude_columns,
default="",
help="Comma-separated list of columns to exclude. "
"For example 'CORRECTED_DATA,"
"SPECTRAL_WINDOW::EFFECTIVE_BW' "
"will exclude CORRECTED_DATA "
"from the main table and "
"EFFECTIVE_BW from the SPECTRAL_WINDOW "
"subtable. SPECTRAL_WINDOW::* will exclude "
"the entire SPECTRAL_WINDOW subtable",
)
parser.add_argument(
"-g",
"--group-columns",
type=Convert.col_converter,
default="",
help="Comma-separatred list of columns to group "
"or partition the input dataset by. "
"This defaults to the default "
"for the underlying storage mechanism."
"This is only supported when converting "
"from casa format.",
)
parser.add_argument(
"-i",
"--index-columns",
type=Convert.col_converter,
default="",
help="Columns to sort "
"the input dataset by. "
"This defaults to the default "
"for the underlying storage mechanism."
"This is only supported when converting "
"from casa format.",
)
parser.add_argument(
"--taql-where",
default="",
help="TAQL where clause. "
"Only useable with CASA inputs. "
"For example, to exclude auto-correlations "
'"ANTENNA1 != ANTENNA2"',
)
parser.add_argument(
"-f",
"--format",
choices=["ms", "casa", "zarr", "parquet"],
default="zarr",
help="Output format",
)
parser.add_argument(
"--force",
action="store_true",
default=False,
help="Force overwrite of output",
)
parser.add_argument(
"-c",
"--chunks",
default="{row: 10000}",
help=(
"chunking schema applied to each dataset "
"e.g. {row: 1000, chan: 16, corr: 1}"
),
type=parse_chunks,
)
def parse_chunks(ctx, param, value):
return ChunkTransformer().visit(ast.parse(value))


def col_converter(ctx, param, value):
if not value:
return None

return [c.strip() for c in value.split(",")]


@click.command
@click.pass_context
@click.argument("input", required=True, callback=_check_input_path)
@click.option("-o", "--output", callback=_check_output_path, required=True)
@click.option(
"-x",
"--exclude",
default="",
callback=_check_exclude_columns,
help="Comma-separated list of columns to exclude. "
"For example 'CORRECTED_DATA,"
"SPECTRAL_WINDOW::EFFECTIVE_BW' "
"will exclude CORRECTED_DATA "
"from the main table and "
"EFFECTIVE_BW from the SPECTRAL_WINDOW "
"subtable. SPECTRAL_WINDOW::* will exclude "
"the entire SPECTRAL_WINDOW subtable",
)
@click.option(
"-g",
"--group-columns",
default="",
callback=col_converter,
help="Comma-separatred list of columns to group "
"or partition the input dataset by. "
"This defaults to the default "
"for the underlying storage mechanism."
"This is only supported when converting "
"from casa format.",
)
@click.option(
"-i",
"--index-columns",
default="",
callback=col_converter,
help="Columns to sort "
"the input dataset by. "
"This defaults to the default "
"for the underlying storage mechanism."
"This is only supported when converting "
"from casa format.",
)
@click.option(
"--taql-where",
default="",
help="TAQL where clause. "
"Only useable with CASA inputs. "
"For example, to exclude auto-correlations "
'"ANTENNA1 != ANTENNA2"',
)
@click.option(
"-f",
"--format",
type=click.Choice(["ms", "casa", "zarr", "parquet"]),
default="zarr",
)
@click.option("--force", is_flag=True)
@click.option(
"-c",
"--chunks",
default="{row: 10000}",
callback=parse_chunks,
help="chunking schema applied to each dataset "
"e.g. {row: 1000, chan: 16, corr: 1}",
)
def convert(
ctx,
input,
output,
exclude,
group_columns,
index_columns,
taql_where,
format,
force,
chunks,
):
converter = Convert(
input,
output,
exclude,
group_columns,
index_columns,
taql_where,
format,
force,
chunks,
)
converter.execute()


class Convert:
def __init__(
self,
input,
output,
exclude,
group_columns,
index_columns,
taql_where,
format,
force,
chunks,
):
self.input = input
self.output = output
self.exclude = exclude
self.group_columns = group_columns
self.index_columns = index_columns
self.taql_where = taql_where
self.format = format
self.force = force
self.chunks = chunks

def execute(self):
import dask

if self.args.output.exists():
if self.args.force:
self.args.output.rm(recursive=True)
if self.output.exists():
if self.force:
self.output.rm(recursive=True)
else:
raise ValueError(
f"{self.args.output} exists. " f"Use --force to overwrite."
)
raise ValueError(f"{self.output} exists. " f"Use --force to overwrite.")

writes = self.convert_table(self.args)
writes = self.convert_table()

dask.compute(writes)

def _expand_group_columns(self, datasets, args):
if not args.group_columns:
def _expand_group_columns(self, datasets):
if not self.group_columns:
return datasets

new_datasets = []
Expand All @@ -202,10 +235,10 @@ def _expand_group_columns(self, datasets, args):
# Remove grouping attribute and recreate grouping columns
new_group_vars = {}
row_chunks = ds.chunks["row"]
row_dims = ds.dims["row"]
row_dims = ds.sizes["row"]
attrs = ds.attrs

for column in args.group_columns:
for column in self.group_columns:
value = attrs.pop(column)
group_column = da.full(row_dims, value, chunks=row_chunks)
new_group_vars[column] = (("row",), group_column)
Expand All @@ -215,46 +248,46 @@ def _expand_group_columns(self, datasets, args):

return new_datasets

def convert_table(self, args):
in_fmt = TableFormat.from_store(args.input)
out_fmt = TableFormat.from_type(args.format)
def convert_table(self):
in_fmt = TableFormat.from_store(self.input)
out_fmt = TableFormat.from_type(self.format)

reader = in_fmt.reader(
group_columns=args.group_columns,
index_columns=args.index_columns,
taql_where=args.taql_where,
group_columns=self.group_columns,
index_columns=self.index_columns,
taql_where=self.taql_where,
)
writer = out_fmt.writer()

datasets = reader(args.input, chunks=args.chunks)
datasets = reader(self.input, chunks=self.chunks)

if exclude_columns := args.exclude.get("MAIN", False):
if exclude_columns := self.exclude.get("MAIN", False):
datasets = [
ds.drop_vars(exclude_columns, errors="ignore") for ds in datasets
]

if isinstance(out_fmt, CasaFormat):
# Reintroduce any grouping columns
datasets = self._expand_group_columns(datasets, args)
datasets = self._expand_group_columns(datasets)

log.info("Input: '%s' %s", in_fmt, str(args.input))
log.info("Output: '%s' %s", out_fmt, str(args.output))
log.info("Input: '%s' %s", in_fmt, str(self.input))
log.info("Output: '%s' %s", out_fmt, str(self.output))

writes = [writer(datasets, args.output)]
writes = [writer(datasets, self.output)]

# Now do the subtables
for table in list(in_fmt.subtables):
if (
table in {"SORTED_TABLE", "SOURCE"}
or args.exclude.get(table, "") == "*"
or self.exclude.get(table, "") == "*"
):
log.warning(f"Ignoring {table}")
continue

in_store = args.input.subtable_store(table)
in_store = self.input.subtable_store(table)
in_fmt = TableFormat.from_store(in_store)
out_store = args.output.subtable_store(table)
out_fmt = TableFormat.from_type(args.format, subtable=table)
out_store = self.output.subtable_store(table)
out_fmt = TableFormat.from_type(self.format, subtable=table)

reader = in_fmt.reader()
writer = out_fmt.writer()
Expand All @@ -264,7 +297,7 @@ def convert_table(self, args):
else:
datasets = reader(in_store)

if exclude_columns := args.exclude.get(table, False):
if exclude_columns := self.exclude.get(table, False):
datasets = [
ds.drop_vars(exclude_columns, errors="ignore") for ds in datasets
]
Expand Down
Loading

0 comments on commit 5eec1d9

Please sign in to comment.