From de9558e82610aeac2bd8958375544324372e92b5 Mon Sep 17 00:00:00 2001 From: Jiwoo Lee Date: Tue, 2 Jan 2024 23:29:23 -0800 Subject: [PATCH] bug fix: when dask import, use `dask.array.where` instead of `np.where` --- lib/eofs/standard.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/eofs/standard.py b/lib/eofs/standard.py index 445b6d7..811b528 100644 --- a/lib/eofs/standard.py +++ b/lib/eofs/standard.py @@ -153,7 +153,10 @@ def __init__(self, dataset, weights=None, center=True, ddof=1): if not self._valid_nan(self._data): raise ValueError('missing values detected in different ' 'locations at different times') - nonMissingIndex = np.where(np.logical_not(np.isnan(self._data[0])))[0] + if has_dask: + nonMissingIndex = dask.array.where(np.logical_not(np.isnan(self._data[0])))[0] + else: + nonMissingIndex = np.where(np.logical_not(np.isnan(self._data[0])))[0] # Remove missing values from the design matrix. dataNoMissing = self._data[:, nonMissingIndex] if dataNoMissing.size == 0: @@ -741,7 +744,10 @@ def projectField(self, field, neofs=None, eofscaling=0, weighted=True): if not self._valid_nan(field_flat): raise ValueError('missing values detected in different ' 'locations at different times') - nonMissingIndex = np.where(np.logical_not(np.isnan(field_flat[0])))[0] + if has_dask: + nonMissingIndex = dask.array.where(np.logical_not(np.isnan(self._data[0])))[0] # lee1043 testing + else: + nonMissingIndex = np.where(np.logical_not(np.isnan(field_flat[0])))[0] try: # Compute chunk sizes if nonMissingIndex is a dask array, so its # shape can be compared with eofsNonMissingIndex later.