From 96e01282e8213ec295526738fef09b0e54809a54 Mon Sep 17 00:00:00 2001 From: landmanbester Date: Wed, 17 Apr 2024 12:32:37 +0200 Subject: [PATCH] add option to write out FFT of residual in restore worker --- pfb/parser/init.yaml | 2 ++ pfb/parser/restore.yaml | 22 +++++++++++------ pfb/utils/misc.py | 55 ++++++++++++++++++++++++++++++----------- pfb/workers/restore.py | 25 +++++++++++++++++++ 4 files changed, 82 insertions(+), 22 deletions(-) diff --git a/pfb/parser/init.yaml b/pfb/parser/init.yaml index f453b9b28..d1bd358b8 100644 --- a/pfb/parser/init.yaml +++ b/pfb/parser/init.yaml @@ -5,6 +5,8 @@ inputs: abbreviation: ms info: Path to measurement set + policies: + repeat: '[]' scans: dtype: List[int] info: diff --git a/pfb/parser/restore.yaml b/pfb/parser/restore.yaml index acb73a9ec..54e934573 100644 --- a/pfb/parser/restore.yaml +++ b/pfb/parser/restore.yaml @@ -6,31 +6,37 @@ inputs: dtype: str abbreviation: mname default: MODEL - info: 'Name of model in mds' + info: + Name of model in dds residual-name: dtype: str abbreviation: rname default: RESIDUAL - info: 'Name of residual in dds' + info: + Name of residual in dds nband: dtype: int required: true abbreviation: nb - info: 'Number of imaging bands' + info: + Number of imaging bands postfix: dtype: str default: 'main' - info: 'Can be used to specify a custom name for the image space data \ - products' + info: + Can be used to specify a custom name for the image space data products outputs: dtype: str default: mMrRiI - info: 'Output products. (m)odel, (r)esidual, (i)mage, (c)lean beam. \ - Captitals correspond to cubes.' + info: + Output products (m)odel, (r)esidual, (i)mage, (c)lean beam, (d)irty, + (f)ft_residuals (amplitude and phase will be produced). + Use captitals to produce corresponding cubes. overwrite: dtype: bool default: false - info: Allow overwrite of output xds + info: + Allow overwriting fits files outputs: {} diff --git a/pfb/utils/misc.py b/pfb/utils/misc.py index 76a94d068..a813e5358 100644 --- a/pfb/utils/misc.py +++ b/pfb/utils/misc.py @@ -1460,27 +1460,54 @@ def combine_columns(x, y, dc, dc1, dc2): # shifty - shift y coordinate by this amount # All sizes are assumed to be in radians. -# ''' +# ''' +# import matplotlib.pyplot as plt +# from scipy.fft import fftn, ifftn +# Fs = np.fft.fftshift +# iFs = np.fft.ifftshift +# # basic +# nx, ny = image.shape +# imhat = Fs(fftn(image)) +# imabs = np.abs(imhat) +# imphase = np.angle(imhat) - 1.0 +# # imphase = np.roll(imphase, nx//2, axis=0) +# imshift = ifftn(iFs(imabs*np.exp(1j*imphase))).real +# impad = np.pad(imhat, ((nx//2, nx//2), (ny//2, ny//2)), mode='constant') +# imo = ifftn(iFs(impad)).real -# # coordinates on input grid -# nx, ny = image.shape -# x = np.arange(-(nx//2), nx//2) * cellxi -# y = np.arange(-(ny//2), ny//2) * cellyi -# xx, yy = np.meshgrid(x, y, indexing='ij') +# print(np.sum(image) - np.sum(imo)) + +# plt.figure(1) +# plt.imshow(image/image.max(), vmin=0, vmax=1, interpolation=None) +# plt.colorbar() +# plt.figure(2) +# plt.imshow(imo/imo.max(), vmin=0, vmax=1, interpolation=None) +# plt.colorbar() +# plt.figure(3) +# plt.imshow(imshift/imshift.max() - image/image.max(), vmin=0, vmax=1, interpolation=None) +# plt.colorbar() + +# plt.show() + + # # coordinates on input grid + # nx, ny = image.shapeimhat + # x = np.arange(-(nx//2), nx//2) * cellxi + # y = np.arange(-(ny//2), ny//2) * cellyi + # xx, yy = np.meshgrid(x, y, indexing='ij') -# # frequencies on output grid -# celluo = 1/(nxo*cellxo) -# cellvo = 1/(nyo*cellyo) -# uo = np.arange(-(nxo//2), nxo//2) * celluo/nxo -# vo = np.arange(-(nyo//2), nyo//2) * cellvo/nyo + # # frequencies on output grid + # celluo = 1/(nxo*cellxo) + # cellvo = 1/(nyo*cellyo) + # uo = np.arange(-(nxo//2), nxo//2) * celluo/nxo + # vo = np.arange(-(nyo//2), nyo//2) * cellvo/nyo -# uu, vv = np.meshgrid(uo, vo, indexing='ij') -# uv = np.vstack((uo, vo)).T + # uu, vv = np.meshgrid(uo, vo, indexing='ij') + # uv = np.vstack((uo, vo)).T -# res1 = finufft.nufft2d3(xx.ravel(), yy.ravel(), image.ravel(), uu.ravel(), vv.ravel()) + # res1 = finufft.nufft2d3(xx.ravel(), yy.ravel(), image.ravel(), uu.ravel(), vv.ravel()) diff --git a/pfb/workers/restore.py b/pfb/workers/restore.py index afbbd4b8e..d17ced940 100644 --- a/pfb/workers/restore.py +++ b/pfb/workers/restore.py @@ -51,6 +51,7 @@ def _restore(**kw): from pfb.utils.fits import (save_fits, add_beampars, set_wcs, dds2fits, dds2fits_mfs) from pfb.utils.misc import Gaussian2D, fitcleanbeam, convolve2gaussres, dds2cubes + from ducc0.fft import r2c basename = f'{opts.output_filename}_{opts.product.upper()}' dds_name = f'{basename}_{opts.postfix}.dds' @@ -154,6 +155,30 @@ def _restore(**kw): hdr, overwrite=opts.overwrite) + if 'f' in opts.outputs: + rhat_mfs = r2c(residual_mfs, forward=True, + nthreads=opts.opts.nvthreads, inorm=0) + save_fits(np.abs(rhat_mfs), + f'{basename}_{opts.postfix}.abs_fft_residual_mfs.fits', + hdr_mfs, + overwrite=opts.overwrite) + save_fits(np.angle(rhat_mfs), + f'{basename}_{opts.postfix}.phase_fft_residual_mfs.fits', + hdr_mfs, + overwrite=opts.overwrite) + + if 'F' in opts.outputs: + rhat = r2c(residual, axes=(1,2), forward=True, + nthreads=opts.opts.nvthreads, inorm=0) + save_fits(np.abs(rhat), + f'{basename}_{opts.postfix}.abs_fft_residual.fits', + hdr, + overwrite=opts.overwrite) + save_fits(np.angle(rhat), + f'{basename}_{opts.postfix}.phase_fft_residual.fits', + hdr, + overwrite=opts.overwrite) + if 'd' in opts.outputs: dirty_mfs = np.sum(dirty, axis=0) save_fits(dirty_mfs,