Skip to content

Commit

Permalink
Modify chunk, select, and waterfall (#392)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: derrick chambers <djachambeador@gmail.com>
  • Loading branch information
ahmadtourei and d-chambers authored Jun 7, 2024
1 parent bca67c4 commit b2cd6a1
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 12 deletions.
11 changes: 5 additions & 6 deletions dascore/core/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,6 @@ def get_slice_tuple(
An object for determining select range.
"""
select_tuple = sanitize_range_param(select)

p1, p2 = (
self._get_compatible_value(x, relative=relative) for x in select_tuple
)
Expand Down Expand Up @@ -578,7 +577,7 @@ def get_sample_count(self, value, samples=False, enforce_lt_coord=False) -> int:
If True, value is already in units of samples.
"""
if not self.evenly_sampled:
msg = "Coordinate is not evenly sampled, cant get sample count."
msg = "Coordinate is not evenly sampled, can't get sample count."
raise CoordError(msg)
if samples:
if not isinstance(value, int | np.integer):
Expand Down Expand Up @@ -751,13 +750,13 @@ def select(self, args, relative=False) -> tuple[BaseCoord, slice | ArrayLike]:
if self.reverse_sorted:
start, stop = stop, start
# we add 1 to stop in slice since its upper limit is exclusive
out = slice(start, (stop + 1) if stop is not None else stop)
if self._slice_degenerate(out):
data = slice(start, (stop + 1) if stop is not None else stop)
if self._slice_degenerate(data):
return self.empty(), slice(0, 0)
new_start = self[start] if start is not None else self.start
new_end = self[stop] + self.step if stop is not None else self.stop
new = self.new(start=new_start, stop=new_end)
return new, out
new_coords = self.new(start=new_start, stop=new_end)
return new_coords, data

def sort(self, reverse=False) -> tuple[BaseCoord, slice | ArrayLike]:
"""Sort the contents of the coord. Return new coord and slice for sorting."""
Expand Down
1 change: 1 addition & 0 deletions dascore/core/spool.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def chunk(
chunker = ChunkManager(
overlap=overlap,
keep_partial=keep_partial,
snap_coords=snap_coords,
group_columns=self._group_columns,
tolerance=tolerance,
conflict=conflict,
Expand Down
21 changes: 21 additions & 0 deletions dascore/proc/select.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Function for querying Patchs."""
from __future__ import annotations

import numpy as np

from dascore.constants import PatchType
from dascore.exceptions import ParameterError
from dascore.utils.patch import patch_function


Expand Down Expand Up @@ -56,6 +59,24 @@ def select(
>>> # Select last time row/column
>>> new_distance2 = patch.select(time=-1, samples=True)
"""
# Raise error when samples=True but kwargs are not integers
if samples:
for value in kwargs.values():
start, stop = (
(value.start, value.stop)
if isinstance(value, slice)
else (value[0], value[1])
if isinstance(value, tuple)
else (value, value)
)

if not all(
isinstance(v, (int | np.integer | type(None) | type(Ellipsis)))
for v in (start, stop)
):
msg = "When samples=True, values must be integers."
raise ParameterError(msg)

new_coords, data = patch.coords.select(
**kwargs,
array=patch.data,
Expand Down
2 changes: 1 addition & 1 deletion dascore/proc/taper.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def taper(
{taper_type}.
**kwargs
Used to specify the dimension along which to taper and the percentage
of total length of the dimension or abolsute units. If a single value
of total length of the dimension or absolute units. If a single value
is passed, the taper will be applied to both ends. A length two tuple
can specify different values for each end, or no taper on one end.
Expand Down
4 changes: 3 additions & 1 deletion dascore/utils/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,21 +132,23 @@ class ChunkManager:
Notes
-----
This class is used internally by `dc.Spool.chunk`.
This class is used internally by `dc.BaseSpool.chunk`.
"""

def __init__(
self,
overlap: timeable_types | numeric_types | None = None,
group_columns: Collection[str] | None = None,
keep_partial=False,
snap_coords=True,
tolerance=1.5,
conflict="raise",
**kwargs,
):
self._overlap = overlap
self._group_columns = group_columns
self._keep_partials = keep_partial
self._snap_coords = snap_coords
self._tolerance = tolerance
self._name, self._value = self._validate_kwargs(kwargs)
self._attr_conflict = conflict
Expand Down
8 changes: 6 additions & 2 deletions dascore/utils/pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,12 @@ def get_interval_columns(df, name, arrays=False):
names = f"{name}_min", f"{name}_max", f"{name}_step"
missing_cols = set(names) - set(df.columns)
if missing_cols:
msg = f"Dataframe is missing {missing_cols} to chunk on {name}"
raise KeyError(msg)
dims = get_dim_names_from_columns(df)
msg = (
f"Cannot chunk spool or dataframe on {missing_cols}, "
f"valid dimensions or columns to chunk on are {dims}"
)
raise ParameterError(msg)
start, stop, step = df[names[0]], df[names[1]], df[names[2]]
if not arrays:
return start, stop, step
Expand Down
7 changes: 6 additions & 1 deletion dascore/viz/waterfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def waterfall(
cmap="bwr",
scale: float | Sequence[float] | None = None,
scale_type: Literal["relative", "absolute"] = "relative",
log=False,
show=False,
) -> plt.Axes:
"""
Expand All @@ -68,6 +69,8 @@ def waterfall(
are:
relative - scale based on half the dynamic range in patch
absolute - scale based on absolute values provided to `scale`
log
If True, visualize the common logarithm of the absolute values of patch data.
show
If True, show the plot, else just return axis.
Expand All @@ -80,7 +83,7 @@ def waterfall(
"""
ax = _get_ax(ax)
cmap = _get_cmap(cmap)
data = patch.data
data = np.log10(np.absolute(patch.data)) if log else patch.data
dims = patch.dims
assert len(dims) == 2, "Can only make waterfall plot of 2D Patch"
dims_r = tuple(reversed(dims))
Expand All @@ -102,6 +105,8 @@ def waterfall(
data_type = str(patch.attrs["data_type"])
data_units = get_quantity_str(patch.attrs.data_units) or ""
dunits = f" ({data_units})" if (data_type and data_units) else f"{data_units}"
if log:
dunits = f"{dunits} - log_10"
label = f"{data_type}{dunits}"
cb.set_label(label)
ax.invert_yaxis() # invert y axis so origin is at top
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core/test_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def test_slice_with_step_raises(self, evenly_sampled_coord):
with pytest.raises(ParameterError, match=match):
evenly_sampled_coord.select(slice(1, 10, 2))

def test_slice_with_step(self, evenly_sampled_coord):
def test_slice_works_as_tuple(self, evenly_sampled_coord):
"""Ensure slice works like tuple."""
coord = evenly_sampled_coord
vmin, vmax = coord.min(), coord.max()
Expand Down
24 changes: 24 additions & 0 deletions tests/test_proc/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import dascore as dc
from dascore.exceptions import ParameterError


class TestSelect:
Expand Down Expand Up @@ -121,6 +122,29 @@ def test_iselect_deprecated(self, random_patch):
with pytest.warns(DeprecationWarning, match=msg):
_ = random_patch.iselect(time=(10, -10))

def test_dist_float_values_samples_true(self, random_patch, samples=True):
"""Ensure range values are integer when samples=True."""
start = 1.1
end = 5.7
with pytest.raises(ParameterError, match="When samples=True"):
random_patch.select(distance=(start, end), samples=samples)

def test_time_float_values_samples_true(self, random_patch, samples=True):
"""Ensure range values are integer when samples=True."""
start = 1.1
end = int(5.7)
with pytest.raises(ParameterError, match="When samples=True"):
random_patch.select(time=(start, end), samples=samples)

def test_float_values_samples_true(self, random_patch, samples=True):
"""Ensure range values are integer when samples=True."""
start = int(1.1)
end = 5.7
with pytest.raises(ParameterError, match="When samples=True"):
random_patch.select(
time=(start, end), distance=(start, end), samples=samples
)


class TestSelectHistory:
"""Test behavior of history added by select."""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_utils/test_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ def test_raises_zero_length_chunk(self):
with pytest.raises(ParameterError, match="must be greater than 0"):
ChunkManager(time=0)

def test_raises_invalid_key_in_kwargs(self, contiguous_df):
"""Ensure an invalid key in kwargs raises an error."""
chunk_manager = ChunkManager(Time=10)
chunk_manager.patch = type("Patch", (object,), {"dims": ["time", "distance"]})()
with pytest.raises(ParameterError, match="Cannot chunk spool or"):
chunk_manager.chunk(contiguous_df)


class TestChunkToMerge:
"""Tests for using chunking to merge contiguous, or overlapping, data."""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_viz/test_waterfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import dascore as dc
from dascore.units import get_quantity_str
from dascore.utils.time import is_datetime64, to_timedelta64


Expand Down Expand Up @@ -156,3 +157,24 @@ def test_show(self, random_patch, monkeypatch):
"""Ensure show path is callable."""
monkeypatch.setattr(plt, "show", lambda: None)
random_patch.viz.waterfall(show=True)

def test_log(self, random_patch):
"""Ensure log is callable."""
ax = random_patch.viz.waterfall(log=True)

# Retrieve the colorbar label
cb = ax.get_figure().get_axes()[-1]
cb_label = cb.get_ylabel()

# Retrieve the expected data type and data units
data_type = str(random_patch.attrs["data_type"])
data_units = get_quantity_str(random_patch.attrs.data_units) or ""
expected_dunits = f" ({data_units})" if data_units else ""

# Construct the expected label
expected_label = f"{data_type}{expected_dunits} - log_10"

# Check if the colorbar label matches the expected label
assert (
cb_label == expected_label
), f"Expected '{expected_label}', but got '{cb_label}'"

0 comments on commit b2cd6a1

Please sign in to comment.