Skip to content

Commit

Permalink
no fastmath in jit
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Feb 21, 2024
1 parent 2e8d9ed commit 1889d09
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 40 deletions.
6 changes: 3 additions & 3 deletions pfb/deconv/clark.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import numpy as np
import numexpr as ne
from functools import partial
import numba
from numba import njit
import dask.array as da
from pfb.operators.psf import psf_convolve_cube
from ducc0.misc import make_noncritical
import pyscilog
log = pyscilog.get_logger('CLARK')

@numba.jit(nopython=True, nogil=True, cache=True) # parallel=True,
@njit(nogil=True, cache=True) # parallel=True,
def subtract(A, psf, Ip, Iq, xhat, nxo2, nyo2):
'''
Subtract psf centered at location of xhat
Expand All @@ -25,7 +25,7 @@ def subtract(A, psf, Ip, Iq, xhat, nxo2, nyo2):
return A


# @numba.jit(nopython=True, nogil=True, cache=True) # parallel=True,
# @njit(nogil=True, cache=True) # parallel=True,
def subminor(A, psf, Ip, Iq, model, wsums, gamma=0.05, th=0.0, maxit=10000):
"""
Run subminor loop in active set
Expand Down
4 changes: 2 additions & 2 deletions pfb/operators/gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Fs = np.fft.fftshift


@njit(parallel=True, nogil=True, fastmath=True, inline='always')
@njit(parallel=True, nogil=True, cache=True, inline='always')
def freqmul(A, x):
nchan, npix = x.shape
out = np.zeros((nchan, npix), dtype=x.dtype)
Expand All @@ -19,7 +19,7 @@ def freqmul(A, x):
return out


@njit(parallel=True, nogil=True, fastmath=True, inline='always')
@njit(parallel=True, nogil=True, cache=True, inline='always')
def make_kernel(nx_psf, ny_psf, sigma0, length_scale):
K = np.zeros((1, nx_psf, ny_psf), dtype=np.float64)
for j in range(nx_psf):
Expand Down
6 changes: 3 additions & 3 deletions pfb/operators/psi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pfb.wavelets.wavelets import wavedecn, waverecn, ravel_coeffs, unravel_coeffs


@numba.njit(nogil=True, fastmath=True, cache=True)
@numba.njit(nogil=True, cache=True)
def pad(x, n):
'''
pad 1D array by n zeros
Expand Down Expand Up @@ -135,7 +135,7 @@ def im2coef(x, alpha, bases, ntot, nmax, nlevels, nthreads=1):
# return graph


@numba.njit(nogil=True, fastmath=True, cache=True, parallel=True)
@numba.njit(nogil=True, cache=True, parallel=True)
def im2coef_dist(x, bases, ntot, nmax, nlevels):
'''
Per band image to coefficients
Expand All @@ -158,7 +158,7 @@ def im2coef_dist(x, bases, ntot, nmax, nlevels):
return alpha


@numba.njit(nogil=True, fastmath=True, cache=True, parallel=True)
@numba.njit(nogil=True, cache=True, parallel=True)
def coef2im_dist(alpha, bases, ntot, iy, sy, nx, ny):
'''
Per band coefficients to image
Expand Down
4 changes: 2 additions & 2 deletions pfb/prox/prox_21.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def prox_21(v, sigma, weight=None, axis=0):
return v * np.expand_dims(ratio, axis=axis) # restores axis


@njit(nogil=True, fastmath=True, cache=True, parallel=True)
@njit(nogil=True, cache=True, parallel=True)
def prox_21_numba(v, result, lam, sigma=1.0, weight=None):
"""
Computes weighted version of
Expand Down Expand Up @@ -63,7 +63,7 @@ def dual_update(v, x, psiH, lam, sigma=1.0, weight=1.0):



@njit(nogil=True, fastmath=True, cache=True, parallel=True)
@njit(nogil=True, cache=True, parallel=True)
def dual_update_numba(vp, v, lam, sigma=1.0, weight=None):
"""
Computes weighted version of
Expand Down
4 changes: 2 additions & 2 deletions pfb/prox/prox_21m.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def prox_21m(v, sigma, weight=1.0, axis=0):
return v * np.expand_dims(ratio, axis=axis) # restores axis


@njit(nogil=True, fastmath=True, cache=True, parallel=True)
@njit(nogil=True, cache=True, parallel=True)
def prox_21m_numba(v, result, lam, sigma=1.0, weight=None):
"""
Computes weighted version of
Expand Down Expand Up @@ -68,7 +68,7 @@ def dual_update(v, x, psiH, lam, sigma=1.0, weight=1.0):



@njit(nogil=True, fastmath=True, cache=True, parallel=True)
@njit(nogil=True, cache=True, parallel=True)
def dual_update_numba(vp, v, lam, sigma=1.0, weight=None):
"""
Computes weighted version of
Expand Down
4 changes: 2 additions & 2 deletions pfb/utils/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,14 @@ def corr_funcs(data, jones):
# The expressions for DIAG_DIAG and DIAG mode are essentially the same
if jones.ndim == 5:
# I and Q have identical weights
@njit(nogil=True, fastmath=True, inline='always')
@njit(nogil=True, cache=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0]
gq00 = gq[0]
W0 = W[0]
return np.real(W0*gp00*gq00*np.conjugate(gp00)*np.conjugate(gq00))

@njit(nogil=True, fastmath=True, inline='always')
@njit(nogil=True, cache=True, inline='always')
def vfunc(gp, gq, W, V):
gp00 = gp[0]
gq00 = gq[0]
Expand Down
5 changes: 2 additions & 3 deletions pfb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

JIT_OPTIONS = {
"nogil": True,
"fastmath": True,
"cache": True
}

Expand Down Expand Up @@ -71,7 +70,7 @@ def kron_matvec(A, b):
x = Z.ravel()
return x.reshape(b.shape)

@jit(nopython=True, fastmath=True, parallel=False, cache=True, nogil=True)
@njit(parallel=False, cache=True, nogil=True)
def kron_matvec2(A, b):
D = len(A)
N = b.size
Expand Down Expand Up @@ -1346,7 +1345,7 @@ def remove_large_islands(x, max_island_size=100):
return x


@njit(parallel=False, nogil=True, fastmath=True, inline='always')
@njit(parallel=False, nogil=True, cache=True, inline='always')
def freqmul(A, x):
nband, nx, ny = x.shape
out = np.zeros_like(x)
Expand Down
20 changes: 10 additions & 10 deletions pfb/utils/stokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,17 +374,17 @@ def stokes_funcs(data, jones, product, pol, nc):
gq00, gq01, gq10, gq11,
w0, w1, w2, w3),
sm.simplify(sm.expand(W[i,i])))
Wjfn = njit(nogil=True, fastmath=True, inline='always')(Wsymb)
Wjfn = njit(nogil=True, inline='always')(Wsymb)


Dsymb = lambdify((gp00, gp01, gp10, gp11,
gq00, gq01, gq10, gq11,
w0, w1, w2, w3,
v00, v01, v10, v11),
sm.simplify(sm.expand(C[i])))
Djfn = njit(nogil=True, fastmath=True, inline='always')(Dsymb)
Djfn = njit(nogil=True, inline='always')(Dsymb)

@njit(nogil=True, fastmath=True, inline='always')
@njit(nogil=True, cache=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0,0]
gp01 = gp[0,1]
Expand All @@ -402,7 +402,7 @@ def wfunc(gp, gq, W):
gq00, gq01, gq10, gq11,
W00, W01, W10, W11).real

@njit(nogil=True, fastmath=True, inline='always')
@njit(nogil=True, cache=True, inline='always')
def vfunc(gp, gq, W, V):
gp00 = gp[0,0]
gp01 = gp[0,1]
Expand Down Expand Up @@ -439,18 +439,18 @@ def vfunc(gp, gq, W, V):
gq00, gq11,
w0, w1, w2, w3),
sm.simplify(sm.expand(W[i,i])))
Wjfn = njit(nogil=True, fastmath=True, inline='always')(Wsymb)
Wjfn = njit(nogil=True, cache=True, inline='always')(Wsymb)


Dsymb = lambdify((gp00, gp11,
gq00, gq11,
w0, w1, w2, w3,
v00, v01, v10, v11),
sm.simplify(sm.expand(C[i])))
Djfn = njit(nogil=True, fastmath=True, inline='always')(Dsymb)
Djfn = njit(nogil=True, cache=True, inline='always')(Dsymb)

if nc==literal('4'):
@njit(nogil=True, fastmath=True, inline='always')
@njit(nogil=True, cache=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0]
gp11 = gp[1]
Expand All @@ -464,7 +464,7 @@ def wfunc(gp, gq, W):
gq00, gq11,
W00, W01, W10, W11).real

@njit(nogil=True, fastmath=True, inline='always')
@njit(nogil=True, cache=True, inline='always')
def vfunc(gp, gq, W, V):
gp00 = gp[0]
gp11 = gp[1]
Expand All @@ -483,7 +483,7 @@ def vfunc(gp, gq, W, V):
W00, W01, W10, W11,
V00, V01, V10, V11)
elif nc==literal('2'):
@njit(nogil=True, fastmath=True, inline='always')
@njit(nogil=True, cache=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0]
gp11 = gp[1]
Expand All @@ -497,7 +497,7 @@ def wfunc(gp, gq, W):
gq00, gq11,
W00, W01, W10, W11).real

@njit(nogil=True, fastmath=True, inline='always')
@njit(nogil=True, cache=True, inline='always')
def vfunc(gp, gq, W, V):
gp00 = gp[0]
gp11 = gp[1]
Expand Down
8 changes: 4 additions & 4 deletions pfb/utils/weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def compute_counts_wrapper(uvw, freq, mask, nx, ny,



@njit(nogil=True, fastmath=True, cache=True, parallel=True)
@njit(nogil=True, cache=True, parallel=True)
def _compute_counts(uvw, freq, mask, nx, ny,
cell_size_x, cell_size_y, dtype,
k=6, ngrid=1): # support hardcoded for now
Expand Down Expand Up @@ -99,7 +99,7 @@ def _compute_counts(uvw, freq, mask, nx, ny,
counts[g, u_idx, v_idx] += 1.0
return counts #.sum(axis=0, keepdims=True)

@njit(nogil=True, fastmath=True, cache=True, inline='always')
@njit(nogil=True, cache=True, inline='always')
def _es_kernel(x, beta, k):
return np.exp(beta*k*(np.sqrt((1-x)*(1+x)) - 1))

Expand All @@ -124,7 +124,7 @@ def counts_to_weights_wrapper(counts, uvw, freq, nx, ny,
cell_size_x, cell_size_y, robust)


@njit(nogil=True, fastmath=True, cache=True)
@njit(nogil=True, cache=True)
def _counts_to_weights(counts, uvw, freq, nx, ny,
cell_size_x, cell_size_y, robust):
# ufreq
Expand Down Expand Up @@ -178,7 +178,7 @@ def filter_extreme_counts(counts, nbox=16, nlevel=10):



@njit(nogil=True, fastmath=True, cache=True)
@njit(nogil=True, cache=True)
def _filter_extreme_counts(counts, nbox=16, level=10):
'''
Replaces extreme counts by local mean computed i
Expand Down
18 changes: 9 additions & 9 deletions pfb/wavelets/wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def impl(axis, ndim):
return impl


@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True)
@numba.generated_jit(nopython=True, nogil=True, cache=True)
def dwt_axis(data, wavelet, mode, axis):
def impl(data, wavelet, mode, axis):
coeff_len = dwt_coeff_length(
Expand Down Expand Up @@ -245,7 +245,7 @@ def impl(data, wavelet, mode, axis):
return impl


@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True)
@numba.generated_jit(nopython=True, nogil=True, cache=True)
def idwt_axis(approx_coeffs, detail_coeffs,
wavelet, mode, axis):

Expand Down Expand Up @@ -343,7 +343,7 @@ def impl(approx_coeffs, detail_coeffs,
return impl


@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True)
@numba.generated_jit(nopython=True, nogil=True, cache=True)
def dwt(data, wavelet, mode="zero", axis=None):

if isinstance(data, nbtypes.misc.Optional):
Expand Down Expand Up @@ -406,7 +406,7 @@ def coeff_product(args, repeat=1):
return result


@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True)
@numba.generated_jit(nopython=True, nogil=True, cache=True)
def idwt(coeffs, wavelet, mode='zero', axis=None):

have_axis = not is_nonelike(axis)
Expand Down Expand Up @@ -492,7 +492,7 @@ def impl(sizes, dec_lens, level=None):
return impl


@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True)
@numba.generated_jit(nopython=True, nogil=True, cache=True)
def wavedecn(data, wavelet, mode='zero', level=None, axis=None):
have_axis = not is_nonelike(axis)

Expand Down Expand Up @@ -527,7 +527,7 @@ def impl(data, wavelet, mode='zero', level=None, axis=None):
return impl


@numba.generated_jit(nopython=True, nogil=True, fastmath=True, cache=True)
@numba.generated_jit(nopython=True, nogil=True, cache=True)
def waverecn(coeffs, wavelet, mode='zero', axis=None):
# ca = coeffs[0]['aa']
# if not isinstance(ca, nbtypes.npytypes.Array):
Expand Down Expand Up @@ -570,7 +570,7 @@ def impl(coeffs, wavelet, mode='zero', axis=None):
return impl


@numba.njit(nogil=True, fastmath=True, cache=True)
@numba.njit(nogil=True, cache=True)
def ravel_coeffs(coeffs):
a_coeffs = coeffs[0]['aa']

Expand Down Expand Up @@ -624,7 +624,7 @@ def ravel_coeffs(coeffs):
return coeff_arr, coeff_tuples, coeff_shapes


@numba.njit(nogil=True, fastmath=True, cache=True)
@numba.njit(nogil=True, cache=True)
def unravel_coeffs(arr, coeff_tuples, coeff_shapes, output_format='wavedecn'):
arr = np.asarray(arr)
coeffs = List()
Expand All @@ -644,7 +644,7 @@ def unravel_coeffs(arr, coeff_tuples, coeff_shapes, output_format='wavedecn'):
return coeffs


@numba.njit(nogil=True, fastmath=True, cache=True)
@numba.njit(nogil=True, cache=True)
def wavelet_setup(x, bases, nlevels):
# set up dictionary info
tys = Dict()
Expand Down

0 comments on commit 1889d09

Please sign in to comment.