From 2f4e8b5090ff2c5209e81481117baa55028e6ae4 Mon Sep 17 00:00:00 2001 From: Mohammad <91499974+m15kh@users.noreply.github.com> Date: Fri, 8 Nov 2024 21:19:54 +0330 Subject: [PATCH] added method for a save wavform as wav format --- src/diart/inference.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/diart/inference.py b/src/diart/inference.py index a09593d3..5925aee7 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -14,6 +14,8 @@ from pyannote.metrics.base import BaseMetric from rx.core import Observer from tqdm import tqdm +import soundfile as sf + from . import blocks from . import operators as dops @@ -159,7 +161,7 @@ def _close_chronometer(self): def attach_hooks( self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None] ): - """Attach hooks to the pipeline. + """Attach hooks to the pipeline. Parameters ---------- @@ -168,6 +170,31 @@ def attach_hooks( """ self.stream = self.stream.pipe(*[ops.do_action(hook) for hook in hooks]) + + def save_waveform_hook(self, save_path: str = "output"): + """Create a hook function to save waveform data. + + Parameters + ---------- + save_path: str + The directory path where waveforms will be saved. + """ + count = 0 + + def save_waveform(results): + nonlocal count + prediction, waveform = results + + if prediction: + filename = f"{save_path}/waveform{count}.wav" + sf.write(filename, waveform.data, samplerate=16000) + print(f"Waveform saved to {filename}") + count += 1 + + return save_waveform + + + def attach_observers(self, *observers: Observer): """Attach rx observers to the pipeline.