Skip to content

Commit

Permalink
add wav write support for patches with non-distance dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
d-chambers committed Feb 7, 2025
1 parent 4117475 commit 0172c51
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
25 changes: 18 additions & 7 deletions dascore/io/wav/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,36 @@ def write(
data, sr = self._get_wav_data(patch, resample_frequency)
if resource.name.endswith(".wav"):
write(filename=str(resource), rate=int(sr), data=data)
else: # write data to directory, one file for each distance
else: # write data to directory, one file for each non-time
resource.mkdir(exist_ok=True, parents=True)
distances = patch.coords.get_array("distance")
for ind, dist in enumerate(distances):
non_time_name = next(
iter(
set(patch.dims)
- {
"time",
}
)
)
non_time = patch.coords.get_array(non_time_name)
for ind, val in enumerate(non_time):
sub_data = np.take(data, ind, axis=1)
sub_path = resource / f"{dist}.wav"
sub_path = resource / f"{non_time_name}_{val}.wav"
write(filename=str(sub_path), rate=int(sr), data=sub_data)

@staticmethod
def _get_wav_data(patch, resample):
"""Pre-condition patch data for writing. Return array and sample rate."""
check_patch_coords(patch, ("time", "distance"))
# Ensure we have a 2D patch which has a time dimension.
check_patch_coords(patch, ("time",))
assert len(patch.dims) == 2, "only 2D patches supported for this function."
time = patch.get_coord("time").step

# handle resampling and normalization
pat = patch.transpose("time", "distance")
pat = patch.transpose("time", ...)
if resample is not None:
pat = pat.resample(time=1 / resample)
# normalize and detrend
pat = pat.detrend("time", "linear").normalize("time", norm="max")
data = pat.data
sample_rate = resample or np.round(ONE_SECOND / pat.attrs["time_step"])
sample_rate = resample or np.round(ONE_SECOND / time)
return data.astype(np.float32), int(sample_rate)
15 changes: 15 additions & 0 deletions tests/test_io/test_wav/test_wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def wave_dir(self, audio_patch, tmp_path_factory):
dc.write(audio_patch, new, "wav")
return new

@pytest.fixture(scope="class")
def audio_patch_non_distance_dim(self, audio_patch):
"""Create a patch that has a non-distance dimension in addition to time."""
patch = audio_patch.rename_coords(distance="microphone")
return patch

def test_directory(self, wave_dir, audio_patch):
"""Sanity checks on wav directory."""
assert wave_dir.exists()
Expand All @@ -43,3 +49,12 @@ def test_resample(self, audio_patch, tmp_path_factory):
dc.write(audio_patch, path, "wav", resample_frequency=1000)
(sr, ar) = read_wav(str(path))
assert sr == 1000

def test_write_non_distance_dims(
self, audio_patch_non_distance_dim, tmp_path_factory
):
"""Ensure any non-time dimension still works."""
path = tmp_path_factory.mktemp("wav_resample")
patch = audio_patch_non_distance_dim
patch.io.write(path, "wav")
assert path.exists()

0 comments on commit 0172c51

Please sign in to comment.