Skip to content

Commit

Permalink
fix_463
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Dec 1, 2024
1 parent ad2ecb5 commit 852bb57
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 1 deletion.
9 changes: 9 additions & 0 deletions dascore/proc/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_filter_units,
get_inverted_quant,
invert_quantity,
quant_sequence_to_quant_array,
)
from dascore.utils.docs import compose_docstring
from dascore.utils.misc import (
Expand Down Expand Up @@ -562,6 +563,14 @@ def _get_slope_array(dft_patch, directional, freq_dims):

def _maybe_transform_units(filt, dft_patch, freq_dims):
"""Handle units on filter."""
# Hand the units/partial units in sequence.
units = getattr(filt, "units", None)
try:
filt = np.array(filt)
except ValueError:
filt = quant_sequence_to_quant_array(filt)
if units:
filt = filt * dc.get_quantity(units)
if not isinstance(filt, dc.units.Quantity):
return filt
array, units = filt.magnitude, filt.units
Expand Down
37 changes: 37 additions & 0 deletions dascore/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections.abc import Sequence
from functools import cache
from typing import TypeVar

Expand All @@ -11,6 +12,7 @@
from pint import DimensionalityError, Quantity, UndefinedUnitError, Unit

import dascore as dc
from dascore.compat import is_array
from dascore.exceptions import UnitError
from dascore.utils.misc import unbyte
from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float
Expand Down Expand Up @@ -304,6 +306,41 @@ def _check_to_units(to_unit, dim):
return out1, out2


def quant_sequence_to_quant_array(sequence: Sequence[Quantity]) -> Quantity:
"""
Convert a sequence of Quantities (eg list) to a Quantity array.
Will simplify all quantities. Raises an error if not all elements have
the same units.
Parameters
----------
sequence
A sequence of Quantities.
Notes
-----
This is probably not efficient for large lists.
"""
if is_array(sequence):
# This is a numpy array, just return multiplied by quantity.
return sequence * get_quantity("dimensionless")
# iterate the sequence and manually convert to base units.
try:
base_unit_sequence = [x.to_base_units() for x in sequence]
except AttributeError:
msg = "Not all values in sequence are quantities."
raise UnitError(msg)
if not len(base_unit_sequence):
return np.array([]) * get_quantity("dimensionless")
units = {x.units for x in base_unit_sequence}
if len(units) != 1:
msg = "Not all values in sequence have compatible units."
raise UnitError(msg)
array = np.array([x.magnitude for x in base_unit_sequence])
return array * next(iter(units))


def __getattr__(name):
"""
Allows arbitrary units (quantities) to be imported from this module.
Expand Down
10 changes: 9 additions & 1 deletion tests/test_proc/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def test_bad_dims(self, example_patch):
patch.slope_filter(filt=filt, dims=("time", "distance"))

def test_units_raise_no_unit_coords(self, example_patch):
"""Ensure A UnitError is raised if one of hte coords does't have units."""
"""Ensure A UnitError is raised if one of the coords doesn't have units."""
patch = example_patch.set_units(distance="")
filt = np.array([1e3, 1.5e3, 5e3, 10e3]) * get_unit("m/s")
with pytest.raises(UnitError):
Expand All @@ -476,3 +476,11 @@ def test_inverted_units(self, example_patch):
out1 = example_patch.slope_filter(filt=slowness)
out2 = example_patch.slope_filter(filt=filt * get_unit("m/s"))
assert np.allclose(out1.data, out2.data)

def test_units_list(self, example_patch):
"""Ensure units as a list still work (see #463)."""
speed = 5_000 * dc.get_quantity("m/s")
filt = [speed * 0.90, speed * 0.95, speed * 1.05, speed * 1.1]
# The test passes if this line doesn't raise an error.
out = example_patch.slope_filter(filt)
assert isinstance(out, dc.Patch)
50 changes: 50 additions & 0 deletions tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_quantity_str,
get_unit,
invert_quantity,
quant_sequence_to_quant_array,
)


Expand Down Expand Up @@ -303,3 +304,52 @@ def test_array_quantity(self):
array = np.arange(10) * get_quantity("m")
out = convert_units(array, to_units="ft")
np.allclose(array.magnitude, out * 3.28084)


class TestQuantSequenceToQuantArray:
"""Ensure we can convert a quantity sequence to an array."""

def test_valid_sequence_same_units(self):
"""Test with a valid sequence of quantities with the same units."""
meter = get_quantity("m")
sequence = [1 * meter, 2 * meter, 3 * meter]
result = quant_sequence_to_quant_array(sequence)
expected = np.array([1, 2, 3]) * meter
np.testing.assert_array_equal(result.magnitude, expected.magnitude)
assert result.units == expected.units

def test_valid_sequence_different_units(self):
"""Test sequence of quantities with compatible but different units."""
m, cm, km = get_quantity("m"), get_quantity("cm"), get_quantity("km")

sequence = [1 * m, 100 * cm, 0.001 * km]
result = quant_sequence_to_quant_array(sequence)
expected = np.array([1, 1, 1]) * m
assert np.allclose(result.magnitude, expected.magnitude)
assert result.units == expected.units

def test_incompatible_units(self):
"""Test with a sequence of quantities with incompatible units."""
sequence = [1 * get_quantity("m"), 1 * get_quantity("s")]
msg = "Not all values in sequence have compatible units."
with pytest.raises(UnitError, match=msg):
quant_sequence_to_quant_array(sequence)

def test_non_quantity_elements(self):
"""Test with a sequence containing non-quantity elements."""
sequence = [1 * get_quantity("m"), 5]
msg = "Not all values in sequence are quantities."
with pytest.raises(UnitError, match=msg):
quant_sequence_to_quant_array(sequence)

def test_empty_sequence(self):
"""Test with an empty sequence."""
sequence = []
out = quant_sequence_to_quant_array(sequence)
assert isinstance(out, Quantity)

def test_numpy_array_input(self):
"""Test with a numpy array input."""
sequence = np.array([1, 2, 3])
out = quant_sequence_to_quant_array(sequence)
assert isinstance(out, Quantity)

0 comments on commit 852bb57

Please sign in to comment.