Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added method for a save wavform as wav format #251

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/diart/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from pyannote.metrics.base import BaseMetric
from rx.core import Observer
from tqdm import tqdm
import soundfile as sf


from . import blocks
from . import operators as dops
Expand Down Expand Up @@ -159,7 +161,7 @@ def _close_chronometer(self):
def attach_hooks(
self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None]
):
"""Attach hooks to the pipeline.
"""Attach hooks to the pipeline.

Parameters
----------
Expand All @@ -168,6 +170,31 @@ def attach_hooks(
"""
self.stream = self.stream.pipe(*[ops.do_action(hook) for hook in hooks])


def save_waveform_hook(self, save_path: str = "output"):
"""Create a hook function to save waveform data.

Parameters
----------
save_path: str
The directory path where waveforms will be saved.
"""
count = 0

def save_waveform(results):
nonlocal count
prediction, waveform = results

if prediction:
filename = f"{save_path}/waveform{count}.wav"
sf.write(filename, waveform.data, samplerate=16000)
print(f"Waveform saved to {filename}")
count += 1

return save_waveform



def attach_observers(self, *observers: Observer):
"""Attach rx observers to the pipeline.

Expand Down
Loading