diff --git a/pfb/deconv/clark.py b/pfb/deconv/clark.py index 4796a59bb..322899afd 100644 --- a/pfb/deconv/clark.py +++ b/pfb/deconv/clark.py @@ -1,14 +1,14 @@ import numpy as np import numexpr as ne from functools import partial -import numba +from numba import njit import dask.array as da from pfb.operators.psf import psf_convolve_cube from ducc0.misc import make_noncritical import pyscilog log = pyscilog.get_logger('CLARK') -@numba.jit(nopython=True, nogil=True, cache=True) # parallel=True, +@njit(nogil=True, cache=True) # parallel=True, def subtract(A, psf, Ip, Iq, xhat, nxo2, nyo2): ''' Subtract psf centered at location of xhat @@ -25,7 +25,7 @@ def subtract(A, psf, Ip, Iq, xhat, nxo2, nyo2): return A -# @numba.jit(nopython=True, nogil=True, cache=True) # parallel=True, +# @njit(nogil=True, cache=True) # parallel=True, def subminor(A, psf, Ip, Iq, model, wsums, gamma=0.05, th=0.0, maxit=10000): """ Run subminor loop in active set diff --git a/pfb/operators/gauss.py b/pfb/operators/gauss.py index e46b364a8..45b015642 100644 --- a/pfb/operators/gauss.py +++ b/pfb/operators/gauss.py @@ -8,7 +8,7 @@ Fs = np.fft.fftshift -@njit(parallel=True, nogil=True, fastmath=True, inline='always') +@njit(parallel=True, nogil=True, cache=True, inline='always') def freqmul(A, x): nchan, npix = x.shape out = np.zeros((nchan, npix), dtype=x.dtype) @@ -19,7 +19,7 @@ def freqmul(A, x): return out -@njit(parallel=True, nogil=True, fastmath=True, inline='always') +@njit(parallel=True, nogil=True, cache=True, inline='always') def make_kernel(nx_psf, ny_psf, sigma0, length_scale): K = np.zeros((1, nx_psf, ny_psf), dtype=np.float64) for j in range(nx_psf): diff --git a/pfb/operators/psi.py b/pfb/operators/psi.py index 2e76f4b9a..f3a826d89 100644 --- a/pfb/operators/psi.py +++ b/pfb/operators/psi.py @@ -7,7 +7,7 @@ from pfb.wavelets.wavelets import wavedecn, waverecn, ravel_coeffs, unravel_coeffs -@numba.njit(nogil=True, fastmath=True, cache=True) +@numba.njit(nogil=True, cache=True) def pad(x, n): ''' pad 1D array by n zeros @@ -135,7 +135,7 @@ def im2coef(x, alpha, bases, ntot, nmax, nlevels, nthreads=1): # return graph -@numba.njit(nogil=True, fastmath=True, cache=True, parallel=True) +@numba.njit(nogil=True, cache=True, parallel=True) def im2coef_dist(x, bases, ntot, nmax, nlevels): ''' Per band image to coefficients @@ -158,7 +158,7 @@ def im2coef_dist(x, bases, ntot, nmax, nlevels): return alpha -@numba.njit(nogil=True, fastmath=True, cache=True, parallel=True) +@numba.njit(nogil=True, cache=True, parallel=True) def coef2im_dist(alpha, bases, ntot, iy, sy, nx, ny): ''' Per band coefficients to image diff --git a/pfb/prox/prox_21.py b/pfb/prox/prox_21.py index 526d7f349..98fe61713 100644 --- a/pfb/prox/prox_21.py +++ b/pfb/prox/prox_21.py @@ -22,7 +22,7 @@ def prox_21(v, sigma, weight=None, axis=0): return v * np.expand_dims(ratio, axis=axis) # restores axis -@njit(nogil=True, fastmath=True, cache=True, parallel=True) +@njit(nogil=True, cache=True, parallel=True) def prox_21_numba(v, result, lam, sigma=1.0, weight=None): """ Computes weighted version of @@ -63,7 +63,7 @@ def dual_update(v, x, psiH, lam, sigma=1.0, weight=1.0): -@njit(nogil=True, fastmath=True, cache=True, parallel=True) +@njit(nogil=True, cache=True, parallel=True) def dual_update_numba(vp, v, lam, sigma=1.0, weight=None): """ Computes weighted version of diff --git a/pfb/prox/prox_21m.py b/pfb/prox/prox_21m.py index ecb5e60c5..62ded441a 100644 --- a/pfb/prox/prox_21m.py +++ b/pfb/prox/prox_21m.py @@ -26,7 +26,7 @@ def prox_21m(v, sigma, weight=1.0, axis=0): return v * np.expand_dims(ratio, axis=axis) # restores axis -@njit(nogil=True, fastmath=True, cache=True, parallel=True) +@njit(nogil=True, cache=True, parallel=True) def prox_21m_numba(v, result, lam, sigma=1.0, weight=None): """ Computes weighted version of @@ -68,7 +68,7 @@ def dual_update(v, x, psiH, lam, sigma=1.0, weight=1.0): -@njit(nogil=True, fastmath=True, cache=True, parallel=True) +@njit(nogil=True, cache=True, parallel=True) def dual_update_numba(vp, v, lam, sigma=1.0, weight=None): """ Computes weighted version of diff --git a/pfb/utils/correlations.py b/pfb/utils/correlations.py index d7fd5d7a0..a198add48 100644 --- a/pfb/utils/correlations.py +++ b/pfb/utils/correlations.py @@ -227,14 +227,14 @@ def corr_funcs(data, jones): # The expressions for DIAG_DIAG and DIAG mode are essentially the same if jones.ndim == 5: # I and Q have identical weights - @njit(nogil=True, fastmath=True, inline='always') + @njit(nogil=True, cache=True, inline='always') def wfunc(gp, gq, W): gp00 = gp[0] gq00 = gq[0] W0 = W[0] return np.real(W0*gp00*gq00*np.conjugate(gp00)*np.conjugate(gq00)) - @njit(nogil=True, fastmath=True, inline='always') + @njit(nogil=True, cache=True, inline='always') def vfunc(gp, gq, W, V): gp00 = gp[0] gq00 = gq[0] diff --git a/pfb/utils/misc.py b/pfb/utils/misc.py index 11a7d632d..ebe817ec1 100644 --- a/pfb/utils/misc.py +++ b/pfb/utils/misc.py @@ -28,7 +28,6 @@ JIT_OPTIONS = { "nogil": True, - "fastmath": True, "cache": True } @@ -71,7 +70,7 @@ def kron_matvec(A, b): x = Z.ravel() return x.reshape(b.shape) -@jit(nopython=True, fastmath=True, parallel=False, cache=True, nogil=True) +@njit(parallel=False, cache=True, nogil=True) def kron_matvec2(A, b): D = len(A) N = b.size @@ -1346,7 +1345,7 @@ def remove_large_islands(x, max_island_size=100): return x -@njit(parallel=False, nogil=True, fastmath=True, inline='always') +@njit(parallel=False, nogil=True, cache=True, inline='always') def freqmul(A, x): nband, nx, ny = x.shape out = np.zeros_like(x) diff --git a/pfb/utils/stokes.py b/pfb/utils/stokes.py index 66fb16e40..4b80fd43d 100644 --- a/pfb/utils/stokes.py +++ b/pfb/utils/stokes.py @@ -374,7 +374,7 @@ def stokes_funcs(data, jones, product, pol, nc): gq00, gq01, gq10, gq11, w0, w1, w2, w3), sm.simplify(sm.expand(W[i,i]))) - Wjfn = njit(nogil=True, fastmath=True, inline='always')(Wsymb) + Wjfn = njit(nogil=True, inline='always')(Wsymb) Dsymb = lambdify((gp00, gp01, gp10, gp11, @@ -382,9 +382,9 @@ def stokes_funcs(data, jones, product, pol, nc): w0, w1, w2, w3, v00, v01, v10, v11), sm.simplify(sm.expand(C[i]))) - Djfn = njit(nogil=True, fastmath=True, inline='always')(Dsymb) + Djfn = njit(nogil=True, inline='always')(Dsymb) - @njit(nogil=True, fastmath=True, inline='always') + @njit(nogil=True, cache=True, inline='always') def wfunc(gp, gq, W): gp00 = gp[0,0] gp01 = gp[0,1] @@ -402,7 +402,7 @@ def wfunc(gp, gq, W): gq00, gq01, gq10, gq11, W00, W01, W10, W11).real - @njit(nogil=True, fastmath=True, inline='always') + @njit(nogil=True, cache=True, inline='always') def vfunc(gp, gq, W, V): gp00 = gp[0,0] gp01 = gp[0,1] @@ -439,7 +439,7 @@ def vfunc(gp, gq, W, V): gq00, gq11, w0, w1, w2, w3), sm.simplify(sm.expand(W[i,i]))) - Wjfn = njit(nogil=True, fastmath=True, inline='always')(Wsymb) + Wjfn = njit(nogil=True, cache=True, inline='always')(Wsymb) Dsymb = lambdify((gp00, gp11, @@ -447,10 +447,10 @@ def vfunc(gp, gq, W, V): w0, w1, w2, w3, v00, v01, v10, v11), sm.simplify(sm.expand(C[i]))) - Djfn = njit(nogil=True, fastmath=True, inline='always')(Dsymb) + Djfn = njit(nogil=True, cache=True, inline='always')(Dsymb) if nc==literal('4'): - @njit(nogil=True, fastmath=True, inline='always') + @njit(nogil=True, cache=True, inline='always') def wfunc(gp, gq, W): gp00 = gp[0] gp11 = gp[1] @@ -464,7 +464,7 @@ def wfunc(gp, gq, W): gq00, gq11, W00, W01, W10, W11).real - @njit(nogil=True, fastmath=True, inline='always') + @njit(nogil=True, cache=True, inline='always') def vfunc(gp, gq, W, V): gp00 = gp[0] gp11 = gp[1] @@ -483,7 +483,7 @@ def vfunc(gp, gq, W, V): W00, W01, W10, W11, V00, V01, V10, V11) elif nc==literal('2'): - @njit(nogil=True, fastmath=True, inline='always') + @njit(nogil=True, cache=True, inline='always') def wfunc(gp, gq, W): gp00 = gp[0] gp11 = gp[1] @@ -497,7 +497,7 @@ def wfunc(gp, gq, W): gq00, gq11, W00, W01, W10, W11).real - @njit(nogil=True, fastmath=True, inline='always') + @njit(nogil=True, cache=True, inline='always') def vfunc(gp, gq, W, V): gp00 = gp[0] gp11 = gp[1] diff --git a/pfb/utils/weighting.py b/pfb/utils/weighting.py index 25ef62c88..cb801ddac 100644 --- a/pfb/utils/weighting.py +++ b/pfb/utils/weighting.py @@ -37,7 +37,7 @@ def compute_counts_wrapper(uvw, freq, mask, nx, ny, -@njit(nogil=True, fastmath=True, cache=True, parallel=True) +@njit(nogil=True, cache=True, parallel=True) def _compute_counts(uvw, freq, mask, nx, ny, cell_size_x, cell_size_y, dtype, k=6, ngrid=1): # support hardcoded for now @@ -99,7 +99,7 @@ def _compute_counts(uvw, freq, mask, nx, ny, counts[g, u_idx, v_idx] += 1.0 return counts #.sum(axis=0, keepdims=True) -@njit(nogil=True, fastmath=True, cache=True, inline='always') +@njit(nogil=True, cache=True, inline='always') def _es_kernel(x, beta, k): return np.exp(beta*k*(np.sqrt((1-x)*(1+x)) - 1)) @@ -124,7 +124,7 @@ def counts_to_weights_wrapper(counts, uvw, freq, nx, ny, cell_size_x, cell_size_y, robust) -@njit(nogil=True, fastmath=True, cache=True) +@njit(nogil=True, cache=True) def _counts_to_weights(counts, uvw, freq, nx, ny, cell_size_x, cell_size_y, robust): # ufreq @@ -178,7 +178,7 @@ def filter_extreme_counts(counts, nbox=16, nlevel=10): -@njit(nogil=True, fastmath=True, cache=True) +@njit(nogil=True, cache=True) def _filter_extreme_counts(counts, nbox=16, level=10): ''' Replaces extreme counts by local mean computed i diff --git a/pfb/wavelets/wavelets.py b/pfb/wavelets/wavelets.py index 961449219..410e4b366 100644 --- a/pfb/wavelets/wavelets.py +++ b/pfb/wavelets/wavelets.py @@ -188,7 +188,7 @@ def impl(axis, ndim): return impl -@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True) +@numba.generated_jit(nopython=True, nogil=True, cache=True) def dwt_axis(data, wavelet, mode, axis): def impl(data, wavelet, mode, axis): coeff_len = dwt_coeff_length( @@ -245,7 +245,7 @@ def impl(data, wavelet, mode, axis): return impl -@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True) +@numba.generated_jit(nopython=True, nogil=True, cache=True) def idwt_axis(approx_coeffs, detail_coeffs, wavelet, mode, axis): @@ -343,7 +343,7 @@ def impl(approx_coeffs, detail_coeffs, return impl -@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True) +@numba.generated_jit(nopython=True, nogil=True, cache=True) def dwt(data, wavelet, mode="zero", axis=None): if isinstance(data, nbtypes.misc.Optional): @@ -406,7 +406,7 @@ def coeff_product(args, repeat=1): return result -@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True) +@numba.generated_jit(nopython=True, nogil=True, cache=True) def idwt(coeffs, wavelet, mode='zero', axis=None): have_axis = not is_nonelike(axis) @@ -492,7 +492,7 @@ def impl(sizes, dec_lens, level=None): return impl -@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True) +@numba.generated_jit(nopython=True, nogil=True, cache=True) def wavedecn(data, wavelet, mode='zero', level=None, axis=None): have_axis = not is_nonelike(axis) @@ -527,7 +527,7 @@ def impl(data, wavelet, mode='zero', level=None, axis=None): return impl -@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True) +@numba.generated_jit(nopython=True, nogil=True, cache=True) def waverecn(coeffs, wavelet, mode='zero', axis=None): # ca = coeffs[0]['aa'] # if not isinstance(ca, nbtypes.npytypes.Array): @@ -570,7 +570,7 @@ def impl(coeffs, wavelet, mode='zero', axis=None): return impl -@numba.njit(nogil=True, fastmath=True, cache=True) +@numba.njit(nogil=True, cache=True) def ravel_coeffs(coeffs): a_coeffs = coeffs[0]['aa'] @@ -624,7 +624,7 @@ def ravel_coeffs(coeffs): return coeff_arr, coeff_tuples, coeff_shapes -@numba.njit(nogil=True, fastmath=True, cache=True) +@numba.njit(nogil=True, cache=True) def unravel_coeffs(arr, coeff_tuples, coeff_shapes, output_format='wavedecn'): arr = np.asarray(arr) coeffs = List() @@ -644,7 +644,7 @@ def unravel_coeffs(arr, coeff_tuples, coeff_shapes, output_format='wavedecn'): return coeffs -@numba.njit(nogil=True, fastmath=True, cache=True) +@numba.njit(nogil=True, cache=True) def wavelet_setup(x, bases, nlevels): # set up dictionary info tys = Dict()