From d6d757baaaac1c7840b89e279188cc7d49150655 Mon Sep 17 00:00:00 2001 From: Ahmad Tourei <92008628+ahmadtourei@users.noreply.github.com> Date: Sun, 12 May 2024 16:27:17 -0600 Subject: [PATCH] Pad (#373) * pad func, start test * finished tests * modified docs * fixed test_overask_raises * removed check for dims, improved docs --- dascore/core/coords.py | 6 +-- dascore/core/patch.py | 1 + dascore/proc/basic.py | 94 +++++++++++++++++++++++++++++++++- dascore/proc/rolling.py | 8 ++- tests/test_core/test_coords.py | 8 ++- tests/test_proc/test_basic.py | 82 +++++++++++++++++++++++++++++ 6 files changed, 191 insertions(+), 8 deletions(-) diff --git a/dascore/core/coords.py b/dascore/core/coords.py index 7acb899d..b4f03064 100644 --- a/dascore/core/coords.py +++ b/dascore/core/coords.py @@ -563,7 +563,7 @@ def update(self, **kwargs): out = out.convert_units(units) return out - def get_sample_count(self, value, samples=False) -> int: + def get_sample_count(self, value, samples=False, enforce_lt_coord=False) -> int: """ Return the number of samples represented by a value. @@ -589,13 +589,13 @@ def get_sample_count(self, value, samples=False) -> int: compat_val = self._get_compatible_value(value, relative=True) duration = compat_val - self.min() samples = int(np.ceil(duration / self.step)) - if samples > len(self): + if enforce_lt_coord and samples > len(self): msg = ( f"value of {value} results in a window larger than coordinate " f"length of {len(self)}" ) raise ParameterError(msg) - return min(samples, len(self)) + return samples def get_next_index(self, value, samples=False, allow_out_of_bounds=False) -> int: """ diff --git a/dascore/core/patch.py b/dascore/core/patch.py index 518c2d77..58aa86fb 100644 --- a/dascore/core/patch.py +++ b/dascore/core/patch.py @@ -240,6 +240,7 @@ def iselect(self, *args, **kwargs): imag = dascore.proc.imag angle = dascore.proc.angle resample = dascore.proc.resample + pad = dascore.proc.pad def iresample(self, *args, **kwargs): """Deprecated method.""" diff --git a/dascore/proc/basic.py b/dascore/proc/basic.py index 2c1c68cd..9adc579d 100644 --- a/dascore/proc/basic.py +++ b/dascore/proc/basic.py @@ -11,10 +11,11 @@ from dascore.constants import PatchType from dascore.core.attrs import PatchAttrs, merge_compatible_coords_attrs from dascore.core.coordmanager import CoordManager, get_coord_manager +from dascore.core.coords import get_coord from dascore.exceptions import UnitError from dascore.units import DimensionalityError, Quantity, Unit, get_quantity from dascore.utils.models import ArrayLike -from dascore.utils.patch import patch_function +from dascore.utils.patch import get_multiple_dim_value_from_kwargs, patch_function def set_dims(self: PatchType, **kwargs: str) -> PatchType: @@ -480,3 +481,94 @@ def dropna(patch: PatchType, dim, how: Literal["any", "all"] = "any") -> PatchTy cm = patch.coords.update(**{dim: coord[to_keep]}) attrs = patch.attrs.update(coords={}) return patch.new(data=new_data, coords=cm, attrs=attrs) + + +@patch_function() +def pad( + patch: PatchType, + mode: Literal["constant"] = "constant", + constant_values: Any = 0, + expand_coords=False, + samples=False, + **kwargs, +) -> PatchType: + """ + Pad the patch data along specified dimensions. + + Parameters + ---------- + mode : str, optional + The mode of padding, by default 'constant'. + constant_values : scalar , optional + A single scalar value used as the padding value across all dimensions. + Defaults to 0. + expand_coords : bool, optional + Determines how coordinates are adjusted when padding is applied. + If set to True, the coordinates will be expanded to maintain their + order and even sampling (if originally evenly sampled), by extrapolating + based on the coordinate's step size. + If set to False, the new coordinates introduced by padding will be + filled with NaN values, preserving the original coordinate values but + not the order or sampling rate. + **kwargs: + Used to specify dimension and number of elements, + either an integer or a tuple (before, after). + + Examples + -------- + >>> import dascore as dc + >>> patch = dc.get_example_patch() + >>> # zero pad `time` dimension with 2 patch's time unit (e.g., sec) + >>> # zeros before and 3 zeros after + >>> padded_patch_1 = patch.pad(time = (2, 3)) + >>> # zero pad `distance` dimension with 4 unit values before and after + >>> padded_patch_3 = patch.pad(distance = 4, constant_values = 1, samples=True) + """ + if isinstance(constant_values, list | tuple): + raise TypeError("constant_values must be a scalar, not a sequence.") + + pad_width = [(0, 0)] * len(patch.shape) + dimfo = get_multiple_dim_value_from_kwargs(patch, kwargs) + new_coords = {} + + for _, info in dimfo.items(): + axis, dim, value = info["axis"], info["dim"], info["value"] + + # Ensure pad_width is a tuple, even if a single integer is provided + if isinstance(value, int): + value = (value, value) + + # Ensure kwargs are in samples + if not samples: + coord = patch.get_coord(dim, require_evenly_sampled=True) + value = ( + coord.get_sample_count(value[0], samples=samples), + coord.get_sample_count(value[1], samples=samples), + ) + pad_width[axis] = value + + # Get new coordinate + if expand_coords: + new_start = coord.min() - value[0] * coord.step + new_end = coord.max() + (value[1] + 1) * coord.step + coord = patch.get_coord(dim, require_evenly_sampled=True) + old_values = coord.values + new_coord = get_coord( + start=new_start, stop=new_end, step=coord.step, units=coord.units + ) + else: + coord = patch.get_coord(dim) + old_values = coord.values.astype(np.float64) + added_nan_values = np.pad( + old_values, pad_width=value, constant_values=np.nan + ) + new_coord = coord.update(data=added_nan_values) + new_coords[dim] = new_coord + + # Pad patch's data + new_data = np.pad(patch.data, pad_width, mode=mode, constant_values=constant_values) + + # Update coord manager + new_coords = patch.coords.update(**new_coords) + + return patch.new(data=new_data, coords=new_coords) diff --git a/dascore/proc/rolling.py b/dascore/proc/rolling.py index 42e6c884..d8e3e2f4 100644 --- a/dascore/proc/rolling.py +++ b/dascore/proc/rolling.py @@ -306,8 +306,12 @@ def _get_engine(step, engine, patch): dim, axis, value = get_dim_value_from_kwargs(patch, kwargs) roll_hist = f"rolling({dim}={value}, step={step}, center={center}, engine={engine})" coord = patch.get_coord(dim) - window = coord.get_sample_count(value, samples=samples) - step = 1 if step is None else coord.get_sample_count(step, samples=samples) + window = coord.get_sample_count(value, samples=samples, enforce_lt_coord=True) + step = ( + 1 + if step is None + else coord.get_sample_count(step, samples=samples, enforce_lt_coord=True) + ) if window == 0 or step == 0: msg = "Window or step size can't be zero. Use any positive values." raise ParameterError(msg) diff --git a/tests/test_core/test_coords.py b/tests/test_core/test_coords.py index 482ab2bf..9e1bee5b 100644 --- a/tests/test_core/test_coords.py +++ b/tests/test_core/test_coords.py @@ -1193,10 +1193,14 @@ def test_overask_raises(self, evenly_sampled_coord): # test for using samples kwargWindow or step size is larger than msg = "results in a window larger than coordinate" with pytest.raises(ParameterError, match=msg): - evenly_sampled_coord.get_sample_count(max_len + 2, samples=True) + evenly_sampled_coord.get_sample_count( + max_len + 2, samples=True, enforce_lt_coord=True + ) # test for using normal mode with pytest.raises(ParameterError, match=msg): - evenly_sampled_coord.get_sample_count(duration + 2 * step) + evenly_sampled_coord.get_sample_count( + duration + 2 * step, enforce_lt_coord=True + ) def test_non_int_raises_with_samples(self, evenly_sampled_coord): """Non integer values should raise when sample=True.""" diff --git a/tests/test_proc/test_basic.py b/tests/test_proc/test_basic.py index fcc9c753..50430a1d 100644 --- a/tests/test_proc/test_basic.py +++ b/tests/test_proc/test_basic.py @@ -377,3 +377,85 @@ def test_3d(self, patch_3d_with_null): axis = out.dims.index("time") out = patch.dropna("time", how="any") assert out.shape[axis] == patch.shape[axis] - 2 + + +class TestPad: + """Tests for the padding functionality in a patch.""" + + def test_pad_time_dimension_samples_true(self, random_patch, samples=True): + """Test padding the time dimension with zeros before and after.""" + padded_patch = random_patch.pad(time=(2, 3), samples=samples) + # Check if the padding is applied correctly + original_shape = random_patch.shape + new_shape = padded_patch.shape + time_axis = random_patch.dims.index("time") + assert new_shape[time_axis] == original_shape[time_axis] + 5 + # Ensure that padded values are zeros + assert np.all(padded_patch.select(time=(None, 2), samples=samples).data == 0) + assert np.all(padded_patch.select(time=(-3, None), samples=samples).data == 0) + + def test_pad_distance_dimension(self, random_patch): + """Test padding the distance with same number of zeros on both sides.""" + padded_patch = random_patch.pad(distance=7) + original_shape = random_patch.shape + new_shape = padded_patch.shape + distance_axis = random_patch.dims.index("distance") + ch_spacing = random_patch.attrs["distance_step"] + assert ( + new_shape[distance_axis] == original_shape[distance_axis] + 14 * ch_spacing + ) + # Ensure that padded values are zeros + assert np.all(padded_patch.select(distance=(None, 7), samples=True).data == 0) + assert np.all(padded_patch.select(distance=(-7, None), samples=True).data == 0) + + def test_pad_distance_dimension_expand_coords( + self, random_patch, expand_coords=True + ): + """Test padding the distance with same number of zeros on both sides.""" + padded_patch = random_patch.pad(distance=4, expand_coords=expand_coords) + original_shape = random_patch.shape + new_shape = padded_patch.shape + distance_axis = random_patch.dims.index("distance") + ch_spacing = random_patch.attrs["distance_step"] + dist_max = random_patch.attrs["distance_max"] + assert ( + new_shape[distance_axis] == original_shape[distance_axis] + 8 * ch_spacing + ) + # Ensure that padded values are zeros + assert np.all(padded_patch.select(distance=(None, -1)).data == 0) + assert np.all(padded_patch.select(distance=(dist_max + 1, None)).data == 0) + + def test_pad_multiple_dimensions_samples_true(self, random_patch, samples=True): + """Test padding multiple dimensions with different pad values.""" + padded_patch = random_patch.pad( + time=(6, 7), distance=(1, 4), constant_values=np.pi, samples=samples + ) + # Check dimensions individually + time_axis = random_patch.dims.index("time") + distance_axis = random_patch.dims.index("distance") + assert padded_patch.shape[time_axis] == random_patch.shape[time_axis] + 13 + assert ( + padded_patch.shape[distance_axis] == random_patch.shape[distance_axis] + 5 + ) + # Check padding values + assert np.allclose( + padded_patch.select(time=(None, 6), samples=samples).data, np.pi, atol=1e-6 + ) + assert np.allclose( + padded_patch.select(time=(-7, None), samples=samples).data, np.pi, atol=1e-6 + ) + assert np.allclose( + padded_patch.select(distance=(None, 1), samples=samples).data, + np.pi, + atol=1e-6, + ) + assert np.allclose( + padded_patch.select(distance=(-4, None), samples=samples).data, + np.pi, + atol=1e-6, + ) + + def test_error_on_sequence_constant_values(self, random_patch): + """Test that providing a sequence for constant_values raises a TypeError.""" + with pytest.raises(TypeError): + random_patch.pad(time=(0, 5), constant_values=(0, 0))