diff --git a/README.md b/README.md
index ef533946..ae13059f 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,7 @@
|
- 🤖 Custom models
+ 🤖 Add your model
|
@@ -110,32 +110,32 @@ See `diart.stream -h` for more options.
### From python
-Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk:
+Use `StreamingInference` to run a pipeline on an audio source and write the results to disk:
```python
-from diart import OnlineSpeakerDiarization
+from diart import SpeakerDiarization
from diart.sources import MicrophoneAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
from diart.sinks import RTTMWriter
-pipeline = OnlineSpeakerDiarization()
+pipeline = SpeakerDiarization()
mic = MicrophoneAudioSource(pipeline.config.sample_rate)
-inference = RealTimeInference(pipeline, mic, do_plot=True)
+inference = StreamingInference(pipeline, mic, do_plot=True)
inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
prediction = inference()
```
For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).
-## 🤖 Custom models
+## 🤖 Add your model
-Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses):
+Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):
```python
-from diart import OnlineSpeakerDiarization, PipelineConfig
+from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import EmbeddingModel, SegmentationModel
from diart.sources import MicrophoneAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
def model_loader():
@@ -168,19 +168,19 @@ class MyEmbeddingModel(EmbeddingModel):
return self.model(waveform, weights)
-config = PipelineConfig(
+config = SpeakerDiarizationConfig(
segmentation=MySegmentationModel(),
embedding=MyEmbeddingModel()
)
-pipeline = OnlineSpeakerDiarization(config)
+pipeline = SpeakerDiarization(config)
mic = MicrophoneAudioSource(config.sample_rate)
-inference = RealTimeInference(pipeline, mic)
+inference = StreamingInference(pipeline, mic)
prediction = inference()
```
## 📈 Tune hyper-parameters
-Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.
+Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs.
### From the command line
@@ -281,7 +281,7 @@ diart.serve --host 0.0.0.0 --port 7007
diart.client microphone --host --port 7007
```
-**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
+**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
See `-h` for more options.
@@ -290,13 +290,13 @@ See `-h` for more options.
For customized solutions, a server can also be created in python using the `WebSocketAudioSource`:
```python
-from diart import OnlineSpeakerDiarization
+from diart import SpeakerDiarization
from diart.sources import WebSocketAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
-pipeline = OnlineSpeakerDiarization()
+pipeline = SpeakerDiarization()
source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
-inference = RealTimeInference(pipeline, source)
+inference = StreamingInference(pipeline, source)
inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
prediction = inference()
```
@@ -354,14 +354,14 @@ or using the inference API:
```python
from diart.inference import Benchmark, Parallelize
-from diart import OnlineSpeakerDiarization, PipelineConfig
+from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.models import SegmentationModel
benchmark = Benchmark("/wav/dir", "/rttm/dir")
name = "pyannote/segmentation@Interspeech2021"
segmentation = SegmentationModel.from_pyannote(name)
-config = PipelineConfig(
+config = SpeakerDiarizationConfig(
# Set the model used in the paper
segmentation=segmentation,
step=0.5,
@@ -370,12 +370,12 @@ config = PipelineConfig(
rho_update=0.422,
delta_new=1.517
)
-benchmark(OnlineSpeakerDiarization, config)
+benchmark(SpeakerDiarization, config)
# Run the same benchmark in parallel
p_benchmark = Parallelize(benchmark, num_workers=4)
if __name__ == "__main__": # Needed for multiprocessing
- p_benchmark(OnlineSpeakerDiarization, config)
+ p_benchmark(SpeakerDiarization, config)
```
This pre-calculates model outputs in batches, so it runs a lot faster.
diff --git a/requirements.txt b/requirements.txt
index 50662023..3241ddc4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,6 +9,7 @@ pandas>=1.4.2
torch>=1.12.1
torchvision>=0.14.0
torchaudio>=0.12.1,<1.0
+torchmetrics>=0.11.1
pyannote.audio>=2.1.1
pyannote.core>=4.5
pyannote.database>=4.1.1
diff --git a/setup.cfg b/setup.cfg
index 594c876e..c70eac0b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -2,11 +2,11 @@
name=diart
version=0.7.0
author=Juan Manuel Coria
-description=Speaker diarization in real time
+description=Streaming speaker diarization in real-time
long_description=file: README.md
long_description_content_type=text/markdown
keywords=speaker diarization, streaming, online, real time, rxpy
-url=https://github.com/juanmc2005/StreamingSpeakerDiarization
+url=https://github.com/juanmc2005/diart
license=MIT
classifiers=
Development Status :: 4 - Beta
@@ -31,6 +31,7 @@ install_requires=
torch>=1.12.1
torchvision>=0.14.0
torchaudio>=0.12.1,<1.0
+ torchmetrics>=0.11.1
pyannote.audio>=2.1.1
pyannote.core>=4.5
pyannote.database>=4.1.1
diff --git a/src/diart/__init__.py b/src/diart/__init__.py
index c9692638..842ba267 100644
--- a/src/diart/__init__.py
+++ b/src/diart/__init__.py
@@ -1,6 +1,10 @@
-from .blocks import (
- OnlineSpeakerDiarization,
- BasePipeline,
+from .pipelines import (
+ Pipeline,
PipelineConfig,
- BasePipelineConfig,
+ SpeakerDiarization,
+ SpeakerDiarizationConfig,
+ VoiceActivityDetection,
+ VoiceActivityDetectionConfig,
+ Transcription,
+ TranscriptionConfig,
)
diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py
index 59a6ef36..96fae0e7 100644
--- a/src/diart/blocks/__init__.py
+++ b/src/diart/blocks/__init__.py
@@ -5,7 +5,7 @@
FirstOnlyStrategy,
DelayedAggregation,
)
-from .clustering import OnlineSpeakerClustering
+from .clustering import IncrementalSpeakerClustering
from .embedding import (
SpeakerEmbedding,
OverlappedSpeechPenalty,
@@ -13,6 +13,5 @@
OverlapAwareSpeakerEmbedding,
)
from .segmentation import SpeakerSegmentation
-from .diarization import OnlineSpeakerDiarization, BasePipeline
-from .config import BasePipelineConfig, PipelineConfig
from .utils import Binarize, Resample, AdjustVolume
+from .asr import SpeechRecognition
diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py
new file mode 100644
index 00000000..83dc0d90
--- /dev/null
+++ b/src/diart/blocks/asr.py
@@ -0,0 +1,66 @@
+from pathlib import Path
+from typing import Optional, Union, List, Text
+
+import torch
+from einops import rearrange
+
+from .. import models as m
+from ..features import TemporalFeatureFormatter, TemporalFeatures
+
+
+class SpeechRecognition:
+ def __init__(self, model: m.SpeechRecognitionModel, device: Optional[torch.device] = None):
+ self.model = model
+ self.model.eval()
+ self.device = device
+ if self.device is None:
+ self.device = torch.device("cpu")
+ self.model.to(self.device)
+ self.formatter = TemporalFeatureFormatter()
+
+ @staticmethod
+ def from_whisper(
+ name: Text,
+ download_path: Optional[Union[Text, Path]] = None,
+ in_memory: bool = False,
+ fp16: bool = False,
+ no_speech_threshold: float = 0.6,
+ compression_ratio_threshold: Optional[float] = 2.4,
+ logprob_threshold: Optional[float] = -1,
+ decode_with_fallback: bool = False,
+ device: Optional[Union[Text, torch.device]] = None,
+ ) -> 'SpeechRecognition':
+ asr_model = m.SpeechRecognitionModel.from_whisper(
+ name,
+ download_path,
+ in_memory,
+ fp16,
+ no_speech_threshold,
+ compression_ratio_threshold,
+ logprob_threshold,
+ decode_with_fallback,
+ )
+ return SpeechRecognition(asr_model, device)
+
+ def __call__(self, waveform: TemporalFeatures) -> List[m.TranscriptionResult]:
+ """
+ Compute the transcription of input audio.
+
+ Parameters
+ ----------
+ waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
+ Audio to transcribe
+
+ Returns
+ -------
+ transcriptions: List[Transcription]
+ A list of timestamped transcriptions
+ """
+ with torch.no_grad():
+ wave = rearrange(
+ self.formatter.cast(waveform),
+ "batch sample channel -> batch channel sample"
+ )
+ # output = self.model(wave.to(self.device)).cpu()
+ output = self.model(wave.to(self.device))
+ return output
diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py
index 882001b9..4b737175 100644
--- a/src/diart/blocks/clustering.py
+++ b/src/diart/blocks/clustering.py
@@ -7,7 +7,7 @@
from ..mapping import SpeakerMap, SpeakerMapBuilder
-class OnlineSpeakerClustering:
+class IncrementalSpeakerClustering:
"""Implements constrained incremental online clustering of speakers and manages cluster centers.
Parameters
diff --git a/src/diart/blocks/config.py b/src/diart/blocks/config.py
deleted file mode 100644
index d8e2a656..00000000
--- a/src/diart/blocks/config.py
+++ /dev/null
@@ -1,153 +0,0 @@
-from typing import Any, Optional, Union, Tuple
-
-import numpy as np
-import torch
-from typing_extensions import Literal
-
-from .. import models as m
-from .. import utils
-from ..audio import FilePath, AudioLoader
-
-
-class BasePipelineConfig:
- @property
- def duration(self) -> float:
- raise NotImplementedError
-
- @property
- def step(self) -> float:
- raise NotImplementedError
-
- @property
- def latency(self) -> float:
- raise NotImplementedError
-
- @property
- def sample_rate(self) -> int:
- raise NotImplementedError
-
- @staticmethod
- def from_dict(data: Any) -> 'BasePipelineConfig':
- raise NotImplementedError
-
- def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
- file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
- right = utils.get_padding_right(self.latency, self.step)
- left = utils.get_padding_left(file_duration + right, self.duration)
- return left, right
-
- def optimal_block_size(self) -> int:
- return int(np.rint(self.step * self.sample_rate))
-
-
-class PipelineConfig(BasePipelineConfig):
- def __init__(
- self,
- segmentation: Optional[m.SegmentationModel] = None,
- embedding: Optional[m.EmbeddingModel] = None,
- duration: Optional[float] = None,
- step: float = 0.5,
- latency: Optional[Union[float, Literal["max", "min"]]] = None,
- tau_active: float = 0.6,
- rho_update: float = 0.3,
- delta_new: float = 1,
- gamma: float = 3,
- beta: float = 10,
- max_speakers: int = 20,
- device: Optional[torch.device] = None,
- **kwargs,
- ):
- # Default segmentation model is pyannote/segmentation
- self.segmentation = segmentation
- if self.segmentation is None:
- self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
-
- # Default duration is the one given by the segmentation model
- self._duration = duration
-
- # Expected sample rate is given by the segmentation model
- self._sample_rate: Optional[int] = None
-
- # Default embedding model is pyannote/embedding
- self.embedding = embedding
- if self.embedding is None:
- self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
-
- # Latency defaults to the step duration
- self._step = step
- self._latency = latency
- if self._latency is None or self._latency == "min":
- self._latency = self._step
- elif self._latency == "max":
- self._latency = self._duration
-
- self.tau_active = tau_active
- self.rho_update = rho_update
- self.delta_new = delta_new
- self.gamma = gamma
- self.beta = beta
- self.max_speakers = max_speakers
-
- self.device = device
- if self.device is None:
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- @staticmethod
- def from_dict(data: Any) -> 'PipelineConfig':
- # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
- device = utils.get(data, "device", None)
- if device is None:
- device = torch.device("cpu") if utils.get(data, "cpu", False) else None
-
- # Instantiate models
- hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
- segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
- segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
- embedding = utils.get(data, "embedding", "pyannote/embedding")
- embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)
-
- # Hyper-parameters and their aliases
- tau = utils.get(data, "tau_active", None)
- if tau is None:
- tau = utils.get(data, "tau", 0.6)
- rho = utils.get(data, "rho_update", None)
- if rho is None:
- rho = utils.get(data, "rho", 0.3)
- delta = utils.get(data, "delta_new", None)
- if delta is None:
- delta = utils.get(data, "delta", 1)
-
- return PipelineConfig(
- segmentation=segmentation,
- embedding=embedding,
- duration=utils.get(data, "duration", None),
- step=utils.get(data, "step", 0.5),
- latency=utils.get(data, "latency", None),
- tau_active=tau,
- rho_update=rho,
- delta_new=delta,
- gamma=utils.get(data, "gamma", 3),
- beta=utils.get(data, "beta", 10),
- max_speakers=utils.get(data, "max_speakers", 20),
- device=device,
- )
-
- @property
- def duration(self) -> float:
- if self._duration is None:
- self._duration = self.segmentation.duration
- return self._duration
-
- @property
- def step(self) -> float:
- return self._step
-
- @property
- def latency(self) -> float:
- return self._latency
-
- @property
- def sample_rate(self) -> int:
- if self._sample_rate is None:
- self._sample_rate = self.segmentation.sample_rate
- return self._sample_rate
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
deleted file mode 100644
index 7f0e162c..00000000
--- a/src/diart/blocks/diarization.py
+++ /dev/null
@@ -1,151 +0,0 @@
-from typing import Optional, Tuple, Sequence
-
-import numpy as np
-import torch
-from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
-
-from .aggregation import DelayedAggregation
-from .clustering import OnlineSpeakerClustering
-from .embedding import OverlapAwareSpeakerEmbedding
-from .segmentation import SpeakerSegmentation
-from .utils import Binarize
-from .config import BasePipelineConfig, PipelineConfig
-
-
-class BasePipeline:
- @staticmethod
- def get_config_class() -> type:
- raise NotImplementedError
-
- @property
- def config(self) -> BasePipelineConfig:
- raise NotImplementedError
-
- def reset(self):
- raise NotImplementedError
-
- def set_timestamp_shift(self, shift: float):
- raise NotImplementedError
-
- def __call__(
- self,
- waveforms: Sequence[SlidingWindowFeature]
- ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
- raise NotImplementedError
-
-
-class OnlineSpeakerDiarization(BasePipeline):
- def __init__(self, config: Optional[PipelineConfig] = None):
- self._config = PipelineConfig() if config is None else config
-
- msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
- assert self._config.step <= self._config.latency <= self._config.duration, msg
-
- self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device)
- self.embedding = OverlapAwareSpeakerEmbedding(
- self._config.embedding, self._config.gamma, self._config.beta, norm=1, device=self._config.device
- )
- self.pred_aggregation = DelayedAggregation(
- self._config.step,
- self._config.latency,
- strategy="hamming",
- cropping_mode="loose",
- )
- self.audio_aggregation = DelayedAggregation(
- self._config.step,
- self._config.latency,
- strategy="first",
- cropping_mode="center",
- )
- self.binarize = Binarize(self._config.tau_active)
-
- # Internal state, handle with care
- self.timestamp_shift = 0
- self.clustering = None
- self.chunk_buffer, self.pred_buffer = [], []
- self.reset()
-
- @staticmethod
- def get_config_class() -> type:
- return PipelineConfig
-
- @property
- def config(self) -> PipelineConfig:
- return self._config
-
- def set_timestamp_shift(self, shift: float):
- self.timestamp_shift = shift
-
- def reset(self):
- self.set_timestamp_shift(0)
- self.clustering = OnlineSpeakerClustering(
- self.config.tau_active,
- self.config.rho_update,
- self.config.delta_new,
- "cosine",
- self.config.max_speakers,
- )
- self.chunk_buffer, self.pred_buffer = [], []
-
- def __call__(
- self,
- waveforms: Sequence[SlidingWindowFeature]
- ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
- batch_size = len(waveforms)
- msg = "Pipeline expected at least 1 input"
- assert batch_size >= 1, msg
-
- # Create batch from chunk sequence, shape (batch, samples, channels)
- batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
-
- expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
- msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
- assert batch.shape[1] == expected_num_samples, msg
-
- # Extract segmentation and embeddings
- segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
- embeddings = self.embedding(batch, segmentations) # shape (batch, speakers, emb_dim)
-
- seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
-
- outputs = []
- for wav, seg, emb in zip(waveforms, segmentations, embeddings):
- # Add timestamps to segmentation
- sw = SlidingWindow(
- start=wav.extent.start,
- duration=seg_resolution,
- step=seg_resolution,
- )
- seg = SlidingWindowFeature(seg.cpu().numpy(), sw)
-
- # Update clustering state and permute segmentation
- permuted_seg = self.clustering(seg, emb)
-
- # Update sliding buffer
- self.chunk_buffer.append(wav)
- self.pred_buffer.append(permuted_seg)
-
- # Aggregate buffer outputs for this time step
- agg_waveform = self.audio_aggregation(self.chunk_buffer)
- agg_prediction = self.pred_aggregation(self.pred_buffer)
- agg_prediction = self.binarize(agg_prediction)
-
- # Shift prediction timestamps if required
- if self.timestamp_shift != 0:
- shifted_agg_prediction = Annotation(agg_prediction.uri)
- for segment, track, speaker in agg_prediction.itertracks(yield_label=True):
- new_segment = Segment(
- segment.start + self.timestamp_shift,
- segment.end + self.timestamp_shift,
- )
- shifted_agg_prediction[new_segment, track] = speaker
- agg_prediction = shifted_agg_prediction
-
- outputs.append((agg_prediction, agg_waveform))
-
- # Make place for new chunks in buffer if required
- if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows:
- self.chunk_buffer = self.chunk_buffer[1:]
- self.pred_buffer = self.pred_buffer[1:]
-
- return outputs
diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py
index b6a3f9ff..3c87edab 100644
--- a/src/diart/console/benchmark.py
+++ b/src/diart/console/benchmark.py
@@ -1,21 +1,31 @@
import argparse
from pathlib import Path
-import diart.argdoc as argdoc
import pandas as pd
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
+from diart import argdoc
+from diart import utils
from diart.inference import Benchmark, Parallelize
def run():
parser = argparse.ArgumentParser()
parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
+ parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+ help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
+ parser.add_argument("--whisper", default="small", type=str,
+ help=f"Whisper model for transcription pipeline. Defaults to 'small'")
+ parser.add_argument("--language", default="en", type=str,
+ help=f"Transcribe in this language. Defaults to 'en' (English)")
parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
parser.add_argument("--embedding", default="pyannote/embedding", type=str,
help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
parser.add_argument("--reference", type=Path,
help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
+ parser.add_argument("--duration", default=5, type=float,
+ help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
+ parser.add_argument("--asr-duration", default=3, type=float,
+ help=f"Duration of the transcription window (in seconds). Defaults to 3")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -34,6 +44,8 @@ def run():
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
args = parser.parse_args()
+ pipeline_class = utils.get_pipeline_class(args.pipeline)
+
benchmark = Benchmark(
args.root,
args.reference,
@@ -43,11 +55,11 @@ def run():
batch_size=args.batch_size,
)
- config = PipelineConfig.from_dict(vars(args))
+ config = pipeline_class.get_config_class().from_dict(vars(args))
if args.num_workers > 0:
benchmark = Parallelize(benchmark, args.num_workers)
- report = benchmark(OnlineSpeakerDiarization, config)
+ report = benchmark(pipeline_class, config)
if args.output is not None and isinstance(report, pd.DataFrame):
report.to_csv(args.output / "benchmark_report.csv")
diff --git a/src/diart/console/client.py b/src/diart/console/client.py
index 084dbc13..816c7e0f 100644
--- a/src/diart/console/client.py
+++ b/src/diart/console/client.py
@@ -3,11 +3,11 @@
from threading import Thread
from typing import Text, Optional
-import diart.argdoc as argdoc
-import diart.sources as src
-import diart.utils as utils
import numpy as np
import rx.operators as ops
+from diart import argdoc
+from diart import sources as src
+from diart import utils
from websocket import WebSocket
@@ -45,7 +45,8 @@ def run():
parser.add_argument("--host", required=True, type=str, help="Server host")
parser.add_argument("--port", required=True, type=int, help="Server port")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
- parser.add_argument("-sr", "--sample-rate", default=16000, type=int, help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000")
+ parser.add_argument("-sr", "--sample-rate", default=16000, type=int,
+ help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000")
parser.add_argument("-o", "--output-file", type=Path, help="Output RTTM file. Defaults to no writing")
args = parser.parse_args()
diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py
index 2f632d57..fe668c5b 100644
--- a/src/diart/console/serve.py
+++ b/src/diart/console/serve.py
@@ -1,21 +1,31 @@
import argparse
from pathlib import Path
-import diart.argdoc as argdoc
-import diart.sources as src
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
-from diart.inference import RealTimeInference
-from diart.sinks import RTTMWriter
+from diart import argdoc
+from diart import sources as src
+from diart import utils
+from diart.inference import StreamingInference
+from diart.pipelines import Pipeline
def run():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host")
parser.add_argument("--port", default=7007, type=int, help="Server port")
+ parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+ help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
+ parser.add_argument("--whisper", default="small", type=str,
+ help=f"Whisper model for transcription pipeline. Defaults to 'small'")
+ parser.add_argument("--language", default="en", type=str,
+ help=f"Transcribe in this language. Defaults to 'en' (English)")
parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
parser.add_argument("--embedding", default="pyannote/embedding", type=str,
help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
+ parser.add_argument("--duration", type=float,
+ help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
+ parser.add_argument("--asr-duration", default=3, type=float,
+ help=f"Duration of the transcription window (in seconds). Defaults to 3")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -31,29 +41,29 @@ def run():
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
args = parser.parse_args()
- # Define online speaker diarization pipeline
- config = PipelineConfig.from_dict(vars(args))
- pipeline = OnlineSpeakerDiarization(config)
+ # Resolve pipeline
+ pipeline_class = utils.get_pipeline_class(args.pipeline)
+ config = pipeline_class.get_config_class().from_dict(vars(args))
+ pipeline: Pipeline = pipeline_class(config)
# Create websocket audio source
audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port)
# Run online inference
- inference = RealTimeInference(
+ inference = StreamingInference(
pipeline,
audio_source,
batch_size=1,
do_profile=False,
- do_plot=False,
show_progress=True,
)
# Write to disk if required
if args.output is not None:
- inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm"))
+ inference.attach_observers(pipeline.suggest_writer(audio_source.uri, args.output))
- # Send back responses as RTTM text lines
- inference.attach_hooks(lambda ann_wav: audio_source.send(ann_wav[0].to_rttm()))
+ # Send back responses as text
+ inference.attach_hooks(lambda result: audio_source.send(utils.serialize_prediction(result)))
# Run server and pipeline
inference()
diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py
index d7218f07..af8e2cf8 100644
--- a/src/diart/console/stream.py
+++ b/src/diart/console/stream.py
@@ -1,20 +1,30 @@
import argparse
from pathlib import Path
-import diart.argdoc as argdoc
-import diart.sources as src
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
-from diart.inference import RealTimeInference
-from diart.sinks import RTTMWriter
+from diart import argdoc
+from diart import sources as src
+from diart import utils
+from diart.inference import StreamingInference
+from diart.pipelines import Pipeline, PipelineConfig
def run():
parser = argparse.ArgumentParser()
parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:'")
+ parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+ help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
+ parser.add_argument("--whisper", default="small", type=str,
+ help=f"Whisper model for transcription pipeline. Defaults to 'small'")
+ parser.add_argument("--language", default="en", type=str,
+ help=f"Transcribe in this language. Defaults to 'en' (English)")
parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
parser.add_argument("--embedding", default="pyannote/embedding", type=str,
help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
+ parser.add_argument("--duration", default=5, type=float,
+ help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
+ parser.add_argument("--asr-duration", default=3, type=float,
+ help=f"Duration of the transcription window (in seconds). Defaults to 3")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -27,14 +37,16 @@ def run():
parser.add_argument("--cpu", dest="cpu", action="store_true",
help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--output", type=str,
- help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file")
+ help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' "
+ f"or parent directory if SOURCE is a file")
parser.add_argument("--hf-token", default="true", type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
args = parser.parse_args()
- # Define online speaker diarization pipeline
- config = PipelineConfig.from_dict(vars(args))
- pipeline = OnlineSpeakerDiarization(config)
+ # Resolve pipeline
+ pipeline_class = utils.get_pipeline_class(args.pipeline)
+ config: PipelineConfig = pipeline_class.get_config_class().from_dict(vars(args))
+ pipeline: Pipeline = pipeline_class(config)
# Manage audio source
block_size = config.optimal_block_size()
@@ -51,15 +63,21 @@ def run():
audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device)
# Run online inference
- inference = RealTimeInference(
+ inference = StreamingInference(
pipeline,
audio_source,
batch_size=1,
do_profile=True,
- do_plot=not args.no_plot,
show_progress=True,
)
- inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm"))
+
+ # Attach observers for required side effects
+ observers = [pipeline.suggest_writer(audio_source.uri, args.output)]
+ if not args.no_plot:
+ observers.append(pipeline.suggest_display())
+ inference.attach_observers(*observers)
+
+ # Run pipeline
inference()
diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py
index 4ad8852a..ea34ed97 100644
--- a/src/diart/console/tune.py
+++ b/src/diart/console/tune.py
@@ -1,10 +1,11 @@
import argparse
from pathlib import Path
-import diart.argdoc as argdoc
import optuna
-from diart.blocks import PipelineConfig, OnlineSpeakerDiarization
-from diart.optim import Optimizer, HyperParameter
+from diart import argdoc
+from diart import utils
+from diart.pipelines.hparams import HyperParameter
+from diart.optim import Optimizer
from optuna.samplers import TPESampler
@@ -13,10 +14,20 @@ def run():
parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
parser.add_argument("--reference", required=True, type=str,
help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
+ parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+ help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
+ parser.add_argument("--whisper", default="small", type=str,
+ help=f"Whisper model for transcription pipeline. Defaults to 'small'")
+ parser.add_argument("--language", default="en", type=str,
+ help=f"Transcribe in this language. Defaults to 'en' (English)")
parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
parser.add_argument("--embedding", default="pyannote/embedding", type=str,
help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
+ parser.add_argument("--duration", default=5, type=float,
+ help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
+ parser.add_argument("--asr-duration", default=3, type=float,
+ help=f"Duration of the transcription window (in seconds). Defaults to 3")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -29,26 +40,36 @@ def run():
parser.add_argument("--cpu", dest="cpu", action="store_true",
help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"),
- help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new")
+ help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. "
+ "Defaults to tau_active, rho_update and delta_new")
parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials")
parser.add_argument("--storage", type=str,
- help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name")
+ help="Optuna storage string. If provided, continue a previous study instead of creating one. "
+ "The database name must match the study name")
parser.add_argument("--output", type=str, help="Working directory")
parser.add_argument("--hf-token", default="true", type=str,
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
args = parser.parse_args()
+ # Retrieve pipeline class
+ pipeline_class = utils.get_pipeline_class(args.pipeline)
+
# Create the base configuration for each trial
- base_config = PipelineConfig.from_dict(vars(args))
+ base_config = pipeline_class.get_config_class().from_dict(vars(args))
# Create hyper-parameters to optimize
+ possible_hparams = pipeline_class.hyper_parameters()
hparams = [HyperParameter.from_name(name) for name in args.hparams]
+ hparams = [hp for hp in hparams if hp in possible_hparams]
+ msg = f"No hyper-parameters to optimize. " \
+ f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}"
+ assert hparams, msg
# Use a custom storage if given
if args.output is not None:
msg = "Both `output` and `storage` were set, but only one was expected"
assert args.storage is None, msg
- args.output = Path(args.output)
+ args.output = Path(args.output).expanduser()
args.output.mkdir(parents=True, exist_ok=True)
study_or_path = args.output
elif args.storage is not None:
@@ -60,11 +81,11 @@ def run():
# Run optimization
Optimizer(
+ pipeline_class=pipeline_class,
speech_path=args.root,
reference_path=args.reference,
study_or_path=study_or_path,
batch_size=args.batch_size,
- pipeline_class=OnlineSpeakerDiarization,
hparams=hparams,
base_config=base_config,
)(num_iter=args.num_iter, show_progress=True)
diff --git a/src/diart/inference.py b/src/diart/inference.py
index f4b65f5f..1589feed 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -2,46 +2,43 @@
from multiprocessing import Pool, freeze_support, RLock, current_process
from pathlib import Path
from traceback import print_exc
-from typing import Union, Text, Optional, Callable, Tuple, List
+from typing import Union, Text, Optional, Callable, Tuple, List, Any, Dict
-import diart.operators as dops
-import diart.sources as src
import numpy as np
import pandas as pd
import rx
import rx.operators as ops
import torch
-from diart import utils
-from diart.blocks import BasePipeline, Resample, BasePipelineConfig
-from diart.progress import ProgressBar, RichProgressBar, TQDMProgressBar
-from diart.sinks import DiarizationPredictionAccumulator, RealTimePlot, WindowClosedException
from pyannote.core import Annotation, SlidingWindowFeature
-from pyannote.database.util import load_rttm
-from pyannote.metrics.diarization import DiarizationErrorRate
from rx.core import Observer
from tqdm import tqdm
+from . import blocks
+from . import operators as dops
+from . import sources as src
+from . import utils
+from .metrics import Metric
+from .pipelines import Pipeline, PipelineConfig
+from .progress import ProgressBar, RichProgressBar, TQDMProgressBar
+from .sinks import WindowClosedException
-class RealTimeInference:
- """Performs inference in real time given a pipeline and an audio source.
- Streams an audio source to an online speaker diarization pipeline.
- It allows users to attach a chain of operations in the form of hooks.
+
+class StreamingInference:
+ """Performs streaming inference given a pipeline and an audio source.
+ Side-effect hooks and observers can also be attached for customized behavior.
Parameters
----------
- pipeline: BasePipeline
- Configured speaker diarization pipeline.
+ pipeline: Pipeline
+ A pipeline.
source: AudioSource
- Audio source to be read and streamed.
+ Audio source to read and stream.
batch_size: int
Number of inputs to send to the pipeline at once.
Defaults to 1.
do_profile: bool
If True, compute and report the processing time of the pipeline.
Defaults to True.
- do_plot: bool
- If True, draw predictions in a moving plot.
- Defaults to False.
show_progress: bool
If True, show a progress bar.
Defaults to True.
@@ -52,11 +49,10 @@ class RealTimeInference:
"""
def __init__(
self,
- pipeline: BasePipeline,
+ pipeline: Pipeline,
source: src.AudioSource,
batch_size: int = 1,
do_profile: bool = True,
- do_plot: bool = False,
show_progress: bool = True,
progress_bar: Optional[ProgressBar] = None,
):
@@ -64,11 +60,9 @@ def __init__(
self.source = source
self.batch_size = batch_size
self.do_profile = do_profile
- self.do_plot = do_plot
self.show_progress = show_progress
- self.accumulator = DiarizationPredictionAccumulator(self.source.uri)
- self.unit = "chunk" if self.batch_size == 1 else "batch"
self._observers = []
+ self._predictions = []
chunk_duration = self.pipeline.config.duration
step_duration = self.pipeline.config.step
@@ -88,11 +82,11 @@ def __init__(
self._pbar.create(
total=self.num_chunks,
description=f"Streaming {self.source.uri}",
- unit=self.unit
+ unit="chunk"
)
# Initialize chronometer for profiling
- self._chrono = utils.Chronometer(self.unit, self._pbar)
+ self._chrono = utils.Chronometer("batch", self._pbar)
self.stream = self.source.stream
@@ -102,7 +96,7 @@ def __init__(
f"but pipeline's is {sample_rate}. Will resample."
logging.warning(msg)
self.stream = self.stream.pipe(
- ops.map(Resample(self.source.sample_rate, sample_rate))
+ ops.map(blocks.Resample(self.source.sample_rate, sample_rate))
)
# Add rx operators to manage the inputs and outputs of the pipeline
@@ -122,7 +116,7 @@ def __init__(
self.stream = self.stream.pipe(
ops.flat_map(lambda results: rx.from_iterable(results)),
- ops.do(self.accumulator),
+ ops.do_action(lambda res: self._predictions.append(res[0] if isinstance(res, tuple) else res)),
)
if show_progress:
@@ -140,13 +134,13 @@ def _close_chronometer(self):
self._chrono.stop(do_count=False)
self._chrono.report()
- def attach_hooks(self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None]):
+ def attach_hooks(self, *hooks: Callable[[Tuple[Any, SlidingWindowFeature]], None]):
"""Attach hooks to the pipeline.
Parameters
----------
- *hooks: (Tuple[Annotation, SlidingWindowFeature]) -> None
- Hook functions to consume emitted annotations and audio.
+ *hooks: (Tuple[Any, SlidingWindowFeature]) -> None
+ Hook functions to consume emitted predictions and audio.
"""
self.stream = self.stream.pipe(*[ops.do_action(hook) for hook in hooks])
@@ -156,7 +150,7 @@ def attach_observers(self, *observers: Observer):
Parameters
----------
*observers: Observer
- Observers to consume emitted annotations and audio.
+ Observers to consume emitted predictions and audio.
"""
self.stream = self.stream.pipe(*[ops.do(sink) for sink in observers])
self._observers.extend(observers)
@@ -181,55 +175,42 @@ def _handle_completion(self):
self._close_pbar()
self._close_chronometer()
- def __call__(self) -> Annotation:
- """Stream audio chunks from `source` to `pipeline`.
+ def __call__(self) -> List[Any]:
+ """Stream audio chunks from a source to a pipeline.
Returns
-------
- predictions: Annotation
- Speaker diarization pipeline predictions
+ predictions: List[Any]
+ Streaming pipeline predictions
"""
if self.show_progress:
self._pbar.start()
- config = self.pipeline.config
- observable = self.stream
- if self.do_plot:
- # Buffering is needed for the real-time plot, so we do this at the very end
- observable = self.stream.pipe(
- dops.buffer_output(
- duration=config.duration,
- step=config.step,
- latency=config.latency,
- sample_rate=config.sample_rate,
- ),
- ops.do(RealTimePlot(config.duration, config.latency)),
- )
- observable.subscribe(
+ self.stream.subscribe(
on_error=self._handle_error,
on_completed=self._handle_completion,
)
- # FIXME if read() isn't blocking, the prediction returned is empty
+ # FIXME if read() isn't blocking, predictions are empty
self.source.read()
- return self.accumulator.get_prediction()
+ return self._predictions
class Benchmark:
"""
- Run an online speaker diarization pipeline on a set of audio files in batches.
+ Run a pipeline on a set of audio files in batches.
Write predictions to a given output directory.
- If the reference is given, calculate the average diarization error rate.
+ If the reference is given, compute the average performance metric.
Parameters
----------
speech_path: Text or Path
Directory with audio files.
reference_path: Text, Path or None
- Directory with reference RTTM files (same names as audio files).
- If None, performance will not be calculated.
+ Directory with reference files (same names as audio files with different extension).
+ If None, performance will not be computed.
Defaults to None.
- output_path: Text, Path or None
- Output directory to store predictions in RTTM format.
+ output_path: Optional[Text | Path]
+ Output directory to store predictions.
If None, predictions will not be written to disk.
Defaults to None.
show_progress: bool
@@ -239,12 +220,7 @@ class Benchmark:
Whether to print a performance report to stdout.
Defaults to True.
batch_size: int
- Inference batch size.
- If < 2, then it will run in real time.
- If >= 2, then it will pre-calculate segmentation and
- embeddings, running the rest in real time.
- The performance between this two modes does not differ.
- Defaults to 32.
+ Inference batch size. Defaults to 32.
"""
def __init__(
self,
@@ -267,9 +243,9 @@ def __init__(
self.reference_path = Path(self.reference_path).expanduser()
assert self.reference_path.is_dir(), "Reference path must be a directory"
- self.output_path = output_path
+ self.output_path: Optional[Union[Text, Path]] = output_path
if self.output_path is not None:
- self.output_path = Path(output_path).expanduser()
+ self.output_path: Path = Path(output_path).expanduser()
self.output_path.mkdir(parents=True, exist_ok=True)
self.show_progress = show_progress
@@ -288,27 +264,29 @@ def get_file_paths(self) -> List[Path]:
def run_single(
self,
- pipeline: BasePipeline,
+ pipeline: Pipeline,
filepath: Path,
progress_bar: ProgressBar,
- ) -> Annotation:
+ ) -> Tuple[Text, Any]:
"""Run a given pipeline on a given file.
- Note that this method does NOT reset the
+ This method does NOT reset the
state of the pipeline before execution.
Parameters
----------
- pipeline: BasePipeline
- Speaker diarization pipeline to run.
+ pipeline: Pipeline
+ A pipeline.
filepath: Path
Path to the target file.
progress_bar: diart.progress.ProgressBar
- An object to manage the progress of this run.
+ Object to display the progress of this run.
Returns
-------
- prediction: Annotation
- Pipeline prediction for the given file.
+ uri: Text
+ File URI.
+ prediction: Any
+ Aggregated pipeline prediction for the given file.
"""
padding = pipeline.config.get_file_padding(filepath)
source = src.FileAudioSource(
@@ -318,33 +296,38 @@ def run_single(
pipeline.config.optimal_block_size(),
)
pipeline.set_timestamp_shift(-padding[0])
- inference = RealTimeInference(
+ inference = StreamingInference(
pipeline,
source,
self.batch_size,
do_profile=False,
- do_plot=False,
show_progress=self.show_progress,
progress_bar=progress_bar,
)
- pred = inference()
- pred.uri = source.uri
+ # Run the pipeline and concatenate predictions
+ pred = pipeline.join_predictions(inference())
+ # Write prediction to disk if required
if self.output_path is not None:
- with open(self.output_path / f"{source.uri}.rttm", "w") as out_file:
- pred.write_rttm(out_file)
+ pipeline.write_prediction(source.uri, pred, self.output_path)
- return pred
+ return source.uri, pred
- def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[Annotation]]:
+ def evaluate(
+ self,
+ predictions: Dict[Text, Any],
+ metric: Metric,
+ ) -> Union[pd.DataFrame, Dict[Text, Any]]:
"""If a reference path was provided,
- compute the diarization error rate of a list of predictions.
+ compute the performance of a list of predictions.
Parameters
----------
- predictions: List[Annotation]
+ predictions: List[Any]
Predictions to evaluate.
+ metric: Metric
+ Evaluation metric.
Returns
-------
@@ -353,67 +336,84 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An
reference path was given. Otherwise return the same predictions.
"""
if self.reference_path is not None:
- metric = DiarizationErrorRate(collar=0, skip_overlap=False)
- progress_bar = TQDMProgressBar("Computing DER", leave=False)
+ # Initialize progress bar
+ progress_bar = TQDMProgressBar(f"Computing {metric.name}", leave=False)
progress_bar.create(total=len(predictions), unit="file")
progress_bar.start()
- for hyp in predictions:
- ref = load_rttm(self.reference_path / f"{hyp.uri}.rttm").popitem()[1]
- metric(ref, hyp)
+
+ # Evaluate each prediction
+ uris = []
+ for uri, pred in predictions.items():
+ ref_file = list(self.reference_path.glob(f"{uri}.*"))
+ if ref_file:
+ ref = metric.load_reference(ref_file[0])
+ metric(ref, pred)
+ uris.append(uri)
+ else:
+ msg = f"Reference file for {uri} not found. Skipping evaluation."
+ logging.warning(msg)
progress_bar.update()
+
+ # Close progress bar safely
progress_bar.close()
- return metric.report(display=self.show_report)
+ # Return performance report
+ return metric.report(uris, self.show_report)
+
return predictions
def __call__(
self,
pipeline_class: type,
- config: BasePipelineConfig,
- ) -> Union[pd.DataFrame, List[Annotation]]:
- """Run a given pipeline on a set of audio files.
- Notice that the internal state of the pipeline is reset before benchmarking.
+ config: PipelineConfig,
+ metric: Optional[Metric] = None,
+ ) -> Union[pd.DataFrame, Dict[Text, Any]]:
+ """Run a pipeline on a set of audio files.
+ The internal state of the pipeline is reset before benchmarking.
Parameters
----------
pipeline_class: class
- Class from the BasePipeline hierarchy.
+ Class from the `Pipeline` hierarchy.
A pipeline from this class will be instantiated by each worker.
- config: BasePipelineConfig
- Diarization pipeline configuration.
+ config: PipelineConfig
+ Pipeline configuration.
+ metric: Optional[Metric]
+ Evaluation metric.
+ Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
Returns
-------
- performance: pandas.DataFrame or List[Annotation]
+ performance: pandas.DataFrame or Dict[Text, Any]
If reference annotations are given, a DataFrame with detailed
performance on each file as well as average performance.
-
- If no reference annotations, a list of predictions.
+ If no reference annotations, a dict of uris with predictions.
"""
audio_file_paths = self.get_file_paths()
num_audio_files = len(audio_file_paths)
pipeline = pipeline_class(config)
- predictions = []
+ predictions = {}
for i, filepath in enumerate(audio_file_paths):
pipeline.reset()
desc = f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})"
progress = TQDMProgressBar(desc, leave=False, do_close=True)
- predictions.append(self.run_single(pipeline, filepath, progress))
+ uri, pred = self.run_single(pipeline, filepath, progress)
+ predictions[uri] = pred
- return self.evaluate(predictions)
+ metric = pipeline.suggest_metric() if metric is None else metric
+ return self.evaluate(predictions, metric)
class Parallelize:
"""Wrapper to parallelize the execution of a `Benchmark` instance.
- Note that models will be copied in each worker instead of being reused.
+ Models will be copied in each worker instead of being reused.
Parameters
----------
benchmark: Benchmark
- Benchmark instance to execute in parallel.
+ Benchmark instance to run in parallel.
num_workers: int
- Number of parallel workers.
- Defaults to 0 (no parallelism).
+ Number of parallel workers. Defaults to 4.
"""
def __init__(
self,
@@ -426,20 +426,20 @@ def __init__(
def run_single_job(
self,
pipeline_class: type,
- config: BasePipelineConfig,
+ config: PipelineConfig,
filepath: Path,
description: Text,
- ):
- """Build and run a pipeline on a single file.
+ ) -> Tuple[Text, Any]:
+ """Instantiate and run a pipeline on a given file.
Configure execution to show progress alongside parallel runs.
Parameters
----------
pipeline_class: class
- Class from the BasePipeline hierarchy.
+ Class from the `Pipeline` hierarchy.
A pipeline from this class will be instantiated.
- config: BasePipelineConfig
- Diarization pipeline configuration.
+ config: PipelineConfig
+ Pipeline configuration.
filepath: Path
Path to the target file.
description: Text
@@ -447,8 +447,10 @@ def run_single_job(
Returns
-------
- prediction: Annotation
- Pipeline prediction for the given file.
+ uri: Text
+ File URI.
+ prediction: Any
+ Aggregated pipeline prediction for the given file.
"""
# The process ID inside the pool determines the position of the progress bar
idx_process = int(current_process().name.split('-')[1]) - 1
@@ -463,26 +465,29 @@ def run_single_job(
def __call__(
self,
pipeline_class: type,
- config: BasePipelineConfig,
- ) -> Union[pd.DataFrame, List[Annotation]]:
- """Run a given pipeline on a set of audio files in parallel.
- Each worker will build and run the pipeline on a different file.
+ config: PipelineConfig,
+ metric: Optional[Metric] = None,
+ ) -> Union[pd.DataFrame, Dict[Text, Any]]:
+ """Run a pipeline on a set of audio files in parallel.
+ Each worker instantiates and runs the pipeline on a different file.
Parameters
----------
pipeline_class: class
- Class from the BasePipeline hierarchy.
+ Class from the Pipeline hierarchy.
A pipeline from this class will be instantiated by each worker.
- config: BasePipelineConfig
- Diarization pipeline configuration.
+ config: PipelineConfig
+ Pipeline configuration.
+ metric: Optional[Metric]
+ Evaluation metric.
+ Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
Returns
-------
- performance: pandas.DataFrame or List[Annotation]
+ performance: pandas.DataFrame or Dict[Text, Any]
If reference annotations are given, a DataFrame with detailed
performance on each file as well as average performance.
-
- If no reference annotations, a list of predictions.
+ If no reference annotations, a dict of uris with predictions.
"""
audio_file_paths = self.benchmark.get_file_paths()
num_audio_files = len(audio_file_paths)
@@ -507,9 +512,15 @@ def __call__(
# Submit all jobs
jobs = [pool.apply_async(self.run_single_job, args=args) for args in arg_list]
- # Wait and collect results
+ # Wait for all jobs to finish
pool.close()
- predictions = [job.get() for job in jobs]
+
+ # Collect results
+ predictions = {}
+ for job in jobs:
+ uri, pred = job.get()
+ predictions[uri] = pred
# Evaluate results
- return self.benchmark.evaluate(predictions)
+ metric = pipeline_class.suggest_metric() if metric is None else metric
+ return self.benchmark.evaluate(predictions, metric)
diff --git a/src/diart/metrics.py b/src/diart/metrics.py
new file mode 100644
index 00000000..d5ae56e4
--- /dev/null
+++ b/src/diart/metrics.py
@@ -0,0 +1,96 @@
+from pathlib import Path
+from typing import Text, Any, List, Union
+
+import pandas as pd
+from pyannote.core import Annotation
+from pyannote.metrics import diarization as dia, detection as det
+from pyannote.metrics.base import BaseMetric as PyannoteBaseMetric
+from pyannote.database.util import load_rttm
+from torchmetrics import text
+
+
+class Metric:
+ @property
+ def name(self) -> Text:
+ raise NotImplementedError
+
+ def __call__(self, reference: Any, prediction: Any) -> float:
+ raise NotImplementedError
+
+ def report(self, uris: List[Text], display: bool = False) -> pd.DataFrame:
+ raise NotImplementedError
+
+ def load_reference(self, filepath: Union[Text, Path]) -> Any:
+ raise NotImplementedError
+
+
+class PyannoteMetric(Metric):
+ def __init__(self, metric: PyannoteBaseMetric):
+ self._metric = metric
+
+ @property
+ def name(self) -> Text:
+ return self._metric.name
+
+ def __call__(self, reference: Annotation, prediction: Annotation) -> float:
+ return self._metric(reference, prediction)
+
+ def report(self, uris: List[Text], display: bool = False) -> pd.DataFrame:
+ return self._metric.report(display)
+
+ def load_reference(self, filepath: Union[Text, Path]) -> Annotation:
+ return load_rttm(filepath).popitem()[1]
+
+
+class DiarizationErrorRate(PyannoteMetric):
+ def __init__(self, collar: float = 0, skip_overlap: bool = False):
+ super().__init__(dia.DiarizationErrorRate(collar, skip_overlap))
+
+
+class DetectionErrorRate(PyannoteMetric):
+ def __init__(self, collar: float = 0, skip_overlap: bool = False):
+ super().__init__(det.DetectionErrorRate(collar, skip_overlap))
+
+
+class WordErrorRate(Metric):
+ def __init__(self, unify_case: bool = False):
+ self.unify_case = unify_case
+ self._metric = text.WordErrorRate()
+ self._values = []
+
+ @property
+ def name(self) -> Text:
+ return "word error rate"
+
+ def __call__(self, reference: Text, prediction: Text) -> float:
+ if self.unify_case:
+ prediction = prediction.lower()
+ reference = reference.lower()
+ # Torchmetrics requires predictions first, then reference
+ value = self._metric(prediction, reference).item()
+ self._values.append(value)
+ return value
+
+ def report(self, uris: List[Text], display: bool = False) -> pd.DataFrame:
+ num_uris, num_values = len(uris), len(self._values)
+ msg = f"URI list size must match values. Found {num_uris} but expected {num_values}"
+ assert num_uris == num_values, msg
+
+ rows = self._values + [self._metric.compute().item()]
+ index = uris + ["TOTAL"]
+ report = pd.DataFrame(rows, index=index, columns=[self.name])
+
+ if display:
+ print(report.to_string(
+ index=True,
+ sparsify=False,
+ justify="right",
+ float_format=lambda f: "{0:.2f}".format(f),
+ ))
+
+ return report
+
+ def load_reference(self, filepath: Union[Text, Path]) -> Text:
+ with open(filepath, "r") as file:
+ lines = [line.strip() for line in file.readlines()]
+ return " ".join(lines)
diff --git a/src/diart/models.py b/src/diart/models.py
index df66e166..485cb43a 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -1,7 +1,11 @@
-from typing import Optional, Text, Union, Callable
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Text, Union, Callable, List, Any
+import numpy as np
import torch
import torch.nn as nn
+from pyannote.core import Segment
try:
import pyannote.audio.pipelines.utils as pyannote_loader
@@ -9,6 +13,19 @@
except ImportError:
_has_pyannote = False
+try:
+ import whisper
+ from whisper.tokenizer import get_tokenizer
+ _has_whisper = True
+ DecodingResult = whisper.DecodingResult
+ DecodingOptions = whisper.DecodingOptions
+ Tokenizer = whisper.tokenizer.Tokenizer
+except ImportError:
+ _has_whisper = False
+ DecodingResult = Any
+ DecodingOptions = Any
+ Tokenizer = Any
+
class PyannoteLoader:
def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
@@ -20,6 +37,26 @@ def __call__(self) -> nn.Module:
return pyannote_loader.get_model(self.model_info, self.hf_token)
+class WhisperLoader:
+ def __init__(
+ self,
+ name: Text,
+ download_path: Optional[Union[Text, Path]] = None,
+ in_memory: bool = False,
+ ):
+ self.name = name
+ self.download_path = download_path
+ self.in_memory = in_memory
+
+ def __call__(self) -> nn.Module:
+ return whisper.load_model(
+ name=self.name,
+ device="cpu",
+ download_root=self.download_path,
+ in_memory=self.in_memory,
+ )
+
+
class LazyModel(nn.Module):
def __init__(self, loader: Callable[[], nn.Module]):
super().__init__()
@@ -163,3 +200,320 @@ def forward(
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(waveform, weights=weights)
+
+
+@dataclass(frozen=True)
+class TranscriptionResult:
+ text: Text
+ chunks: List[Text]
+ timestamps: List[Segment]
+
+
+class SpeechRecognitionModel(LazyModel):
+ @staticmethod
+ def from_whisper(
+ name: Text,
+ download_path: Optional[Union[Text, Path]] = None,
+ in_memory: bool = False,
+ fp16: bool = False,
+ no_speech_threshold: float = 0.6,
+ compression_ratio_threshold: Optional[float] = 2.4,
+ logprob_threshold: Optional[float] = -1,
+ decode_with_fallback: bool = False,
+ ) -> 'SpeechRecognitionModel':
+ msg = "No whisper-transcribed installation found. " \
+ "Visit https://github.com/linto-ai/whisper-timestamped#installation to install"
+ assert _has_whisper, msg
+ return WhisperSpeechRecognitionModel(
+ name,
+ download_path,
+ in_memory,
+ fp16,
+ no_speech_threshold,
+ compression_ratio_threshold,
+ logprob_threshold,
+ decode_with_fallback,
+ )
+
+ @property
+ def duration(self) -> float:
+ raise NotImplementedError
+
+ @property
+ def sample_rate(self) -> int:
+ raise NotImplementedError
+
+ def set_language(self, language: Optional[Text] = None):
+ raise NotImplementedError
+
+ def set_beam_size(self, size: Optional[int] = None):
+ raise NotImplementedError
+
+ def forward(self, waveform: torch.Tensor) -> List[TranscriptionResult]:
+ """
+ Forward pass of the speech recognition model.
+
+ Parameters
+ ----------
+ waveform: torch.Tensor, shape (batch, channels, samples)
+ Batch of audio chunks to transcribe
+
+ Returns
+ -------
+ transcriptions: List[TranscriptionResult]
+ A list of timestamped transcriptions
+ """
+ raise NotImplementedError
+
+
+class WhisperDecoder:
+ def __init__(
+ self,
+ no_speech_threshold: float = 0.6,
+ compression_ratio_threshold: Optional[float] = 2.4,
+ logprob_threshold: Optional[float] = -1,
+ ):
+ self.no_speech_threshold = no_speech_threshold
+ self.compression_ratio_threshold = compression_ratio_threshold
+ self.logprob_threshold = logprob_threshold
+ self.temperatures = (0, 0.2, 0.4, 0.6, 0.8, 1)
+
+ @staticmethod
+ def get_temperature_options(initial: DecodingOptions, t: float) -> DecodingOptions:
+ t_options = {**vars(initial)}
+ if t > 0:
+ t_options.pop("beam_size", None)
+ t_options.pop("patience", None)
+ else:
+ t_options.pop("best_of", None)
+ t_options["temperature"] = t
+ return DecodingOptions(**t_options)
+
+ @staticmethod
+ def decode(
+ model,
+ batch: torch.Tensor,
+ options: DecodingOptions
+ ) -> DecodingResult:
+ return model.decode(batch, options)
+
+ def check_compression(self) -> bool:
+ return self.compression_ratio_threshold is not None
+
+ def check_logprob(self) -> bool:
+ return self.logprob_threshold is not None
+
+ def needs_fallback(self, output: DecodingResult) -> bool:
+ if self.check_compression and output.compression_ratio > self.compression_ratio_threshold:
+ # Transcription is too repetitive
+ return True
+ if self.check_logprob and output.avg_logprob < self.logprob_threshold:
+ # Average log probability is too low
+ return True
+ return False
+
+ def decode_with_fallback(
+ self,
+ model,
+ batch: torch.Tensor,
+ options: DecodingOptions,
+ ) -> List[DecodingResult]:
+ """Transcribe batch and retry with ever-increasing
+ temperatures if the estimated quality of the transcription is not good.
+
+ Parameters
+ ----------
+ model: whisper.Whisper
+ Whisper ASR model (contains 'decode' method).
+ batch: torch.Tensor, shape (batch, channel, samples)
+ Log mel spectrogram batch.
+ options: whisper.DecodingOptions
+ Configuration to decode transcription.
+
+ Returns
+ -------
+ result: List[whisper.DecodingResult]
+ Transcription results for this batch.
+ """
+ batch_size = batch.shape[0]
+ results = [None] * batch_size
+ retry_idx = torch.ones(batch_size).type(torch.bool)
+
+ for t in self.temperatures:
+ # Transcribe with the given temperature
+ t_options = self.get_temperature_options(options, t)
+ outputs = model.decode(batch[retry_idx], t_options)
+
+ # Determine which outputs need to be transcribed again
+ output_idx = torch.where(retry_idx)[0]
+ for idx, out in zip(output_idx, outputs):
+ results[idx] = out
+ if not self.needs_fallback(out):
+ retry_idx[idx] = False
+
+ # No output needs fallback, get out of the loop early
+ if torch.sum(retry_idx).item() == 0:
+ break
+
+ return results
+
+ def split_with_timestamps(
+ self,
+ result: DecodingResult,
+ tokenizer: Tokenizer,
+ chunk_duration: float,
+ token_duration: float,
+ ) -> TranscriptionResult:
+ """Split a Whisper transcription into segments with their respective timestamps.
+ Replace with empty string if no-speech probability is high.
+
+ Parameters
+ ----------
+ result: whisper.DecodingResult
+ A single transcription output from Whisper.
+ tokenizer: whisper.tokenizer.Tokenizer
+ Tokenizer needed to decode outputs.
+ chunk_duration: float
+ Actual duration of each input chunk.
+ token_duration: float
+ Duration of each output token.
+
+ Returns
+ -------
+ result: TranscriptionResult
+ Transcription with identified segments and timestamps.
+ """
+ # Check if the model detects no speech and do not decode
+ if self.no_speech_threshold is not None:
+ no_speech = result.no_speech_prob > self.no_speech_threshold
+ low_confidence = self.logprob_threshold is None or result.avg_logprob < self.logprob_threshold
+ if no_speech and low_confidence:
+ return TranscriptionResult("", [""], [Segment(0, chunk_duration)])
+
+ tokens = torch.tensor(result.tokens)
+ chunks, timestamps = [], []
+ ts_tokens = tokens.ge(tokenizer.timestamp_begin)
+ single_ts_ending = ts_tokens[-2:].tolist() == [False, True]
+ consecutive = torch.where(ts_tokens[:-1] & ts_tokens[1:])[0] + 1
+ if len(consecutive) > 0:
+ # Output contains two consecutive timestamp tokens
+ slices = consecutive.tolist()
+ if single_ts_ending:
+ slices.append(len(tokens))
+
+ # Split into segments based on timestamp tokens
+ last_slice = 0
+ for current_slice in slices:
+ sliced_tokens = tokens[last_slice:current_slice]
+ start_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
+ end_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
+ text_tokens = [token for token in sliced_tokens if token < tokenizer.eot]
+ text = tokenizer.decode(text_tokens).strip()
+ timestamp = Segment(start_pos * token_duration, end_pos * token_duration)
+ if text and timestamp.start != timestamp.end:
+ chunks.append(text)
+ timestamps.append(timestamp)
+ last_slice = current_slice
+ else:
+ # There is a single segment, identify timestamps
+ duration = chunk_duration
+ ts = tokens[ts_tokens.nonzero().flatten()]
+ if len(ts) > 0 and ts[-1].item() != tokenizer.timestamp_begin:
+ # Use last timestamp as end time for the unique chunk
+ last_ts_pos = ts[-1].item() - tokenizer.timestamp_begin
+ duration = last_ts_pos * token_duration
+ text_tokens = [token for token in tokens if token < tokenizer.eot]
+ text = tokenizer.decode(text_tokens).strip()
+ if text:
+ chunks.append(text)
+ timestamps.append(Segment(0, duration))
+
+ return TranscriptionResult(result.text, chunks, timestamps)
+
+
+class WhisperSpeechRecognitionModel(SpeechRecognitionModel):
+ def __init__(
+ self,
+ name: Text,
+ download_path: Optional[Union[Text, Path]] = None,
+ in_memory: bool = False,
+ fp16: bool = False,
+ no_speech_threshold: float = 0.6,
+ compression_ratio_threshold: Optional[float] = 2.4,
+ logprob_threshold: Optional[float] = -1,
+ decode_with_fallback: bool = False,
+ ):
+ super().__init__(WhisperLoader(name, download_path, in_memory))
+ self.fp16 = fp16
+ self.beam_size = None
+ self.language = None
+ self.decode_with_fallback = decode_with_fallback
+ self.decoder = WhisperDecoder(
+ no_speech_threshold, compression_ratio_threshold, logprob_threshold
+ )
+ self._token_duration: Optional[float] = None
+
+ @property
+ def duration(self) -> float:
+ # Whisper's maximum duration per input is 30s
+ return whisper.audio.CHUNK_LENGTH
+
+ @property
+ def sample_rate(self) -> int:
+ return whisper.audio.SAMPLE_RATE
+
+ @property
+ def token_duration(self) -> float:
+ if self._token_duration is None:
+ # 2 mel frames per output token
+ input_stride = int(np.rint(whisper.audio.N_FRAMES / self.model.dims.n_audio_ctx))
+ # Output token duration is 0.02 seconds
+ self._token_duration = input_stride * whisper.audio.HOP_LENGTH / self.sample_rate
+ return self._token_duration
+
+ def set_language(self, language: Optional[Text] = None):
+ self.language = language
+
+ def set_beam_size(self, size: Optional[int] = None):
+ self.beam_size = size
+
+ def forward(self, waveform_batch: torch.Tensor) -> List[TranscriptionResult]:
+ # Remove channel dimension
+ batch = waveform_batch.squeeze(1)
+ num_chunk_samples = batch.shape[-1]
+ # Compute log mel spectrogram
+ batch = whisper.log_mel_spectrogram(batch)
+ # Add padding
+ dtype = torch.float16 if self.fp16 else torch.float32
+ batch = whisper.pad_or_trim(batch, whisper.audio.N_FRAMES).to(batch.device).type(dtype)
+
+ # Configure transcription decoding
+ options = whisper.DecodingOptions(
+ task="transcribe",
+ language=self.language,
+ beam_size=self.beam_size,
+ fp16=self.fp16,
+ )
+
+ # Transcribe batch with fallback if required
+ if self.decode_with_fallback:
+ decode_fn = self.decoder.decode_with_fallback
+ else:
+ decode_fn = self.decoder.decode
+ results = decode_fn(self.model, batch, options)
+
+ # Split into segments and add timestamps
+ tokenizer = get_tokenizer(
+ self.model.is_multilingual,
+ language=options.language,
+ task=options.task,
+ )
+ chunk_duration = int(np.rint(num_chunk_samples / self.sample_rate))
+ transcriptions = [
+ self.decoder.split_with_timestamps(
+ res, tokenizer, chunk_duration, self.token_duration
+ )
+ for res in results
+ ]
+
+ return transcriptions
diff --git a/src/diart/operators.py b/src/diart/operators.py
index 6d73fc9d..c67bf99d 100644
--- a/src/diart/operators.py
+++ b/src/diart/operators.py
@@ -1,9 +1,9 @@
from dataclasses import dataclass
-from typing import Callable, Optional, List, Any, Tuple
+from typing import Callable, Optional, List, Any
import numpy as np
import rx
-from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature, Segment
+from pyannote.core import SlidingWindow, SlidingWindowFeature
from rx import operators as ops
from rx.core import Observable
@@ -97,192 +97,3 @@ def accumulate(state: List[Any], value: Any) -> List[Any]:
return new_state[1:]
return new_state
return rx.pipe(ops.scan(accumulate, []))
-
-
-@dataclass
-class PredictionWithAudio:
- prediction: Annotation
- waveform: Optional[SlidingWindowFeature] = None
-
- @property
- def has_audio(self) -> bool:
- return self.waveform is not None
-
-
-@dataclass
-class OutputAccumulationState:
- annotation: Optional[Annotation]
- waveform: Optional[SlidingWindowFeature]
- real_time: float
- next_sample: Optional[int]
-
- @staticmethod
- def initial() -> 'OutputAccumulationState':
- return OutputAccumulationState(None, None, 0, 0)
-
- @property
- def cropped_waveform(self) -> SlidingWindowFeature:
- return SlidingWindowFeature(
- self.waveform[:self.next_sample],
- self.waveform.sliding_window,
- )
-
- def to_tuple(self) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]:
- return self.annotation, self.cropped_waveform, self.real_time
-
-
-def accumulate_output(
- duration: float,
- step: float,
- patch_collar: float = 0.05,
-) -> Operator:
- """Accumulate predictions and audio to infinity: O(N) space complexity.
- Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations.
-
- Parameters
- ----------
- duration: float
- Buffer duration in seconds.
- step: float
- Duration of the chunks at each event in seconds.
- The first chunk may be bigger given the latency.
- patch_collar: float, optional
- Collar to merge speaker turns of the same speaker, in seconds.
- Defaults to 0.05 (i.e. 50ms).
- Returns
- -------
- A reactive x operator implementing this behavior.
- """
- def accumulate(
- state: OutputAccumulationState,
- value: Tuple[Annotation, Optional[SlidingWindowFeature]]
- ) -> OutputAccumulationState:
- value = PredictionWithAudio(*value)
- annotation, waveform = None, None
-
- # Determine the real time of the stream
- real_time = duration if state.annotation is None else state.real_time + step
-
- # Update total annotation with current predictions
- if state.annotation is None:
- annotation = value.prediction
- else:
- annotation = state.annotation.update(value.prediction).support(patch_collar)
-
- # Update total waveform if there's audio in the input
- new_next_sample = 0
- if value.has_audio:
- num_new_samples = value.waveform.data.shape[0]
- new_next_sample = state.next_sample + num_new_samples
- sw_holder = state
- if state.waveform is None:
- # Initialize the audio buffer with 10 times the size of the first chunk
- waveform, sw_holder = np.zeros((10 * num_new_samples, 1)), value
- elif new_next_sample < state.waveform.data.shape[0]:
- # The buffer still has enough space to accommodate the chunk
- waveform = state.waveform.data
- else:
- # The buffer is full, double its size
- waveform = np.concatenate(
- (state.waveform.data, np.zeros_like(state.waveform.data)), axis=0
- )
- # Copy chunk into buffer
- waveform[state.next_sample:new_next_sample] = value.waveform.data
- waveform = SlidingWindowFeature(waveform, sw_holder.waveform.sliding_window)
-
- return OutputAccumulationState(annotation, waveform, real_time, new_next_sample)
-
- return rx.pipe(
- ops.scan(accumulate, OutputAccumulationState.initial()),
- ops.map(OutputAccumulationState.to_tuple),
- )
-
-
-def buffer_output(
- duration: float,
- step: float,
- latency: float,
- sample_rate: int,
- patch_collar: float = 0.05,
-) -> Operator:
- """Store last predictions and audio inside a fixed buffer.
- Provides the best time/space complexity trade-off if the past data is not needed.
-
- Parameters
- ----------
- duration: float
- Buffer duration in seconds.
- step: float
- Duration of the chunks at each event in seconds.
- The first chunk may be bigger given the latency.
- latency: float
- Latency of the system in seconds.
- sample_rate: int
- Sample rate of the audio source.
- patch_collar: float, optional
- Collar to merge speaker turns of the same speaker, in seconds.
- Defaults to 0.05 (i.e. 50ms).
-
- Returns
- -------
- A reactive x operator implementing this behavior.
- """
- # Define some useful constants
- num_samples = int(round(duration * sample_rate))
- num_step_samples = int(round(step * sample_rate))
- resolution = 1 / sample_rate
-
- def accumulate(
- state: OutputAccumulationState,
- value: Tuple[Annotation, Optional[SlidingWindowFeature]]
- ) -> OutputAccumulationState:
- value = PredictionWithAudio(*value)
- annotation, waveform = None, None
-
- # Determine the real time of the stream and the start time of the buffer
- real_time = duration if state.annotation is None else state.real_time + step
- start_time = max(0., real_time - latency - duration)
-
- # Update annotation and constrain its bounds to the buffer
- if state.annotation is None:
- annotation = value.prediction
- else:
- annotation = state.annotation.update(value.prediction).support(patch_collar)
- if start_time > 0:
- annotation = annotation.extrude(Segment(0, start_time))
-
- # Update the audio buffer if there's audio in the input
- new_next_sample = state.next_sample + num_step_samples
- if value.has_audio:
- if state.waveform is None:
- # Determine the size of the first chunk
- expected_duration = duration + step - latency
- expected_samples = int(round(expected_duration * sample_rate))
- # Shift indicator to start copying new audio in the buffer
- new_next_sample = state.next_sample + expected_samples
- # Buffer size is duration + step
- waveform = np.zeros((num_samples + num_step_samples, 1))
- # Copy first chunk into buffer (slicing because of rounding errors)
- waveform[:expected_samples] = value.waveform.data[:expected_samples]
- elif state.next_sample <= num_samples:
- # The buffer isn't full, copy into next free buffer chunk
- waveform = state.waveform.data
- waveform[state.next_sample:new_next_sample] = value.waveform.data
- else:
- # The buffer is full, shift values to the left and copy into last buffer chunk
- waveform = np.roll(state.waveform.data, -num_step_samples, axis=0)
- # If running on a file, the online prediction may be shorter depending on the latency
- # The remaining audio at the end is appended, so value.waveform may be longer than num_step_samples
- # In that case, we simply ignore the appended samples.
- waveform[-num_step_samples:] = value.waveform.data[:num_step_samples]
-
- # Wrap waveform in a sliding window feature to include timestamps
- window = SlidingWindow(start=start_time, duration=resolution, step=resolution)
- waveform = SlidingWindowFeature(waveform, window)
-
- return OutputAccumulationState(annotation, waveform, real_time, new_next_sample)
-
- return rx.pipe(
- ops.scan(accumulate, OutputAccumulationState.initial()),
- ops.map(OutputAccumulationState.to_tuple),
- )
diff --git a/src/diart/optim.py b/src/diart/optim.py
index 05800a05..be371a71 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -1,51 +1,33 @@
from collections import OrderedDict
-from dataclasses import dataclass
from pathlib import Path
from typing import Sequence, Text, Optional, Union
from optuna import TrialPruned, Study, create_study
from optuna.samplers import TPESampler
from optuna.trial import Trial, FrozenTrial
+from pyannote.metrics.base import BaseMetric
from tqdm import trange, tqdm
+from typing_extensions import Literal
from .audio import FilePath
-from .blocks import BasePipelineConfig, PipelineConfig, OnlineSpeakerDiarization
from .inference import Benchmark
-
-
-@dataclass
-class HyperParameter:
- name: Text
- low: float
- high: float
-
- @staticmethod
- def from_name(name: Text) -> 'HyperParameter':
- if name == "tau_active":
- return TauActive
- if name == "rho_update":
- return RhoUpdate
- if name == "delta_new":
- return DeltaNew
- raise ValueError(f"Hyper-parameter '{name}' not recognized")
-
-
-TauActive = HyperParameter("tau_active", low=0, high=1)
-RhoUpdate = HyperParameter("rho_update", low=0, high=1)
-DeltaNew = HyperParameter("delta_new", low=0, high=2)
+from .pipelines import PipelineConfig
+from .pipelines.hparams import HyperParameter
class Optimizer:
def __init__(
self,
+ pipeline_class: type,
speech_path: Union[Text, Path],
reference_path: Union[Text, Path],
study_or_path: Union[FilePath, Study],
batch_size: int = 32,
- pipeline_class: type = OnlineSpeakerDiarization,
hparams: Optional[Sequence[HyperParameter]] = None,
- base_config: Optional[BasePipelineConfig] = None,
+ base_config: Optional[PipelineConfig] = None,
do_kickstart_hparams: bool = True,
+ metric: Optional[BaseMetric] = None,
+ direction: Literal["minimize", "maximize"] = "minimize",
):
self.pipeline_class = pipeline_class
# FIXME can we run this benchmark in parallel?
@@ -58,15 +40,17 @@ def __init__(
batch_size=batch_size,
)
+ self.metric = metric
+ self.direction = direction
self.base_config = base_config
self.do_kickstart_hparams = do_kickstart_hparams
if self.base_config is None:
- self.base_config = PipelineConfig()
+ self.base_config = self.pipeline_class.get_config_class()()
self.do_kickstart_hparams = False
self.hparams = hparams
if self.hparams is None:
- self.hparams = [TauActive, RhoUpdate, DeltaNew]
+ self.hparams = self.pipeline_class.hyper_parameters()
# Make sure hyper-parameters exist in the configuration class given
possible_hparams = vars(self.base_config)
@@ -85,7 +69,7 @@ def __init__(
storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"),
sampler=TPESampler(),
study_name=study_or_path.stem,
- direction="minimize",
+ direction=self.direction,
load_if_exists=True,
)
else:
@@ -105,7 +89,7 @@ def _callback(self, study: Study, trial: FrozenTrial):
return
self._progress.update(1)
self._progress.set_description(f"Trial {trial.number + 1}")
- values = {"best_der": study.best_value}
+ values = {"best_perf": study.best_value}
for name, value in study.best_params.items():
values[f"best_{name}"] = value
self._progress.set_postfix(OrderedDict(values))
@@ -113,6 +97,9 @@ def _callback(self, study: Study, trial: FrozenTrial):
def objective(self, trial: Trial) -> float:
# Set suggested values for optimized hyper-parameters
trial_config = vars(self.base_config)
+ trial_config["duration"] = self.base_config.duration
+ trial_config["step"] = self.base_config.step
+ trial_config["latency"] = self.base_config.latency
for hparam in self.hparams:
trial_config[hparam.name] = trial.suggest_uniform(
hparam.name, hparam.low, hparam.high
@@ -125,16 +112,21 @@ def objective(self, trial: Trial) -> float:
# Instantiate the new configuration for the trial
config = self.base_config.__class__(**trial_config)
+ # Determine the evaluation metric
+ metric = self.metric
+ if metric is None:
+ metric = self.pipeline_class.suggest_metric()
+
# Run pipeline over the dataset
- report = self.benchmark(self.pipeline_class, config)
+ report = self.benchmark(self.pipeline_class, config, metric)
- # Extract DER from report
- return report.loc["TOTAL", "diarization error rate"]["%"]
+ # Extract target metric from report
+ return report.loc["TOTAL", metric.name].item()
def __call__(self, num_iter: int, show_progress: bool = True):
self._progress = None
if show_progress:
- self._progress = trange(num_iter)
+ self._progress = trange(num_iter, unit="trial")
last_trial = -1
if self.study.trials:
last_trial = self.study.trials[-1].number
diff --git a/src/diart/pipelines/__init__.py b/src/diart/pipelines/__init__.py
new file mode 100644
index 00000000..55676f66
--- /dev/null
+++ b/src/diart/pipelines/__init__.py
@@ -0,0 +1,5 @@
+from .base import Pipeline, PipelineConfig
+from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
+from .speaker_transcription import SpeakerAwareTranscription, SpeakerAwareTranscriptionConfig
+from .transcription import Transcription, TranscriptionConfig
+from .voice import VoiceActivityDetection, VoiceActivityDetectionConfig
diff --git a/src/diart/pipelines/base.py b/src/diart/pipelines/base.py
new file mode 100644
index 00000000..de0582a7
--- /dev/null
+++ b/src/diart/pipelines/base.py
@@ -0,0 +1,84 @@
+from pathlib import Path
+from typing import Any, Tuple, Sequence, Text, List, Union
+
+import numpy as np
+from pyannote.core import SlidingWindowFeature
+from rx.core import Observer
+
+from .hparams import HyperParameter
+from .. import utils
+from ..audio import FilePath, AudioLoader
+from ..metrics import Metric
+
+
+class PipelineConfig:
+ @property
+ def duration(self) -> float:
+ raise NotImplementedError
+
+ @property
+ def step(self) -> float:
+ raise NotImplementedError
+
+ @property
+ def latency(self) -> float:
+ raise NotImplementedError
+
+ @property
+ def sample_rate(self) -> int:
+ raise NotImplementedError
+
+ @staticmethod
+ def from_dict(data: Any) -> 'PipelineConfig':
+ raise NotImplementedError
+
+ def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
+ file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
+ right = utils.get_padding_right(self.latency, self.step)
+ left = utils.get_padding_left(file_duration + right, self.duration)
+ return left, right
+
+ def optimal_block_size(self) -> int:
+ return int(np.rint(self.step * self.sample_rate))
+
+
+class Pipeline:
+ @staticmethod
+ def get_config_class() -> type:
+ raise NotImplementedError
+
+ @staticmethod
+ def suggest_metric() -> Metric:
+ raise NotImplementedError
+
+ @staticmethod
+ def hyper_parameters() -> Sequence[HyperParameter]:
+ raise NotImplementedError
+
+ @property
+ def config(self) -> PipelineConfig:
+ raise NotImplementedError
+
+ def reset(self):
+ raise NotImplementedError
+
+ def set_timestamp_shift(self, shift: float):
+ raise NotImplementedError
+
+ def join_predictions(self, predictions: List[Any]) -> Any:
+ raise NotImplementedError
+
+ def write_prediction(self, uri: Text, prediction: Any, dir_path: Union[Text, Path]):
+ raise NotImplementedError
+
+ def suggest_display(self) -> Observer:
+ raise NotImplementedError
+
+ def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+ raise NotImplementedError
+
+ def __call__(
+ self,
+ waveforms: Sequence[SlidingWindowFeature],
+ ) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
+ raise NotImplementedError
diff --git a/src/diart/pipelines/diarization.py b/src/diart/pipelines/diarization.py
new file mode 100644
index 00000000..114a4223
--- /dev/null
+++ b/src/diart/pipelines/diarization.py
@@ -0,0 +1,277 @@
+from pathlib import Path
+from typing import Optional, Tuple, Sequence, Union, Any, Text, List
+
+import numpy as np
+import torch
+from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
+from rx.core import Observer
+from typing_extensions import Literal
+
+from . import base
+from .hparams import HyperParameter, TauActive, RhoUpdate, DeltaNew
+from .. import blocks
+from .. import models as m
+from .. import sinks
+from .. import utils
+from ..metrics import Metric, DiarizationErrorRate
+
+
+class SpeakerDiarizationConfig(base.PipelineConfig):
+ def __init__(
+ self,
+ segmentation: Optional[m.SegmentationModel] = None,
+ embedding: Optional[m.EmbeddingModel] = None,
+ duration: Optional[float] = None,
+ step: float = 0.5,
+ latency: Optional[Union[float, Literal["max", "min"]]] = None,
+ tau_active: float = 0.5,
+ rho_update: float = 0.3,
+ delta_new: float = 1,
+ gamma: float = 3,
+ beta: float = 10,
+ max_speakers: int = 20,
+ merge_collar: float = 0.05,
+ device: Optional[torch.device] = None,
+ **kwargs,
+ ):
+ # Default segmentation model is pyannote/segmentation
+ self.segmentation = segmentation
+ if self.segmentation is None:
+ self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
+
+ self._duration = duration
+ self._sample_rate: Optional[int] = None
+
+ # Default embedding model is pyannote/embedding
+ self.embedding = embedding
+ if self.embedding is None:
+ self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
+
+ # Latency defaults to the step duration
+ self._step = step
+ self._latency = latency
+ if self._latency is None or self._latency == "min":
+ self._latency = self._step
+ elif self._latency == "max":
+ self._latency = self._duration
+
+ self.tau_active = tau_active
+ self.rho_update = rho_update
+ self.delta_new = delta_new
+ self.gamma = gamma
+ self.beta = beta
+ self.max_speakers = max_speakers
+ self.merge_collar = merge_collar
+
+ self.device = device
+ if self.device is None:
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ @staticmethod
+ def from_dict(data: Any) -> 'SpeakerDiarizationConfig':
+ # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
+ device = utils.get(data, "device", None)
+ if device is None:
+ device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+
+ # Instantiate models
+ hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
+ segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
+ segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
+ embedding = utils.get(data, "embedding", "pyannote/embedding")
+ embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)
+
+ # Hyper-parameters and their aliases
+ tau = utils.get(data, "tau_active", None)
+ if tau is None:
+ tau = utils.get(data, "tau", 0.5)
+ rho = utils.get(data, "rho_update", None)
+ if rho is None:
+ rho = utils.get(data, "rho", 0.3)
+ delta = utils.get(data, "delta_new", None)
+ if delta is None:
+ delta = utils.get(data, "delta", 1)
+
+ return SpeakerDiarizationConfig(
+ segmentation=segmentation,
+ embedding=embedding,
+ duration=utils.get(data, "duration", None),
+ step=utils.get(data, "step", 0.5),
+ latency=utils.get(data, "latency", None),
+ tau_active=tau,
+ rho_update=rho,
+ delta_new=delta,
+ gamma=utils.get(data, "gamma", 3),
+ beta=utils.get(data, "beta", 10),
+ max_speakers=utils.get(data, "max_speakers", 20),
+ merge_collar=utils.get(data, "merge_collar", 0.05),
+ device=device,
+ )
+
+ @property
+ def duration(self) -> float:
+ # Default duration is the one given by the segmentation model
+ if self._duration is None:
+ self._duration = self.segmentation.duration
+ return self._duration
+
+ @property
+ def step(self) -> float:
+ return self._step
+
+ @property
+ def latency(self) -> float:
+ return self._latency
+
+ @property
+ def sample_rate(self) -> int:
+ # Expected sample rate is given by the segmentation model
+ if self._sample_rate is None:
+ self._sample_rate = self.segmentation.sample_rate
+ return self._sample_rate
+
+
+class SpeakerDiarization(base.Pipeline):
+ def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
+ self._config = SpeakerDiarizationConfig() if config is None else config
+
+ msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
+ assert self._config.step <= self._config.latency <= self._config.duration, msg
+
+ self.segmentation = blocks.SpeakerSegmentation(self._config.segmentation, self._config.device)
+ self.embedding = blocks.OverlapAwareSpeakerEmbedding(
+ self._config.embedding, self._config.gamma, self._config.beta, norm=1, device=self._config.device
+ )
+ self.pred_aggregation = blocks.DelayedAggregation(
+ self._config.step,
+ self._config.latency,
+ strategy="hamming",
+ cropping_mode="loose",
+ )
+ self.audio_aggregation = blocks.DelayedAggregation(
+ self._config.step,
+ self._config.latency,
+ strategy="first",
+ cropping_mode="center",
+ )
+ self.binarize = blocks.Binarize(self._config.tau_active)
+
+ # Internal state, handle with care
+ self.timestamp_shift = 0
+ self.clustering = None
+ self.chunk_buffer, self.pred_buffer = [], []
+ self.reset()
+
+ @staticmethod
+ def get_config_class() -> type:
+ return SpeakerDiarizationConfig
+
+ @staticmethod
+ def suggest_metric() -> Metric:
+ return DiarizationErrorRate(collar=0, skip_overlap=False)
+
+ @staticmethod
+ def hyper_parameters() -> Sequence[HyperParameter]:
+ return [TauActive, RhoUpdate, DeltaNew]
+
+ @property
+ def config(self) -> SpeakerDiarizationConfig:
+ return self._config
+
+ def set_timestamp_shift(self, shift: float):
+ self.timestamp_shift = shift
+
+ def join_predictions(self, predictions: List[Annotation]) -> Annotation:
+ result = Annotation(uri=predictions[0].uri)
+ for pred in predictions:
+ result.update(pred)
+ return result.support(self.config.merge_collar)
+
+ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Text, Path]):
+ with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file:
+ prediction.write_rttm(out_file)
+
+ def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+ return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
+
+ def suggest_display(self) -> Observer:
+ return sinks.StreamingPlot(
+ self.config.duration,
+ self.config.step,
+ self.config.latency,
+ self.config.sample_rate
+ )
+
+ def reset(self):
+ self.set_timestamp_shift(0)
+ self.clustering = blocks.IncrementalSpeakerClustering(
+ self.config.tau_active,
+ self.config.rho_update,
+ self.config.delta_new,
+ "cosine",
+ self.config.max_speakers,
+ )
+ self.chunk_buffer, self.pred_buffer = [], []
+
+ def __call__(
+ self,
+ waveforms: Sequence[SlidingWindowFeature],
+ ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
+ batch_size = len(waveforms)
+ msg = "Pipeline expected at least 1 input"
+ assert batch_size >= 1, msg
+
+ # Create batch from chunk sequence, shape (batch, samples, channels)
+ batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
+
+ expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
+ msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
+ assert batch.shape[1] == expected_num_samples, msg
+
+ # Extract segmentation and embeddings
+ segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
+ embeddings = self.embedding(batch, segmentations) # shape (batch, speakers, emb_dim)
+
+ seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
+
+ outputs = []
+ for wav, seg, emb in zip(waveforms, segmentations, embeddings):
+ # Add timestamps to segmentation
+ sw = SlidingWindow(
+ start=wav.extent.start,
+ duration=seg_resolution,
+ step=seg_resolution,
+ )
+ seg = SlidingWindowFeature(seg.cpu().numpy(), sw)
+
+ # Update clustering state and permute segmentation
+ permuted_seg = self.clustering(seg, emb)
+
+ # Update sliding buffer
+ self.chunk_buffer.append(wav)
+ self.pred_buffer.append(permuted_seg)
+
+ # Aggregate buffer outputs for this time step
+ agg_waveform = self.audio_aggregation(self.chunk_buffer)
+ agg_prediction = self.pred_aggregation(self.pred_buffer)
+ agg_prediction = self.binarize(agg_prediction)
+
+ # Shift prediction timestamps if required
+ if self.timestamp_shift != 0:
+ shifted_agg_prediction = Annotation(agg_prediction.uri)
+ for segment, track, speaker in agg_prediction.itertracks(yield_label=True):
+ new_segment = Segment(
+ segment.start + self.timestamp_shift,
+ segment.end + self.timestamp_shift,
+ )
+ shifted_agg_prediction[new_segment, track] = speaker
+ agg_prediction = shifted_agg_prediction
+
+ outputs.append((agg_prediction, agg_waveform))
+
+ # Make place for new chunks in buffer if required
+ if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows:
+ self.chunk_buffer = self.chunk_buffer[1:]
+ self.pred_buffer = self.pred_buffer[1:]
+
+ return outputs
diff --git a/src/diart/pipelines/hparams.py b/src/diart/pipelines/hparams.py
new file mode 100644
index 00000000..740a1edf
--- /dev/null
+++ b/src/diart/pipelines/hparams.py
@@ -0,0 +1,24 @@
+from dataclasses import dataclass
+from typing import Text
+
+
+@dataclass
+class HyperParameter:
+ name: Text
+ low: float
+ high: float
+
+ @staticmethod
+ def from_name(name: Text) -> 'HyperParameter':
+ if name == "tau_active":
+ return TauActive
+ if name == "rho_update":
+ return RhoUpdate
+ if name == "delta_new":
+ return DeltaNew
+ raise ValueError(f"Hyper-parameter '{name}' not recognized")
+
+
+TauActive = HyperParameter("tau_active", low=0, high=1)
+RhoUpdate = HyperParameter("rho_update", low=0, high=1)
+DeltaNew = HyperParameter("delta_new", low=0, high=2)
\ No newline at end of file
diff --git a/src/diart/pipelines/speaker_transcription.py b/src/diart/pipelines/speaker_transcription.py
new file mode 100644
index 00000000..ad2303fa
--- /dev/null
+++ b/src/diart/pipelines/speaker_transcription.py
@@ -0,0 +1,342 @@
+from pathlib import Path
+from typing import Any, Optional, Union, Sequence, Tuple, Text, List
+
+import numpy as np
+import torch
+from diart.metrics import Metric
+from pyannote.core import SlidingWindowFeature, SlidingWindow, Annotation, Segment
+from rx.core import Observer
+from typing_extensions import Literal
+
+from .base import Pipeline, PipelineConfig
+from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
+from .hparams import HyperParameter, TauActive, RhoUpdate, DeltaNew
+from .. import models as m
+from .. import sinks
+from .. import blocks
+from .. import utils
+from ..metrics import WordErrorRate
+
+
+class SpeakerAwareTranscriptionConfig(PipelineConfig):
+ def __init__(
+ self,
+ asr: Optional[m.SpeechRecognitionModel] = None,
+ segmentation: Optional[m.SegmentationModel] = None,
+ embedding: Optional[m.EmbeddingModel] = None,
+ duration: Optional[float] = None,
+ asr_duration: float = 3,
+ step: float = 0.5,
+ latency: Optional[Union[float, Literal["max", "min"]]] = None,
+ tau_active: float = 0.5,
+ rho_update: float = 0.3,
+ delta_new: float = 1,
+ language: Optional[Text] = None,
+ beam_size: Optional[int] = None,
+ gamma: float = 3,
+ beta: float = 10,
+ max_speakers: int = 20,
+ merge_collar: float = 0.05,
+ diarization_device: Optional[torch.device] = None,
+ asr_device: Optional[torch.device] = None,
+ **kwargs,
+ ):
+ # Default segmentation model is pyannote/segmentation
+ self.segmentation = segmentation
+ if self.segmentation is None:
+ self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
+
+ self._duration = duration
+ self._sample_rate: Optional[int] = None
+
+ # Default embedding model is pyannote/embedding
+ self.embedding = embedding
+ if self.embedding is None:
+ self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
+
+ # Latency defaults to the step duration
+ self._step = step
+ self._latency = latency
+ if self._latency is None or self._latency == "min":
+ self._latency = self._step
+ elif self._latency == "max":
+ self._latency = self.duration
+
+ self.tau_active = tau_active
+ self.rho_update = rho_update
+ self.delta_new = delta_new
+ self.gamma = gamma
+ self.beta = beta
+ self.max_speakers = max_speakers
+ self.merge_collar = merge_collar
+ self.asr_duration = asr_duration
+
+ self.diarization_device = diarization_device
+ if self.diarization_device is None:
+ self.diarization_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ self.language = language
+ self.beam_size = beam_size
+
+ self.asr_device = asr_device
+ if self.asr_device is None:
+ self.asr_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Default ASR model is Whisper small (244M parameters)
+ self.asr = asr
+ if self.asr is None:
+ self.asr = m.SpeechRecognitionModel.from_whisper("small")
+ self.asr.set_language(self.language)
+ self.asr.set_beam_size(self.beam_size)
+
+ def to_diarization_config(self) -> SpeakerDiarizationConfig:
+ return SpeakerDiarizationConfig(
+ segmentation=self.segmentation,
+ embedding=self.embedding,
+ duration=self.duration,
+ step=self.step,
+ latency=self.latency,
+ tau_active=self.tau_active,
+ rho_update=self.rho_update,
+ delta_new=self.delta_new,
+ gamma=self.gamma,
+ beta=self.beta,
+ max_speakers=self.max_speakers,
+ merge_collar=self.merge_collar,
+ device=self.diarization_device,
+ )
+
+ @property
+ def duration(self) -> float:
+ # Default duration is the one given by the segmentation model
+ if self._duration is None:
+ self._duration = self.segmentation.duration
+ return self._duration
+
+ @property
+ def step(self) -> float:
+ return self._step
+
+ @property
+ def latency(self) -> float:
+ return self._latency
+
+ @property
+ def sample_rate(self) -> int:
+ if self._sample_rate is None:
+ dia_sample_rate = self.segmentation.sample_rate
+ asr_sample_rate = self.asr.sample_rate
+ msg = "Sample rates for speech recognition and speaker segmentation models must match"
+ assert dia_sample_rate == asr_sample_rate, msg
+ self._sample_rate = dia_sample_rate
+ return self._sample_rate
+
+ @staticmethod
+ def from_dict(data: Any) -> 'SpeakerAwareTranscriptionConfig':
+ # Resolve arguments exactly like diarization
+ dia_config = SpeakerDiarizationConfig.from_dict(data)
+
+ # Default ASR model is Whisper small (244M parameters)
+ whisper_size = utils.get(data, "whisper", "small")
+ asr = m.SpeechRecognitionModel.from_whisper(whisper_size)
+
+ return SpeakerAwareTranscriptionConfig(
+ asr=asr,
+ segmentation=dia_config.segmentation,
+ embedding=dia_config.embedding,
+ duration=dia_config.duration,
+ asr_duration=utils.get(data, "asr_duration", 3),
+ step=dia_config.step,
+ latency=dia_config.latency,
+ tau_active=dia_config.tau_active,
+ rho_update=dia_config.rho_update,
+ delta_new=dia_config.delta_new,
+ language=utils.get(data, "language", None),
+ beam_size=utils.get(data, "beam_size", None),
+ gamma=dia_config.gamma,
+ beta=dia_config.beta,
+ max_speakers=dia_config.max_speakers,
+ merge_collar=dia_config.merge_collar,
+ diarization_device=dia_config.device,
+ # TODO handle different devices
+ asr_device=dia_config.device,
+ )
+
+
+class SpeakerAwareTranscription(Pipeline):
+ def __init__(self, config: Optional[SpeakerAwareTranscriptionConfig] = None):
+ self._config = SpeakerAwareTranscriptionConfig() if config is None else config
+ self.diarization = SpeakerDiarization(self.config.to_diarization_config())
+ self.asr = blocks.SpeechRecognition(self.config.asr, self.config.asr_device)
+
+ # Internal state, handle with care
+ self.audio_buffer, self.dia_buffer = None, None
+
+ @staticmethod
+ def get_config_class() -> type:
+ return SpeakerAwareTranscriptionConfig
+
+ @staticmethod
+ def suggest_metric() -> Metric:
+ # TODO per-speaker WER?
+ return WordErrorRate()
+
+ @staticmethod
+ def hyper_parameters() -> Sequence[HyperParameter]:
+ return [TauActive, RhoUpdate, DeltaNew]
+
+ @property
+ def config(self) -> SpeakerAwareTranscriptionConfig:
+ return self._config
+
+ def reset(self):
+ self.diarization.reset()
+ self.audio_buffer, self.dia_buffer = None, None
+
+ def set_timestamp_shift(self, shift: float):
+ self.diarization.set_timestamp_shift(shift)
+
+ def join_predictions(self, predictions: List[Text]) -> Text:
+ return "\n".join(predictions)
+
+ def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]):
+ with open(Path(dir_path) / f"{uri}.txt", "w") as out_file:
+ out_file.write(prediction)
+
+ def suggest_display(self) -> Observer:
+ return sinks.RichScreen()
+
+ def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+ return sinks.TextWriter(Path(output_dir) / f"{uri}.txt")
+
+ def _update_buffers(self, diarization_output: Sequence[Tuple[Annotation, SlidingWindowFeature]]):
+ # Separate diarization and aligned audio chunks
+ first_chunk = diarization_output[0][1]
+ output_start = first_chunk.extent.start
+ resolution = first_chunk.sliding_window.duration
+ diarization, chunk_data = Annotation(), []
+ for dia, chunk in diarization_output:
+ diarization = diarization.update(dia)
+ chunk_data.append(chunk.data)
+
+ # Update diarization output buffer
+ if self.dia_buffer is None:
+ self.dia_buffer = diarization
+ else:
+ self.dia_buffer = self.dia_buffer.update(diarization)
+ self.dia_buffer = self.dia_buffer.support(self.config.merge_collar)
+
+ # Update audio buffer
+ if self.audio_buffer is None:
+ window = SlidingWindow(resolution, resolution, output_start)
+ self.audio_buffer = SlidingWindowFeature(np.concatenate(chunk_data, axis=0), window)
+ else:
+ chunk_data.insert(0, self.audio_buffer.data)
+ self.audio_buffer = SlidingWindowFeature(
+ np.concatenate(chunk_data, axis=0),
+ self.audio_buffer.sliding_window,
+ )
+
+ def _extract_asr_inputs(self) -> Tuple[List[SlidingWindowFeature], List[Annotation]]:
+ chunk_duration = self.config.asr_duration
+ buffer_duration = self.audio_buffer.extent.duration
+ batch_size = int(buffer_duration / chunk_duration)
+ buffer_start = self.audio_buffer.extent.start
+ resolution = self.audio_buffer.sliding_window.duration
+
+ # Extract audio chunks with their diarization
+ asr_inputs, input_dia, last_end_time = [], [], None
+ for i in range(batch_size):
+ start = buffer_start + i * chunk_duration
+ last_end_time = start + chunk_duration
+ region = Segment(start, last_end_time)
+ chunk = self.audio_buffer.crop(region, fixed=chunk_duration)
+ window = SlidingWindow(resolution, resolution, start)
+ asr_inputs.append(SlidingWindowFeature(chunk, window))
+ input_dia.append(self.dia_buffer.crop(region))
+
+ # Remove extracted chunks from buffers
+ if asr_inputs:
+ new_buffer_bounds = Segment(last_end_time, self.audio_buffer.extent.end)
+ new_buffer = self.audio_buffer.crop(new_buffer_bounds, fixed=new_buffer_bounds.duration)
+ window = SlidingWindow(resolution, resolution, last_end_time)
+ self.audio_buffer = SlidingWindowFeature(new_buffer, window)
+ self.dia_buffer = self.dia_buffer.extrude(Segment(0, last_end_time))
+
+ return asr_inputs, input_dia
+
+ def _get_speaker_transcriptions(
+ self,
+ input_diarization: List[Annotation],
+ asr_inputs: List[SlidingWindowFeature],
+ asr_outputs: List[m.TranscriptionResult],
+ ) -> Text:
+ transcriptions = []
+ for i, waveform in enumerate(asr_inputs):
+ if waveform is None:
+ continue
+ buffer_shift = waveform.sliding_window.start
+ for text, timestamp in zip(asr_outputs[i].chunks, asr_outputs[i].timestamps):
+ if not text.strip():
+ continue
+ target_region = Segment(
+ buffer_shift + timestamp.start,
+ buffer_shift + timestamp.end,
+ )
+ dia = input_diarization[i].crop(target_region)
+ speakers = dia.labels()
+ num_speakers = len(speakers)
+ if num_speakers == 0:
+ # Include transcription but don't assign a speaker
+ transcriptions.append(text)
+ elif num_speakers == 1:
+ # Typical case, annotate text with the only speaker
+ transcriptions.append(f"[{speakers[0]}]{text}")
+ else:
+ # Multiple speakers for the same text block, choose the most active one
+ max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
+ transcriptions.append(f"[{speakers[max_spk]}]{text}")
+ return " ".join(transcriptions).strip()
+
+ def __call__(
+ self,
+ waveforms: Sequence[SlidingWindowFeature],
+ ) -> Sequence[Text]:
+ # Compute diarization output
+ diarization_output = self.diarization(waveforms)
+ self._update_buffers(diarization_output)
+
+ # Extract audio to transcribe from the buffer
+ asr_inputs, asr_input_dia = self._extract_asr_inputs()
+ if not asr_inputs:
+ return ["" for _ in waveforms]
+
+ # Detect non-speech chunks
+ has_voice = torch.tensor([dia.get_timeline().duration() > 0 for dia in asr_input_dia])
+ has_voice = torch.where(has_voice)[0]
+ # Return empty strings if no speech in the entire batch
+ if len(has_voice) == 0:
+ return ["" for _ in waveforms]
+
+ # Create ASR batch, shape (batch, samples, channels)
+ batch = torch.stack([torch.from_numpy(w.data) for w in asr_inputs])
+
+ # Transcribe batch
+ asr_outputs = self.asr(batch[has_voice])
+ asr_outputs = [
+ asr_outputs[i] if i in has_voice else None
+ for i in range(batch.shape[0])
+ ]
+
+ # Attach speaker labels to ASR output and concatenate
+ transcription = self._get_speaker_transcriptions(
+ asr_input_dia, asr_inputs, asr_outputs
+ )
+
+ # Fill output sequence with empty strings
+ batch_size = len(waveforms)
+ output = [transcription]
+ if batch_size > 1:
+ output += [""] * (batch_size - 1)
+
+ return output
diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py
new file mode 100644
index 00000000..49b60175
--- /dev/null
+++ b/src/diart/pipelines/transcription.py
@@ -0,0 +1,184 @@
+from pathlib import Path
+from typing import Sequence, Optional, Any, Union, List, Text, Tuple
+
+import numpy as np
+import torch
+from pyannote.core import SlidingWindowFeature
+from rx.core import Observer
+
+from . import base
+from .hparams import HyperParameter, TauActive
+from .. import blocks
+from .. import models as m
+from .. import sinks
+from .. import utils
+from ..metrics import Metric, WordErrorRate
+
+
+class TranscriptionConfig(base.PipelineConfig):
+ def __init__(
+ self,
+ asr: Optional[m.SpeechRecognitionModel] = None,
+ segmentation: Optional[m.SegmentationModel] = None,
+ tau_active: float = 0.5,
+ duration: Optional[float] = 3,
+ language: Optional[Text] = None,
+ beam_size: int = None,
+ device: Optional[torch.device] = None,
+ **kwargs,
+ ):
+ self.language = language
+ self.beam_size = beam_size
+
+ self.device = device
+ if self.device is None:
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Default ASR model is Whisper small (244M parameters)
+ self.asr = asr
+ if self.asr is None:
+ self.asr = m.SpeechRecognitionModel.from_whisper("small")
+ self.asr.set_language(self.language)
+ self.asr.set_beam_size(self.beam_size)
+
+ self.segmentation = segmentation
+ self.tau_active = tau_active
+
+ self._duration = duration
+ self._sample_rate: Optional[int] = None
+
+ @property
+ def duration(self) -> float:
+ if self._duration is None:
+ self._duration = self.asr.duration
+ return self._duration
+
+ @property
+ def step(self) -> float:
+ return self.duration
+
+ @property
+ def latency(self) -> float:
+ return self.duration
+
+ @property
+ def sample_rate(self) -> int:
+ if self._sample_rate is None:
+ self._sample_rate = self.asr.sample_rate
+ return self._sample_rate
+
+ @staticmethod
+ def from_dict(data: Any) -> 'TranscriptionConfig':
+ # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
+ device = utils.get(data, "device", None)
+ if device is None:
+ device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+
+ # Default ASR model is Whisper small (244M parameters)
+ whisper_size = utils.get(data, "whisper", "small")
+ asr = m.SpeechRecognitionModel.from_whisper(whisper_size)
+
+ # No VAD segmentation by default
+ segmentation = utils.get(data, "segmentation", None)
+ if segmentation is not None:
+ hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
+ segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
+
+ # Tau hyper-parameter and its alias
+ tau = utils.get(data, "tau_active", None)
+ if tau is None:
+ tau = utils.get(data, "tau", 0.5)
+
+ return TranscriptionConfig(
+ asr=asr,
+ segmentation=segmentation,
+ tau_active=tau,
+ duration=utils.get(data, "duration", 3),
+ language=utils.get(data, "language", None),
+ beam_size=utils.get(data, "beam_size", None),
+ device=device,
+ )
+
+
+class Transcription(base.Pipeline):
+ def __init__(self, config: Optional[TranscriptionConfig] = None):
+ self._config = TranscriptionConfig() if config is None else config
+ self.asr = blocks.SpeechRecognition(self.config.asr, self.config.device)
+ self.segmentation = None
+ if self.config.segmentation is not None:
+ self.segmentation = blocks.SpeakerSegmentation(self.config.segmentation, self.config.device)
+
+ @staticmethod
+ def get_config_class() -> type:
+ return TranscriptionConfig
+
+ @staticmethod
+ def suggest_metric() -> Metric:
+ return WordErrorRate()
+
+ @staticmethod
+ def hyper_parameters() -> Sequence[HyperParameter]:
+ return [TauActive]
+
+ @property
+ def config(self) -> TranscriptionConfig:
+ return self._config
+
+ def reset(self):
+ # No internal state. Nothing to do
+ pass
+
+ def set_timestamp_shift(self, shift: float):
+ # No timestamped output. Nothing to do
+ pass
+
+ def join_predictions(self, predictions: List[Text]) -> Text:
+ return "\n".join(predictions)
+
+ def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]):
+ with open(Path(dir_path) / f"{uri}.txt", "w") as out_file:
+ out_file.write(prediction)
+
+ def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+ return sinks.TextWriter(Path(output_dir) / f"{uri}.txt")
+
+ def suggest_display(self) -> Observer:
+ return sinks.RichScreen()
+
+ def __call__(
+ self,
+ waveforms: Sequence[SlidingWindowFeature],
+ ) -> Sequence[Tuple[Text, SlidingWindowFeature]]:
+ batch_size = len(waveforms)
+ msg = "Pipeline expected at least 1 input"
+ assert batch_size >= 1, msg
+
+ # Create batch from chunk sequence, shape (batch, samples, channels)
+ batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
+
+ expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
+ msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
+ assert batch.shape[1] == expected_num_samples, msg
+
+ # Run voice detection if required
+ if self.segmentation is None:
+ has_voice = torch.arange(0, batch_size)
+ else:
+ segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
+ has_voice = torch.max(segmentations, dim=-1)[0] # shape (batch, frames)
+ has_voice = torch.any(has_voice >= self.config.tau_active, dim=-1) # shape (batch,)
+ has_voice = torch.where(has_voice)[0]
+
+ # Return empty strings if no speech in the entire batch
+ if len(has_voice) == 0:
+ return [("", wav) for wav in waveforms]
+
+ # Transcribe batch
+ outputs = self.asr(batch[has_voice])
+ mapping = {i_voice.item(): i_output for i_output, i_voice in enumerate(has_voice)}
+
+ # No-speech outputs are empty strings
+ return [
+ (outputs[mapping[i]].text if i in has_voice else "", waveforms[i])
+ for i in range(batch_size)
+ ]
diff --git a/src/diart/pipelines/voice.py b/src/diart/pipelines/voice.py
new file mode 100644
index 00000000..05eaa216
--- /dev/null
+++ b/src/diart/pipelines/voice.py
@@ -0,0 +1,233 @@
+from pathlib import Path
+from typing import Any, Optional, Union, Sequence, Tuple, Text, List
+
+import numpy as np
+import torch
+from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment
+from rx.core import Observer
+from typing_extensions import Literal
+
+from . import base
+from .hparams import HyperParameter, TauActive
+from .. import blocks
+from .. import models as m
+from .. import sinks
+from .. import utils
+from ..metrics import Metric, DetectionErrorRate
+
+
+class VoiceActivityDetectionConfig(base.PipelineConfig):
+ def __init__(
+ self,
+ segmentation: Optional[m.SegmentationModel] = None,
+ duration: Optional[float] = None,
+ step: float = 0.5,
+ latency: Optional[Union[float, Literal["max", "min"]]] = None,
+ tau_active: float = 0.5,
+ merge_collar: float = 0.05,
+ device: Optional[torch.device] = None,
+ **kwargs,
+ ):
+ # Default segmentation model is pyannote/segmentation
+ self.segmentation = segmentation
+ if self.segmentation is None:
+ self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
+
+ self._duration = duration
+ self._step = step
+ self._sample_rate: Optional[int] = None
+
+ # Latency defaults to the step duration
+ self._latency = latency
+ if self._latency is None or self._latency == "min":
+ self._latency = self._step
+ elif self._latency == "max":
+ self._latency = self._duration
+
+ self.tau_active = tau_active
+ self.merge_collar = merge_collar
+ self.device = device
+ if self.device is None:
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ @property
+ def duration(self) -> float:
+ # Default duration is the one given by the segmentation model
+ if self._duration is None:
+ self._duration = self.segmentation.duration
+ return self._duration
+
+ @property
+ def step(self) -> float:
+ return self._step
+
+ @property
+ def latency(self) -> float:
+ return self._latency
+
+ @property
+ def sample_rate(self) -> int:
+ # Expected sample rate is given by the segmentation model
+ if self._sample_rate is None:
+ self._sample_rate = self.segmentation.sample_rate
+ return self._sample_rate
+
+ @staticmethod
+ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig':
+ # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
+ device = utils.get(data, "device", None)
+ if device is None:
+ device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+
+ # Instantiate segmentation model
+ hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
+ segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
+ segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
+
+ # Tau active and its alias
+ tau = utils.get(data, "tau_active", None)
+ if tau is None:
+ tau = utils.get(data, "tau", 0.5)
+
+ return VoiceActivityDetectionConfig(
+ segmentation=segmentation,
+ duration=utils.get(data, "duration", None),
+ step=utils.get(data, "step", 0.5),
+ latency=utils.get(data, "latency", None),
+ tau_active=tau,
+ merge_collar=utils.get(data, "merge_collar", 0.05),
+ device=device,
+ )
+
+
+class VoiceActivityDetection(base.Pipeline):
+ def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
+ self._config = VoiceActivityDetectionConfig() if config is None else config
+
+ msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
+ assert self._config.step <= self._config.latency <= self._config.duration, msg
+
+ self.segmentation = blocks.SpeakerSegmentation(self._config.segmentation, self._config.device)
+ self.pred_aggregation = blocks.DelayedAggregation(
+ self._config.step,
+ self._config.latency,
+ strategy="hamming",
+ cropping_mode="loose",
+ )
+ self.audio_aggregation = blocks.DelayedAggregation(
+ self._config.step,
+ self._config.latency,
+ strategy="first",
+ cropping_mode="center",
+ )
+ self.binarize = blocks.Binarize(self._config.tau_active)
+
+ # Internal state, handle with care
+ self.timestamp_shift = 0
+ self.chunk_buffer, self.pred_buffer = [], []
+
+ @staticmethod
+ def get_config_class() -> type:
+ return VoiceActivityDetectionConfig
+
+ @staticmethod
+ def suggest_metric() -> Metric:
+ return DetectionErrorRate(collar=0, skip_overlap=False)
+
+ @staticmethod
+ def hyper_parameters() -> Sequence[HyperParameter]:
+ return [TauActive]
+
+ @property
+ def config(self) -> VoiceActivityDetectionConfig:
+ return self._config
+
+ def reset(self):
+ self.set_timestamp_shift(0)
+ self.chunk_buffer, self.pred_buffer = [], []
+
+ def set_timestamp_shift(self, shift: float):
+ self.timestamp_shift = shift
+
+ def join_predictions(self, predictions: List[Annotation]) -> Annotation:
+ result = Annotation(uri=predictions[0].uri)
+ for pred in predictions:
+ result.update(pred)
+ return result.support(self.config.merge_collar)
+
+ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Text, Path]):
+ with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file:
+ prediction.write_rttm(out_file)
+
+ def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+ return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
+
+ def suggest_display(self) -> Observer:
+ return sinks.StreamingPlot(
+ self.config.duration,
+ self.config.step,
+ self.config.latency,
+ self.config.sample_rate
+ )
+
+ def __call__(
+ self,
+ waveforms: Sequence[SlidingWindowFeature],
+ ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
+ batch_size = len(waveforms)
+ msg = "Pipeline expected at least 1 input"
+ assert batch_size >= 1, msg
+
+ # Create batch from chunk sequence, shape (batch, samples, channels)
+ batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
+
+ expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
+ msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
+ assert batch.shape[1] == expected_num_samples, msg
+
+ # Extract segmentation
+ segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
+ voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[0] # shape (batch, frames, 1)
+
+ seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
+
+ outputs = []
+ for wav, vad in zip(waveforms, voice_detection):
+ # Add timestamps to segmentation
+ sw = SlidingWindow(
+ start=wav.extent.start,
+ duration=seg_resolution,
+ step=seg_resolution,
+ )
+ vad = SlidingWindowFeature(vad.cpu().numpy(), sw)
+
+ # Update sliding buffer
+ self.chunk_buffer.append(wav)
+ self.pred_buffer.append(vad)
+
+ # Aggregate buffer outputs for this time step
+ agg_waveform = self.audio_aggregation(self.chunk_buffer)
+ agg_prediction = self.pred_aggregation(self.pred_buffer)
+ agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False)
+
+ # Shift prediction timestamps if required
+ if self.timestamp_shift != 0:
+ shifted_agg_prediction = Timeline(uri=agg_prediction.uri)
+ for segment in agg_prediction:
+ new_segment = Segment(
+ segment.start + self.timestamp_shift,
+ segment.end + self.timestamp_shift,
+ )
+ shifted_agg_prediction.add(new_segment)
+ agg_prediction = shifted_agg_prediction
+
+ # Convert timeline into annotation with single speaker "speech"
+ agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech"))
+ outputs.append((agg_prediction, agg_waveform))
+
+ # Make place for new chunks in buffer if required
+ if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows:
+ self.chunk_buffer = self.chunk_buffer[1:]
+ self.pred_buffer = self.pred_buffer[1:]
+
+ return outputs
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index cf480bed..fbc46904 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -1,25 +1,24 @@
+import re
from pathlib import Path
-from typing import Union, Text, Optional, Tuple
+from typing import Union, Text, Optional, Tuple, Any, List
import matplotlib.pyplot as plt
-from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
+import numpy as np
+import rich
+from pyannote.core import Annotation, Segment, SlidingWindowFeature, SlidingWindow, notebook
from pyannote.database.util import load_rttm
from pyannote.metrics.diarization import DiarizationErrorRate
from rx.core import Observer
-from typing_extensions import Literal
class WindowClosedException(Exception):
pass
-def _extract_annotation(value: Union[Tuple, Annotation]) -> Annotation:
+def _extract_prediction(value: Union[Tuple, Any]) -> Any:
if isinstance(value, tuple):
return value[0]
- if isinstance(value, Annotation):
- return value
- msg = f"Expected tuple or Annotation, but got {type(value)}"
- raise ValueError(msg)
+ return value
class RTTMWriter(Observer):
@@ -43,10 +42,11 @@ def patch(self):
annotation.support(self.patch_collar).write_rttm(file)
def on_next(self, value: Union[Tuple, Annotation]):
- annotation = _extract_annotation(value)
- annotation.uri = self.uri
+ prediction = _extract_prediction(value)
+ # Write prediction in RTTM format
+ prediction.uri = self.uri
with open(self.path, 'a') as file:
- annotation.write_rttm(file)
+ prediction.write_rttm(file)
def on_error(self, error: Exception):
self.patch()
@@ -55,30 +55,47 @@ def on_completed(self):
self.patch()
-class DiarizationPredictionAccumulator(Observer):
+class TextWriter(Observer):
+ def __init__(self, path: Union[Path, Text]):
+ super().__init__()
+ self.path = Path(path).expanduser()
+ if self.path.exists():
+ self.path.unlink()
+
+ def on_error(self, error: Exception):
+ pass
+
+ def on_next(self, value: Union[Tuple, Text]):
+ # Write transcription to file
+ prediction = _extract_prediction(value)
+ with open(self.path, 'a') as file:
+ file.write(prediction + "\n")
+
+
+class DiarizationAccumulator(Observer):
def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05):
super().__init__()
self.uri = uri
self.patch_collar = patch_collar
- self._annotation = None
+ self._prediction: Optional[Annotation] = None
def patch(self):
"""Stitch same-speaker turns that are close to each other"""
- if self._annotation is not None:
- self._annotation = self._annotation.support(self.patch_collar)
+ if self._prediction is not None:
+ self._prediction = self._prediction.support(self.patch_collar)
def get_prediction(self) -> Annotation:
# Patch again in case this is called before on_completed
self.patch()
- return self._annotation
+ return self._prediction
def on_next(self, value: Union[Tuple, Annotation]):
- annotation = _extract_annotation(value)
- annotation.uri = self.uri
- if self._annotation is None:
- self._annotation = annotation
+ prediction = _extract_prediction(value)
+ prediction.uri = self.uri
+ if self._prediction is None:
+ self._prediction = prediction
else:
- self._annotation.update(annotation)
+ self._prediction.update(prediction)
def on_error(self, error: Exception):
self.patch()
@@ -87,25 +104,69 @@ def on_completed(self):
self.patch()
-class RealTimePlot(Observer):
+class RichScreen(Observer):
+ def __init__(self, speaker_colors: Optional[List[Text]] = None):
+ super().__init__()
+ self.colors = speaker_colors
+ if self.colors is None:
+ self.colors = [
+ "bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1",
+ "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2"
+ ]
+ self.num_colors = len(self.colors)
+ self._speaker_to_color = {}
+
+ def on_error(self, error: Exception):
+ pass
+
+ def on_next(self, value: Union[Tuple, Text]):
+ prediction = _extract_prediction(value)
+ if not prediction.strip():
+ return
+ # Extract speakers
+ speakers = sorted(re.findall(r'\[.*?]', prediction))
+ # Colorize based on speakers
+ colorized = prediction
+ for spk in speakers:
+ name = spk[1:-1]
+ if name not in self._speaker_to_color:
+ next_color_idx = len(self._speaker_to_color) % self.num_colors
+ self._speaker_to_color[name] = self.colors[next_color_idx]
+ colorized = colorized.replace(spk, f"[{self._speaker_to_color[name]}]")
+ # Print result
+ rich.print(colorized)
+
+
+class StreamingPlot(Observer):
def __init__(
self,
duration: float,
+ step: float,
latency: float,
- visualization: Literal["slide", "accumulate"] = "slide",
+ sample_rate: float,
reference: Optional[Union[Path, Text]] = None,
+ patch_collar: float = 0.05,
):
super().__init__()
- assert visualization in ["slide", "accumulate"]
- self.visualization = visualization
self.reference = reference
if self.reference is not None:
self.reference = list(load_rttm(reference).values())[0]
self.window_duration = duration
+ self.window_step = step
self.latency = latency
+ self.sample_rate = sample_rate
+ self.patch_collar = patch_collar
+
+ self.num_window_samples = int(np.rint(self.window_duration * self.sample_rate))
+ self.num_step_samples = int(np.rint(self.window_step * self.sample_rate))
+ self.audio_resolution = 1 / self.sample_rate
+
self.figure, self.axs, self.num_axs = None, None, -1
# This flag allows to catch the matplotlib window closed event and make the next call stop iterating
self.window_closed = False
+ self.real_time = 0
+ self.pred_buffer, self.audio_buffer = None, None
+ self.next_sample = 0
def _on_window_closed(self, event):
self.window_closed = True
@@ -127,35 +188,91 @@ def _clear_axs(self):
for i in range(self.num_axs):
self.axs[i].clear()
- def get_plot_bounds(self, real_time: float) -> Segment:
- start_time = 0
- end_time = real_time - self.latency
- if self.visualization == "slide":
- start_time = max(0., end_time - self.window_duration)
+ def get_plot_bounds(self) -> Segment:
+ end_time = self.real_time - self.latency
+ start_time = max(0., end_time - self.window_duration)
return Segment(start_time, end_time)
- def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
+ def on_error(self, error: Exception):
+ # Do nothing on error
+ pass
+
+ def on_next(
+ self,
+ values: Tuple[Annotation, SlidingWindowFeature]
+ ):
if self.window_closed:
raise WindowClosedException
- prediction, waveform, real_time = values
+ prediction, waveform = values
+
+ # TODO break this aggregation code into methods
+
+ # Determine the real time of the stream and the start time of the buffer
+ self.real_time = waveform.extent.end
+ start_time = max(0., self.real_time - self.latency - self.window_duration)
+
+ # Update prediction buffer and constrain its bounds
+ if self.pred_buffer is None:
+ self.pred_buffer = prediction
+ else:
+ self.pred_buffer = self.pred_buffer.update(prediction)
+ self.pred_buffer = self.pred_buffer.support(self.patch_collar)
+ if start_time > 0:
+ self.pred_buffer = self.pred_buffer.extrude(Segment(0, start_time))
+
+ # Update the audio buffer if there's audio in the input
+ new_next_sample = self.next_sample + self.num_step_samples
+ if self.audio_buffer is None:
+ # Determine the size of the first chunk
+ expected_duration = self.window_duration + self.window_step - self.latency
+ expected_samples = int(np.rint(expected_duration * self.sample_rate))
+ # Shift indicator to start copying new audio in the buffer
+ new_next_sample = self.next_sample + expected_samples
+ # Buffer size is duration + step
+ new_buffer = np.zeros((self.num_window_samples + self.num_step_samples, 1))
+ # Copy first chunk into buffer (slicing because of rounding errors)
+ new_buffer[:expected_samples] = waveform.data[:expected_samples]
+ elif self.next_sample <= self.num_window_samples:
+ # The buffer isn't full, copy into next free buffer chunk
+ new_buffer = self.audio_buffer.data
+ new_buffer[self.next_sample:new_next_sample] = waveform.data
+ else:
+ # The buffer is full, shift values to the left and copy into last buffer chunk
+ new_buffer = np.roll(self.audio_buffer.data, -self.num_step_samples, axis=0)
+ # If running on a file, the online prediction may be shorter depending on the latency
+ # The remaining audio at the end is appended, so 'waveform' may be longer than 'num_step_samples'
+ # In that case, we simply ignore the appended samples.
+ new_buffer[-self.num_step_samples:] = waveform.data[:self.num_step_samples]
+
+ # Wrap waveform in a sliding window feature to include timestamps
+ window = SlidingWindow(start=start_time, duration=self.audio_resolution, step=self.audio_resolution)
+ self.audio_buffer = SlidingWindowFeature(new_buffer, window)
+ self.next_sample = new_next_sample
+
# Initialize figure if first call
if self.figure is None:
self._init_figure()
# Clear previous plots
self._clear_axs()
# Set plot bounds
- notebook.crop = self.get_plot_bounds(real_time)
+ notebook.crop = self.get_plot_bounds()
- # Plot current values
+ # Align prediction and reference if possible
if self.reference is not None:
metric = DiarizationErrorRate()
- mapping = metric.optimal_mapping(self.reference, prediction)
- prediction.rename_labels(mapping=mapping, copy=False)
- notebook.plot_annotation(prediction, self.axs[0])
+ mapping = metric.optimal_mapping(self.reference, self.pred_buffer)
+ self.pred_buffer.rename_labels(mapping=mapping, copy=False)
+
+ # Plot prediction
+ notebook.plot_annotation(self.pred_buffer, self.axs[0])
self.axs[0].set_title("Output")
- notebook.plot_feature(waveform, self.axs[1])
+
+ # Plot waveform
+ notebook.plot_feature(self.audio_buffer, self.axs[1])
self.axs[1].set_title("Audio")
+
+ # Plot reference if available
if self.num_axs == 3:
notebook.plot_annotation(self.reference, self.axs[2])
self.axs[2].set_title("Reference")
diff --git a/src/diart/sources.py b/src/diart/sources.py
index 0f5dedf7..76149bb4 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -5,12 +5,12 @@
import numpy as np
import sounddevice as sd
import torch
-from diart import utils
from einops import rearrange
from rx.subject import Subject
from torchaudio.io import StreamReader
from websocket_server import WebsocketServer
+from . import utils
from .audio import FilePath, AudioLoader
@@ -247,8 +247,9 @@ def send(self, message: AnyStr):
message: AnyStr
Bytes or string to send.
"""
- if len(message) > 0:
- self.server.send_message(self.client, message)
+ msg = message.strip()
+ if len(msg) > 0:
+ self.server.send_message(self.client, msg + "\n")
class TorchStreamAudioSource(AudioSource):
diff --git a/src/diart/utils.py b/src/diart/utils.py
index e90861c7..725d9bf4 100644
--- a/src/diart/utils.py
+++ b/src/diart/utils.py
@@ -1,12 +1,14 @@
import base64
import time
-from typing import Optional, Text, Union, Any, Dict
+from typing import Optional, Text, Union, Any, Dict, Tuple
import matplotlib.pyplot as plt
import numpy as np
-from diart.progress import ProgressBar
from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
+from . import pipelines
+from .progress import ProgressBar
+
class Chronometer:
def __init__(self, unit: Text, progress_bar: Optional[ProgressBar] = None):
@@ -74,10 +76,30 @@ def get_padding_left(stream_duration: float, chunk_duration: float) -> float:
return 0
+def repeat_label(label: Text):
+ while True:
+ yield label
+
+
+def get_pipeline_class(class_name: Text) -> type:
+ pipeline_class = getattr(pipelines, class_name, None)
+ msg = f"Pipeline '{class_name}' doesn't exist"
+ assert pipeline_class is not None, msg
+ return pipeline_class
+
+
def get_padding_right(latency: float, step: float) -> float:
return latency - step
+def serialize_prediction(value: Union[Tuple, Annotation, Text]) -> Text:
+ if isinstance(value, tuple):
+ value = value[0]
+ if isinstance(value, Annotation):
+ return value.to_rttm()
+ return value
+
+
def visualize_feature(duration: Optional[float] = None):
def apply(feature: SlidingWindowFeature):
if duration is None: