Skip to content

Commit

Permalink
Awskube (#91)
Browse files Browse the repository at this point in the history
* Bizarre segfault when running clark clean test but not if running from terminal

* only mop when clean stalls. rstrip / from gain_table and ms inputs in init and degrid workers

* Do not write fits files during tests

* awskube

* Add option to overwrite xds in init

* change dtype of ms to str in degrid config

* update defaults in degrid worker

* Remove debug statement in degrid

* rechunk model_vis before writing when model_column does not exist

* set model_exists to True when it already exists

* rechunk model_vis to on disk chunks regardless

* populate on_disc_chunks regardless

* Update Dockerfile

* fix dockerfile

* Let daskms take care of rechunking new columns

* replace . with _ in stimela_cabs

* remove from_url_and_kw, depend on dask-ms master

* update dockerfile, actually grid weights when computing image space weighting

* Parallel over wavelet basis using numba.prange

* Fix failing tests, make wavelet setup consistent everywhere

* Modify cab definition to pass missing as none

* remove curly braces from command

* All workers -> python-ext flavour

* Deprecate agroclean deconv mode

* Changes to weighting

* Test stricter preconditioning for stability

* fix for bsmooth per scan plots

* Make detrending optional in bsmooth

* clean modifications

* robust kalman interpolation

* add file

* revert to stimela lbkibe branch

* remove debug statement and revert to using kube branch of stimela

* Add robust GPR regression

* ref ant in bsmooth

* Also remove ref phase from per scan plots

* 1D Kalman interpolation with integrated Wiener process

* Weirdness in fast regression

* jit not working

* Array free evidence implementation

* jitted version

* use kanterp in b/gsmooth

* do not skip plotting amp of ref antenna

* Hack to move K offset into G

* typo

* typo

* typo

* unwrap phases in gsmooth

* Update K table after zeroing offset

* Sanity check

* More sanity checks

* Add 2pi to delay term

* Reinstate smoothing after sanity checks

* fixes to bsmooth

* bsmooth per scan

* Fix failing tests

* k=0 in compute counts test

* fix Dockerfile and setup.py

* Add performance report when using distributed scheduler

* Only normalise by weights if they are non-zero

* generic root_dir in stimela cabs

* reorganise config files

* remove pdb import

* move stimela config files so that they are consistent with those in cult-cargo

* depend on stimela@FIASCO3-daskjob

* revert changes for stimela@issue-158

* pin bokeh to < 3

* remove compute_context

* ipdb before compute

* print statements

* replace inlined_array everywhere

* add comments where things go wrong in init

* remove debug statements

* test prange over chan in weight_data

* separate log directory for all workers. Restructure init and grid workers so they can be chained without writing visibilities to disk

* depend on QuartiCal@v0.2.1-degridder

* do not duplicate QuartiCal dependencies in setup.py

* start go worker

* write out performance report for grid worker

* make sure log directories are created if the don't exist. Report log location when worker is invoked

* restructure tests to (partially) avoid writing to disk

* test if graph optimisation improves memory consumption

* fix ms_chunks when using cpi of -1 with freq_range selection

* rechunk subds after slice

* also rechunk gains after slice. don't use graph optimisation for init. dds[i] <- ds[i] when rechunking psf in gridder

* avoid div by zero in normdiff, log to logdir in spotless

* wstack -> do_wgridding

* fix failing test

* update test_forwardmodel to test imaging with forward model vs imaging corrected data + weight

* missing Path import in spotless and further modifications to forward model test

* exponential model in fwdbwd

* use Blocker instead of blockwise to compute Stokes data and weights

* correct logic to reuse counts if it already exists

* set sigmainv evn if not using hessnorm

* save param slice

* ainitialise from PARAM when present in dds

* allow switching between non-linearities in fwdbwd, remove mode before backward step

* precondition grad21

* ignore missing subbands in sfactor

* add mising imports

* allow passing in list of measurement sets from cli

* quick fix for unity beam

* allow gain-table to be list

* unpack datasets sum_beam and sum_overlap

* comment visualization and up gridder verbosity

* allow asynchronous in LocalCluster

* allow processes scheduler

* fix failing beam tests

* add warning when extrapolating beam

* fix failing spotless test (init.max_fov > grid.fov)

* drop python3.8

* automate stokes product computation for both linear and circular feeds in the presence of full jones matrices

* jit outer function

* Neaten stokes code

* reshape jones into (corr, corr) in full 2x2 mode

* do not compute vis and wgt at flagged locations (avoid divide by zero errors)

* don't apply Winv to corrected data

* correct jones reshape operation in stokes

* clean up stokes funcs

* rechunk beam coords on read

* salways init imwgt

* inspect vis and wgt funcs

* swap Q sign convention

* no unnecessary symb math

* add rudimentary test for pol products

* restore Winv in symbolic expression and apply wgt during gridding

* test refined gains without Winv in symbolic expression

* test pol products in presence of gains, start adding fastim worker

* add fastim config and worker files

* tweaks to coeff naming

* allow empty string for transfer_model_from parameter, allow l2reweight_dof=0

* remove 3.10 from testing matrix

* hack for dual pol data

* dispatch based on ncorr

* add missing Path import in model2comps

* pin numba to < 0.59

* ncorr -> nc

* compare to literal

* typo

* use -1 to get last corr dimension

* fix complex warning

* Track best model based on rms. Check for divergence

* add noise in test_polproducts test

* depend on latest QC release. remove ipdb from requirements

* restore old QC dependency

* remove depencies installed by qc

* depend on QuartiCal[degrid]

* qc on pypi, manually install ducc0

* reformat uncabbedcabs.yml

* don't jit subminor, remove parallel

* add sympy to dependencies

* no fastmath in jit

* don't cache lambdified functions

* missed cache=True

* track bestim in clean

* diverge_count typo

* dims -> sizes

* tweak fits output names for clean

* Allow explicit path to xds in grid worker

* diverging -> diverge_count in clean

* replace _ with - in config, use . instead of - for stimela cab names

* manually replace - with _ in defaults dict when creating parser

* change relative include to include from (.). Subs - for _ in tests

* changes to config

* pin stimela

* move return statement into try except to print beam interp warning instead of bombing

* Simpler wavelet implementation with vertical parallelism (#93)

* replace old wavelet code with simplified 2D versions

* incorporate new wavelets and psi operator

---------

Co-authored-by: landmanbester <lbester@ska.ac.za>

* replace generated_jit with overloads

* add 3.9 and 3.11 to testing matrix. Unpin numba

---------

Co-authored-by: landmanbester <lbester@ska.ac.za>
  • Loading branch information
landmanbester and landmanbester authored Mar 15, 2024
1 parent 2432997 commit e51475b
Show file tree
Hide file tree
Showing 66 changed files with 3,617 additions and 5,356 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
if: "!contains(github.event.head_commit.message, '[skip ci]')"
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -28,8 +28,8 @@ jobs:
- name: Upgrade pip and setuptools
run: python -m pip install -U pip setuptools

- name: Pin setuptools
run: python -m pip install setuptools==65.5
# - name: Pin setuptools
# run: python -m pip install setuptools==65.5

- name: Install pfb-clean
run: python -m pip install .[testing]
Expand All @@ -46,7 +46,7 @@ jobs:
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: 3.8
python-version: 3.10

- name: Install latest setuptools, wheel, pip
run: python3 -m pip install -U pip setuptools wheel
Expand Down
21 changes: 15 additions & 6 deletions pfb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,23 @@ def set_client(opts, stack, log, scheduler='distributed'):
os.environ["OPENBLAS_NUM_THREADS"] = str(opts.nvthreads)
os.environ["MKL_NUM_THREADS"] = str(opts.nvthreads)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(opts.nvthreads)
os.environ["NUMBA_NUM_THREADS"] = str(opts.nvthreads)
import numexpr as ne
max_cores = ne.detect_number_of_cores()
# ne_threads = min(max_cores, opts.nvthreads)
os.environ["NUMEXPR_NUM_THREADS"] = str(max_cores)
os.environ["NUMBA_NUM_THREADS"] = str(max_cores)

import dask
if scheduler=='distributed':
# TODO - investigate what difference this makes
# with dask.config.set({"distributed.scheduler.worker-saturation": 1.1}):
# client = distributed.Client()
# set up client
if opts.host_address is not None:
host_address = opts.host_address or os.environ.get("DASK_SCHEDULER_ADDRESS")
if host_address is not None:
from distributed import Client
print("Initialising distributed client.", file=log)
client = stack.enter_context(Client(opts.host_address))
client = stack.enter_context(Client(host_address))
else:
if opts.nthreads_dask * opts.nvthreads > nthreads_max:
print("Warning - you are attempting to use more threads than "
Expand All @@ -55,9 +56,11 @@ def set_client(opts, stack, log, scheduler='distributed'):
from dask.distributed import Client, LocalCluster
print("Initialising client with LocalCluster.", file=log)
with dask.config.set({"distributed.scheduler.worker-saturation": 1.1}):
cluster = LocalCluster(processes=True, n_workers=opts.nworkers,
threads_per_worker=opts.nthreads_dask,
memory_limit=0) # str(mem_limit/nworkers)+'GB'
cluster = LocalCluster(processes=opts.nworkers > 1,
n_workers=opts.nworkers,
threads_per_worker=opts.nthreads_dask,
memory_limit=0, # str(mem_limit/nworkers)+'GB'
asynchronous=False)
cluster = stack.enter_context(cluster)
client = stack.enter_context(Client(cluster))

Expand All @@ -73,6 +76,12 @@ def set_client(opts, stack, log, scheduler='distributed'):
dask.config.set(pool=ThreadPool(opts.nthreads_dask))
print(f"Initialising ThreadPool with {opts.nthreads_dask} threads",
file=log)
elif scheduler=='processes':
# TODO - why is the performance so terrible in this case?
from multiprocessing.pool import Pool
dask.config.set(pool=Pool(opts.nthreads_dask))
print(f"Initialising Pool with {opts.nthreads_dask} processes",
file=log)
else:
raise ValueError(f"Unknown scheduler option {opts.scheduler}")

Expand Down
12 changes: 7 additions & 5 deletions pfb/deconv/clark.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import numpy as np
import numexpr as ne
from functools import partial
import numba
from numba import njit
import dask.array as da
from pfb.operators.psf import psf_convolve_cube
from ducc0.misc import make_noncritical
import pyscilog
log = pyscilog.get_logger('CLARK')

@numba.jit(parallel=True, nopython=True, nogil=True, cache=True, inline='always')
@njit(nogil=True, cache=True) # parallel=True,
def subtract(A, psf, Ip, Iq, xhat, nxo2, nyo2):
'''
Subtract psf centered at location of xhat
'''
# loop over active indices
nband = xhat.size
for b in numba.prange(nband):
# for b in range(nband):
# for b in numba.prange(nband):
for b in range(nband):
for i in range(Ip.size):
pp = nxo2 - Ip[i]
qq = nyo2 - Iq[i]
Expand All @@ -25,7 +25,7 @@ def subtract(A, psf, Ip, Iq, xhat, nxo2, nyo2):
return A


@numba.jit(parallel=True, nopython=True, nogil=True, cache=True)
@njit(nogil=True, cache=True) # parallel=True,
def subminor(A, psf, Ip, Iq, model, wsums, gamma=0.05, th=0.0, maxit=10000):
"""
Run subminor loop in active set
Expand Down Expand Up @@ -55,6 +55,8 @@ def subminor(A, psf, Ip, Iq, model, wsums, gamma=0.05, th=0.0, maxit=10000):
q = Iq[pq]
Amax = np.sqrt(Asearch[pq])
fsel = wsums > 0
if fsel.sum() == 0:
raise ValueError("wsums are all zero")
k = 0
while Amax > th and k < maxit:
xhat = A[:, pq]
Expand Down
4 changes: 2 additions & 2 deletions pfb/operators/gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Fs = np.fft.fftshift


@njit(parallel=True, nogil=True, fastmath=True, inline='always')
@njit(parallel=True, nogil=True, cache=True, inline='always')
def freqmul(A, x):
nchan, npix = x.shape
out = np.zeros((nchan, npix), dtype=x.dtype)
Expand All @@ -19,7 +19,7 @@ def freqmul(A, x):
return out


@njit(parallel=True, nogil=True, fastmath=True, inline='always')
@njit(parallel=True, nogil=True, cache=True, inline='always')
def make_kernel(nx_psf, ny_psf, sigma0, length_scale):
K = np.zeros((1, nx_psf, ny_psf), dtype=np.float64)
for j in range(nx_psf):
Expand Down
39 changes: 22 additions & 17 deletions pfb/operators/gridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _loc2psf_vis_impl(uvw,
cell,
x0=0,
y0=0,
wstack=True,
do_wgridding=True,
epsilon=1e-7,
nthreads=1,
precision='single',
Expand Down Expand Up @@ -327,7 +327,7 @@ def _loc2psf_vis_impl(uvw,
center_x=x0,
center_y=y0,
epsilon=1e-7,
do_wgridding=wstack,
do_wgridding=do_wgridding,
nthreads=nthreads,
divide_by_n=divide_by_n)

Expand All @@ -344,7 +344,7 @@ def _loc2psf_vis(uvw,
cell,
x0=0,
y0=0,
wstack=True,
do_wgridding=True,
epsilon=1e-7,
nthreads=1,
precision='single',
Expand All @@ -355,7 +355,7 @@ def _loc2psf_vis(uvw,
cell,
x0,
y0,
wstack,
do_wgridding,
epsilon,
nthreads,
precision,
Expand All @@ -368,7 +368,7 @@ def loc2psf_vis(uvw,
cell,
x0=0,
y0=0,
wstack=True,
do_wgridding=True,
epsilon=1e-7,
nthreads=1,
precision='single',
Expand All @@ -380,7 +380,7 @@ def loc2psf_vis(uvw,
cell, None,
x0, None,
y0, None,
wstack, None,
do_wgridding, None,
epsilon, None,
nthreads, None,
precision, None,
Expand All @@ -405,7 +405,7 @@ def comps2vis(uvw,
x0=0, y0=0,
epsilon=1e-7,
nthreads=1,
wstack=True,
do_wgridding=True,
divide_by_n=False,
ncorr_out=4):

Expand Down Expand Up @@ -436,7 +436,7 @@ def comps2vis(uvw,
y0, None,
epsilon, None,
nthreads, None,
wstack, None,
do_wgridding, None,
divide_by_n, None,
ncorr_out, None,
new_axes={'c': ncorr_out},
Expand All @@ -462,7 +462,7 @@ def _comps2vis(uvw,
x0=0, y0=0,
epsilon=1e-7,
nthreads=1,
wstack=True,
do_wgridding=True,
divide_by_n=False,
ncorr_out=4):
return _comps2vis_impl(uvw[0],
Expand All @@ -481,7 +481,7 @@ def _comps2vis(uvw,
x0=x0, y0=y0,
epsilon=epsilon,
nthreads=nthreads,
wstack=wstack,
do_wgridding=do_wgridding,
divide_by_n=divide_by_n,
ncorr_out=ncorr_out)

Expand All @@ -503,7 +503,7 @@ def _comps2vis_impl(uvw,
x0=0, y0=0,
epsilon=1e-7,
nthreads=1,
wstack=True,
do_wgridding=True,
divide_by_n=False,
ncorr_out=4):
# adjust for chunking
Expand Down Expand Up @@ -538,7 +538,7 @@ def _comps2vis_impl(uvw,
pixsize_x=cellx, pixsize_y=celly,
center_x=x0, center_y=y0,
epsilon=epsilon,
do_wgridding=wstack,
do_wgridding=do_wgridding,
divide_by_n=divide_by_n,
nthreads=nthreads)
if ncorr_out > 1:
Expand Down Expand Up @@ -608,8 +608,11 @@ def image_data_products(uvw,
ressq = (residual_vis*residual_vis.conj()).real
wcount = mask.sum()
if wcount:
ovar = ressq.sum()/wcount
wgt = (dof + 1)/(dof + ressq/ovar)/ovar
ovar = ressq.sum()/wcount # use 67% quantile?
wgt = (l2reweight_dof + 1)/(l2reweight_dof + ressq/ovar)/ovar
else:
wgt = None


# we usually want to re-evaluate this since the robustness may change
if robustness is not None:
Expand All @@ -623,8 +626,10 @@ def image_data_products(uvw,
nx, ny,
cellx, celly,
robustness)

wgt *= imwgt
if wgt is not None:
wgt *= imwgt
else:
wgt = imwgt

if do_weight:
out_dict['WEIGHT'] = wgt
Expand Down Expand Up @@ -674,7 +679,7 @@ def image_data_products(uvw,
center_x=x0,
center_y=y0,
epsilon=1e-7,
do_wgridding=wstack,
do_wgridding=do_wgridding,
nthreads=nthreads,
divide_by_n=False,
flip_v=False, # hardcoded for now
Expand Down
21 changes: 9 additions & 12 deletions pfb/operators/hessian.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import dask
import dask.array as da
from daskms.optimisation import inlined_array
from ducc0.wgridder.experimental import vis2dirty, dirty2vis
from ducc0.misc import make_noncritical
from uuid import uuid4
Expand Down Expand Up @@ -47,8 +46,6 @@ def hessian_xds(x, xds, hessopts, wsum, sigmainv, mask,

convim = hessian(x[b], uvw, wgt, vis_mask, freq, beam, hessopts)

# convim = inlined_array(convim, uvw)

convims[b] += convim

convim = da.stack(convims)/wsum
Expand All @@ -66,7 +63,7 @@ def _hessian_impl(x, uvw, weight, vis_mask, freq, beam,
x0=0.0,
y0=0.0,
cell=None,
wstack=None,
do_wgridding=None,
epsilon=None,
double_accum=None,
nthreads=None):
Expand All @@ -83,7 +80,7 @@ def _hessian_impl(x, uvw, weight, vis_mask, freq, beam,
center_y=y0,
epsilon=epsilon,
nthreads=nthreads,
do_wgridding=wstack,
do_wgridding=do_wgridding,
divide_by_n=False)

convim = vis2dirty(uvw=uvw,
Expand All @@ -99,7 +96,7 @@ def _hessian_impl(x, uvw, weight, vis_mask, freq, beam,
center_y=y0,
epsilon=epsilon,
nthreads=nthreads,
do_wgridding=wstack,
do_wgridding=do_wgridding,
double_precision_accumulation=double_accum,
divide_by_n=False)

Expand Down Expand Up @@ -162,11 +159,11 @@ def _hessian_psf_slice(

from pfb.operators.hessian import _hessian_impl
class hessian_psf_slice(object):
def __init__(self, ds, nbasis, nmax, nthreads, sigmainv, cell, wstack, epsilon, double_accum):
def __init__(self, ds, nbasis, nmax, nthreads, sigmainv, cell, do_wgridding, epsilon, double_accum):
self.nthreads = nthreads
self.sigmainv = sigmainv
self.cell = cell
self.wstack = wstack
self.do_wgridding = do_wgridding
self.epsilon = epsilon
self.double_accum = double_accum
self.lastsize = ds.PSF.shape[-1]
Expand Down Expand Up @@ -244,7 +241,7 @@ def compute_residual(self, x):
self.freq,
None,
cell=self.cell,
wstack=self.wstack,
do_wgridding=self.do_wgridding,
epsilon=self.epsilon,
double_accum=self.double_accum,
nthreads=self.nthreads)
Expand Down Expand Up @@ -289,7 +286,7 @@ def hess_vis(xds,
xout,
x,
sigmainv=1.0,
wstack=True,
do_wgridding=True,
nthreads=1,
epsilon=1e-7,
divide_by_n=False):
Expand Down Expand Up @@ -318,7 +315,7 @@ def hess_vis(xds,
center_x=x0,
center_y=y0,
epsilon=epsilon,
do_wgridding=wstack,
do_wgridding=do_wgridding,
nthreads=nthreads,
divide_by_n=divide_by_n)

Expand All @@ -340,7 +337,7 @@ def hess_vis(xds,
center_x=x0,
center_y=y0,
epsilon=epsilon,
do_wgridding=wstack,
do_wgridding=do_wgridding,
nthreads=nthreads,
divide_by_n=divide_by_n)
xout[field][f't{t}b{b}'] += sigmainv * x[field][f't{t}b{b}']
Expand Down
1 change: 0 additions & 1 deletion pfb/operators/psf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import numexpr as ne
import dask.array as da
from daskms.optimisation import inlined_array
from uuid import uuid4
from ducc0.fft import r2c, c2r, c2c, good_size
from ducc0.misc import roll_resize_roll as rrr
Expand Down
Loading

0 comments on commit e51475b

Please sign in to comment.