diff --git a/dascore/io/wav/core.py b/dascore/io/wav/core.py index d457fd29..c6d6fd92 100644 --- a/dascore/io/wav/core.py +++ b/dascore/io/wav/core.py @@ -26,9 +26,9 @@ def write( Parameters ---------- resource - If a path that ends with .wav, write all the distance channels - to a single file. If not, assume the path is a directory and write - each distance channel to its own wav file. + If a path that ends with .wav, write all non-time channels + to a single file. If not, assume the path is a directory and + write each non-time channel to its own wav file. resample_frequency A resample frequency in Hz. If None, do not perform resampling. Often DAS has non-int sampling rates, so the default resampling @@ -45,7 +45,7 @@ def write( and normalized before writing. - If a single wavefile is specified with the path argument, and - the output the patch has more than one len along the distance + the output the patch has more than one len along the non-time dimension, a multi-channel wavefile is created. There may be some players that do not support multi-channel wavefiles. @@ -62,14 +62,8 @@ def write( write(filename=str(resource), rate=int(sr), data=data) else: # write data to directory, one file for each non-time resource.mkdir(exist_ok=True, parents=True) - non_time_name = next( - iter( - set(patch.dims) - - { - "time", - } - ) - ) + non_time_set = set(patch.dims) - {"time"} + non_time_name = next(iter(non_time_set)) non_time = patch.coords.get_array(non_time_name) for ind, val in enumerate(non_time): sub_data = np.take(data, ind, axis=1) diff --git a/tests/test_io/test_wav/test_wav.py b/tests/test_io/test_wav/test_wav.py index 62043eb0..c3a91dba 100644 --- a/tests/test_io/test_wav/test_wav.py +++ b/tests/test_io/test_wav/test_wav.py @@ -9,6 +9,8 @@ import dascore as dc +ONE_SECOND = dc.to_timedelta64(1) + class TestWriteWav: """Tests for writing wav format to disk.""" @@ -58,3 +60,12 @@ def test_write_non_distance_dims( patch = audio_patch_non_distance_dim patch.io.write(path, "wav") assert path.exists() + # Verify number of WAV files + wavs = list(path.rglob("*.wav")) + assert len(wavs) == len(patch.coords.get_array("microphone")) + # Verify file naming + for mic_val in patch.coords.get_array("microphone"): + assert path / f"microphone_{mic_val}.wav" in wavs + # Verify content of first file + sr, data = read_wav(str(wavs[0])) + assert sr == int(ONE_SECOND / patch.get_coord("time").step)