Skip to content

Commit

Permalink
add option to write out FFT of residual in restore worker
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Apr 17, 2024
1 parent 5b542b0 commit 96e0128
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 22 deletions.
2 changes: 2 additions & 0 deletions pfb/parser/init.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ inputs:
abbreviation: ms
info:
Path to measurement set
policies:
repeat: '[]'
scans:
dtype: List[int]
info:
Expand Down
22 changes: 14 additions & 8 deletions pfb/parser/restore.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,37 @@ inputs:
dtype: str
abbreviation: mname
default: MODEL
info: 'Name of model in mds'
info:
Name of model in dds
residual-name:
dtype: str
abbreviation: rname
default: RESIDUAL
info: 'Name of residual in dds'
info:
Name of residual in dds
nband:
dtype: int
required: true
abbreviation: nb
info: 'Number of imaging bands'
info:
Number of imaging bands
postfix:
dtype: str
default: 'main'
info: 'Can be used to specify a custom name for the image space data \
products'
info:
Can be used to specify a custom name for the image space data products
outputs:
dtype: str
default: mMrRiI
info: 'Output products. (m)odel, (r)esidual, (i)mage, (c)lean beam. \
Captitals correspond to cubes.'
info:
Output products (m)odel, (r)esidual, (i)mage, (c)lean beam, (d)irty,
(f)ft_residuals (amplitude and phase will be produced).
Use captitals to produce corresponding cubes.
overwrite:
dtype: bool
default: false
info: Allow overwrite of output xds
info:
Allow overwriting fits files

outputs:
{}
55 changes: 41 additions & 14 deletions pfb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,27 +1460,54 @@ def combine_columns(x, y, dc, dc1, dc2):
# shifty - shift y coordinate by this amount

# All sizes are assumed to be in radians.
# '''
# '''
# import matplotlib.pyplot as plt
# from scipy.fft import fftn, ifftn
# Fs = np.fft.fftshift
# iFs = np.fft.ifftshift
# # basic
# nx, ny = image.shape
# imhat = Fs(fftn(image))

# imabs = np.abs(imhat)
# imphase = np.angle(imhat) - 1.0
# # imphase = np.roll(imphase, nx//2, axis=0)
# imshift = ifftn(iFs(imabs*np.exp(1j*imphase))).real

# impad = np.pad(imhat, ((nx//2, nx//2), (ny//2, ny//2)), mode='constant')
# imo = ifftn(iFs(impad)).real

# # coordinates on input grid
# nx, ny = image.shape
# x = np.arange(-(nx//2), nx//2) * cellxi
# y = np.arange(-(ny//2), ny//2) * cellyi
# xx, yy = np.meshgrid(x, y, indexing='ij')
# print(np.sum(image) - np.sum(imo))

# plt.figure(1)
# plt.imshow(image/image.max(), vmin=0, vmax=1, interpolation=None)
# plt.colorbar()
# plt.figure(2)
# plt.imshow(imo/imo.max(), vmin=0, vmax=1, interpolation=None)
# plt.colorbar()
# plt.figure(3)
# plt.imshow(imshift/imshift.max() - image/image.max(), vmin=0, vmax=1, interpolation=None)
# plt.colorbar()

# plt.show()

# # coordinates on input grid
# nx, ny = image.shapeimhat
# x = np.arange(-(nx//2), nx//2) * cellxi
# y = np.arange(-(ny//2), ny//2) * cellyi
# xx, yy = np.meshgrid(x, y, indexing='ij')

# # frequencies on output grid
# celluo = 1/(nxo*cellxo)
# cellvo = 1/(nyo*cellyo)
# uo = np.arange(-(nxo//2), nxo//2) * celluo/nxo
# vo = np.arange(-(nyo//2), nyo//2) * cellvo/nyo
# # frequencies on output grid
# celluo = 1/(nxo*cellxo)
# cellvo = 1/(nyo*cellyo)
# uo = np.arange(-(nxo//2), nxo//2) * celluo/nxo
# vo = np.arange(-(nyo//2), nyo//2) * cellvo/nyo

# uu, vv = np.meshgrid(uo, vo, indexing='ij')
# uv = np.vstack((uo, vo)).T
# uu, vv = np.meshgrid(uo, vo, indexing='ij')
# uv = np.vstack((uo, vo)).T


# res1 = finufft.nufft2d3(xx.ravel(), yy.ravel(), image.ravel(), uu.ravel(), vv.ravel())
# res1 = finufft.nufft2d3(xx.ravel(), yy.ravel(), image.ravel(), uu.ravel(), vv.ravel())



25 changes: 25 additions & 0 deletions pfb/workers/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _restore(**kw):
from pfb.utils.fits import (save_fits, add_beampars, set_wcs,
dds2fits, dds2fits_mfs)
from pfb.utils.misc import Gaussian2D, fitcleanbeam, convolve2gaussres, dds2cubes
from ducc0.fft import r2c

basename = f'{opts.output_filename}_{opts.product.upper()}'
dds_name = f'{basename}_{opts.postfix}.dds'
Expand Down Expand Up @@ -154,6 +155,30 @@ def _restore(**kw):
hdr,
overwrite=opts.overwrite)

if 'f' in opts.outputs:
rhat_mfs = r2c(residual_mfs, forward=True,
nthreads=opts.opts.nvthreads, inorm=0)
save_fits(np.abs(rhat_mfs),
f'{basename}_{opts.postfix}.abs_fft_residual_mfs.fits',
hdr_mfs,
overwrite=opts.overwrite)
save_fits(np.angle(rhat_mfs),
f'{basename}_{opts.postfix}.phase_fft_residual_mfs.fits',
hdr_mfs,
overwrite=opts.overwrite)

if 'F' in opts.outputs:
rhat = r2c(residual, axes=(1,2), forward=True,
nthreads=opts.opts.nvthreads, inorm=0)
save_fits(np.abs(rhat),
f'{basename}_{opts.postfix}.abs_fft_residual.fits',
hdr,
overwrite=opts.overwrite)
save_fits(np.angle(rhat),
f'{basename}_{opts.postfix}.phase_fft_residual.fits',
hdr,
overwrite=opts.overwrite)

if 'd' in opts.outputs:
dirty_mfs = np.sum(dirty, axis=0)
save_fits(dirty_mfs,
Expand Down

0 comments on commit 96e0128

Please sign in to comment.