From 1c2357d5b1a0b14a90f5385aca7d1e33d25b6ce8 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Mon, 25 May 2020 17:32:59 +0200 Subject: [PATCH 01/33] moving custom reductions into separate module --- shade_ms/data_plots.py | 169 +----------------------------------- shade_ms/ds_ext.py | 192 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+), 166 deletions(-) create mode 100644 shade_ms/ds_ext.py diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index eb4d8eb..778524b 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -7,9 +7,7 @@ import daskms import dask.array as da -import dask.array.ma as dama import dask.dataframe as dask_df -import datashape.coretypes import xarray import holoviews as holoviews import holoviews.operation.datashader @@ -25,7 +23,7 @@ from collections import OrderedDict from . import data_mappers from .data_mappers import DataAxis - +from .ds_ext import by_integers, by_span USE_REDUCE_BY = False @@ -202,169 +200,6 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, log.info(": complete") return output_dataframes, total_num_points -from datashader.utils import ngjit - -try: - import cudf -except ImportError: - cudf = None - -class count_integers(datashader.count_cat): - """Aggregator. Counts of all elements in ``column``, grouped by value of another column. Like datashader.count_cat, - but for normal columns.""" - _dshape = datashape.dshape(datashape.coretypes.int32) - - def __init__(self, column, modulo): - datashader.count_cat.__init__(self, column) - self.modulo = modulo - self.codes = xarray.DataArray(list(range(self.modulo))) - - def validate(self, in_dshape): - pass - - @property - def inputs(self): - return (datashader.reductions.extract(self.column), ) - - @staticmethod - @ngjit - def _append(x, y, agg, field): - agg[y, x, int(field)] += 1 - - def out_dshape(self, input_dshape): - return datashape.util.dshape(datashape.Record([(c, datashape.coretypes.int32) for c in range(self.modulo)])) - - def _build_finalize(self, dshape): - def finalize(bases, cuda=False, **kwargs): - dims = kwargs['dims'] + [self.column] - coords = kwargs['coords'] - coords[self.column] = list(self.codes.values) - return xarray.DataArray(bases[0], dims=dims, coords=coords) - return finalize - -class extract_multi(datashader.reductions.category_values): - """Extract multiple columns from a dataframe as a numpy array of values. Like datashader.category_values, - but with two columns.""" - def apply(self, df): - if cudf and isinstance(df, cudf.DataFrame): - import cupy - cols = [] - for column in self.columns: - nullval = np.nan if df[self.columns[1]].dtype.kind == 'f' else 0 - cols.append(df[self.columns[0]].to_gpu_array(fillna=nullval)) - return cupy.stack(cols, axis=-1) - else: - return np.stack([df[col].values for col in self.columns], axis=-1) - -class by_integers(datashader.by): - """Like datashader.by, but for integer-valued columns.""" - def __init__(self, cat_column, reduction, modulo): - super().__init__(cat_column, reduction) - self.modulo = modulo - - def _build_temps(self, cuda=False): - return tuple(by_integers(self.cat_column, tmp, self.modulo) for tmp in self.reduction._build_temps(cuda)) - - def validate(self, in_dshape): - if not self.cat_column in in_dshape.dict: - raise ValueError("specified column not found") - - self.reduction.validate(in_dshape) - - def out_dshape(self, input_dshape): - cats = list(range(self.modulo)) - red_shape = self.reduction.out_dshape(input_dshape) - return datashape.util.dshape(datashape.Record([(c, red_shape) for c in cats])) - - @property - def inputs(self): - if self.val_column is not None: - return (extract_multi(self.columns),) - else: - return (datashader.reductions.extract(self.columns[0]),) - - def _build_bases(self, cuda=False): - bases = self.reduction._build_bases(cuda) - if len(bases) == 1 and bases[0] is self: - return bases - return tuple(by_integers(self.cat_column, base, self.modulo) for base in bases) - - def _build_append(self, dshape, schema, cuda=False): - f = self.reduction._build_append(dshape, schema, cuda) - # because we transposed, we also need to flip the - # order of the x/y arguments - if isinstance(self.reduction, datashader.reductions.m2): - def _categorical_append(x, y, agg, cols, tmp1, tmp2, mod=self.modulo): - _agg = agg.transpose() - _ind = int(cols[0]) % mod - f(y, x, _agg[_ind], cols[1], tmp1[_ind], tmp2[_ind]) - elif self.val_column is not None: - def _categorical_append(x, y, agg, field, mod=self.modulo): - _agg = agg.transpose() - f(y, x, _agg[int(field[0]) % mod], field[1]) - else: - def _categorical_append(x, y, agg, field, mod=self.modulo): - _agg = agg.transpose() - f(y, x, _agg[int(field) % mod]) - - return ngjit(_categorical_append) - - def _build_finalize(self, dshape): - cats = list(range(self.modulo)) - - def finalize(bases, cuda=False, **kwargs): - kwargs['dims'] += [self.cat_column] - kwargs['coords'][self.cat_column] = cats - return self.reduction._finalize(bases, cuda=cuda, **kwargs) - - return finalize - -class by_span(by_integers): - """Like datashader.by, but for float-valued columns.""" - def __init__(self, cat_column, reduction, offset, delta, nsteps): - super().__init__(cat_column, reduction, nsteps+1) # allocate extra category for NaNs - self.offset = offset - self.delta = delta - self.nsteps = nsteps - - def _build_temps(self, cuda=False): - return tuple(by_span(self.cat_column, tmp, self.offset, self.delta, self.nsteps) for tmp in self.reduction._build_temps(cuda)) - - def _build_bases(self, cuda=False): - bases = self.reduction._build_bases(cuda) - if len(bases) == 1 and bases[0] is self: - return bases - return tuple(by_span(self.cat_column, base, self.offset, self.delta, self.nsteps) for base in bases) - - def _build_append(self, dshape, schema, cuda=False): - f = self.reduction._build_append(dshape, schema, cuda) - # because we transposed, we also need to flip the - # order of the x/y arguments - if isinstance(self.reduction, datashader.reductions.m2): - def _categorical_append(x, y, agg, cols, tmp1, tmp2, minval=self.offset, d=self.delta, n=self.nsteps): - _agg = agg.transpose() - value = cols[0] - _ind = min(max(0, int((value - minval)/d)), n-1) if value == value else n # value != itself when NaN - #print("a", _ind, cols[0]) - f(y, x, _agg[_ind], cols[1], tmp1[_ind], tmp2[_ind]) - elif self.val_column is not None: - def _categorical_append(x, y, agg, field, minval=self.offset, d=self.delta, n=self.nsteps): - _agg = agg.transpose() - value = field[0] - _ind = min(max(0, int((value - minval)/d)), n-1) if value == value else n # value != itself when NaN - #print("b", _ind, field[0]) - f(y, x, _agg[_ind], field[1]) - else: - def _categorical_append(x, y, agg, field, minval=self.offset, d=self.delta, n=self.nsteps): - _agg = agg.transpose() - value = field - _ind = min(max(0, int((value - minval)/d)), n-1) if value == value else n # value != itself when NaN - #print("c", _ind, field, field != field) - f(y, x, _agg[_ind]) - - return ngjit(_categorical_append) - - def compute_bounds(unknowns, bounds, ddf): """ Given a list of axis with unknown bounds, computes missing bounds and updates the bounds dict @@ -385,6 +220,8 @@ def compute_bounds(unknowns, bounds, ddf): minval, maxval = minval-1, minval+1 bounds[axis] = minval, maxval + + def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, normalize, xlabel, ylabel, title, pngname, options=None): diff --git a/shade_ms/ds_ext.py b/shade_ms/ds_ext.py new file mode 100644 index 0000000..34bd060 --- /dev/null +++ b/shade_ms/ds_ext.py @@ -0,0 +1,192 @@ +# -*- coding: future_fstrings -*- + +import datashape.coretypes +import datashader.transfer_functions +import datashader.reductions + +import xarray +import datashape.coretypes +import datashader.transfer_functions +import datashader.reductions +import numpy as np +from datashader.utils import ngjit +from datashader.reductions import Preprocess + +try: + import cudf +except ImportError: + cudf = None + +# class count_integers(datashader.count_cat): +# """Aggregator. Counts of all elements in ``column``, grouped by value of another column. Like datashader.count_cat, +# but for normal columns.""" +# _dshape = datashape.dshape(datashape.coretypes.int32) +# +# def __init__(self, column, modulo): +# datashader.count_cat.__init__(self, column) +# self.modulo = modulo +# self.codes = xarray.DataArray(list(range(self.modulo))) +# +# def validate(self, in_dshape): +# pass +# +# @property +# def inputs(self): +# return (datashader.reductions.extract(self.column), ) +# +# @staticmethod +# @ngjit +# def _append(x, y, agg, field): +# agg[y, x, int(field)] += 1 +# +# def out_dshape(self, input_dshape): +# return datashape.util.dshape(datashape.Record([(c, datashape.coretypes.int32) for c in range(self.modulo)])) +# +# def _build_finalize(self, dshape): +# def finalize(bases, cuda=False, **kwargs): +# dims = kwargs['dims'] + [self.column] +# coords = kwargs['coords'] +# coords[self.column] = list(self.codes.values) +# return xarray.DataArray(bases[0], dims=dims, coords=coords) +# return finalize + + +class extract_multi(datashader.reductions.category_values): + """Extract multiple columns from a dataframe as a numpy array of values. Like datashader.category_values, + but with two columns.""" + def apply(self, df): + if cudf and isinstance(df, cudf.DataFrame): + import cupy + cols = [] + for column in self.columns: + nullval = np.nan if df[self.columns[1]].dtype.kind == 'f' else 0 + cols.append(df[self.columns[0]].to_gpu_array(fillna=nullval)) + return cupy.stack(cols, axis=-1) + else: + return np.stack([df[col].values for col in self.columns], axis=-1) + +class integer_modulo(Preprocess): + def __init__(self, column, modulo): + super().__init__(column) + self.modulo = modulo + + """Extract just the category codes from a categorical column.""" + def apply(self, df): + if cudf and isinstance(df, cudf.DataFrame): + return (df[self.column] % self.modulo).to_gpu_array() + else: + return df[self.column].values % self.modulo + +class integer_modulo_values(Preprocess): + """Extract multiple columns from a dataframe as a numpy array of values.""" + def __init__(self, columns, modulo): + self.columns = list(columns) + self.modulo = modulo + + @property + def inputs(self): + return self.columns + + def apply(self, df): + if cudf and isinstance(df, cudf.DataFrame): + import cupy + if df[self.columns[1]].dtype.kind == 'f': + nullval = np.nan + else: + nullval = 0 + a = (df[self.columns[0]] % self.modulo).to_gpu_array() + b = df[self.columns[1]].to_gpu_array(fillna=nullval) + return cupy.stack((a, b), axis=-1) + else: + a = df[self.columns[0]].values % self.modulo + b = df[self.columns[1]].values + return np.stack((a, b), axis=-1) + + +class by_integers(datashader.by): + """Like datashader.by, but for integer-valued columns.""" + def __init__(self, cat_column, reduction, modulo): + super().__init__(cat_column, reduction) + self.modulo = modulo + + def _build_temps(self, cuda=False): + return tuple(by_integers(self.cat_column, tmp, self.modulo) for tmp in self.reduction._build_temps(cuda)) + + def validate(self, in_dshape): + if not self.cat_column in in_dshape.dict: + raise ValueError("specified column not found") + + self.reduction.validate(in_dshape) + + def out_dshape(self, input_dshape): + cats = list(range(self.modulo)) + red_shape = self.reduction.out_dshape(input_dshape) + return datashape.util.dshape(datashape.Record([(c, red_shape) for c in cats])) + + @property + def inputs(self): + if self.val_column is not None: + return (integer_modulo_values(self.columns, self.modulo),) + else: + return (integer_modulo(self.columns[0], self.modulo),) + + def _build_bases(self, cuda=False): + bases = self.reduction._build_bases(cuda) + if len(bases) == 1 and bases[0] is self: + return bases + return tuple(by_integers(self.cat_column, base, self.modulo) for base in bases) + + def _build_finalize(self, dshape): + cats = list(range(self.modulo)) + + def finalize(bases, cuda=False, **kwargs): + kwargs['dims'] += [self.cat_column] + kwargs['coords'][self.cat_column] = cats + return self.reduction._finalize(bases, cuda=cuda, **kwargs) + + return finalize + +class by_span(by_integers): + """Like datashader.by, but for float-valued columns.""" + def __init__(self, cat_column, reduction, offset, delta, nsteps): + super().__init__(cat_column, reduction, nsteps+1) # allocate extra category for NaNs + self.offset = offset + self.delta = delta + self.nsteps = nsteps + + def _build_temps(self, cuda=False): + return tuple(by_span(self.cat_column, tmp, self.offset, self.delta, self.nsteps) for tmp in self.reduction._build_temps(cuda)) + + def _build_bases(self, cuda=False): + bases = self.reduction._build_bases(cuda) + if len(bases) == 1 and bases[0] is self: + return bases + return tuple(by_span(self.cat_column, base, self.offset, self.delta, self.nsteps) for base in bases) + + def _build_append(self, dshape, schema, cuda=False): + f = self.reduction._build_append(dshape, schema, cuda) + # because we transposed, we also need to flip the + # order of the x/y arguments + if isinstance(self.reduction, datashader.reductions.m2): + def _categorical_append(x, y, agg, cols, tmp1, tmp2, minval=self.offset, d=self.delta, n=self.nsteps): + _agg = agg.transpose() + value = cols[0] + _ind = min(max(0, int((value - minval)/d)), n-1) if value == value else n # value != itself when NaN + #print("a", _ind, cols[0]) + f(y, x, _agg[_ind], cols[1], tmp1[_ind], tmp2[_ind]) + elif self.val_column is not None: + def _categorical_append(x, y, agg, field, minval=self.offset, d=self.delta, n=self.nsteps): + _agg = agg.transpose() + value = field[0] + _ind = min(max(0, int((value - minval)/d)), n-1) if value == value else n # value != itself when NaN + #print("b", _ind, field[0]) + f(y, x, _agg[_ind], field[1]) + else: + def _categorical_append(x, y, agg, field, minval=self.offset, d=self.delta, n=self.nsteps): + _agg = agg.transpose() + value = field + _ind = min(max(0, int((value - minval)/d)), n-1) if value == value else n # value != itself when NaN + #print("c", _ind, field, field != field) + f(y, x, _agg[_ind]) + + return ngjit(_categorical_append) From 95080a97d5e32eec37811ed63b24783c59c47879 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Mon, 25 May 2020 22:30:19 +0200 Subject: [PATCH 02/33] updated reductions for compaitiblity with ds-0.11.x --- shade_ms/ds_ext.py | 191 +++++++++++++++++++++------------------------ 1 file changed, 89 insertions(+), 102 deletions(-) diff --git a/shade_ms/ds_ext.py b/shade_ms/ds_ext.py index 34bd060..fe7bd09 100644 --- a/shade_ms/ds_ext.py +++ b/shade_ms/ds_ext.py @@ -1,108 +1,108 @@ # -*- coding: future_fstrings -*- -import datashape.coretypes -import datashader.transfer_functions -import datashader.reductions - -import xarray -import datashape.coretypes -import datashader.transfer_functions +import datashape +import datashape.coretypes as ct import datashader.reductions import numpy as np -from datashader.utils import ngjit -from datashader.reductions import Preprocess try: import cudf except ImportError: cudf = None -# class count_integers(datashader.count_cat): -# """Aggregator. Counts of all elements in ``column``, grouped by value of another column. Like datashader.count_cat, -# but for normal columns.""" -# _dshape = datashape.dshape(datashape.coretypes.int32) -# -# def __init__(self, column, modulo): -# datashader.count_cat.__init__(self, column) -# self.modulo = modulo -# self.codes = xarray.DataArray(list(range(self.modulo))) -# -# def validate(self, in_dshape): -# pass -# -# @property -# def inputs(self): -# return (datashader.reductions.extract(self.column), ) -# -# @staticmethod -# @ngjit -# def _append(x, y, agg, field): -# agg[y, x, int(field)] += 1 -# -# def out_dshape(self, input_dshape): -# return datashape.util.dshape(datashape.Record([(c, datashape.coretypes.int32) for c in range(self.modulo)])) -# -# def _build_finalize(self, dshape): -# def finalize(bases, cuda=False, **kwargs): -# dims = kwargs['dims'] + [self.column] -# coords = kwargs['coords'] -# coords[self.column] = list(self.codes.values) -# return xarray.DataArray(bases[0], dims=dims, coords=coords) -# return finalize - - -class extract_multi(datashader.reductions.category_values): - """Extract multiple columns from a dataframe as a numpy array of values. Like datashader.category_values, - but with two columns.""" - def apply(self, df): - if cudf and isinstance(df, cudf.DataFrame): - import cupy - cols = [] - for column in self.columns: - nullval = np.nan if df[self.columns[1]].dtype.kind == 'f' else 0 - cols.append(df[self.columns[0]].to_gpu_array(fillna=nullval)) - return cupy.stack(cols, axis=-1) - else: - return np.stack([df[col].values for col in self.columns], axis=-1) - -class integer_modulo(Preprocess): +# couldn't find anything in the datashape docs about how to check if a CType is an integer, +# so just define a big set +IntegerTypes = {ct.bool_, ct.uint8, ct.uint16, ct.uint32, ct.uint64, ct.int8, ct.int16, ct.int32, ct.int64} + + +def _column_modulo(df, column, modulo): + """ + Helper function. Takes a DataFrame column, modulo integer value + """ + if cudf and isinstance(df, cudf.DataFrame): + ## dunno how to do this in CUDA + raise NotImplementedError("this feature is not implemented in cudf") + ## but is it as simple as this? + # return (df[column] % modulo).to_gpu_array() + else: + return df[column].values % modulo + + +def _column_discretization(df, column, offset, delta, nsteps): + """ + Helper function. Takes a DataFrame column, modulo integer value + """ + if cudf and isinstance(df, cudf.DataFrame): + ## dunno how to do this in CUDA + raise NotImplementedError("this feature is not implemented in cudf") + else: + value = df[column].values + index = ((value - offset) / delta).astype(np.uint32) + index[index < 0] = 0 + index[index >= nsteps] = nsteps - 1 + index[np.isnan(value)] = nsteps + return index + + +class integer_modulo(datashader.reductions.category_codes): + """A variation on category_codes that replaces categories by the values from an integer column, modulo a certain number""" def __init__(self, column, modulo): super().__init__(column) self.modulo = modulo - """Extract just the category codes from a categorical column.""" def apply(self, df): - if cudf and isinstance(df, cudf.DataFrame): - return (df[self.column] % self.modulo).to_gpu_array() - else: - return df[self.column].values % self.modulo + return _column_discretization(df, self.column, self.modulo) + -class integer_modulo_values(Preprocess): - """Extract multiple columns from a dataframe as a numpy array of values.""" +class float_discretization(integer_modulo): + """A variation on category_codes that replaces categories by the values from an integer column, modulo a certain number""" + def __init__(self, column, offset, delta, nsteps): + super().__init__(column, nsteps) + self.offset = offset + self.delta = delta + self.nsteps = nsteps + + def apply(self, df): + return _column_discretization(df, self.column, self.offset, self.delta, self.nsteps) + + +class integer_modulo_values(datashader.reductions.category_values): + """A variation on category_values that replaces categories by the values from an integer column, modulo a certain number""" def __init__(self, columns, modulo): - self.columns = list(columns) + super().__init__(columns) self.modulo = modulo - @property - def inputs(self): - return self.columns - def apply(self, df): + a = _column_modulo(df, self.columns[0], self.modulo) + return self._attach_value(df, a) + + def _attach_value(self, df, a): if cudf and isinstance(df, cudf.DataFrame): import cupy if df[self.columns[1]].dtype.kind == 'f': nullval = np.nan else: nullval = 0 - a = (df[self.columns[0]] % self.modulo).to_gpu_array() b = df[self.columns[1]].to_gpu_array(fillna=nullval) return cupy.stack((a, b), axis=-1) else: - a = df[self.columns[0]].values % self.modulo b = df[self.columns[1]].values return np.stack((a, b), axis=-1) +class float_discretization_values(integer_modulo_values): + """A variation on category_codes that replaces categories by the values from an integer column, modulo a certain number""" + def __init__(self, columns, offset, delta, nsteps): + super().__init__(columns, nsteps+1) # one extra category for NaNs + self.offset = offset + self.delta = delta + self.nsteps = nsteps + + def apply(self, df): + a = _column_discretization(df, self.columns[0], self.offset, self.delta, self.nsteps) + return self._attach_value(df, a) + + class by_integers(datashader.by): """Like datashader.by, but for integer-valued columns.""" def __init__(self, cat_column, reduction, modulo): @@ -115,7 +115,8 @@ def _build_temps(self, cuda=False): def validate(self, in_dshape): if not self.cat_column in in_dshape.dict: raise ValueError("specified column not found") - + if not in_dshape.measure[self.cat_column] in IntegerTypes: + raise ValueError("input must be an integer column") self.reduction.validate(in_dshape) def out_dshape(self, input_dshape): @@ -147,7 +148,7 @@ def finalize(bases, cuda=False, **kwargs): return finalize class by_span(by_integers): - """Like datashader.by, but for float-valued columns.""" + """Like datashader.by, but for float-valued columns, discretized by steps over a certain span""" def __init__(self, cat_column, reduction, offset, delta, nsteps): super().__init__(cat_column, reduction, nsteps+1) # allocate extra category for NaNs self.offset = offset @@ -155,7 +156,20 @@ def __init__(self, cat_column, reduction, offset, delta, nsteps): self.nsteps = nsteps def _build_temps(self, cuda=False): - return tuple(by_span(self.cat_column, tmp, self.offset, self.delta, self.nsteps) for tmp in self.reduction._build_temps(cuda)) + return tuple(by_span(self.cat_column, tmp, self.offset, self.delta, self.nsteps) + for tmp in self.reduction._build_temps(cuda)) + + def validate(self, in_dshape): + if not self.cat_column in in_dshape.dict: + raise ValueError("specified column not found") + self.reduction.validate(in_dshape) + + @property + def inputs(self): + if self.val_column is not None: + return (float_discretization_values(self.columns, self.offset, self.delta, self.nsteps),) + else: + return (float_discretization(self.columns[0], self.offset, self.delta, self.nsteps),) def _build_bases(self, cuda=False): bases = self.reduction._build_bases(cuda) @@ -163,30 +177,3 @@ def _build_bases(self, cuda=False): return bases return tuple(by_span(self.cat_column, base, self.offset, self.delta, self.nsteps) for base in bases) - def _build_append(self, dshape, schema, cuda=False): - f = self.reduction._build_append(dshape, schema, cuda) - # because we transposed, we also need to flip the - # order of the x/y arguments - if isinstance(self.reduction, datashader.reductions.m2): - def _categorical_append(x, y, agg, cols, tmp1, tmp2, minval=self.offset, d=self.delta, n=self.nsteps): - _agg = agg.transpose() - value = cols[0] - _ind = min(max(0, int((value - minval)/d)), n-1) if value == value else n # value != itself when NaN - #print("a", _ind, cols[0]) - f(y, x, _agg[_ind], cols[1], tmp1[_ind], tmp2[_ind]) - elif self.val_column is not None: - def _categorical_append(x, y, agg, field, minval=self.offset, d=self.delta, n=self.nsteps): - _agg = agg.transpose() - value = field[0] - _ind = min(max(0, int((value - minval)/d)), n-1) if value == value else n # value != itself when NaN - #print("b", _ind, field[0]) - f(y, x, _agg[_ind], field[1]) - else: - def _categorical_append(x, y, agg, field, minval=self.offset, d=self.delta, n=self.nsteps): - _agg = agg.transpose() - value = field - _ind = min(max(0, int((value - minval)/d)), n-1) if value == value else n # value != itself when NaN - #print("c", _ind, field, field != field) - f(y, x, _agg[_ind]) - - return ngjit(_categorical_append) From ba32c8ce21f7fd80a77d5558e16f35260f753be1 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Fri, 29 May 2020 16:39:15 +0200 Subject: [PATCH 03/33] updated for DS 0.11 --- shade_ms/data_plots.py | 48 ++++++++++++++++++++++++------------------ shade_ms/ds_ext.py | 2 +- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 778524b..a4711e7 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -249,6 +249,14 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor canvas = datashader.Canvas(options.xcanvas, options.ycanvas, x_range=bounds[xaxis], y_range=bounds[yaxis]) + + if aaxis is not None: + agg_alpha = getattr(datashader.reductions, ared, None) + if agg_alpha is None: + raise ValueError(f"unknown alpha reduction function {ared}") + agg_alpha = agg_alpha(aaxis) + ared = ared or 'count' + if aaxis is not None: agg_alpha = getattr(datashader.reductions, ared, None) if agg_alpha is None: @@ -292,26 +300,26 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor log.info(": no valid data in plot. Check your flags and/or plot limits.") return None - # work around https://github.com/holoviz/datashader/issues/899 - # Basically, 0 is treated as a nan and masked out in _colorize(), which is not correct for float reductions. - # Also, _colorize() does not normalize the totals somehow. - if np.issubdtype(raster.dtype, np.bool_): - pass - elif np.issubdtype(raster.dtype, np.integer): - ## TODO: unfinished business here - ## normalizing the raster bleaches out all colours again (fucks with log scaling, I guess?) - # int values: simply normalize to max total 1. Null values will be masked - # raster = raster.astype(np.float32) / raster.sum(axis=2).max() - pass - else: - # float values: first rescale raster to [0.001, 1]. Not 0, because 0 is masked out in _colorize() - maxval = np.nanmax(raster) - offset = np.nanmin(raster) - raster = .001 + .999*(raster - offset)/(maxval - offset) - # replace NaNs with zeroes (because when we take the total, and 1 channel is present while others are missing...) - raster.data[np.isnan(raster.data)] = 0 - # now rescale so that max total is 1 - raster /= raster.sum(axis=2).max() + # # work around https://github.com/holoviz/datashader/issues/899 + # # Basically, 0 is treated as a nan and masked out in _colorize(), which is not correct for float reductions. + # # Also, _colorize() does not normalize the totals somehow. + # if np.issubdtype(raster.dtype, np.bool_): + # pass + # elif np.issubdtype(raster.dtype, np.integer): + # ## TODO: unfinished business here + # ## normalizing the raster bleaches out all colours again (fucks with log scaling, I guess?) + # # int values: simply normalize to max total 1. Null values will be masked + # # raster = raster.astype(np.float32) / raster.sum(axis=2).max() + # pass + # else: + # # float values: first rescale raster to [0.001, 1]. Not 0, because 0 is masked out in _colorize() + # maxval = np.nanmax(raster) + # offset = np.nanmin(raster) + # raster = .001 + .999*(raster - offset)/(maxval - offset) + # # replace NaNs with zeroes (because when we take the total, and 1 channel is present while others are missing...) + # raster.data[np.isnan(raster.data)] = 0 + # # now rescale so that max total is 1 + # raster /= raster.sum(axis=2).max() if cdatum.is_discrete: # discard empty bins diff --git a/shade_ms/ds_ext.py b/shade_ms/ds_ext.py index fe7bd09..f293bcd 100644 --- a/shade_ms/ds_ext.py +++ b/shade_ms/ds_ext.py @@ -51,7 +51,7 @@ def __init__(self, column, modulo): self.modulo = modulo def apply(self, df): - return _column_discretization(df, self.column, self.modulo) + return _column_modulo(df, self.column, self.modulo) class float_discretization(integer_modulo): From 32c325279558235a0096fbce4f62dd75912b9926 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Sat, 30 May 2020 18:24:06 +0200 Subject: [PATCH 04/33] using proposed DS reduction changes --- shade_ms/data_plots.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index a4711e7..3b5d8ad 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -13,6 +13,7 @@ import holoviews.operation.datashader import datashader.transfer_functions import datashader.reductions +from datashader.reductions import category_modulo, category_binning import numpy as np import pylab import textwrap @@ -23,7 +24,7 @@ from collections import OrderedDict from . import data_mappers from .data_mappers import DataAxis -from .ds_ext import by_integers, by_span +# from .ds_ext import by_integers, by_span USE_REDUCE_BY = False @@ -274,28 +275,32 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor if data_mappers.USE_COUNT_CAT: color_bins = [int(x) for x in getattr(ddf.dtypes, caxis).categories] - log.debug(f'colourizing with count_cat, {len(color_bins)} bins') - agg = datashader.by(caxis, agg_by) + log.debug(f'colourizing using {caxis} categorical, {len(color_bins)} bins') + category = caxis else: color_bins = list(range(cdatum.nlevels)) if cdatum.is_discrete: - log.debug(f'colourizing with by_integers, {len(color_bins)} bins') - agg = by_integers(caxis, agg_by, cdatum.nlevels) + log.debug(f'colourizing using {caxis} modulo {len(color_bins)}') + category = category_modulo(caxis, cdatum.nlevels) else: - log.debug(f'colourizing with by_span, {len(color_bins)} bins') + log.debug(f'colourizing using {caxis} with {len(color_bins)} bins') cmin = bounds[caxis][0] cdelta = (bounds[caxis][1] - cmin) / cdatum.nlevels - agg = by_span(caxis, agg_by, cmin, cdelta, cdatum.nlevels) + category = category_binning(caxis, cmin, cdelta, cdatum.nlevels) - raster = canvas.points(ddf, xaxis, yaxis, agg=agg) + raster = canvas.points(ddf, xaxis, yaxis, agg=datashader.by(category, agg_by)) + is_integer_raster = np.issubdtype(raster.dtype, np.integer) # the by-span aggregator accumulates flagged points in an extra raster plane - if type(agg) is by_span: - flag_raster = raster[..., -1] + if isinstance(category, category_binning): + if is_integer_raster: + log.info(f": {raster[..., -1].data.sum():.3g} points were flagged ") raster = raster[...,:-1] - log.info(f": {flag_raster.data.sum():.3g} points were flagged ") - non_empty = np.array(raster.any(axis=(0, 1))) + if is_integer_raster: + non_empty = np.array(raster.any(axis=(0, 1))) + else: + non_empty = ~(np.isnan(raster.data).all(axis=(0, 1))) if not non_empty.any(): log.info(": no valid data in plot. Check your flags and/or plot limits.") return None From 680678f5477f302306b3f27a3563ad324cf27191 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Mon, 1 Jun 2020 17:35:40 +0200 Subject: [PATCH 05/33] added df-factory code --- shade_ms/dask_utils.py | 99 +++++++++++++++++++++++++++++++++++++++ shade_ms/data_mappers.py | 4 +- shade_ms/data_plots.py | 64 +++++++++---------------- test-dataframe-factory.py | 23 +++++++++ 4 files changed, 147 insertions(+), 43 deletions(-) create mode 100644 test-dataframe-factory.py diff --git a/shade_ms/dask_utils.py b/shade_ms/dask_utils.py index b25483e..93c35fd 100644 --- a/shade_ms/dask_utils.py +++ b/shade_ms/dask_utils.py @@ -143,3 +143,102 @@ def dataframe_factory(out_ind, x, x_ind, y, y_ind, columns=None): # Create the actual Dataframe return dd.DataFrame(graph, name, meta=meta, divisions=divisions) + + +def multicol_dataframe_factory(out_ind, arrays, array_dims): + """ + Creates a dask Dataframe by broadcasting arrays (given by the arrays dict-like object) + against each other and then ravelling them. The array_indices mapping specifies which indices + the arrays have + + .. code-block:: python + + df = dataframe_factory(("row", "chan"), {'x': x, 'y': y}, {x: ("row",), y: ("chan",)}) + + Parameters + ---------- + out_ind : sequence + Output dimensions. + e.g. :code:`(row, chan)` + """ + columns = list(arrays.keys()) + + have_nan_chunks = None + expand = {} + barr = {} + # build up list of arguments for blockwise call below + blockwise_args = [np.broadcast_arrays, out_ind] + + for col, arr in arrays.items(): + if col not in array_dims: + raise ValueError(f"{col} dimensions not specified") + arr_ind = array_dims[col] + if not all(i in out_ind for i in arr_ind): + raise ValueError(f"{col} dimensions not in out_ind") + if not len(arr_ind) == arr.ndim: + raise ValueError(f"len({col}_ind) != {col}.ndim") + have_nan_chunks = have_nan_chunks or any(np.isnan(c) for dc in arr.chunks for c in dc) + + # Generate slicing tuples that will expand arr up to the full + # resolution + expand[col] = tuple(slice(None) if i in arr_ind else None for i in out_ind) + # broadcast vesion of array + barr[col] = arr[expand[col]] + + blockwise_args += [barr[col], out_ind] + + # Create meta data so that blockwise doesn't call + # np.broadcast_arrays and fall over on the tuple + # of arrays that it returns + dtype = np.result_type(*arrays.values()) + meta = np.empty((0,) * len(out_ind), dtype=dtype) + + bcast = da.blockwise(*blockwise_args, + align_arrays=not have_nan_chunks, + meta=meta, + dtype=dtype) + + # Now create a dataframe from the broadcasted arrays + # with lower-level dask graph API + + # Flattened list of broadcast array keys + # We'll use this to generate a 1D (ravelled) dataframe + keys = product((bcast.name,), *(range(b) for b in bcast.numblocks)) + name = "dataframe-" + tokenize(bcast) + + # dictionary defining the graph for this part of the operation + layers = {} + + if have_nan_chunks: + # We can't create proper indices if we don't known our chunk sizes + divisions = [None] + + for i, key in enumerate(keys): + layers[(name, i)] = (_create_dataframe, key, None, None) + divisions.append(None) + else: + # We do know all our chunk sizes, create reasonable dataframe indices + start_idx = 0 + divisions = [0] + + expr = ((e - s for s, e in start_ends(dim_chunks)) + for dim_chunks in bcast.chunks) + chunk_sizes = (reduce(mul, shape, 1) for shape in product(*expr)) + chunk_ranges = start_ends(chunk_sizes) + + for i, (key, (start, end)) in enumerate(zip(keys, chunk_ranges)): + layers[(name, i)] = (_create_dataframe, key, start, end) + start_idx += end - start + divisions.append(start_idx) + + assert len(layers) == bcast.npartitions + assert len(divisions) == bcast.npartitions + 1 + + # Create the HighLevelGraph + graph = HighLevelGraph.from_collections(name, layers, [bcast]) + # Metadata representing the broadcasted and ravelled data + meta = pd.DataFrame(data={col: np.empty((0,), dtype=arr.dtype) for col, arr in arrays.items()}, + columns=columns) + + # Create the actual Dataframe + return dd.DataFrame(graph, name, meta=meta, divisions=divisions) diff --git a/shade_ms/data_mappers.py b/shade_ms/data_mappers.py index 88e9f72..17912b2 100644 --- a/shade_ms/data_mappers.py +++ b/shade_ms/data_mappers.py @@ -282,8 +282,8 @@ def get_value(self, group, corr, extras, flag, flag_row, chanslice): coldata = mapper.mapper(coldata, **{name:extras[name] for name in self.mapper.extras }) # scalar expanded to row vector if np.isscalar(coldata): - coldata = da.full_like(flag_row, fill_value=coldata, dtype=type(coldata)) - flag = flag_row + coldata = da.array(coldata, shape=()) + flag = None else: # apply channel slicing, if there's a channel axis in the array (and the array is a DataArray) if type(coldata) is xarray.DataArray and 'chan' in coldata.dims: diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 3b5d8ad..285d0e7 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -24,6 +24,7 @@ from collections import OrderedDict from . import data_mappers from .data_mappers import DataAxis +from .dask_utils import multicol_dataframe_factory # from .ds_ext import by_integers, by_span USE_REDUCE_BY = False @@ -111,14 +112,17 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, nchan = flag.shape[1] shape = (len(group.row), nchan) - datums = OrderedDict() + arrays = OrderedDict() + shapes = OrderedDict() + ddf = None + num_points = 0 # counts number of new points generated for corr in subset.corr.numbers: # make dictionary of extra values for DataMappers extras['corr'] = corr # loop over datums to be computed for axis in DataAxis.all_axes.values(): - value = datums[axis.label][-1] if axis.label in datums else None + value = arrays.get(axis.label) # a datum was already computed? if value is not None: # if not joining correlations, then that's the only one we'll need, so continue @@ -129,50 +133,29 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, value = None if value is None: value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice) - # reshape values of shape NTIME to (NTIME,1) and NFREQ to (1,NFREQ), and scalar to (NTIME,1) - if value.ndim == 1: + num_points = max(num_points, value.size) + if value.ndim == 0: + shapes[axis.label] = () + elif value.ndim == 1: timefreq_axis = axis.mapper.axis or 0 assert value.shape[0] == shape[timefreq_axis], \ f"{axis.mapper.fullname}: size {value.shape[0]}, expected {shape[timefreq_axis]}" - shape1 = [1,1] - shape1[timefreq_axis] = value.shape[0] - value = value.reshape(shape1) - if timefreq_axis > 0: - value = da.broadcast_to(value, shape) - log.debug(f"axis {axis.mapper.fullname} has shape {value.shape}") + shapes[axis.label] = ("row",) if timefreq_axis == 0 else ("chan",) # else 2D value better match expected shape else: assert value.shape == shape, f"{axis.mapper.fullname}: shape {value.shape}, expected {shape}" - datums.setdefault(axis.label, []).append(value) - - # if joining correlations, stick all elements together. Otherwise, we'd better have one per label - if join_corrs: - datums = OrderedDict({label: da.concatenate(arrs) for label, arrs in datums.items()}) - else: - assert all([len(arrs) == 1 for arrs in datums.values()]) - datums = OrderedDict({label: arrs[0] for label, arrs in datums.items()}) - - # broadcast to same shape, and unravel all datums - datums = OrderedDict({ key: arr.ravel() for key, arr in zip(datums.keys(), - da.broadcast_arrays(*datums.values()))}) - - # if any axis needs to be conjugated, double up all of them - if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]): - for axis in DataAxis.all_axes.values(): - if axis.conjugate: - datums[axis.label] = da.concatenate([datums[axis.label], -datums[axis.label]]) - else: - datums[axis.label] = da.concatenate([datums[axis.label], datums[axis.label]]) - - labels, values = list(datums.keys()), list(datums.values()) - total_num_points += values[0].size - - # now stack them all into a big dataframe - rectype = [(axis.label, np.int32 if axis.nlevels else np.float32) for axis in DataAxis.all_axes.values()] - recarr = da.empty_like(values[0], dtype=rectype) - ddf = dask_df.from_array(recarr) - for label, value in zip(labels, values): - ddf[label] = value + shapes[axis.label] = ("row", "chan") + arrays[axis.label] = value + # any new data generated for this correlation? Make dataframe + if num_points: + total_num_points += num_points + df1 = multicol_dataframe_factory(("row", "chan"), arrays, shapes) + # if any axis needs to be conjugated, double up all of them + if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]): + conj_arrays = {axis.label: -arrays[axis.label] if axis.conjugate else arrays[axis.label] + for axis in DataAxis.all_axes.values()} + df1 = df1.append(multicol_dataframe_factory(("row", "chan"), conj_arrays, shapes)) + ddf = ddf.append(df1) if ddf is not None else df1 # now, are we iterating or concatenating? Make frame key accordingly dataframe_key = (fld if iter_field else None, @@ -250,7 +233,6 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor canvas = datashader.Canvas(options.xcanvas, options.ycanvas, x_range=bounds[xaxis], y_range=bounds[yaxis]) - if aaxis is not None: agg_alpha = getattr(datashader.reductions, ared, None) if agg_alpha is None: diff --git a/test-dataframe-factory.py b/test-dataframe-factory.py new file mode 100644 index 0000000..ed13f93 --- /dev/null +++ b/test-dataframe-factory.py @@ -0,0 +1,23 @@ +import dask.array as da +from shade_ms.dask_utils import dataframe_factory, multicol_dataframe_factory + + +nrow, nfreq, ncorr = 100, 100, 4 + +data1a = da.arange(nrow, chunks=(10,)) + +data1b = da.zeros(dtype=float, shape=(nfreq,), chunks=(100,)) + +data1c = da.zeros(dtype=float, shape=(nfreq,ncorr), chunks=(100,4)) + +data1d = da.zeros(dtype=float, shape=()) + +df = dataframe_factory(("row", "chan"), + data1a, ("row",), + data1b, ("chan",)) + +df1 = multicol_dataframe_factory(("row", "chan", "corr"), + dict(a=data1a, b=data1b, x=data1c, y=data1d), + dict(a=("row",), b=("chan",), x=("chan", "corr"), y=())) + +print(df1['y']) \ No newline at end of file From 3882900f01effde1e9a8c5a3046d3910cb3b808b Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Mon, 1 Jun 2020 19:35:45 +0200 Subject: [PATCH 06/33] Support multicolumn in original dataframe_factory (#59) --- shade_ms/dask_utils.py | 79 +++++++++++++++---------------- shade_ms/tests/test_dask_utils.py | 27 +++++++++++ 2 files changed, 65 insertions(+), 41 deletions(-) diff --git a/shade_ms/dask_utils.py b/shade_ms/dask_utils.py index 93c35fd..798bf1a 100644 --- a/shade_ms/dask_utils.py +++ b/shade_ms/dask_utils.py @@ -19,17 +19,15 @@ def start_ends(chunks): s = e -def _create_dataframe(arrays, start, end): +def _create_dataframe(arrays, start, end, columns): index = None if start is None else np.arange(start, end) - return pd.DataFrame({'x': arrays[0].ravel(), - 'y': arrays[1].ravel()}, + return pd.DataFrame({k: a.ravel() for k, a in zip(columns, arrays)}, index=index) - -def dataframe_factory(out_ind, x, x_ind, y, y_ind, columns=None): +def dataframe_factory(out_ind, *arginds, columns=None): """ - Creates a dask Dataframe by broadcasting the ``x`` and ``y`` arrays + Creates a dask Dataframe by broadcasting *arginds against each other and then ravelling them. .. code-block:: python @@ -43,57 +41,55 @@ def dataframe_factory(out_ind, x, x_ind, y, y_ind, columns=None): out_ind : sequence Output dimensions. e.g. :code:`(row, chan)` - x : :class:`dask.array.Array` - x data - x_ind : sequence - x dimensions. e.g. :code:`(row,)` - y : :class:`dask.array.Array` - y data - y_ind : sequence - y dimensions. e.g. :code:(row,)` + *arginds : Sequence of (:class:`dask.array.Array`, index) + document me columns : sequence, optional Dataframe column names. Defaults to :code:`[x, y]` """ + if not len(arginds) % 2 == 0: + raise ValueError("Must supply an index for each argument") + + args = arginds[::2] + inds = arginds[1::2] + if columns is None: - columns = ['x', 'y'] + columns = ['x', 'y'] + ["c%d" % i for i in range(len(args) - 2)] else: - if not isinstance(columns, (tuple, list)) and len(columns) != 2: - raise ValueError("Columns must be a tuple/list " - "of two column names") + if (not isinstance(columns, (tuple, list)) and + len(columns) != len(args)): - if not all(i in out_ind for i in x_ind): - raise ValueError("x_ind dimensions not in out_ind") + raise ValueError("Columns must be a tuple/list of columns " + "matching the number of arrays") - if not all(i in out_ind for i in y_ind): - raise ValueError("y_ind dimensions not in out_ind") + have_nan_chunks = False - if not len(x_ind) == x.ndim: - raise ValueError("len(x_ind) != x.ndim") + new_args = [] - if not len(y_ind) == y.ndim: - raise ValueError("len(y_ind) != y.ndim") + for a, (arg, ind) in enumerate(zip(args, inds)): + if not all(i in out_ind for i in ind): + raise ValueError("Argument %d dimensions not in out_ind" % a) - have_nan_chunks = (any(np.isnan(c) for dc in x.chunks for c in dc) or - any(np.isnan(c) for dc in y.chunks for c in dc)) + if not len(ind) == arg.ndim: + raise ValueError("Argument %d len(ind) != arg.ndim" % a) - # Generate slicing tuples that will expand x and y up to the full - # resolution - expand_x = tuple(slice(None) if i in x_ind else None for i in out_ind) - expand_y = tuple(slice(None) if i in y_ind else None for i in out_ind) + have_nan_chunks = (any(np.isnan(c) for dc in arg.chunks for c in dc) or + have_nan_chunks) - bx = x[expand_x] - by = y[expand_y] + # Generate slicing tuple that will expand arg up to full resolution + expand = tuple(slice(None) if i in ind else None for i in out_ind) + new_args.append(arg[expand]) # Create meta data so that blockwise doesn't call # np.broadcast_arrays and fall over on the tuple # of arrays that it returns - dtype = np.result_type(x, y) + dtype = np.result_type(*args) meta = np.empty((0,) * len(out_ind), dtype=dtype) + blockargs = (v for pair in ((a, out_ind) for a in new_args) for v in pair) + bcast = da.blockwise(np.broadcast_arrays, out_ind, - bx, out_ind, - by, out_ind, + *blockargs, align_arrays=not have_nan_chunks, meta=meta, dtype=dtype) @@ -114,7 +110,7 @@ def dataframe_factory(out_ind, x, x_ind, y, y_ind, columns=None): divisions = [None] for i, key in enumerate(keys): - layers[(name, i)] = (_create_dataframe, key, None, None) + layers[(name, i)] = (_create_dataframe, key, None, None, columns) divisions.append(None) else: # We do know all our chunk sizes, create reasonable dataframe indices @@ -127,7 +123,7 @@ def dataframe_factory(out_ind, x, x_ind, y, y_ind, columns=None): chunk_ranges = start_ends(chunk_sizes) for i, (key, (start, end)) in enumerate(zip(keys, chunk_ranges)): - layers[(name, i)] = (_create_dataframe, key, start, end) + layers[(name, i)] = (_create_dataframe, key, start, end, columns) start_idx += end - start divisions.append(start_idx) @@ -137,8 +133,9 @@ def dataframe_factory(out_ind, x, x_ind, y, y_ind, columns=None): # Create the HighLevelGraph graph = HighLevelGraph.from_collections(name, layers, [bcast]) # Metadata representing the broadcasted and ravelled data - meta = pd.DataFrame(data={'x': np.empty((0,), dtype=x.dtype), - 'y': np.empty((0,), dtype=y.dtype)}, + + meta = pd.DataFrame(data={k: np.empty((0,), dtype=a.dtype) + for k, a in zip(columns, args)}, columns=columns) # Create the actual Dataframe diff --git a/shade_ms/tests/test_dask_utils.py b/shade_ms/tests/test_dask_utils.py index 1c7af81..cc426cb 100644 --- a/shade_ms/tests/test_dask_utils.py +++ b/shade_ms/tests/test_dask_utils.py @@ -53,3 +53,30 @@ def test_dataframe_factory(test_nan_shapes): # Compare our lazy dataframe series vs (dask or numpy) arrays assert_array_equal(df['x'], x) assert_array_equal(df['y'], y) + + +def test_dataframe_factory_multicol(): + nrow, nfreq, ncorr = 100, 100, 4 + + data1a = da.arange(nrow, chunks=(10,)) + data1b = da.zeros(dtype=float, shape=(nfreq, ncorr), chunks=(100, 4)) + data1c = da.ones(dtype=np.int32, shape=(ncorr,), chunks=(4,)) + + df = dataframe_factory(("row", "chan", "corr"), + data1a, ("row",), + data1b, ("chan", "corr"), + data1c, ("corr",)) + + assert isinstance(df, dd.DataFrame) + assert isinstance(df['x'], dd.Series) + assert isinstance(df['y'], dd.Series) + assert isinstance(df['c0'], dd.Series) + + + x, y, c0 = da.broadcast_arrays(data1a[:, None, None], + data1b[None, :, :], + data1c[None, None, :]) + + assert_array_equal(df['x'], x.ravel()) + assert_array_equal(df['y'], y.ravel()) + assert_array_equal(df['c0'], c0.ravel()) \ No newline at end of file From 396f87e6310d07b168afce1ad1b5a04ef1677437 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Mon, 1 Jun 2020 19:50:38 +0200 Subject: [PATCH 07/33] reverted to @sjperkins pattern for multicol DF --- shade_ms/data_plots.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 285d0e7..608cba0 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -24,7 +24,7 @@ from collections import OrderedDict from . import data_mappers from .data_mappers import DataAxis -from .dask_utils import multicol_dataframe_factory +from .dask_utils import dataframe_factory # from .ds_ext import by_integers, by_span USE_REDUCE_BY = False @@ -149,12 +149,13 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, # any new data generated for this correlation? Make dataframe if num_points: total_num_points += num_points - df1 = multicol_dataframe_factory(("row", "chan"), arrays, shapes) + args = (v for pair in ((array, shapes[key]) for key, array in arrays.items()) for v in pair) + df1 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys()) # if any axis needs to be conjugated, double up all of them if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]): - conj_arrays = {axis.label: -arrays[axis.label] if axis.conjugate else arrays[axis.label] - for axis in DataAxis.all_axes.values()} - df1 = df1.append(multicol_dataframe_factory(("row", "chan"), conj_arrays, shapes)) + args = (v for pair in ((-array if DataAxis[key].conjugate else array, shapes[key]) + for key, array in arrays.items()) for v in pair) + df1 = df1.append(dataframe_factory(("row", "chan"), *args, columns=arrays.keys())) ddf = ddf.append(df1) if ddf is not None else df1 # now, are we iterating or concatenating? Make frame key accordingly From 1d01d78e87f851510453ed2619c8b296748b0894 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Tue, 2 Jun 2020 11:28:40 +0200 Subject: [PATCH 08/33] added debug print --- shade_ms/data_mappers.py | 2 +- shade_ms/data_plots.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/shade_ms/data_mappers.py b/shade_ms/data_mappers.py index 17912b2..3f8bff9 100644 --- a/shade_ms/data_mappers.py +++ b/shade_ms/data_mappers.py @@ -282,7 +282,7 @@ def get_value(self, group, corr, extras, flag, flag_row, chanslice): coldata = mapper.mapper(coldata, **{name:extras[name] for name in self.mapper.extras }) # scalar expanded to row vector if np.isscalar(coldata): - coldata = da.array(coldata, shape=()) + coldata = da.array(coldata) flag = None else: # apply channel slicing, if there's a channel axis in the array (and the array is a DataArray) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 608cba0..9edf822 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -133,6 +133,7 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, value = None if value is None: value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice) + print(axis.label, value.compute().min(), value.compute().max()) num_points = max(num_points, value.size) if value.ndim == 0: shapes[axis.label] = () @@ -182,6 +183,12 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, for key, ddf in list(output_dataframes.items()): output_dataframes[key] = ddf.categorize(categorical_axes) + print("===") + for ddf in output_dataframes.values(): + for axis in DataAxis.all_axes.values(): + value = ddf[axis.label].values.compute() + print(axis.label, value.min(), value.max()) + log.info(": complete") return output_dataframes, total_num_points From cacabcbd26e78bb6a2af677694e65e5b6fc4fe4a Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Tue, 2 Jun 2020 12:50:39 +0200 Subject: [PATCH 09/33] replaced dataframe append with concat, to not much avail --- shade_ms/data_plots.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 9edf822..ae98dae 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -156,8 +156,9 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]): args = (v for pair in ((-array if DataAxis[key].conjugate else array, shapes[key]) for key, array in arrays.items()) for v in pair) - df1 = df1.append(dataframe_factory(("row", "chan"), *args, columns=arrays.keys())) - ddf = ddf.append(df1) if ddf is not None else df1 + df2 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys()) + df1 = dask_df.concat([df1, df2], axis=0) + ddf = dask_df.concat([ddf, df1], axis=0) if ddf is not None else df1 # now, are we iterating or concatenating? Make frame key accordingly dataframe_key = (fld if iter_field else None, @@ -173,7 +174,7 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, output_dataframes[dataframe_key] = ddf else: log.debug(f"appending to frame for {dataframe_key}") - output_dataframes[dataframe_key] = ddf0.append(ddf) + output_dataframes[dataframe_key] = dask_df.concat([ddf0, ddf], axis=0) # convert discrete axes into categoricals if data_mappers.USE_COUNT_CAT: From 01083269a30b40dd3395ee157ff6a481d4911234 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Tue, 2 Jun 2020 13:36:24 +0200 Subject: [PATCH 10/33] Call broadcast_array with subok=True (#60) * Make test data random * Add .gitignore * Add missing __init__.py * Test dataframe min/max * Test min/max with append * Pass sub-classes through to np.broadcast_array --- .gitignore | 138 ++++++++++++++++++++++++++++++ shade_ms/dask_utils.py | 1 + shade_ms/tests/__init__.py | 0 shade_ms/tests/test_dask_utils.py | 30 +++++-- 4 files changed, 164 insertions(+), 5 deletions(-) create mode 100644 .gitignore create mode 100644 shade_ms/tests/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5391d87 --- /dev/null +++ b/.gitignore @@ -0,0 +1,138 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ \ No newline at end of file diff --git a/shade_ms/dask_utils.py b/shade_ms/dask_utils.py index 798bf1a..9b5bb2d 100644 --- a/shade_ms/dask_utils.py +++ b/shade_ms/dask_utils.py @@ -90,6 +90,7 @@ def dataframe_factory(out_ind, *arginds, columns=None): bcast = da.blockwise(np.broadcast_arrays, out_ind, *blockargs, + subok=True, align_arrays=not have_nan_chunks, meta=meta, dtype=dtype) diff --git a/shade_ms/tests/__init__.py b/shade_ms/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shade_ms/tests/test_dask_utils.py b/shade_ms/tests/test_dask_utils.py index cc426cb..1ba44fa 100644 --- a/shade_ms/tests/test_dask_utils.py +++ b/shade_ms/tests/test_dask_utils.py @@ -18,7 +18,7 @@ def test_dataframe_factory(test_nan_shapes): if test_nan_shapes: data1a = data1a[da.where(data1a > 4)] - data1b = da.zeros(dtype=float, shape=(nfreq,), chunks=(100,)) + data1b = da.random.random(size=(nfreq,), chunks=(100,)) df = dataframe_factory(("row", "chan"), data1a, ("row",), @@ -53,14 +53,18 @@ def test_dataframe_factory(test_nan_shapes): # Compare our lazy dataframe series vs (dask or numpy) arrays assert_array_equal(df['x'], x) assert_array_equal(df['y'], y) + assert_array_equal(df['x'].min(), data1a.min()) + assert_array_equal(df['y'].min(), data1b.min()) + assert_array_equal(df['x'].max(), data1a.max()) + assert_array_equal(df['y'].max(), data1b.max()) def test_dataframe_factory_multicol(): nrow, nfreq, ncorr = 100, 100, 4 - data1a = da.arange(nrow, chunks=(10,)) - data1b = da.zeros(dtype=float, shape=(nfreq, ncorr), chunks=(100, 4)) - data1c = da.ones(dtype=np.int32, shape=(ncorr,), chunks=(4,)) + data1a = da.random.random(size=nrow, chunks=(10,)) + data1b = da.random.random(size=(nfreq, ncorr), chunks=(100, 4)) + data1c = da.random.random(size=(ncorr,), chunks=(4,)) df = dataframe_factory(("row", "chan", "corr"), data1a, ("row",), @@ -79,4 +83,20 @@ def test_dataframe_factory_multicol(): assert_array_equal(df['x'], x.ravel()) assert_array_equal(df['y'], y.ravel()) - assert_array_equal(df['c0'], c0.ravel()) \ No newline at end of file + assert_array_equal(df['c0'], c0.ravel()) + + assert_array_equal(df['x'].min(), data1a.min()) + assert_array_equal(df['x'].max(), data1a.max()) + assert_array_equal(df['y'].min(), data1b.min()) + assert_array_equal(df['y'].max(), data1b.max()) + assert_array_equal(df['c0'].min(), data1c.min()) + assert_array_equal(df['c0'].max(), data1c.max()) + + df = df.append(df) + + assert_array_equal(df['x'].min(), data1a.min()) + assert_array_equal(df['x'].max(), data1a.max()) + assert_array_equal(df['y'].min(), data1b.min()) + assert_array_equal(df['y'].max(), data1b.max()) + assert_array_equal(df['c0'].min(), data1c.min()) + assert_array_equal(df['c0'].max(), data1c.max()) From d969ebdd22cbc343b811875b9b7b7ba55dd26f2b Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Tue, 2 Jun 2020 17:14:31 +0200 Subject: [PATCH 11/33] make baselines dicrete again --- shade_ms/data_plots.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index ae98dae..7e5e614 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -99,7 +99,7 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, a1 = da.minimum(group.ANTENNA1.data, group.ANTENNA2.data) a2 = da.maximum(group.ANTENNA1.data, group.ANTENNA2.data) - baselines = a1*len(msinfo.antenna) - a1*(a1-1)//2. + a2 + baselines = a1*len(msinfo.antenna) - a1*(a1-1)//2 + a2 freqs = chan_freqs[ddid] chans = xarray.DataArray(range(len(freqs)), dims=("chan",)) @@ -133,7 +133,7 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, value = None if value is None: value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice) - print(axis.label, value.compute().min(), value.compute().max()) + # print(axis.label, value.compute().min(), value.compute().max()) num_points = max(num_points, value.size) if value.ndim == 0: shapes[axis.label] = () @@ -184,11 +184,11 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, for key, ddf in list(output_dataframes.items()): output_dataframes[key] = ddf.categorize(categorical_axes) - print("===") - for ddf in output_dataframes.values(): - for axis in DataAxis.all_axes.values(): - value = ddf[axis.label].values.compute() - print(axis.label, value.min(), value.max()) + # print("===") + # for ddf in output_dataframes.values(): + # for axis in DataAxis.all_axes.values(): + # value = ddf[axis.label].values.compute() + # print(axis.label, np.nanmin(value), np.nanmax(value)) log.info(": complete") return output_dataframes, total_num_points @@ -282,7 +282,7 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor raster = canvas.points(ddf, xaxis, yaxis, agg=datashader.by(category, agg_by)) is_integer_raster = np.issubdtype(raster.dtype, np.integer) - # the by-span aggregator accumulates flagged points in an extra raster plane + # the binning aggregator accumulates flagged points in an extra raster plane if isinstance(category, category_binning): if is_integer_raster: log.info(f": {raster[..., -1].data.sum():.3g} points were flagged ") From f01baef6124a45f134783d03cce1f4a248902edc Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Tue, 2 Jun 2020 18:33:09 +0200 Subject: [PATCH 12/33] added --saturate-alpha and --saturate-perc options --- shade_ms/data_plots.py | 25 +++++++++++++++++++++++-- shade_ms/main.py | 11 +++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 7e5e614..cf6b690 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -217,6 +217,7 @@ def compute_bounds(unknowns, bounds, ddf): def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, normalize, xlabel, ylabel, title, pngname, + min_alpha=40, saturate_percentile=None, saturate_alpha=None, options=None): figx = options.xcanvas / 60 @@ -350,7 +351,7 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor raster[raster<0] = 0 raster[raster>1] = 1 log.info(f": adjusting alpha (alpha raster was {amin} to {amax})") - img = datashader.transfer_functions.shade(raster, color_key=color_key, how=normalize) + img = datashader.transfer_functions.shade(raster, color_key=color_key, how=normalize, min_alpha=min_alpha) else: log.debug(f'rasterizing using {ared}') raster = canvas.points(ddf, xaxis, yaxis, agg=agg_alpha) @@ -358,7 +359,27 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor log.info(": no valid data in plot. Check your flags and/or plot limits.") return None log.debug('shading') - img = datashader.transfer_functions.shade(raster, cmap=cmap, how=normalize) + img = datashader.transfer_functions.shade(raster, cmap=cmap, how=normalize, min_alpha=min_alpha) + + # resaturate if needed + if saturate_alpha is not None or saturate_percentile is not None: + # get alpha channel + imgval = img.values + alpha = (imgval >> 24)&255 + nulls = alpha255] = 255 + imgval[:] = (imgval & 0xFFFFFF) | alpha.astype(np.uint32)<<24 if options.spread_pix: img = datashader.transfer_functions.dynspread(img, options.spread_thr, max_px=options.spread_pix) diff --git a/shade_ms/main.py b/shade_ms/main.py index 46fc8c1..377d2d5 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -173,6 +173,14 @@ def main(argv): help='Colorcet map used when colouring by a continuous axis (default = %(default)s)') group_opts.add_argument('--dmap', default='glasbey_dark', help='Colorcet map used when colouring by a discrete axis (default = %(default)s)') + group_opts.add_argument('--min-alpha', default=40, type=int, metavar="0-255", + help="""Minimum alpha value used in rendering the canvas. Increase to saturate colour at + the expense of dynamic range. Default is %(default)s.""") + group_opts.add_argument('--saturate-perc', default=95, type=int, metavar="0-100", + help="""Saturate colors so that the range [min-alpha, X] is mapped to [min-alpha, 255], + where X is the given percentile. Default is %(default)s.""") + group_opts.add_argument('--saturate-alpha', default=None, type=int, metavar="0-255", + help="""Saturate colors as above, but with a fixed value of X. Overrides --saturate-perc.""") group_opts.add_argument('--spread-pix', type=int, default=0, metavar="PIX", help="""Dynamically spread rendered pixels to this size""") group_opts.add_argument('--spread-thr', type=float, default=0.5, metavar="THR", @@ -606,6 +614,9 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, with context() as profiler: result = data_plots.create_plot(df, xdatum, ydatum, adatum, ared, cdatum, cmap=cmap, bmap=bmap, dmap=dmap, normalize=normalize, + min_alpha=options.min_alpha, + saturate_alpha=options.saturate_alpha, + saturate_percentile=options.saturate_perc, xlabel=xlabel, ylabel=ylabel, title=title, pngname=pngname, options=options) if result: From 7663f44671eb74e623291064371487adce70c5b0 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Wed, 3 Jun 2020 19:24:44 +0200 Subject: [PATCH 13/33] added proper treatment of sunsets for discrete axes --- shade_ms/data_mappers.py | 77 +++++++++++++++++++++++------------- shade_ms/data_plots.py | 84 ++++++++++++++++++++++------------------ shade_ms/main.py | 12 ++++-- shade_ms/ms_info.py | 30 +++++++++----- 4 files changed, 126 insertions(+), 77 deletions(-) diff --git a/shade_ms/data_mappers.py b/shade_ms/data_mappers.py index 3f8bff9..7c44487 100644 --- a/shade_ms/data_mappers.py +++ b/shade_ms/data_mappers.py @@ -1,4 +1,5 @@ import dask.array as da +import dask.array.core import dask.array.ma as dama import xarray import numpy as np @@ -168,20 +169,22 @@ def __init__(self, column, function, corr, ms, minmax=None, ncol=None, label=Non self.function = function # function to apply to column (see list of DataMappers below) self.corr = corr if corr != "all" else None self.nlevels = ncol - self.minmax = vmin, vmax = tuple(minmax) if minmax is not None else (None, None) + self.minmax = tuple(minmax) if minmax is not None else (None, None) self.label = label self._corr_reduce = None self._is_discrete = None - self.discretized_labels = None # filled for corrs and fields and so + self.mapper = data_mappers[function] - # set up discretized continuous axis - if self.nlevels and vmin is not None and vmax is not None: - self.discretized_delta = delta = (vmax - vmin) / self.nlevels - self.discretized_bin_centers = np.arange(vmin + delta/2, vmax, delta) - else: - self.discretized_delta = self.discretized_bin_centers = None + # if set, axis is discrete and labelled + self.discretized_labels = None - self.mapper = data_mappers[function] + # for discrete axes: if a subset of N indices is explicitly selected for plotting, then this + # is a list of the selected indices, of length N + self.subset_indices = None + # ...and this is a dask array that maps selected indices into bins 0...N-1, and all other values into bin N + self.subset_remapper = None + # ...and this is the maximum valid index in MS + maxind = None # columns with labels? if function == 'CORR' or function == 'STOKES': @@ -198,14 +201,36 @@ def __init__(self, column, function, corr, ms, minmax=None, ncol=None, label=Non # we're creating one for a mapper that will iterate over correlations if corr is not None: self.mapper = DataMapper(name, "", column=False, axis=-1, mapper=lambda x: corr) - self.discretized_labels = subset.corr.names + self.subset_indices = subset.corr + maxind = ms.all_corr.numbers[-1] elif column == "FIELD_ID": - self.discretized_labels = [name for name in ms.field.names if name in subset.field] + self.subset_indices = subset.field + maxind = ms.field.numbers[-1] elif column == "ANTENNA1" or column == "ANTENNA2": - self.discretized_labels = [name for name in ms.all_antenna.names if name in subset.ant] + self.subset_indices = subset.ant + maxind = ms.antenna.numbers[-1] + elif column == "SCAN_NUMBER": + self.subset_indices = subset.scan + maxind = ms.scan.numbers[-1] + elif function == "BASELINE": + self.subset_indices = subset.baseline + maxind = ms.baseline.numbers[-1] elif column == "FLAG" or column == "FLAG_ROW": self.discretized_labels = ["F", "T"] + # make a remapper + if self.subset_indices is not None: + # If last index of subset is max index anyway, mapping is 1:1 -- no remapper needed + # Otherwise map indices in subset into their ordinal numbers in the subset (0...N-1), + # and all other indices to N + if len(self.subset_indices) < maxind+1: + remapper = np.full(maxind+1, len(self.subset_indices)) + for i, index in enumerate(self.subset_indices.numbers): + remapper[index] = i + self.subset_remapper = da.array(remapper) + self.discretized_labels = self.subset_indices.names + self.subset_indices = self.subset_indices.numbers + if self.discretized_labels: self._is_discrete = True @@ -280,7 +305,7 @@ def get_value(self, group, corr, extras, flag, flag_row, chanslice): if np.iscomplexobj(coldata) and mapper is data_mappers["_"]: mapper = data_mappers["amp"] coldata = mapper.mapper(coldata, **{name:extras[name] for name in self.mapper.extras }) - # scalar expanded to row vector + # scalar is just a scalar if np.isscalar(coldata): coldata = da.array(coldata) flag = None @@ -307,24 +332,24 @@ def get_value(self, group, corr, extras, flag, flag_row, chanslice): if self._is_discrete is False: raise TypeError(f"{self.label}: column changed from continuous-valued to discrete. This is a bug, or a very weird MS.") self._is_discrete = True + # do we need to apply a remapping? + if self.subset_remapper is not None: + if type(coldata) is not dask.array.core.Array: # could be xarray backed by dask array + coldata = coldata.data + coldata = self.subset_remapper[coldata] + bad_bins = da.greater_equal(coldata, len(self.subset_indices)) + if flag is None: + flag = bad_bins + else: + flag = da.logical_or(flag, bad_bins) else: if self._is_discrete is True: raise TypeError(f"{self.label}: column chnaged from discrete to continuous-valued. This is a bug, or a very weird MS.") self._is_discrete = False - # # minmax set? discretize over that - # if self.discretized_delta is not None: - # coldata = da.floor((coldata - self.minmax[0])/self.discretized_delta) - # coldata = da.minimum(da.maximum(coldata, 0), self.nlevels-1).astype(COUNT_DTYPE) - # else: - # if not coldata.dtype is bool: - # if not np.issubdtype(coldata.dtype, np.integer): - # raise TypeError(f"{self.name}: min/max must be set to colour by non-integer values") - # coldata = da.remainder(coldata, self.nlevels).astype(COUNT_DTYPE) - + bad_data = da.logical_not(da.isfinite(coldata)) if flag is not None: - flag |= ~da.isfinite(coldata) - return dama.masked_array(coldata, flag) + return dama.masked_array(coldata, da.logical_or(flag, bad_data)) else: - return dama.masked_array(coldata, ~da.isfinite(coldata)) + return dama.masked_array(coldata, bad_data) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index cf6b690..f99969e 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -18,6 +18,7 @@ import pylab import textwrap import argparse +import itertools import matplotlib.cm from shade_ms import log @@ -98,8 +99,7 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, a1 = da.minimum(group.ANTENNA1.data, group.ANTENNA2.data) a2 = da.maximum(group.ANTENNA1.data, group.ANTENNA2.data) - - baselines = a1*len(msinfo.antenna) - a1*(a1-1)//2 + a2 + baselines = msinfo.baseline_number(a1, a2) freqs = chan_freqs[ddid] chans = xarray.DataArray(range(len(freqs)), dims=("chan",)) @@ -231,16 +231,25 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor color_key = color_mapping = color_labels = agg_alpha = raster_alpha = cmin = cdelta = None + # do we need to compute any axis min/max? bounds = OrderedDict({xaxis: xdatum.minmax, yaxis: ydatum.minmax}) - if caxis: - bounds[caxis] = cdatum.minmax - - unknown = [axis for (axis, minmax) in bounds.items() if minmax[0] is None or minmax[1] is None] + unknown = [] + for datum in xdatum, ydatum, cdatum: + if datum is not None: + bounds[datum.label] = datum.minmax + if datum.minmax[0] is None or datum.minmax[1] is None: + unknown.append(datum.label) if unknown: log.info(f": scanning axis min/max for {' '.join(unknown)}") compute_bounds(unknown, bounds, ddf) + # adjust bounds for discrete axes + for datum in xdatum, ydatum: + if datum.is_discrete: + bounds[datum.label] = bounds[datum.label][0]-0.5, bounds[datum.label][0]+0.5 + + # create rendering canvas. TODO: https://github.com/ratt-ru/shadeMS/issues/42 canvas = datashader.Canvas(options.xcanvas, options.ycanvas, x_range=bounds[xaxis], y_range=bounds[yaxis]) if aaxis is not None: @@ -265,15 +274,31 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor # aggregation applied to by() agg_by = agg_alpha if USE_REDUCE_BY and agg_alpha is not None else datashader.count() + # color_bins will be a list of colors to use. If the subset is known, then we preferentially + # pick colours by subset, i.e. we try to preserve the mapping from index to specific color. + # color_labels will be set from discretized_labels, or from range of column values, if axis is discrete + color_labels = cdatum.discretized_labels if data_mappers.USE_COUNT_CAT: - color_bins = [int(x) for x in getattr(ddf.dtypes, caxis).categories] + if cdatum.subset_indices is not None: + color_bins = cdatum.subset_indices + else: + color_bins = [int(x) for x in getattr(ddf.dtypes, caxis).categories] log.debug(f'colourizing using {caxis} categorical, {len(color_bins)} bins') category = caxis + if color_labels is None: + color_labels = list(map(str,color_bins)) else: color_bins = list(range(cdatum.nlevels)) if cdatum.is_discrete: + if cdatum.subset_indices is not None: + num_categories = len(cdatum.subset_indices) + color_bins = cdatum.subset_indices[:cdatum.nlevels] + else: + num_categories = int(bounds[caxis][1]) + 1 log.debug(f'colourizing using {caxis} modulo {len(color_bins)}') - category = category_modulo(caxis, cdatum.nlevels) + category = category_modulo(caxis, len(color_bins)) + if color_labels is None: + color_labels = list(map(str, range(num_categories))) else: log.debug(f'colourizing using {caxis} with {len(color_bins)} bins') cmin = bounds[caxis][0] @@ -297,44 +322,27 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor log.info(": no valid data in plot. Check your flags and/or plot limits.") return None - # # work around https://github.com/holoviz/datashader/issues/899 - # # Basically, 0 is treated as a nan and masked out in _colorize(), which is not correct for float reductions. - # # Also, _colorize() does not normalize the totals somehow. - # if np.issubdtype(raster.dtype, np.bool_): - # pass - # elif np.issubdtype(raster.dtype, np.integer): - # ## TODO: unfinished business here - # ## normalizing the raster bleaches out all colours again (fucks with log scaling, I guess?) - # # int values: simply normalize to max total 1. Null values will be masked - # # raster = raster.astype(np.float32) / raster.sum(axis=2).max() - # pass - # else: - # # float values: first rescale raster to [0.001, 1]. Not 0, because 0 is masked out in _colorize() - # maxval = np.nanmax(raster) - # offset = np.nanmin(raster) - # raster = .001 + .999*(raster - offset)/(maxval - offset) - # # replace NaNs with zeroes (because when we take the total, and 1 channel is present while others are missing...) - # raster.data[np.isnan(raster.data)] = 0 - # # now rescale so that max total is 1 - # raster /= raster.sum(axis=2).max() - if cdatum.is_discrete: # discard empty bins non_empty = np.where(non_empty)[0] raster = raster[..., non_empty] - # just use bin numbers to look up a color directly + # get bin numbers corresponding to non-empty bins color_bins = [color_bins[i] for i in non_empty] + # get list of color labels corresponding to each bin (may be multiple) + color_labels = [color_labels[i::cdatum.nlevels] for i in non_empty] color_key = [dmap[bin] for bin in color_bins] # the numbers may be out of order -- reorder for color bar purposes - bin_color = sorted(zip(color_bins, color_key)) - color_mapping = [col for _, col in bin_color] - if bounds[caxis][1] > cdatum.nlevels: - color_labels = [f"+{bin}" for bin, _ in bin_color] - else: - if cdatum.discretized_labels and len(cdatum.discretized_labels) <= cdatum.nlevels: - color_labels = [cdatum.discretized_labels[bin] for bin, _ in bin_color] + bin_color_label = sorted(zip(color_bins, color_key, color_labels)) + color_mapping = [col for _, col, _ in bin_color_label] + # generate labels + color_labels = [] + for _, _, labels in bin_color_label: + if len(labels) == 1: + color_labels.append(labels[0]) + elif len(labels) == 2: + color_labels.append(f"{labels[0]},{labels[1]}") else: - color_labels = [f"{bin}" for bin, _ in bin_color] + color_labels.append(f"{labels[0]},{labels[1]},...") log.info(f": rendering using {len(color_bins)} colors (values {' '.join(color_labels)})") else: # color labels are bin centres diff --git a/shade_ms/main.py b/shade_ms/main.py index 377d2d5..f2d1a3c 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -357,20 +357,24 @@ class Subset(object): raise NotImplementedError("iteration over antennas not currently supported") if options.baseline != 'all': - subset.baseline = OrderedDict() + bls = set() + a1a2 = set() for blspec in options.baseline.split(","): match = re.fullmatch(r"(\w+)-(\w+)", blspec) ant1 = match and ms.antenna[match.group(1)] ant2 = match and ms.antenna[match.group(2)] if ant1 is None or ant2 is None: raise ValueError("invalid baseline '{blspec}'") - subset.baseline[blspec] = (ant1, ant2) + a1a2.add((ant1, ant2)) + bls.add(ms.baseline_number(ant1, ant2)) # group_cols.append('ANTENNA1') - log.info(f"Baseline(s) : {' '.join(subset.baseline.keys())}") + subset.baseline = ms.all_baseline.get_subset(sorted(bls)) + log.info(f"Baseline(s) : {' '.join(subset.baseline.names)}") mytaql.append("||".join([f'(ANTENNA1=={ant1}&&ANTENNA2=={ant2})||(ANTENNA1=={ant2}&&ANTENNA2=={ant1})' - for ant1, ant2 in subset.baseline.values()])) + for ant1, ant2 in a1a2])) else: log.info('Baseline(s) : all') + subset.baseline = ms.baseline if options.field != 'all': subset.field = ms.field.get_subset(options.field) diff --git a/shade_ms/ms_info.py b/shade_ms/ms_info.py index 004dacb..baf2614 100644 --- a/shade_ms/ms_info.py +++ b/shade_ms/ms_info.py @@ -79,15 +79,26 @@ def __init__(self, msname=None, log=None): all_scans = NamedList("scan", list(map(str, range(scan_numbers[-1]+1)))) self.scan = all_scans.get_subset(scan_numbers) - self.all_antenna = NamedList("antenna", table(msname +'::ANTENNA', ack=False).getcol("NAME")) + antnames = table(msname +'::ANTENNA', ack=False).getcol("NAME") + self.all_antenna = antnames = NamedList("antenna", antnames) - self.antenna = self.all_antenna.get_subset(list(set(tab.getcol("ANTENNA1"))|set(tab.getcol("ANTENNA2")))) + ant1col = tab.getcol("ANTENNA1") + ant2col = tab.getcol("ANTENNA2") + self.antenna = self.all_antenna.get_subset(list(set(ant1col)|set(ant2col))) - baselines = [(p,q) for p in self.antenna.numbers for q in self.antenna.numbers if p <= q] - self.baseline_numbering = { (p, q): i for i, (p, q) in enumerate(baselines)} - self.baseline_numbering.update({ (q, p): i for i, (p, q) in enumerate(baselines)}) + log and log.info(f": {len(self.antenna)}/{len(self.all_antenna)} antennas: {self.antenna.str_list()}") - log and log.info(f": {len(self.antenna)} antennas: {self.antenna.str_list()}") + # list of all possible baselines + blnames = [f"{a1}-{a2}" for i1, a1 in enumerate(antnames) for a2 in antnames[i1:]] + self.all_baseline = NamedList("baseline", blnames) + + # baselines actually present + a1 = np.minimum(ant1col, ant2col) + a2 = np.maximum(ant1col, ant2col) + bls = sorted(set(self.baseline_number(a1, a2))) + self.baseline = NamedList("baseline", [blnames[b] for b in bls], bls) + + log and log.info(f": {len(self.baseline)}/{len(self.all_baseline)} baselines present") pol_tab = table(msname + '::POLARIZATION', ack=False) @@ -137,6 +148,7 @@ def _or(x): log and log.info(f": corrs/Stokes {' '.join(self.all_corr.names)}") - def baseline_number(self, ant1, ant2): - a1 = DataArray.minimum(ant1, ant2) - a2 = DataArray.maximum(ant1, ant2) + def baseline_number(self, a1, a2): + """Returns baseline number, for a1<=a2""" + return a1 * len(self.all_antenna) - a1 * (a1 - 1) // 2 + a2 - a1 + From ce2469eedee6e755009724009e8e2f57d41fedf5 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Thu, 4 Jun 2020 10:57:36 +0200 Subject: [PATCH 14/33] fixed boundaries for discrete axes --- shade_ms/data_plots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index f99969e..9f74689 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -247,7 +247,7 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor # adjust bounds for discrete axes for datum in xdatum, ydatum: if datum.is_discrete: - bounds[datum.label] = bounds[datum.label][0]-0.5, bounds[datum.label][0]+0.5 + bounds[datum.label] = bounds[datum.label][0]-0.5, bounds[datum.label][1]+0.5 # create rendering canvas. TODO: https://github.com/ratt-ru/shadeMS/issues/42 canvas = datashader.Canvas(options.xcanvas, options.ycanvas, x_range=bounds[xaxis], y_range=bounds[yaxis]) From 6b1f96504b7225091c989a201003391083190654 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Thu, 4 Jun 2020 13:58:59 +0200 Subject: [PATCH 15/33] 0.4.0 release prep. Fixes #42 --- shade_ms/data_plots.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 9f74689..8e02284 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -238,19 +238,25 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor if datum is not None: bounds[datum.label] = datum.minmax if datum.minmax[0] is None or datum.minmax[1] is None: - unknown.append(datum.label) + if datum.is_discrete and datum.subset_indices is not None: + bounds[datum.label] = 0, len(datum.subset_indices)-1 + else: + unknown.append(datum.label) if unknown: log.info(f": scanning axis min/max for {' '.join(unknown)}") compute_bounds(unknown, bounds, ddf) # adjust bounds for discrete axes - for datum in xdatum, ydatum: + canvas_sizes = [] + for datum, size in (xdatum, options.xcanvas), (ydatum, options.ycanvas): if datum.is_discrete: bounds[datum.label] = bounds[datum.label][0]-0.5, bounds[datum.label][1]+0.5 + size = int(bounds[datum.label][1]) - int(bounds[datum.label][0]) + 1 + canvas_sizes.append(size) - # create rendering canvas. TODO: https://github.com/ratt-ru/shadeMS/issues/42 - canvas = datashader.Canvas(options.xcanvas, options.ycanvas, x_range=bounds[xaxis], y_range=bounds[yaxis]) + # create rendering canvas. + canvas = datashader.Canvas(canvas_sizes[0], canvas_sizes[1], x_range=bounds[xaxis], y_range=bounds[yaxis]) if aaxis is not None: agg_alpha = getattr(datashader.reductions, ared, None) @@ -409,7 +415,7 @@ def match(artist): fig = pylab.figure(figsize=(figx, figy)) ax = fig.add_subplot(111, facecolor=bgcol) ax.imshow(X=rgb.data, extent=[xmin, xmax, ymin, ymax], - aspect='auto', origin='lower') + aspect='auto', origin='lower', interpolation='nearest') ax.set_title("\n".join(textwrap.wrap(title, 90)), loc='center') ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) From 218debdcfe675126970e1a6f49d5eef759ee2e15 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Thu, 4 Jun 2020 16:21:38 +0200 Subject: [PATCH 16/33] implements #55 --- shade_ms/data_mappers.py | 9 +++++- shade_ms/data_plots.py | 6 +++- shade_ms/main.py | 63 +++++++++++++++++++++++++++++++++++----- 3 files changed, 68 insertions(+), 10 deletions(-) diff --git a/shade_ms/data_mappers.py b/shade_ms/data_mappers.py index 7c44487..00baf2c 100644 --- a/shade_ms/data_mappers.py +++ b/shade_ms/data_mappers.py @@ -134,7 +134,7 @@ def parse_datum_spec(cls, axis_spec, default_column=None, ms=None): return function, column, corr, (has_corr_axis and corr is None) @classmethod - def register(cls, function, column, corr, ms, minmax=None, ncol=None, subset=None): + def register(cls, function, column, corr, ms, minmax=None, ncol=None, subset=None, minmax_cache=None): """ Registers a data axis, which ultimately ends up as a column in the assembled dataframe. For multiple plots, we want to reuse the same information (assuming the same @@ -145,6 +145,8 @@ def register(cls, function, column, corr, ms, minmax=None, ncol=None, subset=Non Corr selects a correlation (or a Stokes product such as I, Q,...) minmax sets axis clipping levels ncol discretizes the axis into N colours between min and max + minmax_cache provides a dict of cached min/max values, which will be looked up via the label, if minmax + is not explicitly set """ # form up label label = "{}_{}_{}".format(col_to_label(column or ''), function, corr) @@ -154,6 +156,11 @@ def register(cls, function, column, corr, ms, minmax=None, ncol=None, subset=Non if key in cls.all_axes: return cls.all_axes[key] else: + # see if minmax should be loaded + if (minmax is None or tuple(minmax) == (None,None)) and minmax_cache and label in minmax_cache: + log.info(f"loading {label} min/max from cache") + minmax = minmax_cache[label] + label0, i = label, 0 while label in cls.all_labels: i += 1 diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 8e02284..7479422 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -218,6 +218,7 @@ def compute_bounds(unknowns, bounds, ddf): def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, normalize, xlabel, ylabel, title, pngname, min_alpha=40, saturate_percentile=None, saturate_alpha=None, + minmax_cache=None, options=None): figx = options.xcanvas / 60 @@ -246,6 +247,9 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor if unknown: log.info(f": scanning axis min/max for {' '.join(unknown)}") compute_bounds(unknown, bounds, ddf) + # populate cache + if minmax_cache is not None: + minmax_cache.update([(label, bounds[label]) for label in unknown]) # adjust bounds for discrete axes canvas_sizes = [] @@ -255,7 +259,7 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor size = int(bounds[datum.label][1]) - int(bounds[datum.label][0]) + 1 canvas_sizes.append(size) - # create rendering canvas. + # create rendering canvas. canvas = datashader.Canvas(canvas_sizes[0], canvas_sizes[1], x_range=bounds[xaxis], y_range=bounds[yaxis]) if aaxis is not None: diff --git a/shade_ms/main.py b/shade_ms/main.py index f2d1a3c..c15b5a0 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -1,6 +1,4 @@ # -*- coding: future_fstrings -*- -# ian.heywood@physics.ox.ac.uk - import matplotlib matplotlib.use('agg') @@ -15,10 +13,9 @@ import re import sys import colorcet -from collections import OrderedDict import dask.diagnostics from contextlib import contextmanager - +import json import argparse @@ -124,6 +121,20 @@ def main(argv): group_opts.add_argument('--cnum', action='append', help=f'Number of steps used to discretize a continuous axis. Default is {DEFAULT_CNUM}.') + group_opts.add_argument('--xlim-load', action='store_true', + help=f'Load x-axis limits from limits file, if available.') + group_opts.add_argument('--ylim-load', action='store_true', + help=f'Load y-axis limits from limits file, if available.') + group_opts.add_argument('--clim-load', action='store_true', + help=f'Load colour axis limits from limits file, if available.') + group_opts.add_argument('--lim-file', default="{ms}-minmax-cache.json", + help="""Name of limits file to save/load. '{ms}' will be substituted for MS base name. + Default is '%(default)s'.""") + group_opts.add_argument('--no-lim-save', action="store_false", dest="lim_save", + help="""Do not save auto-computed limits to limits file. Default is to save.""") + group_opts.add_argument('--lim-save-reset', action="store_true", + help="""Reset limits file when saving. Default adds to existing file.""") + group_opts = parser.add_argument_group('Options for multiple plots or combined plots') group_opts.add_argument('--iter-field', action="store_true", @@ -422,6 +433,25 @@ class Subset(object): blank() + # check minmax cache + msbase = os.path.splitext(os.path.basename(options.ms))[0] + cache_file = options.lim_file.format(ms=msbase) + if options.dir and not "/" in cache_file: + cache_file = os.path.join(options.dir, cache_file) + + # try to load the minmax cache file + if not os.path.exists(cache_file): + minmax_cache = {} + else: + log.info(f"loading minmax cache from {cache_file}") + try: + minmax_cache = json.load(open(cache_file, "rt")) + if type(minmax_cache) is not dict: + raise TypeError("cache cotent is not a dict") + except Exception as exc: + log.error(f"error reading cache file: {exc}. Minmax cache will be reset.") + minmax_cache = {} + # figure out list of plots to make all_plots = [] @@ -470,11 +500,14 @@ def describe_corr(corrvalue): plot_ycorr = corr if ycorr is None else ycorr plot_acorr = corr if acorr is None else acorr plot_ccorr = corr if ccorr is None else ccorr - xdatum = DataAxis.register(xfunction, xcolumn, plot_xcorr, ms=ms, minmax=(xmin, xmax), subset=subset) - ydatum = DataAxis.register(yfunction, ycolumn, plot_ycorr, ms=ms, minmax=(ymin, ymax), subset=subset) - adatum = afunction and DataAxis.register(afunction, acolumn, plot_acorr, ms=ms, subset=subset) + xdatum = DataAxis.register(xfunction, xcolumn, plot_xcorr, ms=ms, minmax=(xmin, xmax), subset=subset, + minmax_cache=minmax_cache if options.xlim_load else None) + ydatum = DataAxis.register(yfunction, ycolumn, plot_ycorr, ms=ms, minmax=(ymin, ymax), subset=subset, + minmax_cache=minmax_cache if options.ylim_load else None) + adatum = afunction and DataAxis.register(afunction, acolumn, plot_acorr, ms=ms, subset=subset) cdatum = cfunction and DataAxis.register(cfunction, ccolumn, plot_ccorr, ms=ms, - minmax=(cmin, cmax), ncol=cnum, subset=subset) + minmax=(cmin, cmax), ncol=cnum, subset=subset, + minmax_cache=minmax_cache if options.clim_load else None) # figure out plot properties -- basically construct a descriptive name and label # looks complicated, but we're just trying to figure out what to put in the plot title... @@ -540,6 +573,10 @@ def describe_corr(corrvalue): all_plots.append((props, xdatum, ydatum, adatum, ared, cdatum)) log.debug(f"adding plot for {props['title']}") + # reset minmax cache if requested + if options.lim_save_reset: + minmax_cache = {} + join_corrs = not options.iter_corr and len(subset.corr) > 1 and have_corr_dependence log.info(' : you have asked for {} plots employing {} unique datums'.format(len(all_plots), @@ -622,6 +659,7 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, saturate_alpha=options.saturate_alpha, saturate_percentile=options.saturate_perc, xlabel=xlabel, ylabel=ylabel, title=title, pngname=pngname, + minmax_cache=minmax_cache, options=options) if result: log.info(f' : wrote {pngname}') @@ -687,5 +725,14 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, elapsed = str(round((clock_stop-clock_start), 2)) log.info('Total time : %s seconds' % (elapsed)) + + if minmax_cache and options.lim_save: + # ensure floats, because in64s and such cause errors + minmax_cache = {axis: list(map(float, minmax)) for axis, minmax in minmax_cache.items()} + + with open(cache_file, "wt") as file: + json.dump(minmax_cache, file, sort_keys=True, indent=4, separators=(',', ': ')) + log.info(f"Saved minmax cache to {cache_file} (disable with --no-lim-save)") + log.info('Finished') blank() From c65501a992fec71524b6af937bfc1da5c7eabba3 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Thu, 4 Jun 2020 16:51:07 +0200 Subject: [PATCH 17/33] simplify min/max for 'constant' axes such as CHAN and FREQ --- shade_ms/data_mappers.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/shade_ms/data_mappers.py b/shade_ms/data_mappers.py index 00baf2c..563632e 100644 --- a/shade_ms/data_mappers.py +++ b/shade_ms/data_mappers.py @@ -28,7 +28,7 @@ def col_to_label(col): class DataMapper(object): """This class defines a mapping from a dask group to an array of real values to be plotted""" - def __init__(self, fullname, unit, mapper, column=None, extras=[], conjugate=False, axis=None): + def __init__(self, fullname, unit, mapper, column=None, extras=[], conjugate=False, axis=None, const=False): """ :param fullname: full name of parameter (real, amplitude, etc.) :param unit: unit string @@ -37,11 +37,15 @@ def __init__(self, fullname, unit, mapper, column=None, extras=[], conjugate=Fal :param extras: extra arguments needed by mapper (e.g. ["freqs", "wavel"]) :param conjugate: sets conjugation flag :param axis: which axis the parameter represets (0 time, 1 freq), if 1-dimensional + :param const: if True, axis is constant (does not depend on MS rows) """ assert mapper is not None self.fullname, self.unit, self.mapper, self.column, self.extras = fullname, unit, mapper, column, extras self.conjugate = conjugate self.axis = axis + # if + self.const = const + _identity = lambda x:x @@ -56,10 +60,10 @@ def __init__(self, fullname, unit, mapper, column=None, extras=[], conjugate=Fal TIME = DataMapper("time", "s", axis=0, column="TIME", mapper=_identity), ROW = DataMapper("row number", "", column=False, axis=0, extras=["rows"], mapper=lambda x,rows: rows), BASELINE = DataMapper("baseline", "", column=False, axis=0, extras=["baselines"], mapper=lambda x,baselines: baselines), - CORR = DataMapper("correlation", "", column=False, axis=0, extras=["corr"], mapper=lambda x,corr: corr), - CHAN = DataMapper("channel", "", column=False, axis=1, extras=["chans"], mapper=lambda x,chans: chans), - FREQ = DataMapper("frequency", "Hz", column=False, axis=1, extras=["freqs"], mapper=lambda x, freqs: freqs), - WAVEL = DataMapper("wavelength", "m", column=False, axis=1, extras=["wavel"], mapper=lambda x, wavel: wavel), + CORR = DataMapper("correlation", "", column=False, axis=0, extras=["corr"], mapper=lambda x,corr: corr, const=True), + CHAN = DataMapper("channel", "", column=False, axis=1, extras=["chans"], mapper=lambda x,chans: chans, const=True), + FREQ = DataMapper("frequency", "Hz", column=False, axis=1, extras=["freqs"], mapper=lambda x, freqs: freqs, const=True), + WAVEL = DataMapper("wavelength", "m", column=False, axis=1, extras=["wavel"], mapper=lambda x, wavel: wavel, const=True), UV = DataMapper("uv-distance", "wavelengths", column="UVW", extras=["wavel"], mapper=lambda uvw, wavel: da.sqrt((uvw[:,:2]**2).sum(axis=1))/wavel), U = DataMapper("u", "wavelengths", column="UVW", extras=["wavel"], @@ -177,6 +181,8 @@ def __init__(self, column, function, corr, ms, minmax=None, ncol=None, label=Non self.corr = corr if corr != "all" else None self.nlevels = ncol self.minmax = tuple(minmax) if minmax is not None else (None, None) + self._minmax_autorange = (self.minmax == (None, None)) + self.label = label self._corr_reduce = None self._is_discrete = None @@ -312,6 +318,11 @@ def get_value(self, group, corr, extras, flag, flag_row, chanslice): if np.iscomplexobj(coldata) and mapper is data_mappers["_"]: mapper = data_mappers["amp"] coldata = mapper.mapper(coldata, **{name:extras[name] for name in self.mapper.extras }) + # for a constant axis, compute minmax on the fly + if mapper.const and self._minmax_autorange: + min1, max1 = coldata.data.min(), coldata.data.max() + self.minmax = min(self.minmax[0], min1) if self.minmax[0] is not None else min1, \ + min(self.minmax[1], max1) if self.minmax[1] is not None else max1 # scalar is just a scalar if np.isscalar(coldata): coldata = da.array(coldata) From 5c4164e86036bca3a4ebb3d63bc5101cbe92bfbd Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Sun, 7 Jun 2020 12:48:40 +0200 Subject: [PATCH 18/33] fixes #53 --- shade_ms/data_mappers.py | 10 +++++- shade_ms/data_plots.py | 66 +++++++++++++++++++++++++++++++++------- shade_ms/main.py | 9 ++++-- shade_ms/ms_info.py | 23 ++++++++++++-- 4 files changed, 90 insertions(+), 18 deletions(-) diff --git a/shade_ms/data_mappers.py b/shade_ms/data_mappers.py index 563632e..62f40c2 100644 --- a/shade_ms/data_mappers.py +++ b/shade_ms/data_mappers.py @@ -6,6 +6,7 @@ import math import re import argparse +from . import ms_info from collections import OrderedDict from shade_ms import log @@ -60,6 +61,7 @@ def __init__(self, fullname, unit, mapper, column=None, extras=[], conjugate=Fal TIME = DataMapper("time", "s", axis=0, column="TIME", mapper=_identity), ROW = DataMapper("row number", "", column=False, axis=0, extras=["rows"], mapper=lambda x,rows: rows), BASELINE = DataMapper("baseline", "", column=False, axis=0, extras=["baselines"], mapper=lambda x,baselines: baselines), + BASELINE_M = DataMapper("baseline", "", column=False, axis=0, extras=["baselines"], mapper=lambda x,baselines: baselines), CORR = DataMapper("correlation", "", column=False, axis=0, extras=["corr"], mapper=lambda x,corr: corr, const=True), CHAN = DataMapper("channel", "", column=False, axis=1, extras=["chans"], mapper=lambda x,chans: chans, const=True), FREQ = DataMapper("frequency", "Hz", column=False, axis=1, extras=["freqs"], mapper=lambda x, freqs: freqs, const=True), @@ -228,6 +230,12 @@ def __init__(self, column, function, corr, ms, minmax=None, ncol=None, label=Non elif function == "BASELINE": self.subset_indices = subset.baseline maxind = ms.baseline.numbers[-1] + elif function == "BASELINE_M": + bl_subset = set(subset.baseline.numbers) # active baselines + numbers = [i for i in ms.baseline_m.numbers if i in bl_subset] + names = [bl for i, bl in zip(ms.baseline_m.numbers, ms.baseline_m.names) if i in bl_subset] + self.subset_indices = ms_info.NamedList("baseline_m", names, numbers) + maxind = ms.baseline.numbers[-1] elif column == "FLAG" or column == "FLAG_ROW": self.discretized_labels = ["F", "T"] @@ -238,7 +246,7 @@ def __init__(self, column, function, corr, ms, minmax=None, ncol=None, label=Non # and all other indices to N if len(self.subset_indices) < maxind+1: remapper = np.full(maxind+1, len(self.subset_indices)) - for i, index in enumerate(self.subset_indices.numbers): + for i, index in enumerate(self.subset_indices.numbers): remapper[index] = i self.subset_remapper = da.array(remapper) self.discretized_labels = self.subset_indices.names diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 7479422..ad668ea 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -28,14 +28,16 @@ from .dask_utils import dataframe_factory # from .ds_ext import by_integers, by_span -USE_REDUCE_BY = False +USE_REDUCE_BY = True def add_options(parser): - parser.add_argument('--reduce-by', action="store_true", help=argparse.SUPPRESS) + # parser.add_argument('--reduce-by', action="store_true", help=argparse.SUPPRESS) + pass def set_options(options): - global USE_REDUCE_BY - USE_REDUCE_BY = options.reduce_by + # global USE_REDUCE_BY + # USE_REDUCE_BY = options.reduce_by + pass def freq_to_wavel(ff): @@ -420,9 +422,10 @@ def match(artist): ax = fig.add_subplot(111, facecolor=bgcol) ax.imshow(X=rgb.data, extent=[xmin, xmax, ymin, ymax], aspect='auto', origin='lower', interpolation='nearest') - ax.set_title("\n".join(textwrap.wrap(title, 90)), loc='center') - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) + + ax.set_title("\n".join(textwrap.wrap(title, 90)), loc='center', fontdict=dict(fontsize=options.fontsize)) + ax.set_xlabel(xlabel, fontdict=dict(fontsize=options.fontsize)) + ax.set_ylabel(ylabel, fontdict=dict(fontsize=options.fontsize)) # ax.plot(xmin,ymin,'.',alpha=0.0) # ax.plot(xmax,ymax,'.',alpha=0.0) @@ -430,9 +433,48 @@ def match(artist): ax.set_xlim([xmin - dx/100, xmax + dx/100]) ax.set_ylim([ymin - dy/100, ymax + dy/100]) - # set fontsize on everything rendered so far - for textobj in fig.findobj(match=match): - textobj.set_fontsize(options.fontsize) + def decimate_list(x, maxel): + """Helper function to reduce a list to < given max number of elements, dividing it by decimal factors of 2 and 5""" + factors = 2, 5, 10 + base = divisor = 1 + while len(x)//divisor > maxel: + for fac in factors: + divisor = fac*base + if len(x)//divisor <= maxel: + break + base *= 10 + return x[::divisor] + + ax.tick_params(labelsize=options.fontsize*0.66) + + # max # of tickmarks and labels to draw for discrete axes + MAXLABELS = 64 # if we have up to this many labels, show them all + MAXLABELS1 = 32 # if we have >MAXLABELS to show, then sparsify and get below this number + MAXTICKS = 300 # if total number of points is within this range, draw them as minor tickmarks + + # do we have discrete labels to put on the axes? + if xdatum.discretized_labels is not None: + n = len(xdatum.discretized_labels) + ticks_labels = list(enumerate(xdatum.discretized_labels)) + if n > MAXLABELS: + ticks_labels = decimate_list(ticks_labels, MAXLABELS1) # enforce max number of tick labels + labels = [label for _, label in ticks_labels] + rot = 90 if max([len(label) for label in xdatum.discretized_labels])*n > 60 else 0 + ax.set_xticks([x[0] for x in ticks_labels]) + ax.set_xticklabels(labels, rotation=rot) + if len(ticks_labels) < n and n <= MAXTICKS: + ax.set_xticks(range(n), minor=True) + + if ydatum.discretized_labels is not None: + n = len(ydatum.discretized_labels) + ticks_labels = list(enumerate(ydatum.discretized_labels)) + if n > MAXLABELS: + ticks_labels = decimate_list(ticks_labels, MAXLABELS1) # enforce max number of tick labels + labels = [label for _, label in ticks_labels] + ax.set_yticks([y[0] for y in ticks_labels]) + ax.set_yticklabels(labels) + if len(ticks_labels) < n and n <= MAXTICKS: + ax.set_yticks(range(n), minor=True) # colorbar? if color_key: @@ -455,12 +497,14 @@ def match(artist): if caxis is not None and cdatum.is_discrete: rot = 0 # adjust fontsize for number of labels - fs = max(options.fontsize*min(1, 32./len(color_labels)), 6) + fs = max(options.fontsize*min(0.8, 20./len(color_labels)), 6) fontdict = dict(fontsize=fs) if max([len(lbl) for lbl in color_labels]) > 3 and len(color_labels) < 8: rot = 90 fontdict['verticalalignment'] ='center' cb.ax.set_yticklabels(color_labels, rotation=rot, fontdict=fontdict) + else: + cb.ax.tick_params(labelsize=options.fontsize*0.8) fig.savefig(pngname, bbox_inches='tight') diff --git a/shade_ms/main.py b/shade_ms/main.py index c15b5a0..5ad5aa7 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -199,7 +199,7 @@ def main(argv): group_opts.add_argument('--bgcol', dest='bgcol', help='RGB hex code for background colour (default = FFFFFF)', default='FFFFFF') group_opts.add_argument('--fontsize', dest='fontsize', - help='Font size for all text elements (default = 20)', default=20) + help='Font size for all text elements (default = 20)', default=16) group_opts = parser.add_argument_group('Output settings') @@ -520,7 +520,9 @@ def describe_corr(corrvalue): labels.append(col_to_label(ycolumn)) titles += describe_corr(plot_ycorr) labels += describe_corr(plot_ycorr) - titles += [ydatum.mapper.fullname, "vs"] + if ydatum.mapper.fullname: + titles += [ydatum.mapper.fullname] + titles += ["vs"] if ydatum.function: labels.append(ydatum.function) # add x column/subset.corr, if different @@ -530,7 +532,8 @@ def describe_corr(corrvalue): if plot_xcorr != plot_ycorr: titles += describe_corr(plot_xcorr) labels += describe_corr(plot_xcorr) - titles += [xdatum.mapper.fullname] + if xdatum.mapper.fullname: + titles += [xdatum.mapper.fullname] if xdatum.function: labels.append(xdatum.function) props['title'] = " ".join(titles) diff --git a/shade_ms/ms_info.py b/shade_ms/ms_info.py index baf2614..4114bac 100644 --- a/shade_ms/ms_info.py +++ b/shade_ms/ms_info.py @@ -2,9 +2,9 @@ from casacore.tables import table import re import daskms +import math import numpy as np from collections import OrderedDict -from xarray import DataArray class NamedList(object): """Holds a list of names (e.g. field names), and provides common indexing and subset operations""" @@ -79,8 +79,10 @@ def __init__(self, msname=None, log=None): all_scans = NamedList("scan", list(map(str, range(scan_numbers[-1]+1)))) self.scan = all_scans.get_subset(scan_numbers) - antnames = table(msname +'::ANTENNA', ack=False).getcol("NAME") + anttab = table(msname +'::ANTENNA', ack=False) + antnames = anttab.getcol("NAME") self.all_antenna = antnames = NamedList("antenna", antnames) + self.antpos = anttab.getcol("POSITION") ant1col = tab.getcol("ANTENNA1") ant2col = tab.getcol("ANTENNA2") @@ -89,15 +91,30 @@ def __init__(self, msname=None, log=None): log and log.info(f": {len(self.antenna)}/{len(self.all_antenna)} antennas: {self.antenna.str_list()}") # list of all possible baselines + nant = len(antnames) blnames = [f"{a1}-{a2}" for i1, a1 in enumerate(antnames) for a2 in antnames[i1:]] self.all_baseline = NamedList("baseline", blnames) + # list of baseline lengths + self.baseline_lengths = [math.sqrt(((pos1-pos2)**2).sum()) + for i1, pos1 in enumerate(self.antpos) for pos2 in self.antpos[i1:]] + + # sort order to put baselines by length + sorted_by_length = sorted([(x, i) for i, x in enumerate(self.baseline_lengths)]) + self.baseline_sorted_index = [i for _, i in sorted_by_length] + self.baseline_sorted_length = [x for x, _ in sorted_by_length] + # baselines actually present a1 = np.minimum(ant1col, ant2col) a2 = np.maximum(ant1col, ant2col) - bls = sorted(set(self.baseline_number(a1, a2))) + bl_set = set(self.baseline_number(a1, a2)) + bls = sorted(bl_set) self.baseline = NamedList("baseline", [blnames[b] for b in bls], bls) + # make list of baselines present, in meters + blm = [i for i in self.baseline_sorted_index if i in bl_set] + self.baseline_m = NamedList("baseline_m", [f"{int(round(self.baseline_lengths[b]))}m" for b in blm], blm) + log and log.info(f": {len(self.baseline)}/{len(self.all_baseline)} baselines present") pol_tab = table(msname + '::POLARIZATION', ack=False) From 2d611d06ae628c3ec26381f2098f1504210eb78a Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Sun, 7 Jun 2020 14:30:28 +0200 Subject: [PATCH 19/33] fixed baseline sorting when all baselines in place. Omit autocorr by default --- shade_ms/data_mappers.py | 10 ++++++---- shade_ms/main.py | 17 +++++++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/shade_ms/data_mappers.py b/shade_ms/data_mappers.py index 62f40c2..dc36635 100644 --- a/shade_ms/data_mappers.py +++ b/shade_ms/data_mappers.py @@ -241,10 +241,12 @@ def __init__(self, column, function, corr, ms, minmax=None, ncol=None, label=Non # make a remapper if self.subset_indices is not None: - # If last index of subset is max index anyway, mapping is 1:1 -- no remapper needed - # Otherwise map indices in subset into their ordinal numbers in the subset (0...N-1), - # and all other indices to N - if len(self.subset_indices) < maxind+1: + # if the mapping from indices to bins 1:1? + subind = np.array(self.subset_indices.numbers) + identity = subind[0] == 0 and ((subind[1:]-subind[:-1]) == 1).all() + # If mapping is not 1:1, or subset is short of full set, then we need a remapper. + # Map indices in subset into their ordinal numbers in the subset (0...N-1), and all other indices to N + if len(self.subset_indices) < maxind+1 or not identity: remapper = np.full(maxind+1, len(self.subset_indices)) for i, index in enumerate(self.subset_indices.numbers): remapper[index] = i diff --git a/shade_ms/main.py b/shade_ms/main.py index 5ad5aa7..e666c41 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -156,8 +156,9 @@ def main(argv): help='Antennas to plot (comma-separated list of names, default = all)') group_opts.add_argument('--ant-num', help='Antennas to plot (comma-separated list of numbers, or a [start]:[stop][:step] slice, overrides --ant)') - group_opts.add_argument('--baseline', default='all', - help="Baselines to plot, as 'ant1-ant2' (comma-separated list, default = all)") + group_opts.add_argument('--baseline', default='noauto', + help="Baselines to plot, as 'ant1-ant2' (comma-separated list, default of 'noauto' omits " + "auto-correlations, use 'all' to select all)") group_opts.add_argument('--spw', default='all', help='Spectral windows (DDIDs) to plot (comma-separated list, default = all)') group_opts.add_argument('--field', default='all', @@ -367,7 +368,14 @@ class Subset(object): if options.iter_antenna: raise NotImplementedError("iteration over antennas not currently supported") - if options.baseline != 'all': + if options.baseline == 'all': + log.info('Baseline(s) : all') + subset.baseline = ms.baseline + elif options.baseline == 'noauto': + log.info('Baseline(s) : all except autocorrelations') + subset.baseline = ms.all_baseline.get_subset([i for i in ms.baseline.numbers if ms.baseline_lengths[i]!=0]) + mytaql.append("ANTENNA1!=ANTENNA2") + else: bls = set() a1a2 = set() for blspec in options.baseline.split(","): @@ -383,9 +391,6 @@ class Subset(object): log.info(f"Baseline(s) : {' '.join(subset.baseline.names)}") mytaql.append("||".join([f'(ANTENNA1=={ant1}&&ANTENNA2=={ant2})||(ANTENNA1=={ant2}&&ANTENNA2=={ant1})' for ant1, ant2 in a1a2])) - else: - log.info('Baseline(s) : all') - subset.baseline = ms.baseline if options.field != 'all': subset.field = ms.field.get_subset(options.field) From 179639876c470914bc16730fce3319d3b6475520 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Sun, 7 Jun 2020 16:20:13 +0200 Subject: [PATCH 20/33] fixes #52. Also makes antenna subset selection more sensible. --- shade_ms/data_plots.py | 27 ++++++++++++++++++++------- shade_ms/main.py | 41 ++++++++++++++++++++++++++--------------- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index ad668ea..89bf6d7 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -47,7 +47,7 @@ def freq_to_wavel(ff): def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, chanslice, subset, noflags, noconj, - iter_field, iter_spw, iter_scan, iter_ant, + iter_field, iter_spw, iter_scan, iter_ant, iter_baseline, join_corrs=False, row_chunk_size=100000): @@ -74,6 +74,10 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, if antenna is not None: taql = f"({mytaql})&&(ANTENNA1=={antenna} || ANTENNA2=={antenna})" if mytaql else \ f"(ANTENNA1=={antenna} || ANTENNA2=={antenna})" + # add baselines to group columns + if iter_baseline: + group_cols = list(group_cols) + ["ANTENNA1", "ANTENNA2"] + # get MS data msdata = daskms.xds_from_ms(msinfo.msname, columns=list(ms_cols), group_cols=group_cols, taql_where=taql, chunks=dict(row=row_chunk_size)) @@ -88,12 +92,20 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, # iterate over groups for group in msdata: + if not len(group.row): + continue ddid = group.DATA_DESC_ID # always present - fld = group.FIELD_ID # always present + fld = group.FIELD_ID # always present if fld not in subset.field or ddid not in subset.spw: log.debug(f"field {fld} ddid {ddid} not in selection, skipping") continue scan = getattr(group, 'SCAN_NUMBER', None) # will be present if iterating over scans + ant1 = getattr(group, 'ANTENNA1', None) # will be present if iterating over baselines + ant2 = getattr(group, 'ANTENNA2', None) # will be present if iterating over baselines + if ant1 is not None and ant2 is not None: + baseline = msinfo.baseline_number(ant1, ant2) + else: + baseline = None # always read flags -- easier that way flag = group.FLAG if not noflags else None @@ -166,7 +178,7 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, dataframe_key = (fld if iter_field else None, ddid if iter_spw else None, scan if iter_scan else None, - antenna) + antenna if antenna is not None else baseline) # do we already have a frame for this key ddf0 = output_dataframes.get(dataframe_key) @@ -200,10 +212,11 @@ def compute_bounds(unknowns, bounds, ddf): Given a list of axis with unknown bounds, computes missing bounds and updates the bounds dict """ # setup function to compute min/max on every column for which we don't have a min/max - r = ddf.map_partitions(lambda df: - np.array([[(np.nanmin(df[axis].values).item() if bounds[axis][0] is None else bounds[axis][0]) for axis in unknowns]+ - [(np.nanmax(df[axis].values).item() if bounds[axis][1] is None else bounds[axis][1]) for axis in unknowns]]), - ).compute() + with np.errstate(all='ignore'): + r = ddf.map_partitions(lambda df: + np.array([[(np.nanmin(df[axis].values).item() if bounds[axis][0] is None else bounds[axis][0]) for axis in unknowns]+ + [(np.nanmax(df[axis].values).item() if bounds[axis][1] is None else bounds[axis][1]) for axis in unknowns]]), + ).compute() # setup new bounds dict based on this for i, axis in enumerate(unknowns): diff --git a/shade_ms/main.py b/shade_ms/main.py index e666c41..894afad 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -139,8 +139,6 @@ def main(argv): group_opts.add_argument('--iter-field', action="store_true", help='Separate plots per field (default is to combine in one plot)') - group_opts.add_argument('--iter-antenna', action="store_true", - help='Separate plots per antenna (default is to combine in one plot)') group_opts.add_argument('--iter-spw', action="store_true", help='Separate plots per spw (default is to combine in one plot)') group_opts.add_argument('--iter-scan', action="store_true", @@ -149,6 +147,8 @@ def main(argv): help='Separate plots per correlation or Stokes (default is to combine in one plot)') group_opts.add_argument('--iter-ant', action="store_true", help='Separate plots per antenna (default is to combine in one plot)') + group_opts.add_argument('--iter-baseline', action="store_true", + help='Separate plots per baseline (default is to combine in one plot)') group_opts = parser.add_argument_group('Data subset selection') @@ -209,10 +209,10 @@ def main(argv): help='Send all plots to this output directory') group_opts.add_argument('-s', '--suffix', help="suffix to be included in filenames, can include {options}") group_opts.add_argument('--png', dest='pngname', - default="plot-{ms}{_field}{_Spw}{_Scan}{_Ant}-{label}{_alphalabel}{_colorlabel}{_suffix}.png", + default="plot-{ms}{_field}{_Spw}{_Scan}{_Ant}{_Baseline}-{label}{_alphalabel}{_colorlabel}{_suffix}.png", help='Template for output png files, default "%(default)s"') group_opts.add_argument('--title', - default="{ms}{_field}{_Spw}{_Scan}{_Ant}{_title}{_Alphatitle}{_Colortitle}", + default="{ms}{_field}{_Spw}{_Scan}{_Ant}{_Baseline}{_title}{_Alphatitle}{_Colortitle}", help='Template for plot titles, default "%(default)s"') group_opts.add_argument('--xlabel', default="{xname}{_xunit}", @@ -252,6 +252,8 @@ def main(argv): dmap = getattr(colorcet, options.dmap, None) if dmap is None: parser.error(f"unknown --dmap {options.dmap}") + if options.iter_ant and options.iter_baseline: + parser.error("cannot combine --iter-ant and --iter-baseline") options.ms = options.ms.rstrip('/') @@ -360,14 +362,12 @@ class Subset(object): else: subset.ant = ms.antenna.get_subset(options.ant) log.info(f"Antenna name(s) : {' '.join(subset.ant.names)}") - mytaql.append("||".join([f'ANTENNA1=={ant}||ANTENNA2=={ant}' for ant in subset.ant.numbers])) + antnum_set = f"[{','.join(map(str, subset.ant.numbers))}]" + mytaql.append(f"ANTENNA1 IN {antnum_set} && ANTENNA2 IN {antnum_set}") else: subset.ant = ms.antenna log.info('Antenna(s) : all') - if options.iter_antenna: - raise NotImplementedError("iteration over antennas not currently supported") - if options.baseline == 'all': log.info('Baseline(s) : all') subset.baseline = ms.baseline @@ -379,13 +379,20 @@ class Subset(object): bls = set() a1a2 = set() for blspec in options.baseline.split(","): - match = re.fullmatch(r"(\w+)-(\w+)", blspec) + match = re.fullmatch(r"(\w+)-(\w*|[*])", blspec) ant1 = match and ms.antenna[match.group(1)] - ant2 = match and ms.antenna[match.group(2)] + ant2 = match and (ms.antenna[match.group(2)] if match.group(2) not in ['', '*'] else '*') if ant1 is None or ant2 is None: raise ValueError("invalid baseline '{blspec}'") - a1a2.add((ant1, ant2)) - bls.add(ms.baseline_number(ant1, ant2)) + if ant2 == '*': + ant2set = ms.all_antenna.numbers + else: + ant2set = [ant2] + # loop + for ant2 in ant2set: + a1, a2 = min(ant1, ant2), max(ant1, ant2) + a1a2.add((a1, a2)) + bls.add(ms.baseline_number(a1, a2)) # group_cols.append('ANTENNA1') subset.baseline = ms.all_baseline.get_subset(sorted(bls)) log.info(f"Baseline(s) : {' '.join(subset.baseline.names)}") @@ -600,6 +607,7 @@ def describe_corr(corrvalue): noflags=options.noflags, noconj=options.noconj, iter_field=options.iter_field, iter_spw=options.iter_spw, iter_scan=options.iter_scan, iter_ant=options.iter_ant, + iter_baseline=options.iter_baseline, join_corrs=join_corrs, row_chunk_size=options.row_chunk_size) @@ -677,7 +685,7 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, log.info(f' : wrote profiler info to {profile_file}') - for (fld, spw, scan, antenna), df in dataframes.items(): + for (fld, spw, scan, antenna_or_baseline), df in dataframes.items(): # update keys to be substituted into title and filename if fld is not None: keys['field_num'] = fld @@ -686,8 +694,11 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, keys['spw'] = spw if scan is not None: keys['scan'] = scan - if antenna is not None: - keys['ant'] = ms.all_antenna[antenna] + if antenna_or_baseline is not None: + if options.iter_ant: + keys['ant'] = ms.all_antenna[antenna_or_baseline] + else: + keys['baseline'] = ms.all_baseline[antenna_or_baseline] # now loop over plot types for props, xdatum, ydatum, adatum, ared, cdatum in all_plots: From c009d19fc89c980a026e117d45a91857c68d4ac7 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Tue, 9 Jun 2020 18:32:01 +0200 Subject: [PATCH 21/33] fixes for when not iterating baselines --- shade_ms/data_plots.py | 13 +++++++------ shade_ms/main.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 89bf6d7..59b02d4 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -99,10 +99,10 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, if fld not in subset.field or ddid not in subset.spw: log.debug(f"field {fld} ddid {ddid} not in selection, skipping") continue - scan = getattr(group, 'SCAN_NUMBER', None) # will be present if iterating over scans - ant1 = getattr(group, 'ANTENNA1', None) # will be present if iterating over baselines - ant2 = getattr(group, 'ANTENNA2', None) # will be present if iterating over baselines - if ant1 is not None and ant2 is not None: + scan = getattr(group, 'SCAN_NUMBER', None) # will be present if iterating over scans + if iter_baseline: + ant1 = getattr(group, 'ANTENNA1', None) # will be present if iterating over baselines + ant2 = getattr(group, 'ANTENNA2', None) # will be present if iterating over baselines baseline = msinfo.baseline_number(ant1, ant2) else: baseline = None @@ -168,8 +168,9 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, df1 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys()) # if any axis needs to be conjugated, double up all of them if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]): - args = (v for pair in ((-array if DataAxis[key].conjugate else array, shapes[key]) - for key, array in arrays.items()) for v in pair) + arr_shape = [(-arrays[axis.label] if axis.conjugate else arrays[axis.label], shapes[axis.label]) + for axis in DataAxis.all_axes.values()] + args = (v for pair in arr_shape for v in pair) df2 = dataframe_factory(("row", "chan"), *args, columns=arrays.keys()) df1 = dask_df.concat([df1, df2], axis=0) ddf = dask_df.concat([ddf, df1], axis=0) if ddf is not None else df1 diff --git a/shade_ms/main.py b/shade_ms/main.py index 894afad..5710aa3 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -627,6 +627,7 @@ def describe_corr(corrvalue): keys['scan'] = subset.scan.names if options.scan != 'all' else '' keys['ant'] = subset.ant.names if options.ant != 'all' else '' ## TODO: also handle ant-num settings keys['spw'] = subset.spw.names if options.spw != 'all' else '' + keys['baseline'] = None keys['suffix'] = suffix = options.suffix.format(**options.__dict__) if options.suffix else '' keys['_suffix'] = f".{suffix}" if suffix else '' @@ -697,7 +698,7 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, if antenna_or_baseline is not None: if options.iter_ant: keys['ant'] = ms.all_antenna[antenna_or_baseline] - else: + elif options.iter_ant: keys['baseline'] = ms.all_baseline[antenna_or_baseline] # now loop over plot types From 7b91c28067b60b0eb2a1784dbc4545f4aeca59ad Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Thu, 11 Jun 2020 16:43:30 +0200 Subject: [PATCH 22/33] fixing scan indexing. WIP. --- setup.py | 4 ++-- shade_ms/main.py | 2 +- shade_ms/ms_info.py | 26 ++++++++++++++++---------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 113b4f6..5bd20ce 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ requirements = [ "dask-ms[xarray]", "dask[complete]", -"datashader", +"datashader @ git+ssh://git@github.com/o-smirnov/datashader.git", "holoviews", "matplotlib>2.2.3; python_version >= '3.5'", "future-fstrings", @@ -17,7 +17,7 @@ extras_require = {'testing': ['pytest', 'pytest-flake8']} PACKAGE_NAME = 'shadems' -__version__ = '0.3.0' +__version__ = '0.4.0' setup(name = PACKAGE_NAME, version = __version__, diff --git a/shade_ms/main.py b/shade_ms/main.py index 5710aa3..00ea4dd 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -417,7 +417,7 @@ class Subset(object): log.info(f'SPW(s) : all') if options.scan != 'all': - subset.scan = ms.scan.get_subset(options.scan) + subset.scan = ms.scan.get_subset(options.scan, allow_numeric_indices=False) log.info(f"Scan(s) : {' '.join(subset.scan.names)}") mytaql.append("||".join([f'SCAN_NUMBER=={n}' for n in subset.scan.numbers])) else: diff --git a/shade_ms/ms_info.py b/shade_ms/ms_info.py index 4114bac..66e2734 100644 --- a/shade_ms/ms_info.py +++ b/shade_ms/ms_info.py @@ -14,6 +14,7 @@ def __init__(self, label, names, numbers=None): self.names = names self.numbers = numbers or range(len(self.names)) self.map = dict(zip(names, self.numbers)) + self.numindex = {num: i for i, num in enumerate(self.numbers)} def __len__(self): return len(self.names) @@ -22,27 +23,32 @@ def __contains__(self, name): return name in self.map if type(name) is str else name in self.numbers def __getitem__(self, item, default=None): - return self.map.get(item, default) if type(item) is str else self.names[item] + if type(item) is str: + return self.map.get(item, default) + elif type(item) is slice: + return self.names[item] + else: + return self.names[self.numindex[item]] - def get_subset(self, subset): + def get_subset(self, subset, allow_numeric_indices=True): """Extracts subset using a comma-separated string or list of indices""" if type(subset) in (list, tuple): - return NamedList(self.label, [self.names[x] for x in subset], subset) + return NamedList(self.label, [self.names[x] for x in subset], [self.numbers[x] for x in subset]) elif type(subset) is str: if subset == "all": return self - numbers = [] + ind = [] for x in subset.split(","): - if re.fullmatch('\d+', x): + if allow_numeric_indices and re.fullmatch('\d+', x): x = int(x) - if x < 0 or x >= len(self): + if x not in self.numindex: raise ValueError(f"invalid {self.label} number {x}") - numbers.append(x) + ind.append(self.numindex[x]) elif x in self.map: - numbers.append(self.map[x]) + ind.append(self.numindex[self.map[x]]) else: raise ValueError(f"invalid {self.label} '{x}'") - return NamedList(self.label, [self.names[x] for x in numbers], numbers) + return NamedList(self.label, [self.names[x] for x in ind], [self.numbers[x] for x in ind]) else: raise TypeError(f"unknown subset of type {type(subset)}") @@ -75,7 +81,7 @@ def __init__(self, msname=None, log=None): log and log.info(f": {len(self.field)} fields: {' '.join(self.field.names)}") scan_numbers = sorted(set(tab.getcol("SCAN_NUMBER"))) - log and log.info(f": {len(scan_numbers)} scans, first #{scan_numbers[0]}, last #{scan_numbers[-1]}") + log and log.info(f": {len(scan_numbers)} scans: {' '.join(map(str, scan_numbers))}") all_scans = NamedList("scan", list(map(str, range(scan_numbers[-1]+1)))) self.scan = all_scans.get_subset(scan_numbers) From 6c5c66cefe8d83c74601e9cff34e47f54170503d Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Thu, 11 Jun 2020 21:13:51 +0200 Subject: [PATCH 23/33] fixing indexing of subsets --- shade_ms/ms_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shade_ms/ms_info.py b/shade_ms/ms_info.py index 66e2734..9e073e1 100644 --- a/shade_ms/ms_info.py +++ b/shade_ms/ms_info.py @@ -28,7 +28,7 @@ def __getitem__(self, item, default=None): elif type(item) is slice: return self.names[item] else: - return self.names[self.numindex[item]] + return self.names[item] # self.numindex[item]] def get_subset(self, subset, allow_numeric_indices=True): """Extracts subset using a comma-separated string or list of indices""" From 5c4f5a7a2c66866b28733b9c27267b41f34d6f73 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Thu, 11 Jun 2020 22:13:38 +0200 Subject: [PATCH 24/33] fixes #61 --- shade_ms/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/shade_ms/main.py b/shade_ms/main.py index 00ea4dd..5c290e3 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -541,7 +541,7 @@ def describe_corr(corrvalue): if xcolumn and (xcolumn != ycolumn or not xdatum.function) and not xdatum.mapper.column: titles.append(xcolumn) labels.append(col_to_label(xcolumn)) - if plot_xcorr != plot_ycorr: + if plot_xcorr is not plot_ycorr: titles += describe_corr(plot_xcorr) labels += describe_corr(plot_xcorr) if xdatum.mapper.fullname: @@ -556,7 +556,8 @@ def describe_corr(corrvalue): if acolumn and (acolumn != xcolumn or acolumn != ycolumn) and adatum.mapper.column is None: titles.append(acolumn) labels.append(col_to_label(acolumn)) - if plot_acorr and (plot_acorr != plot_xcorr or plot_acorr != plot_ycorr): + if plot_acorr is not None and plot_acorr is not False and \ + (plot_acorr is not plot_xcorr or plot_acorr is not plot_ycorr): titles += describe_corr(plot_acorr) labels += describe_corr(plot_acorr) titles += [adatum.mapper.fullname] @@ -572,7 +573,8 @@ def describe_corr(corrvalue): if ccolumn and (ccolumn != xcolumn or ccolumn != ycolumn) and cdatum.mapper.column is None: titles.append(ccolumn) labels.append(col_to_label(ccolumn)) - if plot_ccorr and (plot_ccorr != plot_xcorr or plot_ccorr != plot_ycorr): + if plot_ccorr is not None and plot_ccorr is not False and \ + (plot_ccorr is not plot_xcorr or plot_ccorr is not plot_ycorr): titles += describe_corr(plot_ccorr) labels += describe_corr(plot_ccorr) if cdatum.mapper.fullname: From c64084af7e9931a1cbc703d594df851054dfac29 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Fri, 12 Jun 2020 21:17:18 +0200 Subject: [PATCH 25/33] fixing usage of discrete colormaps --- setup.py | 1 + shade_ms/data_mappers.py | 5 ++++- shade_ms/data_plots.py | 33 ++++++++++++++++++++++++++++----- shade_ms/main.py | 18 +++++++----------- 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 5bd20ce..f1c6508 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ "datashader @ git+ssh://git@github.com/o-smirnov/datashader.git", "holoviews", "matplotlib>2.2.3; python_version >= '3.5'", +"cmasher", "future-fstrings", "requests", "MSUtils" diff --git a/shade_ms/data_mappers.py b/shade_ms/data_mappers.py index dc36635..07e5118 100644 --- a/shade_ms/data_mappers.py +++ b/shade_ms/data_mappers.py @@ -330,7 +330,10 @@ def get_value(self, group, corr, extras, flag, flag_row, chanslice): coldata = mapper.mapper(coldata, **{name:extras[name] for name in self.mapper.extras }) # for a constant axis, compute minmax on the fly if mapper.const and self._minmax_autorange: - min1, max1 = coldata.data.min(), coldata.data.max() + if np.isscalar(coldata): + min1 = max1 = coldata + else: + min1, max1 = coldata.data.min(), coldata.data.max() self.minmax = min(self.minmax[0], min1) if self.minmax[0] is not None else min1, \ min(self.minmax[1], max1) if self.minmax[1] is not None else max1 # scalar is just a scalar diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 59b02d4..297419c 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -21,6 +21,9 @@ import itertools import matplotlib.cm from shade_ms import log +import colorcet +import cmasher +import matplotlib.cm from collections import OrderedDict from . import data_mappers @@ -34,12 +37,29 @@ def add_options(parser): # parser.add_argument('--reduce-by', action="store_true", help=argparse.SUPPRESS) pass + def set_options(options): # global USE_REDUCE_BY # USE_REDUCE_BY = options.reduce_by pass +def get_colormap(cmap_name): + cmap = getattr(colorcet, cmap_name, None) + if cmap: + log.info(f"using colourmap colorcet.{cmap_name}") + return cmap + cmap = getattr(cmasher, cmap_name, None) + if cmap: + log.info(f"using colourmap cmasher.{cmap_name}") + else: + cmap = getattr(matplotlib.cm, cmap_name, None) + if cmap is None: + raise ValueError(f"unknown colourmap {cmap_name}") + log.info(f"using colourmap matplotplib.cm.{cmap_name}") + return [ f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}" for r,g,b in cmap.colors ] + + def freq_to_wavel(ff): c = 299792458.0 # m/s return c/ff @@ -316,7 +336,7 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor else: color_bins = list(range(cdatum.nlevels)) if cdatum.is_discrete: - if cdatum.subset_indices is not None: + if cdatum.subset_indices is not None and options.dmap_preserve: num_categories = len(cdatum.subset_indices) color_bins = cdatum.subset_indices[:cdatum.nlevels] else: @@ -353,10 +373,13 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor non_empty = np.where(non_empty)[0] raster = raster[..., non_empty] # get bin numbers corresponding to non-empty bins - color_bins = [color_bins[i] for i in non_empty] + if options.dmap_preserve: + color_bins = [color_bins[bin] for bin in non_empty] + else: + color_bins = [color_bins[i] for i, _ in enumerate(non_empty)] # get list of color labels corresponding to each bin (may be multiple) color_labels = [color_labels[i::cdatum.nlevels] for i in non_empty] - color_key = [dmap[bin] for bin in color_bins] + color_key = [dmap[bin%len(dmap)] for bin in color_bins] # the numbers may be out of order -- reorder for color bar purposes bin_color_label = sorted(zip(color_bins, color_key, color_labels)) color_mapping = [col for _, col, _ in bin_color_label] @@ -373,8 +396,8 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor else: # color labels are bin centres bin_centers = [cmin + cdelta*(i+0.5) for i in color_bins] - # map to colors pulled from 256 color map - color_key = [bmap[(i*256)//cdatum.nlevels] for i in color_bins] + # map to colors pulled from color map + color_key = [bmap[(i*len(bmap))//cdatum.nlevels] for i in color_bins] color_labels = list(map(str, bin_centers)) log.info(f": shading using {len(color_bins)} colors (bin centres are {' '.join(color_labels)})") diff --git a/shade_ms/main.py b/shade_ms/main.py index 5c290e3..55f5e20 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -12,7 +12,6 @@ import itertools import re import sys -import colorcet import dask.diagnostics from contextlib import contextmanager import json @@ -181,10 +180,12 @@ def main(argv): group_opts.add_argument('--cmap', default='bkr', help="""Colorcet map used without --colour-by (default = %(default)s), see https://colorcet.holoviz.org""") - group_opts.add_argument('--bmap', default='bkr', + group_opts.add_argument('--bmap', default='pride', help='Colorcet map used when colouring by a continuous axis (default = %(default)s)') group_opts.add_argument('--dmap', default='glasbey_dark', help='Colorcet map used when colouring by a discrete axis (default = %(default)s)') + group_opts.add_argument('--dmap-preserve', action='store_true', + help='Preserve colour assignments in discrete axes even when discrete values are missing.') group_opts.add_argument('--min-alpha', default=40, type=int, metavar="0-255", help="""Minimum alpha value used in rendering the canvas. Increase to saturate colour at the expense of dynamic range. Default is %(default)s.""") @@ -243,15 +244,10 @@ def main(argv): options = parser.parse_args(argv) - cmap = getattr(colorcet, options.cmap, None) - if cmap is None: - parser.error(f"unknown --cmap {options.cmap}") - bmap = getattr(colorcet, options.bmap, None) - if bmap is None: - parser.error(f"unknown --bmap {options.bmap}") - dmap = getattr(colorcet, options.dmap, None) - if dmap is None: - parser.error(f"unknown --dmap {options.dmap}") + cmap = data_plots.get_colormap(options.cmap) + bmap = data_plots.get_colormap(options.bmap) + dmap = data_plots.get_colormap(options.dmap) + if options.iter_ant and options.iter_baseline: parser.error("cannot combine --iter-ant and --iter-baseline") From 7fffc801a299d1c552f18e254f961ff950b0ba44 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Fri, 12 Jun 2020 21:28:35 +0200 Subject: [PATCH 26/33] fixed setup --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f1c6508..c985db3 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ requirements = [ "dask-ms[xarray]", "dask[complete]", -"datashader @ git+ssh://git@github.com/o-smirnov/datashader.git", +"datashader @ git+https://github.com/o-smirnov/datashader.git", "holoviews", "matplotlib>2.2.3; python_version >= '3.5'", "cmasher", From b02d7641fd05f31ac4fa82a85d4a85547858b963 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Sun, 14 Jun 2020 13:19:55 +0200 Subject: [PATCH 27/33] added colorbar to plots that have an A axis reduction --- shade_ms/data_plots.py | 41 +++++++++++------------------------------ shade_ms/main.py | 6 +++--- 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 297419c..5854c3b 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -31,16 +31,11 @@ from .dask_utils import dataframe_factory # from .ds_ext import by_integers, by_span -USE_REDUCE_BY = True - def add_options(parser): - # parser.add_argument('--reduce-by', action="store_true", help=argparse.SUPPRESS) pass def set_options(options): - # global USE_REDUCE_BY - # USE_REDUCE_BY = options.reduce_by pass @@ -266,7 +261,7 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor aaxis = adatum and adatum.label caxis = cdatum and cdatum.label - color_key = color_mapping = color_labels = agg_alpha = raster_alpha = cmin = cdelta = None + color_key = color_mapping = color_labels = color_minmax = agg_alpha = cmin = cdelta = None # do we need to compute any axis min/max? bounds = OrderedDict({xaxis: xdatum.minmax, yaxis: ydatum.minmax}) @@ -299,26 +294,14 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor canvas = datashader.Canvas(canvas_sizes[0], canvas_sizes[1], x_range=bounds[xaxis], y_range=bounds[yaxis]) if aaxis is not None: - agg_alpha = getattr(datashader.reductions, ared, None) - if agg_alpha is None: - raise ValueError(f"unknown alpha reduction function {ared}") - agg_alpha = agg_alpha(aaxis) - ared = ared or 'count' - - if aaxis is not None: - agg_alpha = getattr(datashader.reductions, ared, None) + agg_alpha = getattr(datashader.reductions, ared, None) if ared else datashader.reductions.count if agg_alpha is None: raise ValueError(f"unknown alpha reduction function {ared}") agg_alpha = agg_alpha(aaxis) - ared = ared or 'count' if cdatum is not None: - if agg_alpha is not None and not USE_REDUCE_BY: - log.debug(f'rasterizing alpha channel using {ared}(aaxis)') - raster_alpha = canvas.points(ddf, xaxis, yaxis, agg=agg_alpha) - # aggregation applied to by() - agg_by = agg_alpha if USE_REDUCE_BY and agg_alpha is not None else datashader.count() + agg_by = agg_alpha if agg_alpha else datashader.count() # color_bins will be a list of colors to use. If the subset is known, then we preferentially # pick colours by subset, i.e. we try to preserve the mapping from index to specific color. @@ -400,21 +383,19 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor color_key = [bmap[(i*len(bmap))//cdatum.nlevels] for i in color_bins] color_labels = list(map(str, bin_centers)) log.info(f": shading using {len(color_bins)} colors (bin centres are {' '.join(color_labels)})") - - if raster_alpha is not None: - amin = adatum.minmax[0] if adatum.minmax[0] is not None else np.nanmin(raster_alpha) - amax = adatum.minnax[1] if adatum.minmax[1] is not None else np.nanmax(raster_alpha) - raster = raster*(raster_alpha-amin)/(amax-amin) - raster[raster<0] = 0 - raster[raster>1] = 1 - log.info(f": adjusting alpha (alpha raster was {amin} to {amax})") img = datashader.transfer_functions.shade(raster, color_key=color_key, how=normalize, min_alpha=min_alpha) + # set color_minmax for colorbar + color_minmax = bounds[caxis] else: log.debug(f'rasterizing using {ared}') raster = canvas.points(ddf, xaxis, yaxis, agg=agg_alpha) if not raster.data.any(): log.info(": no valid data in plot. Check your flags and/or plot limits.") return None + # get min/max cor colorbar + if aaxis: + color_minmax = np.nanmin(raster), np.nanmax(raster) + color_key = cmap log.debug('shading') img = datashader.transfer_functions.shade(raster, cmap=cmap, how=normalize, min_alpha=min_alpha) @@ -514,7 +495,7 @@ def decimate_list(x, maxel): ax.set_yticks(range(n), minor=True) # colorbar? - if color_key: + if color_minmax: import matplotlib.colors # discrete axis if caxis is not None and cdatum.is_discrete: @@ -523,7 +504,7 @@ def decimate_list(x, maxel): colormap = matplotlib.colors.ListedColormap(color_mapping) # discretized axis else: - norm = matplotlib.colors.Normalize(*bounds[caxis]) + norm = matplotlib.colors.Normalize(*color_minmax) colormap = matplotlib.colors.ListedColormap(color_key) # auto-mark colorbar, since it represents a continuous range of values ticks = None diff --git a/shade_ms/main.py b/shade_ms/main.py index 55f5e20..4443e7d 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -78,7 +78,7 @@ def main(argv): help="""Intensity axis. Can be none, or given once, or given the same number of times as --xaxis. If none, plot intensity (a.k.a. alpha channel) is proportional to density of points. Otherwise, a reduction function (see --ared below) is applied to the given values, and the result is used - to determine intensity. + to determine intensity. All columns and variations listed under --xaxis are available for --aaxis. """) group_opts.add_argument('--ared', action="append", @@ -86,8 +86,8 @@ def main(argv): mean, std, first, last, mode. Default is mean.""") group_opts.add_argument('-c', '--colour-by', action="append", - help="""Colour axis. Can be none, or given once, or given the same number of times as --xaxis. - All columns and variations listed under --xaxis are available for colouring by.""") + help="""Colour (a.k.a. category) axis. Can be none, or given once, or given the same number of + times as --xaxis. All columns and variations listed under --xaxis are available for colouring by.""") group_opts.add_argument('-C', '--col', metavar="COLUMN", dest='col', action="append", default=[], help="""Name of visibility column (default is DATA), if needed. This is used if From bdf6e5c5376710b2605af339047ed1085f5da7cd Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Sun, 14 Jun 2020 13:53:05 +0200 Subject: [PATCH 28/33] added amin/amax options --- shade_ms/data_plots.py | 6 ++++-- shade_ms/main.py | 34 ++++++++++++++++++++++------------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 5854c3b..61948d4 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -394,10 +394,12 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor return None # get min/max cor colorbar if aaxis: - color_minmax = np.nanmin(raster), np.nanmax(raster) + amin, amax = adatum.minmax + color_minmax = (amin if amin is not None else np.nanmin(raster)), \ + (amax if amax is not None else np.nanmax(raster)) color_key = cmap log.debug('shading') - img = datashader.transfer_functions.shade(raster, cmap=cmap, how=normalize, min_alpha=min_alpha) + img = datashader.transfer_functions.shade(raster, cmap=cmap, how=normalize, span=color_minmax, min_alpha=min_alpha) # resaturate if needed if saturate_alpha is not None or saturate_percentile is not None: diff --git a/shade_ms/main.py b/shade_ms/main.py index 4443e7d..e3a68ac 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -106,17 +106,21 @@ def main(argv): help="""Minimum x-axis value (default = data min). For multiple plots, you can give this multiple times, or use a comma-separated list, but note that the clipping is the same per axis across all plots, so only the last applicable setting will be used. The list may include empty - elements (or 'None') to not apply a clip.""") + elements (or 'None') to not apply a clip. Default computes clips from data min/max.""") group_opts.add_argument('--xmax', action='append', - help='Maximum x-axis value (default = data max).') + help='Maximum x-axis value.') group_opts.add_argument('--ymin', action='append', - help='Minimum y-axis value (default = data min).') + help='Minimum y-axis value.') group_opts.add_argument('--ymax', action='append', - help='Maximum y-axis value (default = data max).') + help='Maximum y-axis value.') + group_opts.add_argument('--amin', action='append', + help='Minimum intensity-axis value.') + group_opts.add_argument('--amax', action='append', + help='Maximum intensity-axis value.') group_opts.add_argument('--cmin', action='append', - help='Minimum colouring value. Must be supplied for every non-discrete axis to be coloured by.') + help='Minimum value to be coloured by.') group_opts.add_argument('--cmax', action='append', - help='Maximum colouring value. Must be supplied for every non-discrete axis to be coloured by.') + help='Maximum value to be coloured by.') group_opts.add_argument('--cnum', action='append', help=f'Number of steps used to discretize a continuous axis. Default is {DEFAULT_CNUM}.') @@ -176,7 +180,8 @@ def main(argv): group_opts.add_argument('-Y', '--ycanvas', type=int, help='Canvas y-size in pixels (default = %(default)s)', default=900) group_opts.add_argument('--norm', choices=['auto', 'eq_hist', 'cbrt', 'log', 'linear'], default='auto', - help="Pixel scale normalization (default is 'log' when colouring, and 'eq_hist' when not)") + help="Pixel scale normalization (default is 'log' with caxis, 'linear' with aaxis, and " + "'eq_hist' when neither is in use.)") group_opts.add_argument('--cmap', default='bkr', help="""Colorcet map used without --colour-by (default = %(default)s), see https://colorcet.holoviz.org""") @@ -302,6 +307,8 @@ def get_conformal_list(name, force_type=None, default=None): ymins = get_conformal_list('ymin', float) ymaxs = get_conformal_list('ymax', float) aaxes = get_conformal_list('aaxis') + amins = get_conformal_list('amin', float) + amaxs = get_conformal_list('amax', float) areds = get_conformal_list('ared', str, 'mean') caxes = get_conformal_list('colour_by') cmins = get_conformal_list('cmin', float) @@ -313,8 +320,10 @@ def get_conformal_list(name, force_type=None, default=None): parser.error("--xmin/--xmax must be either both set, or neither") if any([(a is None)^(b is None) for a, b in zip(ymins, ymaxs)]): parser.error("--xmin/--xmax must be either both set, or neither") - if any([(a is None)^(b is None) for a, b in zip(ymins, ymaxs)]): + if any([(a is None)^(b is None) for a, b in zip(cmins, cmaxs)]): parser.error("--cmin/--cmax must be either both set, or neither") + if any([(a is None)^(b is None) for a, b in zip(amins, amaxs)]): + parser.error("--amin/--amax must be either both set, or neither") # check chan slice def parse_slice_spec(spec, name): @@ -467,8 +476,8 @@ class Subset(object): have_corr_dependence = False # now go create definitions - for xaxis, yaxis, default_column, caxis, aaxis, ared, xmin, xmax, ymin, ymax, cmin, cmax, cnum in \ - zip(xaxes, yaxes, columns, caxes, aaxes, areds, xmins, xmaxs, ymins, ymaxs, cmins, cmaxs, cnums): + for xaxis, yaxis, default_column, caxis, aaxis, ared, xmin, xmax, ymin, ymax, amin, amax, cmin, cmax, cnum in \ + zip(xaxes, yaxes, columns, caxes, aaxes, areds, xmins, xmaxs, ymins, ymaxs, amins, amaxs, cmins, cmaxs, cnums): # get axis specs xspecs = DataAxis.parse_datum_spec(xaxis, default_column, ms=ms) yspecs = DataAxis.parse_datum_spec(yaxis, default_column, ms=ms) @@ -512,7 +521,8 @@ def describe_corr(corrvalue): minmax_cache=minmax_cache if options.xlim_load else None) ydatum = DataAxis.register(yfunction, ycolumn, plot_ycorr, ms=ms, minmax=(ymin, ymax), subset=subset, minmax_cache=minmax_cache if options.ylim_load else None) - adatum = afunction and DataAxis.register(afunction, acolumn, plot_acorr, ms=ms, subset=subset) + adatum = afunction and DataAxis.register(afunction, acolumn, plot_acorr, ms=ms, + minmax=(amin, amax), subset=subset) cdatum = cfunction and DataAxis.register(cfunction, ccolumn, plot_ccorr, ms=ms, minmax=(cmin, cmax), ncol=cnum, subset=subset, minmax_cache=minmax_cache if options.clim_load else None) @@ -662,7 +672,7 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, log.info(f": rendering {pngname}") normalize = options.norm if normalize == "auto": - normalize = "log" if cdatum is not None else "eq_hist" + normalize = "log" if cdatum is not None else ("eq_hist" if adatum is None else 'linear') if options.profile: context = dask.diagnostics.ResourceProfiler else: From c806f8a8689b71d329e361cd8c1a63a7398dced3 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Sun, 14 Jun 2020 14:52:30 +0200 Subject: [PATCH 29/33] added check for null alpha plane --- shade_ms/data_plots.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 61948d4..a1610ec 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -408,18 +408,21 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor alpha = (imgval >> 24)&255 nulls = alpha255] = 255 - imgval[:] = (imgval & 0xFFFFFF) | alpha.astype(np.uint32)<<24 + if nulls.all(): + log.debug(f"alpha255] = 255 + imgval[:] = (imgval & 0xFFFFFF) | alpha.astype(np.uint32)<<24 if options.spread_pix: img = datashader.transfer_functions.dynspread(img, options.spread_thr, max_px=options.spread_pix) From 13d38cfa2e0da5778ddc3682b350dea8a42afd47 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Mon, 15 Jun 2020 15:01:41 +0200 Subject: [PATCH 30/33] fixes #62 --- shade_ms/data_plots.py | 157 +++++++++++++++++++++++------------------ shade_ms/main.py | 11 +-- shade_ms/ms_info.py | 10 ++- 3 files changed, 102 insertions(+), 76 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index a1610ec..d00ca5a 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -67,6 +67,7 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, row_chunk_size=100000): ms_cols = {'ANTENNA1', 'ANTENNA2'} + ms_cols.update(msinfo.indexing_columns.keys()) if not noflags: ms_cols.update({'FLAG', 'FLAG_ROW'}) # get visibility columns @@ -75,10 +76,13 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, total_num_points = 0 # total number of points to plot - # output dataframes, indexed by (field, spw, scan, antenna, correlation) - # If any of these axes is not being iterated over, then the index is None + # output dataframes, indexed by (field, spw, scan, antenna_or_baseline) + # If any of these axes is not being iterated over, then the index at that position is None output_dataframes = OrderedDict() + # output subsets of indexing columns, indexed by same tuple + output_subsets = OrderedDict() + if iter_ant: antenna_subsets = zip(subset.ant.numbers, subset.ant.names) else: @@ -146,6 +150,18 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, ddf = None num_points = 0 # counts number of new points generated + # Make frame key -- data subset corresponds to this frame + dataframe_key = (fld if iter_field else None, + ddid if iter_spw else None, + scan if iter_scan else None, + antenna if antenna is not None else baseline) + + # update subsets of MS indexing columns that we've seen for this dataframe + output_subset1 = output_subsets.setdefault(dataframe_key, + {column:set() for column in msinfo.indexing_columns.keys()}) + for column, _ in msinfo.indexing_columns.items(): + output_subset1[column].update(getattr(group, column).compute().data) + for corr in subset.corr.numbers: # make dictionary of extra values for DataMappers extras['corr'] = corr @@ -190,12 +206,6 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, df1 = dask_df.concat([df1, df2], axis=0) ddf = dask_df.concat([ddf, df1], axis=0) if ddf is not None else df1 - # now, are we iterating or concatenating? Make frame key accordingly - dataframe_key = (fld if iter_field else None, - ddid if iter_spw else None, - scan if iter_scan else None, - antenna if antenna is not None else baseline) - # do we already have a frame for this key ddf0 = output_dataframes.get(dataframe_key) @@ -221,7 +231,7 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, # print(axis.label, np.nanmin(value), np.nanmax(value)) log.info(": complete") - return output_dataframes, total_num_points + return output_dataframes, output_subsets, total_num_points def compute_bounds(unknowns, bounds, ddf): """ @@ -246,7 +256,7 @@ def compute_bounds(unknowns, bounds, ddf): -def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, normalize, +def create_plot(ddf, index_subsets, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, normalize, xlabel, ylabel, title, pngname, min_alpha=40, saturate_percentile=None, saturate_alpha=None, minmax_cache=None, @@ -261,7 +271,7 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor aaxis = adatum and adatum.label caxis = cdatum and cdatum.label - color_key = color_mapping = color_labels = color_minmax = agg_alpha = cmin = cdelta = None + color_key = color_labels = color_minmax = agg_alpha = None # do we need to compute any axis min/max? bounds = OrderedDict({xaxis: xdatum.minmax, yaxis: ydatum.minmax}) @@ -303,36 +313,71 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor # aggregation applied to by() agg_by = agg_alpha if agg_alpha else datashader.count() - # color_bins will be a list of colors to use. If the subset is known, then we preferentially - # pick colours by subset, i.e. we try to preserve the mapping from index to specific color. - # color_labels will be set from discretized_labels, or from range of column values, if axis is discrete - color_labels = cdatum.discretized_labels + # figure out mapping from raster planes to colours + # after this if-else block, category will be an aggregator instance yielding N categories, + # color_key will be a list of N colors, and color_label will be a list of N textual labels + if data_mappers.USE_COUNT_CAT: - if cdatum.subset_indices is not None: - color_bins = cdatum.subset_indices - else: - color_bins = [int(x) for x in getattr(ddf.dtypes, caxis).categories] - log.debug(f'colourizing using {caxis} categorical, {len(color_bins)} bins') + cats = getattr(ddf.dtypes, caxis).categories + log.debug(f'colourizing using {caxis} categorical, {len(cats)} bins') category = caxis - if color_labels is None: - color_labels = list(map(str,color_bins)) + color_key = dmap[:len(cats)] + color_labels = list(map(str, cats)) else: - color_bins = list(range(cdatum.nlevels)) if cdatum.is_discrete: - if cdatum.subset_indices is not None and options.dmap_preserve: - num_categories = len(cdatum.subset_indices) - color_bins = cdatum.subset_indices[:cdatum.nlevels] + # make dictionary from index to label, omitting values that are not in the MS subset to begin with + if cdatum.discretized_labels: + active_subset = OrderedDict(enumerate(cdatum.discretized_labels)) + # else make up integer labels on the spot + else: + active_subset = OrderedDict(enumerate(map(str, range(bounds[caxis][1]+1)))) + # Check if the subset needs to be refined, because it is known to be smaller for this dataframe + if cdatum.columns[0] in index_subsets and len(cdatum.columns) == 1: + df_index_subset = index_subsets[cdatum.columns[0]] + if cdatum.subset_remapper is not None: + remapper = cdatum.subset_remapper.compute() + df_index_subset = set(remapper[x] for x in df_index_subset) + active_subset = OrderedDict((idx, active_subset[idx]) for idx in df_index_subset) + log.debug(f"subset of indices for this axis is a priori {list(active_subset.keys())}") + # max known index + max_index = max(active_subset.keys()) + num_colors = min(cdatum.nlevels, len(dmap)) + color_key = dmap[:num_colors] + # if we have fewer indices than colour levels, and the max index is sensible, we'll aggregate to one + # raster slice per index value directly + if len(active_subset) <= num_colors and max_index < max(num_colors, 256): + num_colors = max_index+1 + log.debug(f"aggregating directly into {max_index+1} categories") + category = category_modulo(caxis, max_index+1) + color_label_list = {idx: [value] for idx, value in active_subset.items()} else: - num_categories = int(bounds[caxis][1]) + 1 - log.debug(f'colourizing using {caxis} modulo {len(color_bins)}') - category = category_modulo(caxis, len(color_bins)) - if color_labels is None: - color_labels = list(map(str, range(num_categories))) + log.debug(f"aggregating modulo {num_colors} categories") + category = category_modulo(caxis, num_colors) + # each slice maps to, potentially, multiple labels from the subset + color_label_list = {i: [active_subset[idx] for idx in range(i, max_index+1, num_colors) if idx in active_subset] + for i in range(num_colors)} + # and colors just come from the bottom of the colormap + color_dict = dict(enumerate(options.dmap[:num_colors])) + # convert lists of color labels into strings + color_labels = ['']*num_colors + for i, labels in color_label_list.items(): + if len(labels) < 3: + color_labels[i] = ",".join(labels) + else: + color_labels[i] = ",".join(labels[:2] + ["..."]) + # else we discretize a span of values else: - log.debug(f'colourizing using {caxis} with {len(color_bins)} bins') + num_colors = min(cdatum.nlevels, len(bmap)) + log.debug(f'colourizing using {caxis} with {num_colors} bins') cmin = bounds[caxis][0] - cdelta = (bounds[caxis][1] - cmin) / cdatum.nlevels - category = category_binning(caxis, cmin, cdelta, cdatum.nlevels) + cdelta = (bounds[caxis][1] - cmin) / num_colors + category = category_binning(caxis, cmin, cdelta, num_colors) + # color labels are bin centres + bin_centers = [cmin + cdelta*(i+0.5) for i in range(num_colors)] + # map to colors pulled from entire extent of color map + color_key = [bmap[(i*len(bmap))//num_colors] for i in range(num_colors)] + color_labels = [str(bin) for bin in bin_centers] + log.info(f": aggregating using {num_colors} bins at {' '.join(color_labels)})") raster = canvas.points(ddf, xaxis, yaxis, agg=datashader.by(category, agg_by)) is_integer_raster = np.issubdtype(raster.dtype, np.integer) @@ -351,38 +396,17 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor log.info(": no valid data in plot. Check your flags and/or plot limits.") return None - if cdatum.is_discrete: - # discard empty bins + if cdatum.is_discrete and not data_mappers.USE_COUNT_CAT: + # discard empty planes non_empty = np.where(non_empty)[0] raster = raster[..., non_empty] - # get bin numbers corresponding to non-empty bins + # compress colours to bottom of colormap, unless asked to preserve assignments if options.dmap_preserve: - color_bins = [color_bins[bin] for bin in non_empty] + color_key = [color_key[bin] for bin in non_empty] else: - color_bins = [color_bins[i] for i, _ in enumerate(non_empty)] - # get list of color labels corresponding to each bin (may be multiple) - color_labels = [color_labels[i::cdatum.nlevels] for i in non_empty] - color_key = [dmap[bin%len(dmap)] for bin in color_bins] - # the numbers may be out of order -- reorder for color bar purposes - bin_color_label = sorted(zip(color_bins, color_key, color_labels)) - color_mapping = [col for _, col, _ in bin_color_label] - # generate labels - color_labels = [] - for _, _, labels in bin_color_label: - if len(labels) == 1: - color_labels.append(labels[0]) - elif len(labels) == 2: - color_labels.append(f"{labels[0]},{labels[1]}") - else: - color_labels.append(f"{labels[0]},{labels[1]},...") - log.info(f": rendering using {len(color_bins)} colors (values {' '.join(color_labels)})") - else: - # color labels are bin centres - bin_centers = [cmin + cdelta*(i+0.5) for i in color_bins] - # map to colors pulled from color map - color_key = [bmap[(i*len(bmap))//cdatum.nlevels] for i in color_bins] - color_labels = list(map(str, bin_centers)) - log.info(f": shading using {len(color_bins)} colors (bin centres are {' '.join(color_labels)})") + color_key = color_key[:len(non_empty)] + color_labels = [color_labels[bin] for bin in non_empty] + img = datashader.transfer_functions.shade(raster, color_key=color_key, how=normalize, min_alpha=min_alpha) # set color_minmax for colorbar color_minmax = bounds[caxis] @@ -438,9 +462,6 @@ def create_plot(ddf, xdatum, ydatum, adatum, ared, cdatum, cmap, bmap, dmap, nor log.debug('rendering image') - def match(artist): - return artist.__module__ == 'matplotlib.text' - fig = pylab.figure(figsize=(figx, figy)) ax = fig.add_subplot(111, facecolor=bgcol) ax.imshow(X=rgb.data, extent=[xmin, xmax, ymin, ymax], @@ -504,9 +525,9 @@ def decimate_list(x, maxel): import matplotlib.colors # discrete axis if caxis is not None and cdatum.is_discrete: - norm = matplotlib.colors.Normalize(-0.5, len(color_bins)-0.5) - ticks = np.arange(len(color_bins)) - colormap = matplotlib.colors.ListedColormap(color_mapping) + norm = matplotlib.colors.Normalize(-0.5, len(color_key)-0.5) + ticks = np.arange(len(color_key)) + colormap = matplotlib.colors.ListedColormap(color_key) # discretized axis else: norm = matplotlib.colors.Normalize(*color_minmax) diff --git a/shade_ms/main.py b/shade_ms/main.py index e3a68ac..3078848 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -609,7 +609,7 @@ def describe_corr(corrvalue): log.debug(f"taql is {mytaql}, group_cols is {group_cols}, join subset.corr is {join_corrs}") - dataframes, np = \ + dataframes, index_subsets, np = \ data_plots.get_plot_data(ms, group_cols, mytaql, ms.chan_freqs, chanslice=chanslice, subset=subset, noflags=options.noflags, noconj=options.noconj, @@ -667,7 +667,7 @@ def generate_string_from_keys(template, keys, listsep=" ", titlesep=" ", prefix= jobs = [] executor = None - def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel): + def render_single_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel): """Renders a single plot. Make this a function since we might call it in parallel""" log.info(f": rendering {pngname}") normalize = options.norm @@ -678,7 +678,7 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, else: context = nullcontext with context() as profiler: - result = data_plots.create_plot(df, xdatum, ydatum, adatum, ared, cdatum, + result = data_plots.create_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, cmap=cmap, bmap=bmap, dmap=dmap, normalize=normalize, min_alpha=options.min_alpha, saturate_alpha=options.saturate_alpha, @@ -695,6 +695,7 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, for (fld, spw, scan, antenna_or_baseline), df in dataframes.items(): + subset = index_subsets[fld, spw, scan, antenna_or_baseline] # update keys to be substituted into title and filename if fld is not None: keys['field_num'] = fld @@ -735,12 +736,12 @@ def render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, log.info(f' : created output directory {dirname}') if options.num_parallel < 2 or len(all_plots) < 2: - render_single_plot(df, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel) + render_single_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel) else: from concurrent.futures import ThreadPoolExecutor executor = ThreadPoolExecutor(options.num_parallel) log.info(f' : submitting job for {pngname}') - jobs.append(executor.submit(render_single_plot, df, xdatum, ydatum, adatum, ared, cdatum, + jobs.append(executor.submit(render_single_plot, df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname, title, xlabel, ylabel)) # wait for jobs to finish diff --git a/shade_ms/ms_info.py b/shade_ms/ms_info.py index 9e073e1..20e2797 100644 --- a/shade_ms/ms_info.py +++ b/shade_ms/ms_info.py @@ -70,6 +70,9 @@ def __init__(self, msname=None, log=None): self.valid_columns = set(tab.colnames()) + # indexing columns are read into memory in their entirety up front + self.indexing_columns = dict() + spw_tab = daskms.xds_from_table(msname + '::SPECTRAL_WINDOW', columns=['CHAN_FREQ']) self.chan_freqs = spw_tab[0].CHAN_FREQ # important for this to be an xarray self.nspw = self.chan_freqs.shape[0] @@ -80,7 +83,8 @@ def __init__(self, msname=None, log=None): self.field = NamedList("field", table(msname +'::FIELD', ack=False).getcol("NAME")) log and log.info(f": {len(self.field)} fields: {' '.join(self.field.names)}") - scan_numbers = sorted(set(tab.getcol("SCAN_NUMBER"))) + self.indexing_columns["SCAN_NUMBER"] = tab.getcol("SCAN_NUMBER") + scan_numbers = sorted(set(self.indexing_columns["SCAN_NUMBER"])) log and log.info(f": {len(scan_numbers)} scans: {' '.join(map(str, scan_numbers))}") all_scans = NamedList("scan", list(map(str, range(scan_numbers[-1]+1)))) self.scan = all_scans.get_subset(scan_numbers) @@ -90,8 +94,8 @@ def __init__(self, msname=None, log=None): self.all_antenna = antnames = NamedList("antenna", antnames) self.antpos = anttab.getcol("POSITION") - ant1col = tab.getcol("ANTENNA1") - ant2col = tab.getcol("ANTENNA2") + ant1col = self.indexing_columns["ANTENNA1"] = tab.getcol("ANTENNA1") + ant2col = self.indexing_columns["ANTENNA2"] = tab.getcol("ANTENNA2") self.antenna = self.all_antenna.get_subset(list(set(ant1col)|set(ant2col))) log and log.info(f": {len(self.antenna)}/{len(self.all_antenna)} antennas: {self.antenna.str_list()}") From dc3797535d4ff511a43125d31a860aeee167a4c2 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Mon, 15 Jun 2020 17:13:46 +0200 Subject: [PATCH 31/33] fixed adding scalar indexing columns to subset --- shade_ms/data_plots.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index d00ca5a..95a982c 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -160,7 +160,11 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, output_subset1 = output_subsets.setdefault(dataframe_key, {column:set() for column in msinfo.indexing_columns.keys()}) for column, _ in msinfo.indexing_columns.items(): - output_subset1[column].update(getattr(group, column).compute().data) + value = getattr(group, column) + if np.isscalar(value): + output_subset1[column].add(value) + else: + output_subset1[column].update(value.compute().data) for corr in subset.corr.numbers: # make dictionary of extra values for DataMappers From ef4c2f8f5023acb891a3b78057824ac6ad633061 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Mon, 15 Jun 2020 17:16:30 +0200 Subject: [PATCH 32/33] fixed missing baseline label when --iter-baseline --- shade_ms/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shade_ms/main.py b/shade_ms/main.py index 3078848..17cfe06 100644 --- a/shade_ms/main.py +++ b/shade_ms/main.py @@ -707,7 +707,7 @@ def render_single_plot(df, subset, xdatum, ydatum, adatum, ared, cdatum, pngname if antenna_or_baseline is not None: if options.iter_ant: keys['ant'] = ms.all_antenna[antenna_or_baseline] - elif options.iter_ant: + elif options.iter_baseline: keys['baseline'] = ms.all_baseline[antenna_or_baseline] # now loop over plot types From 13ac65966f955cda7b9e08912fa793e91db141e2 Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Tue, 16 Jun 2020 10:03:25 +0200 Subject: [PATCH 33/33] more fixes to labelling when subsets in effect --- shade_ms/data_plots.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/shade_ms/data_plots.py b/shade_ms/data_plots.py index 95a982c..b37557a 100644 --- a/shade_ms/data_plots.py +++ b/shade_ms/data_plots.py @@ -80,6 +80,9 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, # If any of these axes is not being iterated over, then the index at that position is None output_dataframes = OrderedDict() + # number of rows per each dataframe + output_rows = OrderedDict() + # output subsets of indexing columns, indexed by same tuple output_subsets = OrderedDict() @@ -126,6 +129,25 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, else: baseline = None + # Make frame key -- data subset corresponds to this frame + dataframe_key = (fld if iter_field else None, + ddid if iter_spw else None, + scan if iter_scan else None, + antenna if antenna is not None else baseline) + + # update subsets of MS indexing columns that we've seen for this dataframe + output_subset1 = output_subsets.setdefault(dataframe_key, + {column:set() for column in msinfo.indexing_columns.keys()}) + for column, _ in msinfo.indexing_columns.items(): + value = getattr(group, column) + if np.isscalar(value): + output_subset1[column].add(value) + else: + output_subset1[column].update(value.compute().data) + + # number of rows in dataframe + nrows0 = output_rows.setdefault(dataframe_key, 0) + # always read flags -- easier that way flag = group.FLAG if not noflags else None flag_row = group.FLAG_ROW if not noflags else None @@ -150,21 +172,6 @@ def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, ddf = None num_points = 0 # counts number of new points generated - # Make frame key -- data subset corresponds to this frame - dataframe_key = (fld if iter_field else None, - ddid if iter_spw else None, - scan if iter_scan else None, - antenna if antenna is not None else baseline) - - # update subsets of MS indexing columns that we've seen for this dataframe - output_subset1 = output_subsets.setdefault(dataframe_key, - {column:set() for column in msinfo.indexing_columns.keys()}) - for column, _ in msinfo.indexing_columns.items(): - value = getattr(group, column) - if np.isscalar(value): - output_subset1[column].add(value) - else: - output_subset1[column].update(value.compute().data) for corr in subset.corr.numbers: # make dictionary of extra values for DataMappers