From 9bbae76fde3ea181c44fccfb64c52086b8b3f6e5 Mon Sep 17 00:00:00 2001 From: Derrick Chambers Date: Thu, 13 Jun 2024 12:29:30 -0600 Subject: [PATCH] Multi source correlate (#400) --------- Co-authored-by: ahmadtourei --- dascore/core/coords.py | 104 +++++++++---- dascore/core/patch.py | 2 + dascore/proc/__init__.py | 2 +- dascore/proc/basic.py | 132 ++++++++++------ dascore/proc/correlate.py | 250 +++++++++++++++++++----------- dascore/proc/whiten.py | 8 +- dascore/transform/fourier.py | 3 +- dascore/units.py | 3 + dascore/utils/patch.py | 11 +- tests/test_core/test_coords.py | 56 ++++++- tests/test_proc/test_basic.py | 66 +++++++- tests/test_proc/test_correlate.py | 84 ++++++++-- tests/test_proc/test_whiten.py | 8 +- tests/test_units.py | 6 + 14 files changed, 531 insertions(+), 204 deletions(-) diff --git a/dascore/core/coords.py b/dascore/core/coords.py index bcea2956..b0d10caa 100644 --- a/dascore/core/coords.py +++ b/dascore/core/coords.py @@ -790,30 +790,38 @@ def get_next_index(self, value, samples=False, allow_out_of_bounds=False) -> int >>> # The next (not closest) index is return for value not in coord. >>> assert coord.get_next_index(2.000001) == 3 """ + if not self.sorted: + msg = f"Coords must be sorted to use get_next_index, {self} is not." + raise CoordError(msg) + input_array_like = isinstance(value, Sized) + array = np.atleast_1d(value) # handle samples if samples: min_val, max_val = 0, len(self) - 1 - value = int(np.round(value)) + array = array.astype(np.int64) + wrap_around = array < 0 # account for negative indexing - value = value if value >= 0 else value + max_val + 1 + array[wrap_around] = array[wrap_around] + max_val + 1 else: - value = self._get_compatible_value(value) + array = self._get_compatible_value(array) min_val, max_val = self.min(), self.max() # handle out of bounds cases - if (is_gt := value > max_val) or (value < min_val): - if not allow_out_of_bounds: - msg = f"Value: {value} is out of bounds for {self}" - raise ValueError(msg) - return max_val if is_gt else min_val + is_gt, is_lt = array > max_val, array < min_val + if not allow_out_of_bounds and np.any(is_gt | is_lt): + msg = f"Value: {array} is out of bounds for {self}" + raise ValueError(msg) + # Fix max values + array[is_gt] = max_val + array[is_lt] = min_val # samples should already have the answer, just return if samples: - return value + return array if input_array_like else array[0] # otherwise get forward and backward inds - for_index = self._get_index(value, forward=True) - back_index = self._get_index(value, forward=False) - ranges = [x for x in [for_index, back_index] if x is not None] - assert len(ranges) - return ranges[0] + forward_index = self._get_index(array, forward=True) + back_index = self._get_index(array, forward=False) + bad_for_index = pd.isnull(forward_index) | forward_index == -9999 + forward_index[bad_for_index] = back_index[bad_for_index] + return forward_index if input_array_like else forward_index[0] def approx_equal(self: BaseCoord, other: BaseCoord) -> bool: """ @@ -1015,6 +1023,9 @@ def validate_start_stop_step_len(cls, values): stop = start + step * length if pd.isnull(step): step = (stop - start) / length + # handle conversion to integer if other values are ints. + if isinstance(start, int) and isinstance(stop, int): + step = int(step) if np.isclose(np.round(step), step) else step if step != 0: int_val = int(np.ceil(np.round((stop - start) / step, 1))) stop = start + step * int_val @@ -1031,11 +1042,11 @@ def validate_start_stop_step_len(cls, values): raise CoordError(msg) # Note: dtype was a property before but it messed up model # serialization. - values["dtype"] = np.asarray(start).dtype + values["dtype"] = np.asarray(start + step).dtype return values def __getitem__(self, item): - if isinstance(item, int): + if isinstance(item, (int | np.integer)): if item >= len(self): raise IndexError(f"{item} exceeds coord length of {self}") return self.values[item] @@ -1082,6 +1093,7 @@ def select( if self.reverse_sorted: start, stop = stop, start # we add 1 to stop in slice since its upper limit is exclusive + start = None if start == 0 else start data = slice(start, (stop + 1) if stop is not None else stop) if self._slice_degenerate(data): return self.empty(), slice(0, 0) @@ -1109,14 +1121,19 @@ def _get_index(self, value, forward=True): """Get the index corresponding to a value.""" if (value := self._get_compatible_value(value)) is None: return value + input_is_array = isinstance(value, Sized) + array = np.atleast_1d(value) func = np.ceil if forward else np.floor - start, _, step = self.start, self.stop, self.step + start, step = self.start, self.step # Due to float weirdness we need a little bit of a fudge factor here. - fraction = func(np.round((value - start) / step, decimals=10)) - out = int(fraction) - if (out <= 0 and forward) or (out >= len(self) and not forward): + fraction = func(np.round((array - start) / step, decimals=10)) + out = fraction.astype(np.int64) + lt_forward = (out < 0) & forward + gt_back = (out >= len(self)) & (not forward) + bad_values = lt_forward | gt_back + if not input_is_array and np.any(bad_values): return None - return out + return out if input_is_array else int(out[0]) @compose_docstring(doc=BaseCoord.update_limits.__doc__) def update_limits(self, min=None, max=None, step=None, **kwargs) -> Self: @@ -1160,7 +1177,9 @@ def values(self) -> ArrayLike: if is_datetime64(self.start) or is_timedelta64(self.start): out = np.arange(self.start, self.stop, self.step) else: - out = np.linspace(self.start, self.stop - self.step, num=len(self)) + out = np.linspace( + self.start, self.stop - self.step, num=len(self), dtype=self.dtype + ) # again, due to round-off error the array can one element longer than # anticipated. The slice here just ensures shape and len match. return array(out[: len(self)]) @@ -1352,32 +1371,49 @@ def select( # by -1 in _get_index the inverted range is used. if self.reverse_sorted: v1, v2 = v2, v1 - start = self._get_index(v1, left=True) + start = self._get_index(v1, forward=False) new_start = start if start is not None and start > 0 else None - stop = self._get_index(v2, left=False) + stop = self._get_index(v2, forward=True) new_stop = stop if stop is not None and stop < len(self) else None + # We need to add 1 to end so 1 sample get selected if start == stop + if new_stop is not None: + if self.values[new_stop] == v2: + new_stop = new_stop + 1 out = slice(new_start, new_stop) if self._slice_degenerate(out): return self.empty(), slice(0, 0) return self.new(values=self.values[out]), out - def _get_index(self, value, left=True): + def _get_index(self, value, forward=True): """ Get the index corresponding to a value. - Left indicates if this is the min value. + Forward indicates if this is the max (left) value. """ - if (value := self._get_compatible_value(value)) is None: - return value - values = self.values - side_dict = {True: "left", False: "right"} + if (new_value := self._get_compatible_value(value)) is None: + return new_value + values = np.atleast_1d(self.values) # since search sorted only works on ascending monotonic arrays we # negative descending arrays to get the same effect. if self.reverse_sorted: values = to_float(values) * -1 - value = to_float(value) * -1 - ind = np.searchsorted(values, value, side=side_dict[left]) - return ind + new_value = to_float(new_value) * -1 + # side = "right" if forward else "left" + # out = np.atleast_1d(np.searchsorted(values, new_value, side=side)) + # Search values. Ensure the returned index is in bounds (eg values GT + # coord max should still have a range in coords. + new_value = np.atleast_1d(new_value) + right = np.searchsorted(values, new_value, side="right") + # right_ok = (right < len(self)) & (right < 0) + left = np.searchsorted(values, new_value, side="left") + left_ok = (left < len(self)) & (left > 0) + eq = left_ok & (values.take(left, mode="clip") == new_value) + out = right if forward else left + # where equal it should also be left values. This makes the function + # behavior consistent with BaseCoord._get_index. + if not self.reverse_sorted: + out[eq] = left[eq] + return out if is_array(value) else int(out[0]) def _step_meets_requirement(self, op): """Return True is any data increment meets the comp. requirement.""" @@ -1560,7 +1596,7 @@ def _maybe_get_start_stop_step(data): if isinstance(data, BaseCoord): # just return coordinate return data if not isinstance(data, np.ndarray): - data = array(data) + data = np.atleast_1d(data) if np.size(data) == 0: dtype = dtype or data.dtype return CoordPartial(shape=data.shape, units=units, step=step, dtype=dtype) diff --git a/dascore/core/patch.py b/dascore/core/patch.py index ce8e508a..4d009a7b 100644 --- a/dascore/core/patch.py +++ b/dascore/core/patch.py @@ -279,6 +279,7 @@ def iselect(self, *args, **kwargs): return self.select(*args, samples=True, **kwargs) correlate = dascore.proc.correlate + correlate_shift = dascore.proc.correlate_shift decimate = dascore.proc.decimate detrend = dascore.proc.detrend dropna = dascore.proc.dropna @@ -289,6 +290,7 @@ def iselect(self, *args, **kwargs): savgol_filter = dascore.proc.savgol_filter gaussian_filter = dascore.proc.gaussian_filter abs = dascore.proc.abs + conj = dascore.proc.conj real = dascore.proc.real imag = dascore.proc.imag angle = dascore.proc.angle diff --git a/dascore/proc/__init__.py b/dascore/proc/__init__.py index 4ae8a322..0841c727 100644 --- a/dascore/proc/__init__.py +++ b/dascore/proc/__init__.py @@ -6,7 +6,7 @@ import dascore.proc.aggregate as agg from .basic import * # noqa from .coords import * # noqa -from .correlate import correlate +from .correlate import correlate, correlate_shift from .detrend import detrend from .filter import median_filter, pass_filter, sobel_filter, savgol_filter, gaussian_filter from .resample import decimate, interpolate, resample diff --git a/dascore/proc/basic.py b/dascore/proc/basic.py index 6ee1baaa..08508c42 100644 --- a/dascore/proc/basic.py +++ b/dascore/proc/basic.py @@ -7,14 +7,16 @@ import numpy as np import pandas as pd +from scipy.fft import next_fast_len import dascore as dc from dascore.constants import DEFAULT_ATTRS_TO_IGNORE, PatchType from dascore.core.attrs import PatchAttrs from dascore.core.coordmanager import CoordManager, get_coord_manager from dascore.core.coords import get_coord -from dascore.exceptions import PatchBroadcastError, UnitError +from dascore.exceptions import ParameterError, PatchBroadcastError, UnitError from dascore.units import DimensionalityError, Quantity, Unit, get_quantity +from dascore.utils.misc import _get_nullish from dascore.utils.models import ArrayLike from dascore.utils.patch import ( _merge_aligned_coords, @@ -205,6 +207,23 @@ def abs(patch: PatchType) -> PatchType: return patch.new(data=np.abs(patch.data)) +@patch_function() +def conj(patch: PatchType) -> PatchType: + """ + Apply the complex conjugate of the patch data. + + Examples + -------- + >>> import dascore + >>> pa = dascore.get_example_patch() + >>> + >>> # Example 1 + >>> dft = pa.dft(None) # multi-dim dft + >>> conj = dft.conj() + """ + return patch.new(data=np.conj(patch.data)) + + @patch_function() def real(patch: PatchType) -> PatchType: """ @@ -551,7 +570,7 @@ def pad( patch: PatchType, mode: Literal["constant"] = "constant", constant_values: Any = 0, - expand_coords=False, + expand_coords=True, samples=False, **kwargs, ) -> PatchType: @@ -568,72 +587,89 @@ def pad( expand_coords : bool, optional Determines how coordinates are adjusted when padding is applied. If set to True, the coordinates will be expanded to maintain their - order and even sampling (if originally evenly sampled), by extrapolating - based on the coordinate's step size. - If set to False, the new coordinates introduced by padding will be - filled with NaN values, preserving the original coordinate values but - not the order or sampling rate. + order and even sampling (if evenly sampled), by extrapolating + based on the coordinate's step size. If set to False, or coordinate + is not evenly sampled, the new coordinates introduced by padding + will be padded with NaN values. **kwargs: Used to specify dimension and number of elements, either an integer or a tuple (before, after). + In addition, the following strings are supported: + + "fft" - pad to the next fast fft length along the given dimension by + adding values to the end of the axis. + + "correlate" - prepare the coordinate for correlation/convolution in + the frequency domain by pading to the next fast fft length after + 2*n - 1 where n is the current dimension length by adding values + to the end of the axis. Examples -------- >>> import dascore as dc >>> patch = dc.get_example_patch() - >>> # zero pad `time` dimension with 2 patch's time unit (e.g., sec) + >>> # Zero pad `time` dimension with 2 patch's time unit (e.g., sec) >>> # zeros before and 3 zeros after - >>> padded_patch_1 = patch.pad(time = (2, 3)) - >>> # zero pad `distance` dimension with 4 unit values before and after - >>> padded_patch_3 = patch.pad(distance = 4, constant_values = 1, samples=True) - """ - if isinstance(constant_values, list | tuple): - raise TypeError("constant_values must be a scalar, not a sequence.") - - pad_width = [(0, 0)] * len(patch.shape) - dimfo = get_multiple_dim_value_from_kwargs(patch, kwargs) - new_coords = {} - - for _, info in dimfo.items(): - axis, dim, value = info["axis"], info["dim"], info["value"] - - # Ensure pad_width is a tuple, even if a single integer is provided - if isinstance(value, int): + >>> padded_patch_1 = patch.pad(time=(2, 3)) + >>> # Pad `distance` dimension with 1s 4 samples before and 4 after. + >>> padded_patch_3 = patch.pad(distance=4, constant_values=1, samples=True) + >>> # Get patch ready for fast fft along time dimension. + >>> padded_fft = patch.pad(time="fft") + """ + + def _get_pad_tuple(value, samples, coord): + """ + Get a tuple, in samples, of (pad_to_start, pad_to_end). + """ + if value in {"fft", "correlate"}: + target_length = len(coord) if value == "fft" else 2 * len(coord) - 1 + # Determine value so that the output dim will be a fast length. + value = (0, next_fast_len(target_length) - len(coord)) + samples = True # ensure padding isn't interpreted as coord units. + elif not isinstance(value, Sequence): value = (value, value) - - # Ensure kwargs are in samples - if not samples: - coord = patch.get_coord(dim, require_evenly_sampled=True) - value = ( - coord.get_sample_count(value[0], samples=samples), - coord.get_sample_count(value[1], samples=samples), - ) - pad_width[axis] = value - - # Get new coordinate - if expand_coords: - new_start = coord.min() - value[0] * coord.step - new_end = coord.max() + (value[1] + 1) * coord.step - coord = patch.get_coord(dim, require_evenly_sampled=True) - old_values = coord.values + if not samples: # Ensure values are in samples. + value = tuple(coord.get_sample_count(x) for x in value) + return value + + def _get_new_coord(coord, pad_tuple, expand_coords): + """Get the new coordinate along the expanded axis.""" + if expand_coords and coord.evenly_sampled: + new_start = coord.min() - pad_tuple[0] * coord.step + new_end = coord.max() + (pad_tuple[1] + 1) * coord.step + assert coord.evenly_sampled, "expand_coords requires evenly sampled." new_coord = get_coord( start=new_start, stop=new_end, step=coord.step, units=coord.units ) else: - coord = patch.get_coord(dim) - old_values = coord.values.astype(np.float64) + old_values = coord.values + # Need to convert ints to float so NaN can be used. + if np.issubdtype(old_values.dtype, np.integer): + old_values = old_values.astype(np.float64) + null_value = _get_nullish(old_values.dtype) added_nan_values = np.pad( - old_values, pad_width=value, constant_values=np.nan + old_values, pad_width=pad_tuple, constant_values=null_value ) new_coord = coord.update(data=added_nan_values) - new_coords[dim] = new_coord + return new_coord - # Pad patch's data - new_data = np.pad(patch.data, pad_width, mode=mode, constant_values=constant_values) + if isinstance(constant_values, Sequence): + raise ParameterError("constant_values must be a scalar, not a sequence.") - # Update coord manager - new_coords = patch.coords.update(**new_coords) + pad_width = [(0, 0)] * len(patch.shape) + dimfo = get_multiple_dim_value_from_kwargs(patch, kwargs) + new_coords = {} + for _, info in dimfo.items(): + axis, dim, value = info["axis"], info["dim"], info["value"] + coord = patch.get_coord(dim, require_evenly_sampled=not samples) + pad_tuple = _get_pad_tuple(value, samples, coord) + pad_width[axis] = pad_tuple + new_coords[dim] = _get_new_coord(coord, pad_tuple, expand_coords) + + # Pad data, update coord manager, and return. + new_data = np.pad(patch.data, pad_width, mode=mode, constant_values=constant_values) + new_coords = patch.coords.update(**new_coords) return patch.new(data=new_data, coords=new_coords) diff --git a/dascore/proc/correlate.py b/dascore/proc/correlate.py index ac0c87de..9af48ac8 100644 --- a/dascore/proc/correlate.py +++ b/dascore/proc/correlate.py @@ -2,144 +2,218 @@ from __future__ import annotations +import warnings + import numpy as np -from scipy.fftpack import next_fast_len import dascore as dc from dascore.constants import PatchType -from dascore.units import Quantity -from dascore.utils.misc import broadcast_for_index from dascore.utils.patch import ( - _get_dx_or_spacing_and_axes, get_dim_value_from_kwargs, patch_function, ) +from dascore.utils.time import to_float + +def _get_source_fft(patch, dim, source, source_axis, samples): + """ + Get an array of coordinate sources. -def _get_correlated_coord(old_coord, data_shape, real=True): - """Get the new coordinate which corresponds to correlated values.""" - step = old_coord.step - one_sided_len = data_shape // 2 - new = dc.core.get_coord( - start=-one_sided_len * step, - stop=(one_sided_len + 1) * step, - step=step, - ) - assert len(new) == data_shape, "failed to create correlated coord" - return new + This function will place the new sources in a third dimension so + they broadcast with the original fft matrix. + """ + # Extract an array containing just the sources + coord_source = patch.get_coord(dim) + index_source = coord_source.get_next_index(source, samples=samples) + selecter = [slice(None), slice(None), None] + selecter[source_axis] = np.atleast_1d(index_source) + source = patch.data[tuple(selecter)] + # Now transpose source so source dim is list. Essentially we just + # need to swap the source axis with the last axis. + out = np.swapaxes(source, source_axis, -1) + return out -def _shift(data, cx_len, axis): +@patch_function() +def correlate_shift(patch, dim, undo_weighting=True): """ - Re-assemble fft data so zero lag is in center. + Apply a shift to the patch data to undo correlation in frequency domain. - Also accounts for padding for fast fft. + Also adds the appropriate coordinate prefixed with "lag" and has a datatype + of float. + + Parameters + ---------- + patch + The input patch + dim + The dimension name that was correlated in the freq. domain. + undo_weighting + If True, also undo the weighting artifact caused by DASCore's dft + weighting. This is done by simply dividing by the coordinate step. + See [dft note](`docs/notes/dft_notes.qmd`) for more details. + + Examples + -------- + >>> import dascore as dc + >>> patch = dc.get_example_patch() + >>> + >>> # Example 1 + >>> # An auto-correlation of the example patch + >>> dft = patch.dft("time", real=True) + >>> dft_sq = dft * dft.conj() + >>> idft = dft_sq.idft() + >>> auto_patch = idft.correlate_shift(dim="time") """ - ndims = len(data.shape) - start_slice = slice(-np.floor(cx_len / 2).astype(int), None) - start_ind = broadcast_for_index(ndims, axis, start_slice) - stop_slice = slice(None, np.ceil(cx_len / 2).astype(int)) - stop_ind = broadcast_for_index(ndims, axis, stop_slice) - data1 = data[start_ind] - data2 = data[stop_ind] - data = np.concatenate([data1, data2], axis=axis) - return data + coord = patch.get_coord(dim, require_evenly_sampled=True) + axis = patch.dims.index(dim) + data = np.fft.fftshift(patch.data, axes=axis) + if undo_weighting: + data = data / to_float(coord.step) + # so it appears (from testing) there is one lest sample on the positive + # side. + step = to_float(coord.step) + new_start = -np.ceil((len(coord) - 1) / 2) * step + new_end = np.ceil((len(coord) - 1) / 2) * step + new_coord = dc.get_coord(start=new_start, stop=new_end, step=step) + assert len(new_coord) == len(coord) + cm = patch.coords + new_cm = cm.update(**{dim: new_coord}).rename_coord(**{dim: f"lag_{dim}"}) + out = patch.update(data=data, coords=new_cm) + return out @patch_function() def correlate( - patch: PatchType, lag: int | float | Quantity | None = None, samples=False, **kwargs + patch: PatchType, + samples=False, + lag=None, + **kwargs, ) -> PatchType: """ - Correlate a single row/column in a 2D patch with every other row/column. + Correlate source row/columns in a 2D patch with all other row/columns. + + Correlations are done in the frequency domain. This function can accept a + patch whose target dimension has already been transformed with the + [`Patch.dft`](`dascore.transform.fourier.dft`) method, otherwise the dft + will be performed. If the input has already been transformed, + [`Patch.correlation_shift`](`dascore.proc.correlate.correlate_shift`) + is useful to undo dft artefacts after the idft is applied. + + While a 2D patch is required for input, a 3D patch is returned where the + 3rd dimesion cooresponds to the source rows/columns. For the case of a + single source, the [`Patch.squeeze`](`dascore.Patch.squeeze`) method + can be helpful to remove length 1 dimensions. Parameters ---------- patch : PatchType The input data patch to be cross-correlated. Must be 2-dimensional. - lag : - An optional argument to save only certain lag times instead of full - output. + The patch can be in time or frequency domains. samples : bool, optional (default = False) If True, the argument specified in kwargs refers to the *sample* not value along that axis. See examples for details. + lag + Deprecated, just use select on the output patch instead. **kwargs - Additional arguments to specify cross correlation dimension and the - master source, to which - we cross-correlate all other channels/time samples. + Specifies correlation dimension and the master source(s), to which + we want to cross-correlate all other channels/time samples.If the + master source is an array, the function will compute correlations for + all the posible pairs. Examples -------- >>> import dascore as dc >>> from dascore.units import m, s - >>> patch = dc.get_example_patch() - + >>> # Get a patch composed of sin waves whose correlation results + >>> # can easily be checked. + >>> patch = dc.get_example_patch( + ... "sin_wav", + ... sample_rate=100, + ... frequency=range(10, 20), + ... duration=5, + ... channel_count=10, + ... ).taper(time=0.5).set_units(distance='m') + >>> >>> # Example 1 >>> # Calculate cc for all channels as receivers and - >>> # the 10 m channel as the master channel. - >>> cc_patch = patch.correlate(distance = 10 * m) - + >>> # the 10 m channel as the master channel. Squeeze the output + >>> # so the returned patch is 2D. + >>> cc_patch = patch.correlate(distance = 10 * m).squeeze() + >>> >>> # Example 2 - >>> # Calculate cc within (-2,2) sec of lag for all channels as receivers and - >>> # the 10 m channel as the master channel. The new patch has dimensions - >>> # (lag_time, distance) - >>> cc_patch = patch.correlate(distance = 10 * m, lag = 2 * s) - + >>> # Get cc within (-2,2) sec of lag for all channels as receivers + >>> # and the 10 m channel as the master channel. The new patch has dimensions + >>> # (lag_time, distance, source_distance) + >>> cc_patch = ( + ... patch.correlate(distance = 10 * m) + ... .select(lag_time=(-2, 2)) + ... ) + >>> >>> # Example 3 - >>> # Use 2nd channel (python is 0 indexed) along distance as master channel - >>> cc_patch = patch.correlate(distance=1, samples=True) - + >>> # First remove every other distance channel (less memory usage) + >>> # the use the new 2nd channel as the source. + >>> cc_patch = ( + ... patch.decimate(distance=2, filter_type=None) + ... .correlate(distance=1, samples=True) + ... ) + >>> >>> # Example 4 - >>> # Correlate along time dimension + >>> # Correlate along time dimension (perhaps for template matching + >>> # applications) >>> cc_patch = patch.correlate(time=100, samples=True) + >>> + >>> # Example 5 + >>> # A pipeline of frequency domain correlation and an array of sources + >>> padded_patch = patch.pad(time="correlate") # pad to at least 2n + 1 + >>> dft_patch = patch.dft("time", real=True) + >>> # Any other pre-processing steps go here... + >>> # ... + >>> # Perform the correlation with 3 source channels + >>> cc_patch = dft_patch.correlate(distance=[1, 3, 7], samples=True) + >>> # Perform any post-processing here + >>> # ... + >>> # Convert back to time domain, apply `correlate shift` to undo + >>> # fft related shifting and scaling as well as create lag coordinate. + >>> cc_out = cc_patch.idft().correlate_shift("time") Notes ----- - The cross-correlation is performed in the frequency domain for efficiency - reasons. + 1 - The cross-correlation is performed in the frequency domain. - The output dimension is opposite of the one specified in kwargs, has + 2 - The output dimension is opposite of the one specified in kwargs, has the units of float, and the string "lag_" prepended. For example, "lag_time". """ - assert len(patch.dims) == 2, "must be a 2D patch" + if lag is not None: + msg = "Correlate lag is deprecated. Simply use on the output patch." + warnings.warn(msg, DeprecationWarning) + assert len(patch.dims) == 2, "must be a 2D patch." dim, source_axis, source = get_dim_value_from_kwargs(patch, kwargs) - # get the axis and coord over which fft should be calculated. + # Get the axis and coord over which fft should be calculated. fft_axis = next(iter(set(range(len(patch.dims))) - {source_axis})) fft_dim = patch.dims[fft_axis] - # ensure coordinate is evenly spaced - _get_dx_or_spacing_and_axes(patch, fft_dim, require_evenly_spaced=True) - fft_coord = patch.get_coord(fft_dim) - # get the coordinate which contains the source - coord_source = patch.get_coord(dim) - index_source = coord_source.get_next_index(source, samples=samples) - # get the closest fast length. Some padding is applied in the fft to avoid - # inefficient lengths. Note: This is not always a power of 2. - cx_len = patch.shape[fft_axis] * 2 - 1 - fast_len = next_fast_len(cx_len) - # determine proper fft, ifft functions based on data being real or complex + # Determine if the input patch has already been transformed. + input_dft = fft_dim.startswith("ft_") is_real = not np.issubdtype(patch.data.dtype, np.complexfloating) - fft_func = np.fft.rfft if is_real else np.fft.fft - ifft_func = np.fft.irfft if is_real else np.fft.ifft - # perform ffts and get source array (a sub-slice of larger fft) - fft = fft_func(patch.data, axis=fft_axis, n=fast_len) - ndims = len(patch.shape) - slicer = slice(index_source, index_source + 1) - inds = broadcast_for_index(ndims, axis=source_axis, value=slicer) - source_fft = fft[inds] - # perform correlation in freq domain and transform back to time domain - fft_prod = fft * np.conj(source_fft) - # the n parameter needs to be odd so we have a 0 lag time. This only - # applies to real fft - n_out = fast_len if (not is_real or fast_len % 2 != 0) else fast_len - 1 - corr_array = ifft_func(fft_prod, axis=fft_axis, n=n_out) - corr_data = _shift(corr_array, cx_len, axis=fft_axis) - # get new coordinate along correlation dimension - new_coord = _get_correlated_coord(fft_coord, corr_data.shape[fft_axis]) - coords = patch.coords.update(**{fft_dim: new_coord}).rename_coord( - **{fft_dim: f"lag_{fft_dim}"} - ) - out = dc.Patch(coords=coords, data=corr_data, attrs=patch.attrs) - if lag is not None: - out = out.select(**{f"lag_{fft_dim}": (-lag, +lag)}) + if not input_dft: # Standard dft workflow for correlation + # Note: we use .func here to avoid getting these added to the history. + padded = patch.pad.func(patch, **{fft_dim: "correlate"}) + patch = padded.dft.func(padded, fft_dim, real=fft_dim if is_real else None) + # Get the sources. + source = patch.get_coord(dim).values if source is None else source + source_fft = _get_source_fft(patch, dim, source, source_axis, samples) + # Need to insert new axis so the arrays broadcast correctly. + fft_patch_array = patch.data[..., None] + fft_prod = fft_patch_array * np.conj(source_fft) + # Create frequency domain patch with results + source = getattr(source, "magnitude", source) # strips units + new_coord = dc.get_coord(values=np.atleast_1d(source)) + dim_name = f"source_{dim}" + cm = patch.coords.update(**{dim_name: (dim_name, new_coord)}) + out = patch.update(data=fft_prod, coords=cm) + # Undo fft if this function did one, shift, and update coord. + if not input_dft: + idft = out.idft.func(out) + out = idft.correlate_shift.func(idft, fft_dim) return out diff --git a/dascore/proc/whiten.py b/dascore/proc/whiten.py index fa2ce16b..5b760556 100644 --- a/dascore/proc/whiten.py +++ b/dascore/proc/whiten.py @@ -21,7 +21,7 @@ def whiten( patch: PatchType, smooth_size: float | None = None, tukey_alpha: float = 0.1, - ifft: bool = True, + idft: bool = True, **kwargs, ) -> PatchType: """ @@ -42,7 +42,7 @@ def whiten( its value is 0.1. See more details at https://docs.scipy.org/doc/scipy/reference /generated/scipy.signal.windows.tukey.html - ifft + idft If False, returns the whitened result in the frequency domain without converting it back to the time domain. Defaults to True. **kwargs @@ -62,7 +62,7 @@ def whiten( 3) Amplitude is NOT preserved - 4) If ifft = False, since for the purely real input data the negative + 4) If idft = False, since for the purely real input data the negative frequency terms are just the complex conjugates of the corresponding positive-frequency terms, the output does not include the negative frequency terms, and therefore the length of the transformed axis @@ -227,7 +227,7 @@ def plot_spectrum(x, T, ax, phase=False): norm_amp *= tiled_win - if ifft: + if idft: # revert back to time-domain, using the phase of the original signal whitened_data = np.real( nft.irfft(norm_amp * np.exp(1j * phase), n=comp_nsamp, axis=dim_ind) diff --git a/dascore/transform/fourier.py b/dascore/transform/fourier.py index 9c54d26e..66cb2835 100644 --- a/dascore/transform/fourier.py +++ b/dascore/transform/fourier.py @@ -127,8 +127,7 @@ def dft( - Non-dimensional coordiantes associated with transformed coordinates will be dropped in the output. - - See the [FFT note](`notes/dft_notes.qmd`) in the Notes section - of DASCore's documentation for more details. + - See the [FFT notes](`docs/notes/dft_notes.qmd`) for more details. See Also -------- diff --git a/dascore/units.py b/dascore/units.py index f8ab58a8..133532c6 100644 --- a/dascore/units.py +++ b/dascore/units.py @@ -133,6 +133,9 @@ def convert_units( to_units, from_units = get_quantity(to_units), get_quantity(from_units) if from_units is None: return data + elif to_units is None: + msg = "Cannot convert units to_units are not specified" + raise UnitError(msg) try: mult1, add, mult2 = _get_conversion_factors(from_units, to_units) except DimensionalityError as e: diff --git a/dascore/utils/patch.py b/dascore/utils/patch.py index f67159dd..7407fe8e 100644 --- a/dascore/utils/patch.py +++ b/dascore/utils/patch.py @@ -684,20 +684,21 @@ def check_coords( def _merge_aligned_coords(cm1, cm2): """Merge aligned coordinates removing non coords.""" - assert cm1.dims == cm2.dims, "coordinates are not aligned" + assert cm1.dims == cm2.dims, "dimensions are not aligned" out = {} for name in set(cm1.coord_map) & set(cm2.coord_map): coord1 = cm1.coord_map[name] coord2 = cm2.coord_map[name] + dim1, dim2 = cm1.dim_map.get(name), cm2.dim_map.get(name) # Coords already equal, just use first. - if coord1.approx_equal(coord2): - out[name] = coord1 + if coord1.approx_equal(coord2) and dim1 == dim2: + out[name] = (dim1, coord1) # Deal with Non coords non_count = sum([coord1._partial, coord2._partial]) if non_count == 1: - out[name] = coord1 if coord2._partial else coord2 + out[name] = (dim1, coord1 if coord2._partial else coord2) elif non_count == 2: - out[name] = coord1 if coord1.size > coord2.size else coord2 + out[name] = (dim1, coord1 if coord1.size > coord2.size else coord2) assert name in out return cm1.update(**out) diff --git a/tests/test_core/test_coords.py b/tests/test_core/test_coords.py index 81e6800b..8cf4e3e9 100644 --- a/tests/test_core/test_coords.py +++ b/tests/test_core/test_coords.py @@ -473,8 +473,7 @@ def test_select_end_end_time(self, coord): def test_intra_sample_select(self, coord): """ - Selecting ranges that fall within samples should raise. - + Selecting ranges that fall within samples should become de-generate. This is consistent with pandas indices. """ values = coord.values @@ -522,7 +521,7 @@ def test_select_bounds(self, long_coord): assert set(values).issuperset(set(new.values)) def test_select_out_of_bounds_too_early(self, coord): - """Applying a select out of bounds (too early) should raise an Error.""" + """Applying a select out of bounds (too early) should return partial.""" diff = (coord.max() - coord.min()) / (len(coord) - 1) # get a range which is for sure before data. v1 = coord.min() - np.abs(100 * diff) @@ -540,7 +539,7 @@ def test_select_out_of_bounds_too_early(self, coord): assert new == coord def test_select_out_of_bounds_too_late(self, coord): - """Applying a select out of bounds (too late) should raise an Error.""" + """Applying a select out of bounds (too late) should return partial.""" diff = (coord.max() - coord.min()) / (len(coord) - 1) # get a range which is for sure after data. v1 = coord.max() + np.abs(100 * diff) @@ -932,6 +931,7 @@ def test_dtype(self, evenly_sampled_coord, evenly_sampled_date_coord): def test_select_tuple_ints(self, evenly_sampled_coord): """Ensure a tuple works as a limit.""" assert evenly_sampled_coord.select((50, None))[1] == slice(50, None) + evenly_sampled_coord.select((0, None)) assert evenly_sampled_coord.select((0, None))[1] == slice(None, None) assert evenly_sampled_coord.select((-10, None))[1] == slice(None, None) assert evenly_sampled_coord.select((None, None))[1] == slice(None, None) @@ -1676,12 +1676,47 @@ def test_between_values(self, evenly_sampled_coord): def test_units(self, evenly_sampled_float_coord_with_units): """Ensure values with units work.""" coord = evenly_sampled_float_coord_with_units + val1 = np.array([10, 20]) * get_quantity("m") + val2 = val1.to(get_quantity("ft")) + ind1 = coord.get_next_index(val1) + ind2 = coord.get_next_index(val2) + assert len(val1) == len(ind1) + assert np.all(ind1 == ind2) + + def test_unit_array(self, evenly_sampled_float_coord_with_units): + """Ensure an array of units returns inds of equal shape.""" + coord = evenly_sampled_float_coord_with_units.set_units("m") val1 = 10 * get_quantity("m") val2 = val1.to(get_quantity("ft")) ind1 = coord.get_next_index(val1) ind2 = coord.get_next_index(val2) assert ind1 == ind2 + def test_array_input(self, coord): + """Get next coord should work with an array.""" + inds = np.array([1, 2, 3, 4]) + try: + out1 = coord.get_next_index(inds, samples=True) + except CoordError: + pytest.skip(f"{coord} doesn't support get_next_index.") + values = coord.values[inds] + out2 = coord.get_next_index(values) + assert np.all(out1 == inds) and np.all(out2 == out1) + + def test_out_of_bounds_array_range_coord(self, evenly_sampled_coord): + """Ensure get index works with array that has out of bounds values.""" + coord = evenly_sampled_coord + values = coord.values + inputs = np.arange(3) + np.max(values) + out = coord.get_next_index(inputs, allow_out_of_bounds=True) + assert np.allclose(out, 99) + + def test_non_sorted_coord_raises(self, random_coord): + """Ensure a non-sorted coord raises.""" + msg = "Coords must be sorted" + with pytest.raises(CoordError, match=msg): + random_coord.get_next_index(10) + class TestUpdate: """Tests for updating coordinates.""" @@ -1844,3 +1879,16 @@ def test_snap_awkward(self, awkward_off_by_one): """Ensure off by one error don't occur with snap.""" out = awkward_off_by_one.snap() assert out.shape == awkward_off_by_one.shape + + def test_dtype_start_int_others_float(self): + """ + Ensure the correct dtype is given when the start value is an int + but the other values are not. + """ + coord = get_coord( + start=0, + step=1.1, + stop=11.2, + units="m", + ) + assert np.issubdtype(coord.dtype, np.floating) diff --git a/tests/test_proc/test_basic.py b/tests/test_proc/test_basic.py index 25c4a669..227ac240 100644 --- a/tests/test_proc/test_basic.py +++ b/tests/test_proc/test_basic.py @@ -7,10 +7,11 @@ import numpy as np import pandas as pd import pytest +from scipy.fft import next_fast_len import dascore as dc from dascore import get_example_patch -from dascore.exceptions import PatchBroadcastError, UnitError +from dascore.exceptions import ParameterError, PatchBroadcastError, UnitError from dascore.proc.basic import apply_operator from dascore.units import furlongs, get_quantity, m, s from dascore.utils.misc import _merge_tuples @@ -325,6 +326,11 @@ def test_patches_non_coords_different_len(self, random_patch): assert np.allclose(out.data, 1) assert out.shape == patch_2.shape + def test_non_dim_coords(self, random_dft_patch): + """Ensure ufuncs can still be applied to coords with non dim coords.""" + out = random_dft_patch * random_dft_patch + assert set(out.coords.coord_map) == set(random_dft_patch.coords.coord_map) + class TestPatchBroadcasting: """Tests for patches broadcasting to allow operations on each other.""" @@ -592,12 +598,68 @@ def test_pad_multiple_dimensions_samples_true(self, random_patch, samples=True): atol=1e-6, ) + @pytest.mark.parametrize("length", [10, 13, 200, 293]) + def test_fft_pad(self, random_patch, length): + """Tests for padding to next fast length.""" + # First trim patch to match length + patch_in = random_patch.select(time=(0, length), samples=True) + axis = patch_in.dims.index("time") + length_old = patch_in.shape[axis] + next_fast = next_fast_len(length_old) + out = patch_in.pad(time="fft") + assert out.dims == patch_in.dims, "dims should not have changed." + assert out.shape[axis] == next_fast + + def test_correlate_pad(self, random_patch): + """Ensure the fft correlate pad returns sensible values.""" + axis = random_patch.dims.index("time") + length_old = random_patch.shape[axis] + next_fast = next_fast_len(length_old * 2 - 1) + out = random_patch.pad(time="correlate") + assert out.shape[axis] == next_fast + + def test_pad_no_expand(self, random_patch): + """Ensure we can pad without expanding dimensions.""" + out = random_patch.pad(time=10, samples=True, expand_coords=False) + coord_array = out.get_array("time") + # Check the first and last nans + assert np.all(pd.isnull(coord_array[:10])) + assert np.all(pd.isnull(coord_array[-10:])) + # Check the middle values are equal to old values. + old_array = random_patch.get_array("time") + # The datatype should not have changed + assert coord_array.dtype == old_array.dtype + assert np.all(old_array == coord_array[10:-10]) + + def test_pad_no_expand_int_coord(self, random_patch): + """Ensure we can pad an integer coordinate.""" + # Ensure distances in an in coordinate. + coord = random_patch.get_coord("distance") + new_vals = np.arange(len(coord), dtype=np.int64) + patch = random_patch.update_coords(distance=new_vals) + # Apply padding, ensure NaN values appear. + padded = patch.pad(distance=4, expand_coords=False) + coord_array = padded.get_array("distance") + assert np.all(pd.isnull(coord_array[:4])) + assert np.all(pd.isnull(coord_array[-4:])) + def test_error_on_sequence_constant_values(self, random_patch): """Test that providing a sequence for constant_values raises a TypeError.""" - with pytest.raises(TypeError): + with pytest.raises(ParameterError): random_patch.pad(time=(0, 5), constant_values=(0, 0)) +class TestConj: + """Tests for complex conjugate.""" + + def test_imaginary_part_reversed(self, random_dft_patch): + """Ensure the imaginary part of the array is reversed.""" + imag1 = np.imag(random_dft_patch.data) + conj = random_dft_patch.conj() + imag2 = np.imag(conj.data) + assert np.allclose(imag1, -imag2) + + class TestRoll: """Test cases for patch roll method.""" diff --git a/tests/test_proc/test_correlate.py b/tests/test_proc/test_correlate.py index b8adec8a..41ba5067 100644 --- a/tests/test_proc/test_correlate.py +++ b/tests/test_proc/test_correlate.py @@ -4,10 +4,29 @@ import pytest import dascore as dc -from dascore.units import m, s +from dascore.exceptions import UnitError +from dascore.units import m from dascore.utils.time import to_float +class TestCorrelateShift: + """Tests for the correlation shift function.""" + + def test_auto_correlation(self, random_dft_patch): + """Perform auto correlation and undo shifting.""" + dft_conj = random_dft_patch.conj() + dft_sq = random_dft_patch * dft_conj + idft = dft_sq.idft() + auto_patch = idft.correlate_shift(dim="time") + assert np.allclose(np.imag(auto_patch.data), 0) + assert "lag_time" in auto_patch.dims + coord_array = auto_patch.get_array("lag_time") + # ensure the max value happens at zero lag time. + time_ax = auto_patch.dims.index("lag_time") + argmax = np.argmax(random_dft_patch.data, axis=time_ax) + assert np.allclose(coord_array[argmax], 0) + + class TestCorrelateInternal: """Tests case of intra-patch correlation function.""" @@ -62,13 +81,14 @@ def test_basic_correlation(self, corr_patch): dstep2 = corr_patch.get_coord("distance").step assert dstep1 == dstep2 - def test_correlation_with_lag(self, corr_patch): - """Ensure correlation works with a lag specified.""" - lag = 1.9 - out = corr_patch.correlate(distance=0, samples=True, lag=lag) - coord = out.get_coord("lag_time").values - assert to_float(coord[0]) >= -lag - assert to_float(coord[-1]) <= lag + def test_transpose_independent(self, corr_patch): + """The order of the dims shouldn't affect the result.""" + new_order = corr_patch.dims[::-1] + patch1 = corr_patch + patch2 = corr_patch.transpose(*new_order) + corr1 = patch1.correlate(time=3, samples=True) + corr2 = patch2.correlate(time=3, samples=True) + assert corr1.transpose(*corr2.dims).equals(corr2) def test_time_lags(self, ricker_moveout_patch): """Ensure time lags are consistent with expected velocities.""" @@ -79,11 +99,11 @@ def test_time_lags(self, ricker_moveout_patch): lag_times = to_float(corr.get_coord("lag_time").values[argmax]) # get calculated times, they should be close to lag times expected_times = distances / self.moveout_velocity - assert np.allclose(lag_times, expected_times) + assert np.allclose(lag_times.flatten(), expected_times) def test_units(self, random_patch): - """Ensure units can be passed as kwarg and lag params.""" - c_patch = random_patch.correlate(distance=10 * m, lag=2 * s) + """Ensure units can be passed as kwarg params.""" + c_patch = random_patch.correlate(distance=10 * m) assert isinstance(c_patch, dc.Patch) def test_complex_patch(self, ricker_moveout_patch): @@ -97,4 +117,44 @@ def test_complex_patch(self, ricker_moveout_patch): lag_times = to_float(corr.get_coord("lag_time").values[argmax]) # get calculated times, they should be close to lag times expected_times = distances / self.moveout_velocity - assert np.allclose(lag_times, expected_times) + assert np.allclose(lag_times.flatten(), expected_times) + + def test_correlation_freq_domain_patch(self, corr_patch): + """ + Test correlation when the input patch is already in the frequency + domain. + """ + # perform FFT on the original patch to simulate frequency domain data + fft_patch = corr_patch.dft("time") + out = fft_patch.correlate(distance=0, samples=True) + # The patch should still be complex. + assert np.issubdtype(out.data.dtype, np.complexfloating) + # Check if the shape of the output is the same as the original patch + # plus a dimension 1 for the source. + assert out.shape == tuple([*corr_patch.shape, 1]) + + def test_correlate_decimated_patch(self, corr_patch): + """Ensure a decimated patch can be correlated.""" + out = corr_patch.decimate(distance=2, filter_type=None).correlate( + distance=1, samples=True + ) + assert isinstance(out, dc.Patch) + + def test_correlate_units_raises(self, corr_patch): + """When the patch doesn't have units an error should raise.""" + patch = corr_patch.set_units(distance=None) + with pytest.raises(UnitError): + patch.correlate(distance=0 * m) + + def test_correlate_units(self, corr_patch): + """When the patch has units it should work to specify them.""" + patch = corr_patch.set_units(distance="m") + out1 = patch.correlate(distance=1 * m) + assert isinstance(out1, dc.Patch) + out2 = patch.correlate(distance=np.array([1, 2]) * m) + assert isinstance(out2, dc.Patch) + + def test_lag_deprecated(self, corr_patch): + """Ensure the lag parameter is deprecated.""" + with pytest.warns(DeprecationWarning): + corr_patch.correlate(time=1, lag=10, samples=True) diff --git a/tests/test_proc/test_whiten.py b/tests/test_proc/test_whiten.py index 61853389..32a6f9b3 100644 --- a/tests/test_proc/test_whiten.py +++ b/tests/test_proc/test_whiten.py @@ -195,15 +195,15 @@ def test_whiten_along_distance(self, test_patch): whitened_patch.coords.get_array("distance"), ) - def test_whiten_ifft_false(self, test_patch): + def test_whiten_idft_false(self, test_patch): """ Ensure whiten function can return the result in the frequency domain - when the ifft flag is set to True. + when the idft flag is set to Flase. """ # whiten the patch and return in frequency domain - whitened_patch_freq_domain = test_patch.whiten(smooth_size=5, ifft=False) + whitened_patch_freq_domain = test_patch.whiten(smooth_size=5, idft=False) # check if the returned data is in the frequency domain assert np.iscomplexobj( whitened_patch_freq_domain.data - ), "Expected the data to be complex, indicating freq. domain representation." + ), "Expected the output to be complex, indicating freq. domain representation." diff --git a/tests/test_units.py b/tests/test_units.py index e28c11c6..c7a0c274 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -255,3 +255,9 @@ def test_convert_offset_units_multiple_mags(self): f_array = (array * (9 * 2.5 / 5) + 32.0) / 6 out = convert_units(array, from_units="2.5*degC", to_units="6*degF") assert np.allclose(f_array, out) + + def test_not_output_units_raises(self): + """Ensure an error is raised if output units are None.""" + msg = "are not specified" + with pytest.raises(UnitError, match=msg): + convert_units(1, from_units="m", to_units=None)