diff --git a/dascore/core/spool.py b/dascore/core/spool.py index 796938a9..a9ef01e0 100644 --- a/dascore/core/spool.py +++ b/dascore/core/spool.py @@ -14,6 +14,7 @@ from typing_extensions import Self import dascore as dc +from dascore.compat import is_array from dascore.constants import ( PROGRESS_LEVELS, WARN_LEVELS, @@ -56,7 +57,7 @@ class BaseSpool(abc.ABC): _rich_style = "bold" @abc.abstractmethod - def __getitem__(self, item: int) -> PatchType: + def __getitem__(self, item: int | slice | np.ndarray) -> PatchType: """Returns a patch from the spool.""" @abc.abstractmethod @@ -376,6 +377,27 @@ def __init__( self._select_kwargs = {} if select_kwargs is None else select_kwargs self._merge_kwargs = {} if merge_kwargs is None else merge_kwargs + def _select_from_array(self, array) -> Self: + """Create new spool with contents changed from array input.""" + if np.issubdtype(array.dtype, np.bool_): # boolean select + df = self._df[array] + elif np.issubdtype(array.dtype, np.integer): + df = self._df.iloc[array] + else: + msg = "Only bool or int dtypes are supported for spool array selection." + raise ValueError(msg) + source = self._source_df + inst = self._instruction_df + select_kwargs, merge_kwargs = self._select_kwargs, self._merge_kwargs + new = self.new_from_df( + df, + source_df=source, + instruction_df=inst, + select_kwargs=select_kwargs, + merge_kwargs=merge_kwargs, + ) + return new + def __getitem__(self, item) -> PatchType | BaseSpool: if isinstance(item, slice): # a slice was used, return a sub-spool new_df = self._df.iloc[item] @@ -387,6 +409,8 @@ def __getitem__(self, item) -> PatchType | BaseSpool: instruction_df=new_inst, source_df=new_source, ) + elif is_array(item): # An array was passed use np type selection. + return self._select_from_array(np.asarray(item)) else: # a single index was used, should return a single patch out = self._unbox_patch(self._get_patches_from_index(item)) return out diff --git a/docs/tutorial/spool.qmd b/docs/tutorial/spool.qmd index 657f3b12..53106fe5 100644 --- a/docs/tutorial/spool.qmd +++ b/docs/tutorial/spool.qmd @@ -101,6 +101,37 @@ for patch in spool: new_spool = spool[1:] ``` +An array can also be used (just like numpy) to select/re-arrange spool contents. For example, a boolean array can be used to de-select patches: + +```{python} +import dascore as dc +import numpy as np + +spool = dc.get_example_spool() + +# Get bool array, true values indicate patch is kept, false is discarded. +bool_array = np.ones(len(spool), dtype=np.bool_) +bool_array[1] = False + +# Remove patch at position 1 from spool. +new = spool[bool_array] +``` + +and an integer array can be used to deselect/rearrange patches + +```{python} +import dascore as dc +import numpy as np + +spool = dc.get_example_spool() + +# Get an array of integers which indicate the index of included patches +bool_array = np.array([2, 0]) + +# create a new spool with patch 2 and patch 0. +new = spool[bool_array] +``` + # get_contents The [`get_contents`](`dascore.core.spool.BaseSpool.get_contents`) method returns a dataframe listing the spool contents. This method may not be supported on all spools, especially those interfacing with large remote resources. diff --git a/tests/test_core/test_spool.py b/tests/test_core/test_spool.py index 5fdea624..4eace752 100644 --- a/tests/test_core/test_spool.py +++ b/tests/test_core/test_spool.py @@ -179,6 +179,62 @@ def test_skip_slice(self, random_spool): assert new_spool[1].equals(random_spool[2]) +class TestSpoolBoolArraySelect: + """Tests for selecting patches using a boolean array.""" + + def test_bool_all_true(self, random_spool): + """All True should return an equal spool.""" + bool_array = np.ones(len(random_spool), dtype=np.bool_) + out = random_spool[bool_array] + assert out == random_spool + + def test_bool_all_false(self, random_spool): + """All False should return an empty spool.""" + bool_array = np.zeros(len(random_spool), dtype=np.bool_) + out = random_spool[bool_array] + assert len(out) == 0 + + def test_bool_some_true(self, random_spool): + """Some true values should return a spool with some values.""" + bool_array = np.ones(len(random_spool), dtype=np.bool_) + bool_array[1] = False + out = random_spool[bool_array] + assert len(out) == sum(bool_array) + df1 = out.get_contents() + df2 = random_spool.get_contents()[bool_array] + assert df1.equals(df2) + + +class TestSpoolIntArraySelect: + """Tests for selecting patches using an integer array.""" + + def test_uniform(self, random_spool): + """A uniform monotonic increasing array should return same spool.""" + array = np.arange(len(random_spool)) + spool = random_spool[array] + assert spool == random_spool + + def test_out_of_bounds_raises(self, random_spool): + """Ensure int values gt the spool len raises.""" + array = np.arange(len(random_spool)) + array[0] = len(random_spool) + 10 + with pytest.raises(IndexError): + random_spool[array] + + def test_bad_array_type(self, random_spool): + """Ensure a non-index or int array raises.""" + array = np.arange(len(random_spool)) + 0.01 + with pytest.raises(ValueError, match="Only bool or int dtypes"): + random_spool[array] + + def test_rearrange(self, random_spool): + """Ensure patch order can be changed.""" + array = np.array([len(random_spool) - 1, 0]) + out = random_spool[array] + assert out[0] == random_spool[-1] + assert out[-1] == random_spool[0] + + class TestSpoolIterable: """Tests for iterating Spools."""