From 1ff7673c9e8988bd0fe76be8c270d0814eb9aefb Mon Sep 17 00:00:00 2001 From: landmanbester Date: Wed, 4 Sep 2024 17:44:02 +0200 Subject: [PATCH] remove QuartiCal dependency --- pfb/__init__.py | 5 +- pfb/operators/gridder.py | 1 - pfb/operators/psi.py | 2 +- pfb/utils/correlations.py | 1 - pfb/utils/misc.py | 305 -------------------------------------- pfb/utils/stokes2vis.py | 2 - pfb/utils/weighting.py | 1 - pfb/workers/init.py | 15 +- setup.py | 11 +- 9 files changed, 12 insertions(+), 331 deletions(-) diff --git a/pfb/__init__.py b/pfb/__init__.py index 3e4b8ffc5..6238248e1 100644 --- a/pfb/__init__.py +++ b/pfb/__init__.py @@ -40,7 +40,7 @@ def set_client(nworkers, log, stack=None, host_address=None, if client_log_level == 'error': logging.getLogger("distributed").setLevel(logging.ERROR) logging.getLogger("bokeh").setLevel(logging.ERROR) - logging.getLogger("tornado").setLevel(logging.ERROR) + logging.getLogger("tornado").setLevel(logging.CRITICAL) elif client_log_level == 'warning': logging.getLogger("distributed").setLevel(logging.WARNING) logging.getLogger("bokeh").setLevel(logging.WARNING) @@ -55,7 +55,6 @@ def set_client(nworkers, log, stack=None, host_address=None, logging.getLogger("tornado").setLevel(logging.DEBUG) import dask - dask.config.set({'distributed.comm.compression': 'lz4'}) # set up client host_address = host_address or os.environ.get("DASK_SCHEDULER_ADDRESS") if host_address is not None: @@ -68,7 +67,7 @@ def set_client(nworkers, log, stack=None, host_address=None, dask.config.set({ 'distributed.comm.compression': { 'on': True, - 'type': 'blosc' + 'type': 'lz4' } }) cluster = LocalCluster(processes=True, diff --git a/pfb/operators/gridder.py b/pfb/operators/gridder.py index 52fd26d56..f2cadfab9 100644 --- a/pfb/operators/gridder.py +++ b/pfb/operators/gridder.py @@ -13,7 +13,6 @@ from ducc0.wgridder.experimental import vis2dirty, dirty2vis from ducc0.fft import c2r, r2c, c2c from africanus.constants import c as lightspeed -from quartical.utils.dask import Blocker from pfb.utils.weighting import counts_to_weights, _compute_counts from pfb.utils.beam import eval_beam from pfb.utils.naming import xds_from_list diff --git a/pfb/operators/psi.py b/pfb/operators/psi.py index e778bf24f..11376bcc1 100644 --- a/pfb/operators/psi.py +++ b/pfb/operators/psi.py @@ -286,7 +286,7 @@ def __init__(self, nband, nx, ny, bases, nlevel, nthreads): self.Nxmax = self.psib[0].Nxmax self.Nymax = self.psib[0].Nymax - self.nthreads_per_band = nthreads//nband + self.nthreads_per_band = np.maximum(1, nthreads//nband) def dot(self, x, alphao): ''' diff --git a/pfb/utils/correlations.py b/pfb/utils/correlations.py index 410f396c9..a29f53f9e 100644 --- a/pfb/utils/correlations.py +++ b/pfb/utils/correlations.py @@ -4,7 +4,6 @@ from dask.graph_manipulation import clone import dask.array as da from xarray import Dataset -from quartical.utils.numba import coerce_literal from operator import getitem from pfb.utils.beam import interp_beam diff --git a/pfb/utils/misc.py b/pfb/utils/misc.py index 688d50c4c..783eb23a2 100644 --- a/pfb/utils/misc.py +++ b/pfb/utils/misc.py @@ -21,7 +21,6 @@ from collections import namedtuple from africanus.coordinates.coordinates import radec_to_lmn import xarray as xr -from quartical.utils.dask import Blocker from scipy.interpolate import RegularGridInterpolator from scipy.linalg import solve_triangular import sympy as sm @@ -763,310 +762,6 @@ def chunkify_rows(time, utimes_per_chunk, daskify_idx=False): return tuple(row_chunks), time_bin_indices, time_bin_counts -def rephase_vis(vis, uvw, radec_in, radec_out): - return da.blockwise(_rephase_vis, 'rf', - vis, 'rf', - uvw, 'r3', - radec_in, None, - radec_out, None, - dtype=vis.dtype) - -def _rephase_vis(vis, uvw, radec_in, radec_out): - l_in, m_in, n_in = radec_to_lmn(radec_in) - l_out, m_out, n_out = radec_to_lmn(radec_out) - return vis * np.exp(1j*(uvw[:, 0]*(l_out-l_in) + - uvw[:, 1]*(m_out-m_in) + - uvw[:, 2]*(n_out-n_in))) - - -# TODO - should allow coarsening to values other than 1 -def concat_row(xds): - times_in = [] - freqs = [] - for ds in xds: - times_in.append(ds.time_out) - freqs.append(ds.freq_out) - - times_in = np.unique(times_in) - freqs = np.unique(freqs) - - nband = freqs.size - ntime_in = times_in.size - - if ntime_in == 1: # no need to concatenate - return xds - - # do merge manually because different variables require different - # treatment anyway eg. the BEAM should be computed as a weighted sum - xds_out = [] - for b in range(nband): - xdsb = [] - times = [] - freq_max = [] - freq_min = [] - time_max = [] - time_min = [] - nu = freqs[b] - for ds in xds: - if ds.freq_out == nu: - xdsb.append(ds) - times.append(ds.time_out) - freq_max.append(ds.freq_max) - freq_min.append(ds.freq_min) - time_max.append(ds.time_max) - time_min.append(ds.time_min) - - wgt = [ds.WEIGHT for ds in xdsb] - vis = [ds.VIS for ds in xdsb] - mask = [ds.MASK for ds in xdsb] - uvw = [ds.UVW for ds in xdsb] - - # get weighted sum of beams - beam = sum_beam(xdsb) - l_beam = xdsb[0].l_beam.data - m_beam = xdsb[0].m_beam.data - - wgto = xr.concat(wgt, dim='row') - viso = xr.concat(vis, dim='row') - masko = xr.concat(mask, dim='row') - uvwo = xr.concat(uvw, dim='row') - - xdso = xr.merge((wgto, viso, masko, uvwo)) - xdso = xdso.assign({'BEAM': (('l_beam', 'm_beam'), beam)}) - xdso['FREQ'] = xdsb[0].FREQ # is this always going to be the case? - - xdso = xdso.chunk({'row':-1, 'l_beam':-1, 'm_beam':-1}) - - xdso = xdso.assign_coords({ - 'chan': (('chan',), xdsb[0].chan.data), - 'l_beam': (('l_beam',), l_beam), - 'm_beam': (('m_beam',), m_beam) - }) - - times = np.array(times) - freq_max = np.array(freq_max) - freq_min = np.array(freq_min) - time_max = np.array(time_max) - time_min = np.array(time_min) - tout = np.round(np.mean(times), 5) # avoid precision issues - xdso = xdso.assign_attrs({ - 'dec': xdsb[0].dec, # always the case? - 'ra': xdsb[0].ra, # always the case? - 'time_out': tout, - 'time_max': time_max.max(), - 'time_min': time_min.min(), - 'timeid': 0, - 'freq_out': nu, - 'freq_max': freq_max.max(), - 'freq_min': freq_min.min(), - }) - xds_out.append(xdso) - return xds_out - - -def concat_chan(xds, nband_out=1): - times = [] - freqs_in = [] - freqs_min = [] - freqs_max = [] - all_freqs = [] - for ds in xds: - times.append(ds.time_out) - freqs_in.append(ds.freq_out) - freqs_min.append(ds.freq_min) - freqs_max.append(ds.freq_max) - all_freqs.append(ds.chan) - - times = np.unique(times) - freqs_in = np.unique(freqs_in) - freqs_min = np.unique(freqs_min) - freqs_max = np.unique(freqs_max) - all_freqs = np.unique(np.concatenate(all_freqs)) - - nband_in = freqs_in.size - ntime = times.size - - if nband_in == nband_out or nband_in == 1: # no need to concatenate - return xds - - # currently assuming linearly spaced frequencies - freq_bins = np.linspace(freqs_min.min(), freqs_max.max(), nband_out+1) - bin_centers = (freq_bins[1:] + freq_bins[0:-1])/2 - - xds_out = [] - for t in range(ntime): - time = times[t] - for b in range(nband_out): - xdst = [] - flow = freq_bins[b] - fhigh = freq_bins[b+1] - freqsb = all_freqs[all_freqs >= flow] - # exclusive except for the last one - if b==nband_out-1: - freqsb = freqsb[freqsb <= fhigh] - else: - freqsb = freqsb[freqsb < fhigh] - time_max = [] - time_min = [] - for ds in xds: - # ds overlaps output if either ds.freq_min or ds.freq_max lies in the bin - low_in = ds.freq_min > flow and ds.freq_min < fhigh - high_in = ds.freq_max > flow and ds.freq_max < fhigh - - if ds.time_out == time and (low_in or high_in): - xdst.append(ds) - time_max.append(ds.time_max) - time_min.append(ds.time_min) - - nrow = xdst[0].row.size - nchan = freqsb.size - - freqs_dask = da.from_array(freqsb, chunks=nchan) - blocker = Blocker(sum_overlap, 'rc') - blocker.add_input('ufreq', freqs_dask, 'f') - blocker.add_input('flow', flow, None) - blocker.add_input('fhigh', fhigh, None) - - for i, ds in enumerate(xdst): - ds = ds.chunk({'row':-1, 'chan':-1}) - blocker.add_input(f'vis{i}', ds.VIS.data, 'rc') - blocker.add_input(f'wgt{i}', ds.WEIGHT.data, 'rc') - blocker.add_input(f'mask{i}', ds.MASK.data, 'rc') - blocker.add_input(f'freq{i}', ds.FREQ.data, 'c') - - blocker.add_output('viso', 'rf', ((nrow,), (nchan,)), xdst[0].VIS.dtype) - blocker.add_output('wgto', 'rf', ((nrow,), (nchan,)), xdst[0].WEIGHT.dtype) - blocker.add_output('masko', 'rf', ((nrow,), (nchan,)), xdst[0].MASK.dtype) - - out_dict = blocker.get_dask_outputs() - - # get weighted sum of beam - beam = sum_beam(xdst) - l_beam = xdst[0].l_beam.data - m_beam = xdst[0].m_beam.data - - data_vars = { - 'VIS': (('row', 'chan'), out_dict['viso']), - 'WEIGHT': (('row', 'chan'), out_dict['wgto']), - 'MASK': (('row', 'chan'), out_dict['masko']), - 'FREQ': (('chan',), freqs_dask), - 'UVW': (('row', 'three'), xdst[0].UVW.data), # should be the same across data sets - 'BEAM': (('l_beam', 'm_beam'), beam) - } - - coords = { - 'chan': (('chan',), freqsb), - 'l_beam': (('l_beam',), l_beam), - 'm_beam': (('m_beam',), m_beam) - } - - fout = np.round(bin_centers[b], 5) # avoid precision issues - time_max = np.array(time_max) - time_min = np.array(time_min) - attrs = { - 'freq_out': fout, - 'freq_max': fhigh, - 'freq_min': flow, - 'bandid': b, - 'dec': xdst[0].dec, - 'ra': xdst[0].ra, - 'time_out': time, - 'time_max': time_max.max(), - 'time_min': time_min.min() - } - - xdso = xr.Dataset(data_vars=data_vars, - coords=coords, - attrs=attrs) - - xds_out.append(xdso) - return xds_out - - -def sum_beam(xds): - ''' - Compute the weighted sum of the beams contained in xds - weighting by the sum of the weights in each ds - ''' - nx, ny = xds[0].BEAM.shape - btype = xds[0].BEAM.dtype - blocker = Blocker(_sum_beam, 'xy') - blocker.add_input('nx', nx, None) - blocker.add_input('ny', ny, None) - blocker.add_input('btype', btype, None) - for i, ds in enumerate(xds): - blocker.add_input(f'beam{i}', ds.BEAM.data, 'xy') - blocker.add_input(f'wgt{i}', ds.WEIGHT.data, 'rf') - - blocker.add_output('beam', 'xy', ((nx,),(ny,)), btype) - out_dict = blocker.get_dask_outputs() - return out_dict['beam'] - -def _sum_beam(nx, ny, btype, **kwargs): - beam = np.zeros((nx, ny), dtype=btype) - # need to separate the different variables in kwargs - # i.e. beam, wgt -> nvars=2 - nitems = len(kwargs)//2 - wsum = 0.0 - for i in range(nitems): - wgti = kwargs[f'wgt{i}'] - wsumi = wgti.sum() - beam += wsumi * kwargs[f'beam{i}'] - wsum += wsumi - - if wsum: - beam /= wsum - - # blocker expects dict as output - out_dict = {} - out_dict['beam'] = beam - - return out_dict - -def sum_overlap(ufreq, flow, fhigh, **kwargs): - # need to separate the different variables in kwargs - # i.e. vis, wgt, mask, freq -> nvars=4 - nitems = len(kwargs)//4 - - # output grids - nchan = ufreq.size - nrow = kwargs['vis0'].shape[0] - viso = np.zeros((nrow, nchan), dtype=kwargs['vis0'].dtype) - wgto = np.zeros((nrow, nchan), dtype=kwargs['wgt0'].dtype) - masko = np.zeros((nrow, nchan), dtype=kwargs['mask0'].dtype) - - # weighted sum at overlap - for i in range(nitems): - vis = kwargs[f'vis{i}'] - wgt = kwargs[f'wgt{i}'] - mask = kwargs[f'mask{i}'] - nu = kwargs[f'freq{i}'] - _, idx0, idx1 = np.intersect1d(nu, ufreq, assume_unique=True, return_indices=True) - try: - viso[:, idx1] += vis[:, idx0] * wgt[:, idx0] * mask[:, idx0] - wgto[:, idx1] += wgt[:, idx0] * mask[:, idx0] - masko[:, idx1] += mask[:, idx0] - except Exception as e: - print(flow, fhigh, ufreq, nu) - raise e - - # unmasked where at least one data point is unflagged - masko = np.where(masko > 0, True, False) - # TODO - why does this get trigerred? - # if (wgto[masko]==0).any(): - # print(np.where(wgto[masko]==0)) - # raise ValueError("Weights are zero at unflagged location") - viso[masko] = viso[masko]/wgto[masko] - - # blocker expects a dictionary as output - out_dict = {} - out_dict['viso'] = viso - out_dict['wgto'] = wgto - out_dict['masko'] = masko.astype(np.uint8) - - return out_dict - - def l1reweight_func(model, psiH=None, outvar=None, diff --git a/pfb/utils/stokes2vis.py b/pfb/utils/stokes2vis.py index 8259eff04..6d102ad74 100644 --- a/pfb/utils/stokes2vis.py +++ b/pfb/utils/stokes2vis.py @@ -6,12 +6,10 @@ from distributed import worker_client import dask.array as da from xarray import Dataset -# from quartical.utils.numba import coerce_literal from operator import getitem from pfb.utils.beam import interp_beam from pfb.utils.misc import weight_from_sigma, combine_columns import dask -from quartical.utils.dask import Blocker from pfb.utils.stokes import stokes_funcs from pfb.utils.weighting import weight_data from uuid import uuid4 diff --git a/pfb/utils/weighting.py b/pfb/utils/weighting.py index 808f9208e..64f86b53a 100644 --- a/pfb/utils/weighting.py +++ b/pfb/utils/weighting.py @@ -5,7 +5,6 @@ import dask.array as da from ducc0.fft import c2c from africanus.constants import c as lightspeed -from quartical.utils.dask import Blocker from pfb.utils.misc import JIT_OPTIONS from pfb.utils.stokes import stokes_funcs from pfb.utils.naming import xds_from_list diff --git a/pfb/workers/init.py b/pfb/workers/init.py index f252f63e0..d33bef864 100644 --- a/pfb/workers/init.py +++ b/pfb/workers/init.py @@ -14,16 +14,6 @@ from pfb.parser.schemas import schema -# from pfb import parser -# from scabha.configuratt import load_nested -# from scabha.schema_utils import Schema -# config_file = os.path.dirname(parser.__file__) + '/init.yaml' -# schema, _ = load_nested([config_file], -# structured=OmegaConf.structured(Schema), -# config_class="PfbCleanCabs", -# use_cache=False) -# schema = OmegaConf.create(schema).init - @cli.command(context_settings={'show_default': True}) @clickify_parameters(schema.init) def init(**kw): @@ -144,7 +134,10 @@ def _init(**kw): gain_names = None if opts.freq_range is not None and len(opts.freq_range): - fmin, fmax = opts.freq_range.strip(' ').split(':') + try: + fmin, fmax = opts.freq_range.strip(' ').split(':') + except: + import ipdb; ipdb.set_trace() if len(fmin) > 0: freq_min = float(fmin) else: diff --git a/setup.py b/setup.py index 07776b843..0fdcc1d8d 100644 --- a/setup.py +++ b/setup.py @@ -14,17 +14,16 @@ 'Click', "ducc0" "@git+https://github.com/mreineck/ducc.git" - "@tweak_wgridder_conventions", - "QuartiCal" - "@git+https://github.com/ratt-ru/QuartiCal.git" - "@unpinneddeps", + "@ducc0", "sympy", - "stimela >= 2.0rc18", + "stimela" + "@git+https://github.com/caracal-pipeline/stimela.git" + "@clickify_missing_as_none", "streamjoy >= 0.0.8", "codex-africanus[complete] >= 0.3.7", + "dask-ms[xarray, zarr, s3]", "tbb", "jax[cpu]", - "ipycytoscape", "lz4", "ipdb", "psutil"