Skip to content

Commit

Permalink
use bfgs for fitting psf
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Apr 15, 2024
1 parent c1305e8 commit d083703
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 22 deletions.
1 change: 1 addition & 0 deletions pfb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def set_client(opts, stack, log, scheduler='distributed'):
os.environ["MKL_NUM_THREADS"] = str(opts.nvthreads)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(opts.nvthreads)
os.environ["NUMBA_NUM_THREADS"] = str(opts.nvthreads)
os.environ["JAX_ENABLE_X64"] = 'True'
# this may be required for numba parallelism
# find python and set LD_LIBRARY_PATH
paths = sys.path
Expand Down
53 changes: 33 additions & 20 deletions pfb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from daskms.experimental.zarr import xds_from_zarr
from omegaconf import ListConfig
from skimage.morphology import label
from scipy.optimize import curve_fit
from scipy.optimize import curve_fit, fmin_l_bfgs_b
from collections import namedtuple
from africanus.coordinates.coordinates import radec_to_lmn
import xarray as xr
Expand All @@ -25,6 +25,9 @@
import sympy as sm
from sympy.utilities.lambdify import lambdify
from sympy.parsing.sympy_parser import parse_expr
import jax.numpy as jnp
from jax import value_and_grad
import jax

JIT_OPTIONS = {
"nogil": True,
Expand Down Expand Up @@ -499,6 +502,29 @@ def _restore_corrs(vis, ncorr):
model_vis[:, :, -1] = vis
return model_vis


# model to fit
@jax.jit
def psf_errorsq(x, data, xy):
emaj, emin, pa = x
Smin = jnp.minimum(emaj, emin)
Smaj = jnp.maximum(emaj, emin)
# print(emaj, emin, pa)
A = jnp.array([[1. / Smin ** 2, 0],
[0, 1. / Smaj ** 2]])

c, s, t = jnp.cos, jnp.sin, jnp.deg2rad(-pa)
R = jnp.array([[c(t), -s(t)],
[s(t), c(t)]])
B = jnp.dot(jnp.dot(R.T, A), R)
Q = jnp.einsum('nb,bc,cn->n', xy.T, B, xy)
# GaussPar should corresponds to FWHM
fwhm_conv = 2 * jnp.sqrt(2 * np.log(2))
model = jnp.exp(-fwhm_conv * Q)
res = data - model
return jnp.vdot(res, res)


def fitcleanbeam(psf: np.ndarray,
level: float = 0.5,
pixsize: float = 1.0,
Expand All @@ -516,23 +542,6 @@ def fitcleanbeam(psf: np.ndarray,
y = np.arange(-ny / 2, ny / 2)
xx, yy = np.meshgrid(x, y, indexing='ij')

# model to fit
def func(xy, emaj, emin, pa):
Smin = np.minimum(emaj, emin)
Smaj = np.maximum(emaj, emin)

A = np.array([[1. / Smin ** 2, 0],
[0, 1. / Smaj ** 2]])

c, s, t = np.cos, np.sin, np.deg2rad(-pa)
R = np.array([[c(t), -s(t)],
[s(t), c(t)]])
A = np.dot(np.dot(R.T, A), R)
R = np.einsum('nb,bc,cn->n', xy.T, A, xy)
# GaussPar should corresponds to FWHM
fwhm_conv = 2 * np.sqrt(2 * np.log(2))
return np.exp(-fwhm_conv * R)

Gausspars = []
for v in range(nband):
# make sure psf is normalised
Expand Down Expand Up @@ -563,8 +572,12 @@ def func(xy, emaj, emin, pa):
xy = np.vstack((x, y))
emaj0 = np.maximum(xdiff, ydiff)
emin0 = np.minimum(xdiff, ydiff)
p, _ = curve_fit(func, xy, psfv, p0=(emaj0, emin0, 0.0),
maxfev=2000)
dfunc = value_and_grad(psf_errorsq)
p, f, d = fmin_l_bfgs_b(dfunc,
np.array((emaj0, emin0, 0.0)),
args=(psfv, xy),
bounds=((0, None), (0, None), (None, None)),
factr=1e11)
Gausspars.append([p[0] * pixsize, p[1] * pixsize, p[2]])

return Gausspars
Expand Down
1 change: 0 additions & 1 deletion pfb/workers/spotless.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def _spotless(ddsi=None, **kw):
rms_comps,
alpha=opts.alpha)
l1weight = reweighter(model)
# l1weight[l1weight < 1.0] = 0.0
else:
l1weight = np.ones((nbasis, Nymax, Nxmax), dtype=dirty.dtype)
reweighter = None
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"sympy",
"stimela >= 2.0rc14",
"streamjoy",
"tbb"
"tbb",
"jax[cpu]"
]


Expand Down

0 comments on commit d083703

Please sign in to comment.