From 0172c51c2587b39c75e74a0e3a56ce46afec8049 Mon Sep 17 00:00:00 2001 From: Derrick Chambers Date: Fri, 7 Feb 2025 10:43:42 -0800 Subject: [PATCH] add wav write support for patches with non-distance dimensions --- dascore/io/wav/core.py | 25 ++++++++++++++++++------- tests/test_io/test_wav/test_wav.py | 15 +++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/dascore/io/wav/core.py b/dascore/io/wav/core.py index ee95677f..d457fd29 100644 --- a/dascore/io/wav/core.py +++ b/dascore/io/wav/core.py @@ -60,25 +60,36 @@ def write( data, sr = self._get_wav_data(patch, resample_frequency) if resource.name.endswith(".wav"): write(filename=str(resource), rate=int(sr), data=data) - else: # write data to directory, one file for each distance + else: # write data to directory, one file for each non-time resource.mkdir(exist_ok=True, parents=True) - distances = patch.coords.get_array("distance") - for ind, dist in enumerate(distances): + non_time_name = next( + iter( + set(patch.dims) + - { + "time", + } + ) + ) + non_time = patch.coords.get_array(non_time_name) + for ind, val in enumerate(non_time): sub_data = np.take(data, ind, axis=1) - sub_path = resource / f"{dist}.wav" + sub_path = resource / f"{non_time_name}_{val}.wav" write(filename=str(sub_path), rate=int(sr), data=sub_data) @staticmethod def _get_wav_data(patch, resample): """Pre-condition patch data for writing. Return array and sample rate.""" - check_patch_coords(patch, ("time", "distance")) + # Ensure we have a 2D patch which has a time dimension. + check_patch_coords(patch, ("time",)) assert len(patch.dims) == 2, "only 2D patches supported for this function." + time = patch.get_coord("time").step + # handle resampling and normalization - pat = patch.transpose("time", "distance") + pat = patch.transpose("time", ...) if resample is not None: pat = pat.resample(time=1 / resample) # normalize and detrend pat = pat.detrend("time", "linear").normalize("time", norm="max") data = pat.data - sample_rate = resample or np.round(ONE_SECOND / pat.attrs["time_step"]) + sample_rate = resample or np.round(ONE_SECOND / time) return data.astype(np.float32), int(sample_rate) diff --git a/tests/test_io/test_wav/test_wav.py b/tests/test_io/test_wav/test_wav.py index 99670b9a..62043eb0 100644 --- a/tests/test_io/test_wav/test_wav.py +++ b/tests/test_io/test_wav/test_wav.py @@ -25,6 +25,12 @@ def wave_dir(self, audio_patch, tmp_path_factory): dc.write(audio_patch, new, "wav") return new + @pytest.fixture(scope="class") + def audio_patch_non_distance_dim(self, audio_patch): + """Create a patch that has a non-distance dimension in addition to time.""" + patch = audio_patch.rename_coords(distance="microphone") + return patch + def test_directory(self, wave_dir, audio_patch): """Sanity checks on wav directory.""" assert wave_dir.exists() @@ -43,3 +49,12 @@ def test_resample(self, audio_patch, tmp_path_factory): dc.write(audio_patch, path, "wav", resample_frequency=1000) (sr, ar) = read_wav(str(path)) assert sr == 1000 + + def test_write_non_distance_dims( + self, audio_patch_non_distance_dim, tmp_path_factory + ): + """Ensure any non-time dimension still works.""" + path = tmp_path_factory.mktemp("wav_resample") + patch = audio_patch_non_distance_dim + patch.io.write(path, "wav") + assert path.exists()