Skip to content

Commit

Permalink
Pad (#373)
Browse files Browse the repository at this point in the history
* pad func, start test

* finished tests

* modified docs

* fixed test_overask_raises

* removed check for dims, improved docs
  • Loading branch information
ahmadtourei authored May 12, 2024
1 parent ed448d6 commit d6d757b
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 8 deletions.
6 changes: 3 additions & 3 deletions dascore/core/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def update(self, **kwargs):
out = out.convert_units(units)
return out

def get_sample_count(self, value, samples=False) -> int:
def get_sample_count(self, value, samples=False, enforce_lt_coord=False) -> int:
"""
Return the number of samples represented by a value.
Expand All @@ -589,13 +589,13 @@ def get_sample_count(self, value, samples=False) -> int:
compat_val = self._get_compatible_value(value, relative=True)
duration = compat_val - self.min()
samples = int(np.ceil(duration / self.step))
if samples > len(self):
if enforce_lt_coord and samples > len(self):
msg = (
f"value of {value} results in a window larger than coordinate "
f"length of {len(self)}"
)
raise ParameterError(msg)
return min(samples, len(self))
return samples

def get_next_index(self, value, samples=False, allow_out_of_bounds=False) -> int:
"""
Expand Down
1 change: 1 addition & 0 deletions dascore/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def iselect(self, *args, **kwargs):
imag = dascore.proc.imag
angle = dascore.proc.angle
resample = dascore.proc.resample
pad = dascore.proc.pad

def iresample(self, *args, **kwargs):
"""Deprecated method."""
Expand Down
94 changes: 93 additions & 1 deletion dascore/proc/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from dascore.constants import PatchType
from dascore.core.attrs import PatchAttrs, merge_compatible_coords_attrs
from dascore.core.coordmanager import CoordManager, get_coord_manager
from dascore.core.coords import get_coord
from dascore.exceptions import UnitError
from dascore.units import DimensionalityError, Quantity, Unit, get_quantity
from dascore.utils.models import ArrayLike
from dascore.utils.patch import patch_function
from dascore.utils.patch import get_multiple_dim_value_from_kwargs, patch_function


def set_dims(self: PatchType, **kwargs: str) -> PatchType:
Expand Down Expand Up @@ -480,3 +481,94 @@ def dropna(patch: PatchType, dim, how: Literal["any", "all"] = "any") -> PatchTy
cm = patch.coords.update(**{dim: coord[to_keep]})
attrs = patch.attrs.update(coords={})
return patch.new(data=new_data, coords=cm, attrs=attrs)


@patch_function()
def pad(
patch: PatchType,
mode: Literal["constant"] = "constant",
constant_values: Any = 0,
expand_coords=False,
samples=False,
**kwargs,
) -> PatchType:
"""
Pad the patch data along specified dimensions.
Parameters
----------
mode : str, optional
The mode of padding, by default 'constant'.
constant_values : scalar , optional
A single scalar value used as the padding value across all dimensions.
Defaults to 0.
expand_coords : bool, optional
Determines how coordinates are adjusted when padding is applied.
If set to True, the coordinates will be expanded to maintain their
order and even sampling (if originally evenly sampled), by extrapolating
based on the coordinate's step size.
If set to False, the new coordinates introduced by padding will be
filled with NaN values, preserving the original coordinate values but
not the order or sampling rate.
**kwargs:
Used to specify dimension and number of elements,
either an integer or a tuple (before, after).
Examples
--------
>>> import dascore as dc
>>> patch = dc.get_example_patch()
>>> # zero pad `time` dimension with 2 patch's time unit (e.g., sec)
>>> # zeros before and 3 zeros after
>>> padded_patch_1 = patch.pad(time = (2, 3))
>>> # zero pad `distance` dimension with 4 unit values before and after
>>> padded_patch_3 = patch.pad(distance = 4, constant_values = 1, samples=True)
"""
if isinstance(constant_values, list | tuple):
raise TypeError("constant_values must be a scalar, not a sequence.")

pad_width = [(0, 0)] * len(patch.shape)
dimfo = get_multiple_dim_value_from_kwargs(patch, kwargs)
new_coords = {}

for _, info in dimfo.items():
axis, dim, value = info["axis"], info["dim"], info["value"]

# Ensure pad_width is a tuple, even if a single integer is provided
if isinstance(value, int):
value = (value, value)

# Ensure kwargs are in samples
if not samples:
coord = patch.get_coord(dim, require_evenly_sampled=True)
value = (
coord.get_sample_count(value[0], samples=samples),
coord.get_sample_count(value[1], samples=samples),
)
pad_width[axis] = value

# Get new coordinate
if expand_coords:
new_start = coord.min() - value[0] * coord.step
new_end = coord.max() + (value[1] + 1) * coord.step
coord = patch.get_coord(dim, require_evenly_sampled=True)
old_values = coord.values
new_coord = get_coord(
start=new_start, stop=new_end, step=coord.step, units=coord.units
)
else:
coord = patch.get_coord(dim)
old_values = coord.values.astype(np.float64)
added_nan_values = np.pad(
old_values, pad_width=value, constant_values=np.nan
)
new_coord = coord.update(data=added_nan_values)
new_coords[dim] = new_coord

# Pad patch's data
new_data = np.pad(patch.data, pad_width, mode=mode, constant_values=constant_values)

# Update coord manager
new_coords = patch.coords.update(**new_coords)

return patch.new(data=new_data, coords=new_coords)
8 changes: 6 additions & 2 deletions dascore/proc/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,12 @@ def _get_engine(step, engine, patch):
dim, axis, value = get_dim_value_from_kwargs(patch, kwargs)
roll_hist = f"rolling({dim}={value}, step={step}, center={center}, engine={engine})"
coord = patch.get_coord(dim)
window = coord.get_sample_count(value, samples=samples)
step = 1 if step is None else coord.get_sample_count(step, samples=samples)
window = coord.get_sample_count(value, samples=samples, enforce_lt_coord=True)
step = (
1
if step is None
else coord.get_sample_count(step, samples=samples, enforce_lt_coord=True)
)
if window == 0 or step == 0:
msg = "Window or step size can't be zero. Use any positive values."
raise ParameterError(msg)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_core/test_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,10 +1193,14 @@ def test_overask_raises(self, evenly_sampled_coord):
# test for using samples kwargWindow or step size is larger than
msg = "results in a window larger than coordinate"
with pytest.raises(ParameterError, match=msg):
evenly_sampled_coord.get_sample_count(max_len + 2, samples=True)
evenly_sampled_coord.get_sample_count(
max_len + 2, samples=True, enforce_lt_coord=True
)
# test for using normal mode
with pytest.raises(ParameterError, match=msg):
evenly_sampled_coord.get_sample_count(duration + 2 * step)
evenly_sampled_coord.get_sample_count(
duration + 2 * step, enforce_lt_coord=True
)

def test_non_int_raises_with_samples(self, evenly_sampled_coord):
"""Non integer values should raise when sample=True."""
Expand Down
82 changes: 82 additions & 0 deletions tests/test_proc/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,85 @@ def test_3d(self, patch_3d_with_null):
axis = out.dims.index("time")
out = patch.dropna("time", how="any")
assert out.shape[axis] == patch.shape[axis] - 2


class TestPad:
"""Tests for the padding functionality in a patch."""

def test_pad_time_dimension_samples_true(self, random_patch, samples=True):
"""Test padding the time dimension with zeros before and after."""
padded_patch = random_patch.pad(time=(2, 3), samples=samples)
# Check if the padding is applied correctly
original_shape = random_patch.shape
new_shape = padded_patch.shape
time_axis = random_patch.dims.index("time")
assert new_shape[time_axis] == original_shape[time_axis] + 5
# Ensure that padded values are zeros
assert np.all(padded_patch.select(time=(None, 2), samples=samples).data == 0)
assert np.all(padded_patch.select(time=(-3, None), samples=samples).data == 0)

def test_pad_distance_dimension(self, random_patch):
"""Test padding the distance with same number of zeros on both sides."""
padded_patch = random_patch.pad(distance=7)
original_shape = random_patch.shape
new_shape = padded_patch.shape
distance_axis = random_patch.dims.index("distance")
ch_spacing = random_patch.attrs["distance_step"]
assert (
new_shape[distance_axis] == original_shape[distance_axis] + 14 * ch_spacing
)
# Ensure that padded values are zeros
assert np.all(padded_patch.select(distance=(None, 7), samples=True).data == 0)
assert np.all(padded_patch.select(distance=(-7, None), samples=True).data == 0)

def test_pad_distance_dimension_expand_coords(
self, random_patch, expand_coords=True
):
"""Test padding the distance with same number of zeros on both sides."""
padded_patch = random_patch.pad(distance=4, expand_coords=expand_coords)
original_shape = random_patch.shape
new_shape = padded_patch.shape
distance_axis = random_patch.dims.index("distance")
ch_spacing = random_patch.attrs["distance_step"]
dist_max = random_patch.attrs["distance_max"]
assert (
new_shape[distance_axis] == original_shape[distance_axis] + 8 * ch_spacing
)
# Ensure that padded values are zeros
assert np.all(padded_patch.select(distance=(None, -1)).data == 0)
assert np.all(padded_patch.select(distance=(dist_max + 1, None)).data == 0)

def test_pad_multiple_dimensions_samples_true(self, random_patch, samples=True):
"""Test padding multiple dimensions with different pad values."""
padded_patch = random_patch.pad(
time=(6, 7), distance=(1, 4), constant_values=np.pi, samples=samples
)
# Check dimensions individually
time_axis = random_patch.dims.index("time")
distance_axis = random_patch.dims.index("distance")
assert padded_patch.shape[time_axis] == random_patch.shape[time_axis] + 13
assert (
padded_patch.shape[distance_axis] == random_patch.shape[distance_axis] + 5
)
# Check padding values
assert np.allclose(
padded_patch.select(time=(None, 6), samples=samples).data, np.pi, atol=1e-6
)
assert np.allclose(
padded_patch.select(time=(-7, None), samples=samples).data, np.pi, atol=1e-6
)
assert np.allclose(
padded_patch.select(distance=(None, 1), samples=samples).data,
np.pi,
atol=1e-6,
)
assert np.allclose(
padded_patch.select(distance=(-4, None), samples=samples).data,
np.pi,
atol=1e-6,
)

def test_error_on_sequence_constant_values(self, random_patch):
"""Test that providing a sequence for constant_values raises a TypeError."""
with pytest.raises(TypeError):
random_patch.pad(time=(0, 5), constant_values=(0, 0))

0 comments on commit d6d757b

Please sign in to comment.