From 884c41fe77041a3cc7047edc53cb5c33f4055aa8 Mon Sep 17 00:00:00 2001 From: Derrick Chambers Date: Fri, 20 Dec 2024 23:30:09 -0800 Subject: [PATCH] fix_463 (#464) * fix_463 * neubrex decimation check (#465) - also fixes CI mamba setup. * segyio dependency issue note made next to install instructions * clarify fetch in tutorials (#468) * fix_463 --------- Co-authored-by: eileenrmartin Co-authored-by: Ahmad Tourei <92008628+ahmadtourei@users.noreply.github.com> --- dascore/proc/filter.py | 9 ++++++ dascore/units.py | 37 +++++++++++++++++++++++++ tests/test_proc/test_filter.py | 10 ++++++- tests/test_units.py | 50 ++++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 1 deletion(-) diff --git a/dascore/proc/filter.py b/dascore/proc/filter.py index dc814cc9..5df2b6b0 100644 --- a/dascore/proc/filter.py +++ b/dascore/proc/filter.py @@ -26,6 +26,7 @@ get_filter_units, get_inverted_quant, invert_quantity, + quant_sequence_to_quant_array, ) from dascore.utils.docs import compose_docstring from dascore.utils.misc import ( @@ -562,6 +563,14 @@ def _get_slope_array(dft_patch, directional, freq_dims): def _maybe_transform_units(filt, dft_patch, freq_dims): """Handle units on filter.""" + # Hand the units/partial units in sequence. + units = getattr(filt, "units", None) + try: + filt = np.array(filt) + except ValueError: + filt = quant_sequence_to_quant_array(filt) + if units: + filt = filt * dc.get_quantity(units) if not isinstance(filt, dc.units.Quantity): return filt array, units = filt.magnitude, filt.units diff --git a/dascore/units.py b/dascore/units.py index 4b64c4fe..f573c9be 100644 --- a/dascore/units.py +++ b/dascore/units.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Sequence from functools import cache from typing import TypeVar @@ -11,6 +12,7 @@ from pint import DimensionalityError, Quantity, UndefinedUnitError, Unit import dascore as dc +from dascore.compat import is_array from dascore.exceptions import UnitError from dascore.utils.misc import unbyte from dascore.utils.time import dtype_time_like, is_datetime64, is_timedelta64, to_float @@ -304,6 +306,41 @@ def _check_to_units(to_unit, dim): return out1, out2 +def quant_sequence_to_quant_array(sequence: Sequence[Quantity]) -> Quantity: + """ + Convert a sequence of Quantities (eg list) to a Quantity array. + + Will simplify all quantities. Raises an error if not all elements have + the same units. + + Parameters + ---------- + sequence + A sequence of Quantities. + + Notes + ----- + This is probably not efficient for large lists. + """ + if is_array(sequence): + # This is a numpy array, just return multiplied by quantity. + return sequence * get_quantity("dimensionless") + # iterate the sequence and manually convert to base units. + try: + base_unit_sequence = [x.to_base_units() for x in sequence] + except AttributeError: + msg = "Not all values in sequence are quantities." + raise UnitError(msg) + if not len(base_unit_sequence): + return np.array([]) * get_quantity("dimensionless") + units = {x.units for x in base_unit_sequence} + if len(units) != 1: + msg = "Not all values in sequence have compatible units." + raise UnitError(msg) + array = np.array([x.magnitude for x in base_unit_sequence]) + return array * next(iter(units)) + + def __getattr__(name): """ Allows arbitrary units (quantities) to be imported from this module. diff --git a/tests/test_proc/test_filter.py b/tests/test_proc/test_filter.py index 8ad275c2..9551a295 100644 --- a/tests/test_proc/test_filter.py +++ b/tests/test_proc/test_filter.py @@ -451,7 +451,7 @@ def test_bad_dims(self, example_patch): patch.slope_filter(filt=filt, dims=("time", "distance")) def test_units_raise_no_unit_coords(self, example_patch): - """Ensure A UnitError is raised if one of hte coords does't have units.""" + """Ensure A UnitError is raised if one of the coords doesn't have units.""" patch = example_patch.set_units(distance="") filt = np.array([1e3, 1.5e3, 5e3, 10e3]) * get_unit("m/s") with pytest.raises(UnitError): @@ -476,3 +476,11 @@ def test_inverted_units(self, example_patch): out1 = example_patch.slope_filter(filt=slowness) out2 = example_patch.slope_filter(filt=filt * get_unit("m/s")) assert np.allclose(out1.data, out2.data) + + def test_units_list(self, example_patch): + """Ensure units as a list still work (see #463).""" + speed = 5_000 * dc.get_quantity("m/s") + filt = [speed * 0.90, speed * 0.95, speed * 1.05, speed * 1.1] + # The test passes if this line doesn't raise an error. + out = example_patch.slope_filter(filt) + assert isinstance(out, dc.Patch) diff --git a/tests/test_units.py b/tests/test_units.py index 5e6c4c07..34946ce2 100644 --- a/tests/test_units.py +++ b/tests/test_units.py @@ -17,6 +17,7 @@ get_quantity_str, get_unit, invert_quantity, + quant_sequence_to_quant_array, ) @@ -303,3 +304,52 @@ def test_array_quantity(self): array = np.arange(10) * get_quantity("m") out = convert_units(array, to_units="ft") np.allclose(array.magnitude, out * 3.28084) + + +class TestQuantSequenceToQuantArray: + """Ensure we can convert a quantity sequence to an array.""" + + def test_valid_sequence_same_units(self): + """Test with a valid sequence of quantities with the same units.""" + meter = get_quantity("m") + sequence = [1 * meter, 2 * meter, 3 * meter] + result = quant_sequence_to_quant_array(sequence) + expected = np.array([1, 2, 3]) * meter + np.testing.assert_array_equal(result.magnitude, expected.magnitude) + assert result.units == expected.units + + def test_valid_sequence_different_units(self): + """Test sequence of quantities with compatible but different units.""" + m, cm, km = get_quantity("m"), get_quantity("cm"), get_quantity("km") + + sequence = [1 * m, 100 * cm, 0.001 * km] + result = quant_sequence_to_quant_array(sequence) + expected = np.array([1, 1, 1]) * m + assert np.allclose(result.magnitude, expected.magnitude) + assert result.units == expected.units + + def test_incompatible_units(self): + """Test with a sequence of quantities with incompatible units.""" + sequence = [1 * get_quantity("m"), 1 * get_quantity("s")] + msg = "Not all values in sequence have compatible units." + with pytest.raises(UnitError, match=msg): + quant_sequence_to_quant_array(sequence) + + def test_non_quantity_elements(self): + """Test with a sequence containing non-quantity elements.""" + sequence = [1 * get_quantity("m"), 5] + msg = "Not all values in sequence are quantities." + with pytest.raises(UnitError, match=msg): + quant_sequence_to_quant_array(sequence) + + def test_empty_sequence(self): + """Test with an empty sequence.""" + sequence = [] + out = quant_sequence_to_quant_array(sequence) + assert isinstance(out, Quantity) + + def test_numpy_array_input(self): + """Test with a numpy array input.""" + sequence = np.array([1, 2, 3]) + out = quant_sequence_to_quant_array(sequence) + assert isinstance(out, Quantity)