From bca2873e3ebaa901fdcaa37096f7e1904359f2ca Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Wed, 19 Apr 2023 17:41:41 +0200
Subject: [PATCH 01/23] New feature: streaming voice activity detection.
 Pipeline name changes

---
 src/diart/__init__.py           |  10 +-
 src/diart/blocks/__init__.py    |   5 +-
 src/diart/blocks/base.py        |  92 ++++++++++++++
 src/diart/blocks/config.py      | 153 -----------------------
 src/diart/blocks/diarization.py | 145 ++++++++++++++++++----
 src/diart/blocks/vad.py         | 208 ++++++++++++++++++++++++++++++++
 src/diart/console/benchmark.py  |  12 +-
 src/diart/console/client.py     |   6 +-
 src/diart/console/serve.py      |  19 +--
 src/diart/console/stream.py     |  19 +--
 src/diart/console/tune.py       |  26 +++-
 src/diart/inference.py          |  86 +++++++------
 src/diart/optim.py              |  56 ++++-----
 src/diart/sinks.py              |  47 +++++---
 src/diart/sources.py            |   2 +-
 src/diart/utils.py              |  16 ++-
 16 files changed, 605 insertions(+), 297 deletions(-)
 create mode 100644 src/diart/blocks/base.py
 delete mode 100644 src/diart/blocks/config.py
 create mode 100644 src/diart/blocks/vad.py

diff --git a/src/diart/__init__.py b/src/diart/__init__.py
index c9692638..e29287a0 100644
--- a/src/diart/__init__.py
+++ b/src/diart/__init__.py
@@ -1,6 +1,8 @@
 from .blocks import (
-    OnlineSpeakerDiarization,
-    BasePipeline,
-    PipelineConfig,
-    BasePipelineConfig,
+    SpeakerDiarization,
+    StreamingPipeline,
+    SpeakerDiarizationConfig,
+    StreamingConfig,
+    VoiceActivityDetection,
+    VoiceActivityDetectionConfig,
 )
diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py
index 59a6ef36..e6e8c479 100644
--- a/src/diart/blocks/__init__.py
+++ b/src/diart/blocks/__init__.py
@@ -13,6 +13,7 @@
     OverlapAwareSpeakerEmbedding,
 )
 from .segmentation import SpeakerSegmentation
-from .diarization import OnlineSpeakerDiarization, BasePipeline
-from .config import BasePipelineConfig, PipelineConfig
+from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
+from .base import StreamingConfig, StreamingPipeline
 from .utils import Binarize, Resample, AdjustVolume
+from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig
diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py
new file mode 100644
index 00000000..28f313eb
--- /dev/null
+++ b/src/diart/blocks/base.py
@@ -0,0 +1,92 @@
+from typing import Any, Tuple, Sequence, Text
+from dataclasses import dataclass
+
+import numpy as np
+from pyannote.core import SlidingWindowFeature
+from pyannote.metrics.base import BaseMetric
+
+from .. import utils
+from ..audio import FilePath, AudioLoader
+
+
+@dataclass
+class HyperParameter:
+    name: Text
+    low: float
+    high: float
+
+    @staticmethod
+    def from_name(name: Text) -> 'HyperParameter':
+        if name == "tau_active":
+            return TauActive
+        if name == "rho_update":
+            return RhoUpdate
+        if name == "delta_new":
+            return DeltaNew
+        raise ValueError(f"Hyper-parameter '{name}' not recognized")
+
+
+TauActive = HyperParameter("tau_active", low=0, high=1)
+RhoUpdate = HyperParameter("rho_update", low=0, high=1)
+DeltaNew = HyperParameter("delta_new", low=0, high=2)
+
+
+class StreamingConfig:
+    @property
+    def duration(self) -> float:
+        raise NotImplementedError
+
+    @property
+    def step(self) -> float:
+        raise NotImplementedError
+
+    @property
+    def latency(self) -> float:
+        raise NotImplementedError
+
+    @property
+    def sample_rate(self) -> int:
+        raise NotImplementedError
+
+    @staticmethod
+    def from_dict(data: Any) -> 'StreamingConfig':
+        raise NotImplementedError
+
+    def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
+        file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
+        right = utils.get_padding_right(self.latency, self.step)
+        left = utils.get_padding_left(file_duration + right, self.duration)
+        return left, right
+
+    def optimal_block_size(self) -> int:
+        return int(np.rint(self.step * self.sample_rate))
+
+
+class StreamingPipeline:
+    @staticmethod
+    def get_config_class() -> type:
+        raise NotImplementedError
+
+    @staticmethod
+    def suggest_metric() -> BaseMetric:
+        raise NotImplementedError
+
+    @staticmethod
+    def hyper_parameters() -> Sequence[HyperParameter]:
+        raise NotImplementedError
+
+    @property
+    def config(self) -> StreamingConfig:
+        raise NotImplementedError
+
+    def reset(self):
+        raise NotImplementedError
+
+    def set_timestamp_shift(self, shift: float):
+        raise NotImplementedError
+
+    def __call__(
+        self,
+        waveforms: Sequence[SlidingWindowFeature]
+    ) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
+        raise NotImplementedError
diff --git a/src/diart/blocks/config.py b/src/diart/blocks/config.py
deleted file mode 100644
index d8e2a656..00000000
--- a/src/diart/blocks/config.py
+++ /dev/null
@@ -1,153 +0,0 @@
-from typing import Any, Optional, Union, Tuple
-
-import numpy as np
-import torch
-from typing_extensions import Literal
-
-from .. import models as m
-from .. import utils
-from ..audio import FilePath, AudioLoader
-
-
-class BasePipelineConfig:
-    @property
-    def duration(self) -> float:
-        raise NotImplementedError
-
-    @property
-    def step(self) -> float:
-        raise NotImplementedError
-
-    @property
-    def latency(self) -> float:
-        raise NotImplementedError
-
-    @property
-    def sample_rate(self) -> int:
-        raise NotImplementedError
-
-    @staticmethod
-    def from_dict(data: Any) -> 'BasePipelineConfig':
-        raise NotImplementedError
-
-    def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
-        file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
-        right = utils.get_padding_right(self.latency, self.step)
-        left = utils.get_padding_left(file_duration + right, self.duration)
-        return left, right
-
-    def optimal_block_size(self) -> int:
-        return int(np.rint(self.step * self.sample_rate))
-
-
-class PipelineConfig(BasePipelineConfig):
-    def __init__(
-        self,
-        segmentation: Optional[m.SegmentationModel] = None,
-        embedding: Optional[m.EmbeddingModel] = None,
-        duration: Optional[float] = None,
-        step: float = 0.5,
-        latency: Optional[Union[float, Literal["max", "min"]]] = None,
-        tau_active: float = 0.6,
-        rho_update: float = 0.3,
-        delta_new: float = 1,
-        gamma: float = 3,
-        beta: float = 10,
-        max_speakers: int = 20,
-        device: Optional[torch.device] = None,
-        **kwargs,
-    ):
-        # Default segmentation model is pyannote/segmentation
-        self.segmentation = segmentation
-        if self.segmentation is None:
-            self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
-
-        # Default duration is the one given by the segmentation model
-        self._duration = duration
-
-        # Expected sample rate is given by the segmentation model
-        self._sample_rate: Optional[int] = None
-
-        # Default embedding model is pyannote/embedding
-        self.embedding = embedding
-        if self.embedding is None:
-            self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
-
-        # Latency defaults to the step duration
-        self._step = step
-        self._latency = latency
-        if self._latency is None or self._latency == "min":
-            self._latency = self._step
-        elif self._latency == "max":
-            self._latency = self._duration
-
-        self.tau_active = tau_active
-        self.rho_update = rho_update
-        self.delta_new = delta_new
-        self.gamma = gamma
-        self.beta = beta
-        self.max_speakers = max_speakers
-
-        self.device = device
-        if self.device is None:
-            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-    @staticmethod
-    def from_dict(data: Any) -> 'PipelineConfig':
-        # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
-        device = utils.get(data, "device", None)
-        if device is None:
-            device = torch.device("cpu") if utils.get(data, "cpu", False) else None
-
-        # Instantiate models
-        hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
-        segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
-        segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
-        embedding = utils.get(data, "embedding", "pyannote/embedding")
-        embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)
-
-        # Hyper-parameters and their aliases
-        tau = utils.get(data, "tau_active", None)
-        if tau is None:
-            tau = utils.get(data, "tau", 0.6)
-        rho = utils.get(data, "rho_update", None)
-        if rho is None:
-            rho = utils.get(data, "rho", 0.3)
-        delta = utils.get(data, "delta_new", None)
-        if delta is None:
-            delta = utils.get(data, "delta", 1)
-
-        return PipelineConfig(
-            segmentation=segmentation,
-            embedding=embedding,
-            duration=utils.get(data, "duration", None),
-            step=utils.get(data, "step", 0.5),
-            latency=utils.get(data, "latency", None),
-            tau_active=tau,
-            rho_update=rho,
-            delta_new=delta,
-            gamma=utils.get(data, "gamma", 3),
-            beta=utils.get(data, "beta", 10),
-            max_speakers=utils.get(data, "max_speakers", 20),
-            device=device,
-        )
-
-    @property
-    def duration(self) -> float:
-        if self._duration is None:
-            self._duration = self.segmentation.duration
-        return self._duration
-
-    @property
-    def step(self) -> float:
-        return self._step
-
-    @property
-    def latency(self) -> float:
-        return self._latency
-
-    @property
-    def sample_rate(self) -> int:
-        if self._sample_rate is None:
-            self._sample_rate = self.segmentation.sample_rate
-        return self._sample_rate
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
index 7f0e162c..f2a25119 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/blocks/diarization.py
@@ -1,42 +1,137 @@
-from typing import Optional, Tuple, Sequence
+from typing import Optional, Tuple, Sequence, Union, Any
 
 import numpy as np
 import torch
 from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
+from pyannote.metrics.base import BaseMetric
+from pyannote.metrics.diarization import DiarizationErrorRate
+from typing_extensions import Literal
 
 from .aggregation import DelayedAggregation
+from . import base
 from .clustering import OnlineSpeakerClustering
 from .embedding import OverlapAwareSpeakerEmbedding
 from .segmentation import SpeakerSegmentation
 from .utils import Binarize
-from .config import BasePipelineConfig, PipelineConfig
+from .. import models as m
+from .. import utils
 
 
-class BasePipeline:
+class SpeakerDiarizationConfig(base.StreamingConfig):
+    def __init__(
+        self,
+        segmentation: Optional[m.SegmentationModel] = None,
+        embedding: Optional[m.EmbeddingModel] = None,
+        duration: Optional[float] = None,
+        step: float = 0.5,
+        latency: Optional[Union[float, Literal["max", "min"]]] = None,
+        tau_active: float = 0.6,
+        rho_update: float = 0.3,
+        delta_new: float = 1,
+        gamma: float = 3,
+        beta: float = 10,
+        max_speakers: int = 20,
+        device: Optional[torch.device] = None,
+        **kwargs,
+    ):
+        # Default segmentation model is pyannote/segmentation
+        self.segmentation = segmentation
+        if self.segmentation is None:
+            self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
+
+        self._duration = duration
+        self._sample_rate: Optional[int] = None
+
+        # Default embedding model is pyannote/embedding
+        self.embedding = embedding
+        if self.embedding is None:
+            self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
+
+        # Latency defaults to the step duration
+        self._step = step
+        self._latency = latency
+        if self._latency is None or self._latency == "min":
+            self._latency = self._step
+        elif self._latency == "max":
+            self._latency = self._duration
+
+        self.tau_active = tau_active
+        self.rho_update = rho_update
+        self.delta_new = delta_new
+        self.gamma = gamma
+        self.beta = beta
+        self.max_speakers = max_speakers
+
+        self.device = device
+        if self.device is None:
+            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
     @staticmethod
-    def get_config_class() -> type:
-        raise NotImplementedError
+    def from_dict(data: Any) -> 'SpeakerDiarizationConfig':
+        # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
+        device = utils.get(data, "device", None)
+        if device is None:
+            device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+
+        # Instantiate models
+        hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
+        segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
+        segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
+        embedding = utils.get(data, "embedding", "pyannote/embedding")
+        embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)
+
+        # Hyper-parameters and their aliases
+        tau = utils.get(data, "tau_active", None)
+        if tau is None:
+            tau = utils.get(data, "tau", 0.6)
+        rho = utils.get(data, "rho_update", None)
+        if rho is None:
+            rho = utils.get(data, "rho", 0.3)
+        delta = utils.get(data, "delta_new", None)
+        if delta is None:
+            delta = utils.get(data, "delta", 1)
+
+        return SpeakerDiarizationConfig(
+            segmentation=segmentation,
+            embedding=embedding,
+            duration=utils.get(data, "duration", None),
+            step=utils.get(data, "step", 0.5),
+            latency=utils.get(data, "latency", None),
+            tau_active=tau,
+            rho_update=rho,
+            delta_new=delta,
+            gamma=utils.get(data, "gamma", 3),
+            beta=utils.get(data, "beta", 10),
+            max_speakers=utils.get(data, "max_speakers", 20),
+            device=device,
+        )
 
     @property
-    def config(self) -> BasePipelineConfig:
-        raise NotImplementedError
+    def duration(self) -> float:
+        # Default duration is the one given by the segmentation model
+        if self._duration is None:
+            self._duration = self.segmentation.duration
+        return self._duration
 
-    def reset(self):
-        raise NotImplementedError
+    @property
+    def step(self) -> float:
+        return self._step
 
-    def set_timestamp_shift(self, shift: float):
-        raise NotImplementedError
+    @property
+    def latency(self) -> float:
+        return self._latency
 
-    def __call__(
-        self,
-        waveforms: Sequence[SlidingWindowFeature]
-    ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
-        raise NotImplementedError
+    @property
+    def sample_rate(self) -> int:
+        # Expected sample rate is given by the segmentation model
+        if self._sample_rate is None:
+            self._sample_rate = self.segmentation.sample_rate
+        return self._sample_rate
 
 
-class OnlineSpeakerDiarization(BasePipeline):
-    def __init__(self, config: Optional[PipelineConfig] = None):
-        self._config = PipelineConfig() if config is None else config
+class SpeakerDiarization(base.StreamingPipeline):
+    def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
+        self._config = SpeakerDiarizationConfig() if config is None else config
 
         msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
         assert self._config.step <= self._config.latency <= self._config.duration, msg
@@ -67,10 +162,18 @@ def __init__(self, config: Optional[PipelineConfig] = None):
 
     @staticmethod
     def get_config_class() -> type:
-        return PipelineConfig
+        return SpeakerDiarizationConfig
+
+    @staticmethod
+    def suggest_metric() -> BaseMetric:
+        return DiarizationErrorRate(collar=0, skip_overlap=False)
+
+    @staticmethod
+    def hyper_parameters() -> Sequence[base.HyperParameter]:
+        return [base.TauActive, base.RhoUpdate, base.DeltaNew]
 
     @property
-    def config(self) -> PipelineConfig:
+    def config(self) -> SpeakerDiarizationConfig:
         return self._config
 
     def set_timestamp_shift(self, shift: float):
diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py
new file mode 100644
index 00000000..def833b6
--- /dev/null
+++ b/src/diart/blocks/vad.py
@@ -0,0 +1,208 @@
+from typing import Any, Optional, Union, Sequence, Tuple
+
+import numpy as np
+import torch
+from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment
+from pyannote.metrics.base import BaseMetric
+from pyannote.metrics.detection import DetectionErrorRate
+from typing_extensions import Literal
+
+from .aggregation import DelayedAggregation
+from . import base
+from .segmentation import SpeakerSegmentation
+from .utils import Binarize
+from .. import models as m
+from .. import utils
+
+
+class VoiceActivityDetectionConfig(base.StreamingConfig):
+    def __init__(
+        self,
+        segmentation: Optional[m.SegmentationModel] = None,
+        duration: Optional[float] = None,
+        step: float = 0.5,
+        latency: Optional[Union[float, Literal["max", "min"]]] = None,
+        tau_active: float = 0.6,
+        device: Optional[torch.device] = None,
+        **kwargs,
+    ):
+        # Default segmentation model is pyannote/segmentation
+        self.segmentation = segmentation
+        if self.segmentation is None:
+            self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
+
+        self._duration = duration
+        self._step = step
+        self._sample_rate: Optional[int] = None
+
+        # Latency defaults to the step duration
+        self._latency = latency
+        if self._latency is None or self._latency == "min":
+            self._latency = self._step
+        elif self._latency == "max":
+            self._latency = self._duration
+
+        self.tau_active = tau_active
+        self.device = device
+        if self.device is None:
+            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    @property
+    def duration(self) -> float:
+        # Default duration is the one given by the segmentation model
+        if self._duration is None:
+            self._duration = self.segmentation.duration
+        return self._duration
+
+    @property
+    def step(self) -> float:
+        return self._step
+
+    @property
+    def latency(self) -> float:
+        return self._latency
+
+    @property
+    def sample_rate(self) -> int:
+        # Expected sample rate is given by the segmentation model
+        if self._sample_rate is None:
+            self._sample_rate = self.segmentation.sample_rate
+        return self._sample_rate
+
+    @staticmethod
+    def from_dict(data: Any) -> 'VoiceActivityDetectionConfig':
+        # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
+        device = utils.get(data, "device", None)
+        if device is None:
+            device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+
+        # Instantiate segmentation model
+        hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
+        segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
+        segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
+
+        # Tau active and its alias
+        tau = utils.get(data, "tau_active", None)
+        if tau is None:
+            tau = utils.get(data, "tau", 0.6)
+
+        return VoiceActivityDetectionConfig(
+            segmentation=segmentation,
+            duration=utils.get(data, "duration", None),
+            step=utils.get(data, "step", 0.5),
+            latency=utils.get(data, "latency", None),
+            tau_active=tau,
+            device=device,
+        )
+
+
+class VoiceActivityDetection(base.StreamingPipeline):
+    def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
+        self._config = VoiceActivityDetectionConfig() if config is None else config
+
+        msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
+        assert self._config.step <= self._config.latency <= self._config.duration, msg
+
+        self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device)
+        self.pred_aggregation = DelayedAggregation(
+            self._config.step,
+            self._config.latency,
+            strategy="hamming",
+            cropping_mode="loose",
+        )
+        self.audio_aggregation = DelayedAggregation(
+            self._config.step,
+            self._config.latency,
+            strategy="first",
+            cropping_mode="center",
+        )
+        self.binarize = Binarize(self._config.tau_active)
+
+        # Internal state, handle with care
+        self.timestamp_shift = 0
+        self.chunk_buffer, self.pred_buffer = [], []
+
+    @staticmethod
+    def get_config_class() -> type:
+        return VoiceActivityDetectionConfig
+
+    @staticmethod
+    def suggest_metric() -> BaseMetric:
+        return DetectionErrorRate(collar=0, skip_overlap=False)
+
+    @staticmethod
+    def hyper_parameters() -> Sequence[base.HyperParameter]:
+        return [base.TauActive]
+
+    @property
+    def config(self) -> base.StreamingConfig:
+        return self._config
+
+    def reset(self):
+        self.set_timestamp_shift(0)
+        self.chunk_buffer, self.pred_buffer = [], []
+
+    def set_timestamp_shift(self, shift: float):
+        self.timestamp_shift = shift
+
+    def __call__(
+        self,
+        waveforms: Sequence[SlidingWindowFeature],
+    ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
+        batch_size = len(waveforms)
+        msg = "Pipeline expected at least 1 input"
+        assert batch_size >= 1, msg
+
+        # Create batch from chunk sequence, shape (batch, samples, channels)
+        batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
+
+        expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
+        msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
+        assert batch.shape[1] == expected_num_samples, msg
+
+        # Extract segmentation
+        segmentations = self.segmentation(batch)  # shape (batch, frames, speakers)
+        voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[0]  # shape (batch, frames, 1)
+
+        seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
+
+        outputs = []
+        for wav, vad in zip(waveforms, voice_detection):
+            # Add timestamps to segmentation
+            sw = SlidingWindow(
+                start=wav.extent.start,
+                duration=seg_resolution,
+                step=seg_resolution,
+            )
+            vad = SlidingWindowFeature(vad.cpu().numpy(), sw)
+
+            # Update sliding buffer
+            self.chunk_buffer.append(wav)
+            self.pred_buffer.append(vad)
+
+            # Aggregate buffer outputs for this time step
+            agg_waveform = self.audio_aggregation(self.chunk_buffer)
+            agg_prediction = self.pred_aggregation(self.pred_buffer)
+            agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False)
+
+            # Shift prediction timestamps if required
+            if self.timestamp_shift != 0:
+                shifted_agg_prediction = Timeline(uri=agg_prediction.uri)
+                for segment in agg_prediction:
+                    new_segment = Segment(
+                        segment.start + self.timestamp_shift,
+                        segment.end + self.timestamp_shift,
+                    )
+                    shifted_agg_prediction.add(new_segment)
+                agg_prediction = shifted_agg_prediction
+
+            # Convert timeline into annotation with single speaker "speech"
+            agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech"))
+            outputs.append((agg_prediction, agg_waveform))
+
+            # Make place for new chunks in buffer if required
+            if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows:
+                self.chunk_buffer = self.chunk_buffer[1:]
+                self.pred_buffer = self.pred_buffer[1:]
+
+        return outputs
diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py
index b6a3f9ff..27d524c5 100644
--- a/src/diart/console/benchmark.py
+++ b/src/diart/console/benchmark.py
@@ -1,15 +1,17 @@
 import argparse
 from pathlib import Path
 
-import diart.argdoc as argdoc
 import pandas as pd
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
+from diart import argdoc
+from diart import utils
 from diart.inference import Benchmark, Parallelize
 
 
 def run():
     parser = argparse.ArgumentParser()
     parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
+    parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+                        help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -34,6 +36,8 @@ def run():
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
     args = parser.parse_args()
 
+    pipeline_class = utils.get_pipeline_class(args.pipeline)
+
     benchmark = Benchmark(
         args.root,
         args.reference,
@@ -43,11 +47,11 @@ def run():
         batch_size=args.batch_size,
     )
 
-    config = PipelineConfig.from_dict(vars(args))
+    config = pipeline_class.get_config_class().from_dict(vars(args))
     if args.num_workers > 0:
         benchmark = Parallelize(benchmark, args.num_workers)
 
-    report = benchmark(OnlineSpeakerDiarization, config)
+    report = benchmark(pipeline_class, config)
 
     if args.output is not None and isinstance(report, pd.DataFrame):
         report.to_csv(args.output / "benchmark_report.csv")
diff --git a/src/diart/console/client.py b/src/diart/console/client.py
index 084dbc13..db4915fa 100644
--- a/src/diart/console/client.py
+++ b/src/diart/console/client.py
@@ -3,11 +3,11 @@
 from threading import Thread
 from typing import Text, Optional
 
-import diart.argdoc as argdoc
-import diart.sources as src
-import diart.utils as utils
 import numpy as np
 import rx.operators as ops
+from diart import argdoc
+from diart import sources as src
+from diart import utils
 from websocket import WebSocket
 
 
diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py
index 2f632d57..46bb9328 100644
--- a/src/diart/console/serve.py
+++ b/src/diart/console/serve.py
@@ -1,10 +1,10 @@
 import argparse
 from pathlib import Path
 
-import diart.argdoc as argdoc
-import diart.sources as src
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
-from diart.inference import RealTimeInference
+from diart import argdoc
+from diart import sources as src
+from diart import utils
+from diart.inference import StreamingInference
 from diart.sinks import RTTMWriter
 
 
@@ -12,6 +12,8 @@ def run():
     parser = argparse.ArgumentParser()
     parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host")
     parser.add_argument("--port", default=7007, type=int, help="Server port")
+    parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+                        help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -31,15 +33,16 @@ def run():
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
     args = parser.parse_args()
 
-    # Define online speaker diarization pipeline
-    config = PipelineConfig.from_dict(vars(args))
-    pipeline = OnlineSpeakerDiarization(config)
+    # Resolve pipeline
+    pipeline_class = utils.get_pipeline_class(args.pipeline)
+    config = pipeline_class.get_config_class().from_dict(vars(args))
+    pipeline = pipeline_class(config)
 
     # Create websocket audio source
     audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port)
 
     # Run online inference
-    inference = RealTimeInference(
+    inference = StreamingInference(
         pipeline,
         audio_source,
         batch_size=1,
diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py
index d7218f07..e0c670c5 100644
--- a/src/diart/console/stream.py
+++ b/src/diart/console/stream.py
@@ -1,16 +1,18 @@
 import argparse
 from pathlib import Path
 
-import diart.argdoc as argdoc
-import diart.sources as src
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
-from diart.inference import RealTimeInference
+from diart import argdoc
+from diart import sources as src
+from diart import utils
+from diart.inference import StreamingInference
 from diart.sinks import RTTMWriter
 
 
 def run():
     parser = argparse.ArgumentParser()
     parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:<DEVICE_ID>'")
+    parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+                        help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -32,9 +34,10 @@ def run():
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
     args = parser.parse_args()
 
-    # Define online speaker diarization pipeline
-    config = PipelineConfig.from_dict(vars(args))
-    pipeline = OnlineSpeakerDiarization(config)
+    # Resolve pipeline
+    pipeline_class = utils.get_pipeline_class(args.pipeline)
+    config = pipeline_class.get_config_class().from_dict(vars(args))
+    pipeline = pipeline_class(config)
 
     # Manage audio source
     block_size = config.optimal_block_size()
@@ -51,7 +54,7 @@ def run():
         audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device)
 
     # Run online inference
-    inference = RealTimeInference(
+    inference = StreamingInference(
         pipeline,
         audio_source,
         batch_size=1,
diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py
index 4ad8852a..a1f1b63a 100644
--- a/src/diart/console/tune.py
+++ b/src/diart/console/tune.py
@@ -1,10 +1,11 @@
 import argparse
 from pathlib import Path
 
-import diart.argdoc as argdoc
 import optuna
-from diart.blocks import PipelineConfig, OnlineSpeakerDiarization
-from diart.optim import Optimizer, HyperParameter
+from diart import argdoc
+from diart import utils
+from diart.blocks.base import HyperParameter
+from diart.optim import Optimizer
 from optuna.samplers import TPESampler
 
 
@@ -13,6 +14,8 @@ def run():
     parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
     parser.add_argument("--reference", required=True, type=str,
                         help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
+    parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+                        help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -38,17 +41,28 @@ def run():
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
     args = parser.parse_args()
 
+    # Retrieve pipeline class
+    pipeline_class = utils.get_pipeline_class(args.pipeline)
+
     # Create the base configuration for each trial
-    base_config = PipelineConfig.from_dict(vars(args))
+    base_config = pipeline_class.get_config_class().from_dict(vars(args))
 
     # Create hyper-parameters to optimize
+    possible_hparams = pipeline_class.hyper_parameters()
     hparams = [HyperParameter.from_name(name) for name in args.hparams]
+    hparams = [hp for hp in hparams if hp in possible_hparams]
+    if not hparams:
+        print(
+            f"No hyper-parameters to optimize. "
+            f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}"
+        )
+        exit(1)
 
     # Use a custom storage if given
     if args.output is not None:
         msg = "Both `output` and `storage` were set, but only one was expected"
         assert args.storage is None, msg
-        args.output = Path(args.output)
+        args.output = Path(args.output).expanduser()
         args.output.mkdir(parents=True, exist_ok=True)
         study_or_path = args.output
     elif args.storage is not None:
@@ -60,11 +74,11 @@ def run():
 
     # Run optimization
     Optimizer(
+        pipeline_class=pipeline_class,
         speech_path=args.root,
         reference_path=args.reference,
         study_or_path=study_or_path,
         batch_size=args.batch_size,
-        pipeline_class=OnlineSpeakerDiarization,
         hparams=hparams,
         base_config=base_config,
     )(num_iter=args.num_iter, show_progress=True)
diff --git a/src/diart/inference.py b/src/diart/inference.py
index f4b65f5f..6afda89e 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -4,32 +4,33 @@
 from traceback import print_exc
 from typing import Union, Text, Optional, Callable, Tuple, List
 
-import diart.operators as dops
-import diart.sources as src
 import numpy as np
 import pandas as pd
 import rx
 import rx.operators as ops
 import torch
-from diart import utils
-from diart.blocks import BasePipeline, Resample, BasePipelineConfig
-from diart.progress import ProgressBar, RichProgressBar, TQDMProgressBar
-from diart.sinks import DiarizationPredictionAccumulator, RealTimePlot, WindowClosedException
 from pyannote.core import Annotation, SlidingWindowFeature
 from pyannote.database.util import load_rttm
-from pyannote.metrics.diarization import DiarizationErrorRate
+from pyannote.metrics.base import BaseMetric
 from rx.core import Observer
 from tqdm import tqdm
 
+from . import blocks
+from . import operators as dops
+from . import sources as src
+from . import utils
+from .progress import ProgressBar, RichProgressBar, TQDMProgressBar
+from .sinks import PredictionAccumulator, StreamingPlot, WindowClosedException
 
-class RealTimeInference:
+
+class StreamingInference:
     """Performs inference in real time given a pipeline and an audio source.
     Streams an audio source to an online speaker diarization pipeline.
     It allows users to attach a chain of operations in the form of hooks.
 
     Parameters
     ----------
-    pipeline: BasePipeline
+    pipeline: StreamingPipeline
         Configured speaker diarization pipeline.
     source: AudioSource
         Audio source to be read and streamed.
@@ -52,7 +53,7 @@ class RealTimeInference:
     """
     def __init__(
         self,
-        pipeline: BasePipeline,
+        pipeline: blocks.StreamingPipeline,
         source: src.AudioSource,
         batch_size: int = 1,
         do_profile: bool = True,
@@ -66,7 +67,7 @@ def __init__(
         self.do_profile = do_profile
         self.do_plot = do_plot
         self.show_progress = show_progress
-        self.accumulator = DiarizationPredictionAccumulator(self.source.uri)
+        self.accumulator = PredictionAccumulator(self.source.uri)
         self.unit = "chunk" if self.batch_size == 1 else "batch"
         self._observers = []
 
@@ -102,7 +103,7 @@ def __init__(
                   f"but pipeline's is {sample_rate}. Will resample."
             logging.warning(msg)
             self.stream = self.stream.pipe(
-                ops.map(Resample(self.source.sample_rate, sample_rate))
+                ops.map(blocks.Resample(self.source.sample_rate, sample_rate))
             )
 
         # Add rx operators to manage the inputs and outputs of the pipeline
@@ -202,7 +203,7 @@ def __call__(self) -> Annotation:
                     latency=config.latency,
                     sample_rate=config.sample_rate,
                 ),
-                ops.do(RealTimePlot(config.duration, config.latency)),
+                ops.do(StreamingPlot(config.duration, config.latency)),
             )
         observable.subscribe(
             on_error=self._handle_error,
@@ -288,7 +289,7 @@ def get_file_paths(self) -> List[Path]:
 
     def run_single(
         self,
-        pipeline: BasePipeline,
+        pipeline: blocks.StreamingPipeline,
         filepath: Path,
         progress_bar: ProgressBar,
     ) -> Annotation:
@@ -298,7 +299,7 @@ def run_single(
 
         Parameters
         ----------
-        pipeline: BasePipeline
+        pipeline: StreamingPipeline
             Speaker diarization pipeline to run.
         filepath: Path
             Path to the target file.
@@ -318,7 +319,7 @@ def run_single(
             pipeline.config.optimal_block_size(),
         )
         pipeline.set_timestamp_shift(-padding[0])
-        inference = RealTimeInference(
+        inference = StreamingInference(
             pipeline,
             source,
             self.batch_size,
@@ -337,7 +338,11 @@ def run_single(
 
         return pred
 
-    def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[Annotation]]:
+    def evaluate(
+        self,
+        predictions: List[Annotation],
+        metric: BaseMetric,
+    ) -> Union[pd.DataFrame, List[Annotation]]:
         """If a reference path was provided,
         compute the diarization error rate of a list of predictions.
 
@@ -345,6 +350,8 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An
         ----------
         predictions: List[Annotation]
             Predictions to evaluate.
+        metric: BaseMetric
+            Evaluation metric from pyannote.metrics.
 
         Returns
         -------
@@ -353,8 +360,7 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An
             reference path was given. Otherwise return the same predictions.
         """
         if self.reference_path is not None:
-            metric = DiarizationErrorRate(collar=0, skip_overlap=False)
-            progress_bar = TQDMProgressBar("Computing DER", leave=False)
+            progress_bar = TQDMProgressBar(f"Computing {metric.name}", leave=False)
             progress_bar.create(total=len(predictions), unit="file")
             progress_bar.start()
             for hyp in predictions:
@@ -368,18 +374,22 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An
     def __call__(
         self,
         pipeline_class: type,
-        config: BasePipelineConfig,
+        config: blocks.StreamingConfig,
+        metric: Optional[BaseMetric] = None,
     ) -> Union[pd.DataFrame, List[Annotation]]:
         """Run a given pipeline on a set of audio files.
-        Notice that the internal state of the pipeline is reset before benchmarking.
+        The internal state of the pipeline is reset before benchmarking.
 
         Parameters
         ----------
         pipeline_class: class
-            Class from the BasePipeline hierarchy.
+            Class from the StreamingPipeline hierarchy.
             A pipeline from this class will be instantiated by each worker.
-        config: BasePipelineConfig
-            Diarization pipeline configuration.
+        config: StreamingConfig
+            Streaming pipeline configuration.
+        metric: Optional[BaseMetric]
+            Evaluation metric from pyannote.metrics.
+            Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
 
         Returns
         -------
@@ -400,7 +410,8 @@ def __call__(
             progress = TQDMProgressBar(desc, leave=False, do_close=True)
             predictions.append(self.run_single(pipeline, filepath, progress))
 
-        return self.evaluate(predictions)
+        metric = pipeline.suggest_metric() if metric is None else metric
+        return self.evaluate(predictions, metric)
 
 
 class Parallelize:
@@ -426,20 +437,20 @@ def __init__(
     def run_single_job(
         self,
         pipeline_class: type,
-        config: BasePipelineConfig,
+        config: blocks.StreamingConfig,
         filepath: Path,
         description: Text,
-    ):
+    ) -> Annotation:
         """Build and run a pipeline on a single file.
         Configure execution to show progress alongside parallel runs.
 
         Parameters
         ----------
         pipeline_class: class
-            Class from the BasePipeline hierarchy.
+            Class from the StreamingPipeline hierarchy.
             A pipeline from this class will be instantiated.
-        config: BasePipelineConfig
-            Diarization pipeline configuration.
+        config: StreamingConfig
+            Streaming pipeline configuration.
         filepath: Path
             Path to the target file.
         description: Text
@@ -463,7 +474,8 @@ def run_single_job(
     def __call__(
         self,
         pipeline_class: type,
-        config: BasePipelineConfig,
+        config: blocks.StreamingConfig,
+        metric: Optional[BaseMetric] = None,
     ) -> Union[pd.DataFrame, List[Annotation]]:
         """Run a given pipeline on a set of audio files in parallel.
         Each worker will build and run the pipeline on a different file.
@@ -471,10 +483,13 @@ def __call__(
         Parameters
         ----------
         pipeline_class: class
-            Class from the BasePipeline hierarchy.
+            Class from the StreamingPipeline hierarchy.
             A pipeline from this class will be instantiated by each worker.
-        config: BasePipelineConfig
-            Diarization pipeline configuration.
+        config: StreamingConfig
+            Streaming pipeline configuration.
+        metric: Optional[BaseMetric]
+            Evaluation metric from pyannote.metrics.
+            Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
 
         Returns
         -------
@@ -512,4 +527,5 @@ def __call__(
         predictions = [job.get() for job in jobs]
 
         # Evaluate results
-        return self.benchmark.evaluate(predictions)
+        metric = pipeline_class.suggest_metric() if metric is None else metric
+        return self.benchmark.evaluate(predictions, metric)
diff --git a/src/diart/optim.py b/src/diart/optim.py
index 05800a05..f7a96a6e 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -1,51 +1,32 @@
 from collections import OrderedDict
-from dataclasses import dataclass
 from pathlib import Path
 from typing import Sequence, Text, Optional, Union
 
 from optuna import TrialPruned, Study, create_study
 from optuna.samplers import TPESampler
 from optuna.trial import Trial, FrozenTrial
+from pyannote.metrics.base import BaseMetric
 from tqdm import trange, tqdm
+from typing_extensions import Literal
 
+from . import blocks
 from .audio import FilePath
-from .blocks import BasePipelineConfig, PipelineConfig, OnlineSpeakerDiarization
 from .inference import Benchmark
 
 
-@dataclass
-class HyperParameter:
-    name: Text
-    low: float
-    high: float
-
-    @staticmethod
-    def from_name(name: Text) -> 'HyperParameter':
-        if name == "tau_active":
-            return TauActive
-        if name == "rho_update":
-            return RhoUpdate
-        if name == "delta_new":
-            return DeltaNew
-        raise ValueError(f"Hyper-parameter '{name}' not recognized")
-
-
-TauActive = HyperParameter("tau_active", low=0, high=1)
-RhoUpdate = HyperParameter("rho_update", low=0, high=1)
-DeltaNew = HyperParameter("delta_new", low=0, high=2)
-
-
 class Optimizer:
     def __init__(
         self,
+        pipeline_class: type,
         speech_path: Union[Text, Path],
         reference_path: Union[Text, Path],
         study_or_path: Union[FilePath, Study],
         batch_size: int = 32,
-        pipeline_class: type = OnlineSpeakerDiarization,
-        hparams: Optional[Sequence[HyperParameter]] = None,
-        base_config: Optional[BasePipelineConfig] = None,
+        hparams: Optional[Sequence[blocks.base.HyperParameter]] = None,
+        base_config: Optional[blocks.StreamingConfig] = None,
         do_kickstart_hparams: bool = True,
+        metric: Optional[BaseMetric] = None,
+        direction: Literal["minimize", "maximize"] = "minimize",
     ):
         self.pipeline_class = pipeline_class
         # FIXME can we run this benchmark in parallel?
@@ -58,15 +39,17 @@ def __init__(
             batch_size=batch_size,
         )
 
+        self.metric = metric
+        self.direction = direction
         self.base_config = base_config
         self.do_kickstart_hparams = do_kickstart_hparams
         if self.base_config is None:
-            self.base_config = PipelineConfig()
+            self.base_config = self.pipeline_class.get_config_class()()
             self.do_kickstart_hparams = False
 
         self.hparams = hparams
         if self.hparams is None:
-            self.hparams = [TauActive, RhoUpdate, DeltaNew]
+            self.hparams = self.pipeline_class.hyper_parameters()
 
         # Make sure hyper-parameters exist in the configuration class given
         possible_hparams = vars(self.base_config)
@@ -85,7 +68,7 @@ def __init__(
                 storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"),
                 sampler=TPESampler(),
                 study_name=study_or_path.stem,
-                direction="minimize",
+                direction=self.direction,
                 load_if_exists=True,
             )
         else:
@@ -105,7 +88,7 @@ def _callback(self, study: Study, trial: FrozenTrial):
             return
         self._progress.update(1)
         self._progress.set_description(f"Trial {trial.number + 1}")
-        values = {"best_der": study.best_value}
+        values = {"best_perf": study.best_value}
         for name, value in study.best_params.items():
             values[f"best_{name}"] = value
         self._progress.set_postfix(OrderedDict(values))
@@ -125,11 +108,16 @@ def objective(self, trial: Trial) -> float:
         # Instantiate the new configuration for the trial
         config = self.base_config.__class__(**trial_config)
 
+        # Determine the evaluation metric
+        metric = self.metric
+        if metric is None:
+            metric = self.pipeline_class.suggest_metric()
+
         # Run pipeline over the dataset
-        report = self.benchmark(self.pipeline_class, config)
+        report = self.benchmark(self.pipeline_class, config, metric)
 
-        # Extract DER from report
-        return report.loc["TOTAL", "diarization error rate"]["%"]
+        # Extract target metric from report
+        return report.loc["TOTAL", metric.name]["%"]
 
     def __call__(self, num_iter: int, show_progress: bool = True):
         self._progress = None
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index cf480bed..63c170d0 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -8,12 +8,14 @@
 from rx.core import Observer
 from typing_extensions import Literal
 
+from . import utils
+
 
 class WindowClosedException(Exception):
     pass
 
 
-def _extract_annotation(value: Union[Tuple, Annotation]) -> Annotation:
+def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation:
     if isinstance(value, tuple):
         return value[0]
     if isinstance(value, Annotation):
@@ -43,10 +45,11 @@ def patch(self):
                 annotation.support(self.patch_collar).write_rttm(file)
 
     def on_next(self, value: Union[Tuple, Annotation]):
-        annotation = _extract_annotation(value)
-        annotation.uri = self.uri
+        prediction = _extract_prediction(value)
+        # Write prediction in RTTM format
+        prediction.uri = self.uri
         with open(self.path, 'a') as file:
-            annotation.write_rttm(file)
+            prediction.write_rttm(file)
 
     def on_error(self, error: Exception):
         self.patch()
@@ -55,30 +58,30 @@ def on_completed(self):
         self.patch()
 
 
-class DiarizationPredictionAccumulator(Observer):
+class PredictionAccumulator(Observer):
     def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05):
         super().__init__()
         self.uri = uri
         self.patch_collar = patch_collar
-        self._annotation = None
+        self._prediction: Optional[Annotation] = None
 
     def patch(self):
         """Stitch same-speaker turns that are close to each other"""
-        if self._annotation is not None:
-            self._annotation = self._annotation.support(self.patch_collar)
+        if self._prediction is not None:
+            self._prediction = self._prediction.support(self.patch_collar)
 
     def get_prediction(self) -> Annotation:
         # Patch again in case this is called before on_completed
         self.patch()
-        return self._annotation
+        return self._prediction
 
     def on_next(self, value: Union[Tuple, Annotation]):
-        annotation = _extract_annotation(value)
-        annotation.uri = self.uri
-        if self._annotation is None:
-            self._annotation = annotation
+        prediction = _extract_prediction(value)
+        prediction.uri = self.uri
+        if self._prediction is None:
+            self._prediction = prediction
         else:
-            self._annotation.update(annotation)
+            self._prediction.update(prediction)
 
     def on_error(self, error: Exception):
         self.patch()
@@ -87,7 +90,7 @@ def on_completed(self):
         self.patch()
 
 
-class RealTimePlot(Observer):
+class StreamingPlot(Observer):
     def __init__(
         self,
         duration: float,
@@ -134,11 +137,15 @@ def get_plot_bounds(self, real_time: float) -> Segment:
             start_time = max(0., end_time - self.window_duration)
         return Segment(start_time, end_time)
 
-    def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
+    def on_next(
+        self,
+        values: Tuple[Annotation, SlidingWindowFeature, float]
+    ):
         if self.window_closed:
             raise WindowClosedException
 
         prediction, waveform, real_time = values
+
         # Initialize figure if first call
         if self.figure is None:
             self._init_figure()
@@ -147,15 +154,21 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
         # Set plot bounds
         notebook.crop = self.get_plot_bounds(real_time)
 
-        # Plot current values
+        # Align prediction and reference if possible
         if self.reference is not None:
             metric = DiarizationErrorRate()
             mapping = metric.optimal_mapping(self.reference, prediction)
             prediction.rename_labels(mapping=mapping, copy=False)
+
+        # Plot prediction
         notebook.plot_annotation(prediction, self.axs[0])
         self.axs[0].set_title("Output")
+
+        # Plot waveform
         notebook.plot_feature(waveform, self.axs[1])
         self.axs[1].set_title("Audio")
+
+        # Plot reference if available
         if self.num_axs == 3:
             notebook.plot_annotation(self.reference, self.axs[2])
             self.axs[2].set_title("Reference")
diff --git a/src/diart/sources.py b/src/diart/sources.py
index 0f5dedf7..b34d5cf3 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -5,12 +5,12 @@
 import numpy as np
 import sounddevice as sd
 import torch
-from diart import utils
 from einops import rearrange
 from rx.subject import Subject
 from torchaudio.io import StreamReader
 from websocket_server import WebsocketServer
 
+from . import utils
 from .audio import FilePath, AudioLoader
 
 
diff --git a/src/diart/utils.py b/src/diart/utils.py
index e90861c7..e825ef29 100644
--- a/src/diart/utils.py
+++ b/src/diart/utils.py
@@ -4,9 +4,11 @@
 
 import matplotlib.pyplot as plt
 import numpy as np
-from diart.progress import ProgressBar
 from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
 
+from .progress import ProgressBar
+from . import blocks
+
 
 class Chronometer:
     def __init__(self, unit: Text, progress_bar: Optional[ProgressBar] = None):
@@ -74,6 +76,18 @@ def get_padding_left(stream_duration: float, chunk_duration: float) -> float:
     return 0
 
 
+def repeat_label(label: Text):
+    while True:
+        yield label
+
+
+def get_pipeline_class(class_name: Text) -> type:
+    pipeline_class = getattr(blocks, class_name, None)
+    msg = f"Pipeline '{class_name}' doesn't exist"
+    assert pipeline_class is not None, msg
+    return pipeline_class
+
+
 def get_padding_right(latency: float, step: float) -> float:
     return latency - step
 

From 74470617b482d2d546a83bff5a0996d90d0df079 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Wed, 19 Apr 2023 17:43:51 +0200
Subject: [PATCH 02/23] Update link in setup.cfg

---
 setup.cfg | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index 594c876e..e67e4426 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -2,11 +2,11 @@
 name=diart
 version=0.7.0
 author=Juan Manuel Coria
-description=Speaker diarization in real time
+description=Streaming speaker diarization in real-time
 long_description=file: README.md
 long_description_content_type=text/markdown
 keywords=speaker diarization, streaming, online, real time, rxpy
-url=https://github.com/juanmc2005/StreamingSpeakerDiarization
+url=https://github.com/juanmc2005/diart
 license=MIT
 classifiers=
     Development Status :: 4 - Beta

From 498539438861a4d7472b75387e7b0cb4e6768dc3 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Wed, 19 Apr 2023 17:51:41 +0200
Subject: [PATCH 03/23] Update code snippets in README

---
 README.md | 42 +++++++++++++++++++++---------------------
 1 file changed, 21 insertions(+), 21 deletions(-)

diff --git a/README.md b/README.md
index ef533946..57ca293a 100644
--- a/README.md
+++ b/README.md
@@ -110,17 +110,17 @@ See `diart.stream -h` for more options.
 
 ### From python
 
-Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk:
+Use `StreamingInference` to run a pipeline on an audio source and write the results to disk:
 
 ```python
-from diart import OnlineSpeakerDiarization
+from diart import SpeakerDiarization
 from diart.sources import MicrophoneAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
 from diart.sinks import RTTMWriter
 
-pipeline = OnlineSpeakerDiarization()
+pipeline = SpeakerDiarization()
 mic = MicrophoneAudioSource(pipeline.config.sample_rate)
-inference = RealTimeInference(pipeline, mic, do_plot=True)
+inference = StreamingInference(pipeline, mic, do_plot=True)
 inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
 prediction = inference()
 ```
@@ -129,13 +129,13 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n
 
 ## 🤖 Custom models
 
-Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses):
+Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):
 
 ```python
-from diart import OnlineSpeakerDiarization, PipelineConfig
+from diart import SpeakerDiarization, SpeakerDiarizationConfig
 from diart.models import EmbeddingModel, SegmentationModel
 from diart.sources import MicrophoneAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
 
 
 def model_loader():
@@ -168,19 +168,19 @@ class MyEmbeddingModel(EmbeddingModel):
         return self.model(waveform, weights)
 
     
-config = PipelineConfig(
+config = SpeakerDiarizationConfig(
     segmentation=MySegmentationModel(),
     embedding=MyEmbeddingModel()
 )
-pipeline = OnlineSpeakerDiarization(config)
+pipeline = SpeakerDiarization(config)
 mic = MicrophoneAudioSource(config.sample_rate)
-inference = RealTimeInference(pipeline, mic)
+inference = StreamingInference(pipeline, mic)
 prediction = inference()
 ```
 
 ## 📈 Tune hyper-parameters
 
-Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.
+Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs.
 
 ### From the command line
 
@@ -281,7 +281,7 @@ diart.serve --host 0.0.0.0 --port 7007
 diart.client microphone --host <server-address> --port 7007
 ```
 
-**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
+**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
 
 See `-h` for more options.
 
@@ -290,13 +290,13 @@ See `-h` for more options.
 For customized solutions, a server can also be created in python using the `WebSocketAudioSource`:
 
 ```python
-from diart import OnlineSpeakerDiarization
+from diart import SpeakerDiarization
 from diart.sources import WebSocketAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
 
-pipeline = OnlineSpeakerDiarization()
+pipeline = SpeakerDiarization()
 source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
-inference = RealTimeInference(pipeline, source)
+inference = StreamingInference(pipeline, source)
 inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
 prediction = inference()
 ```
@@ -354,14 +354,14 @@ or using the inference API:
 
 ```python
 from diart.inference import Benchmark, Parallelize
-from diart import OnlineSpeakerDiarization, PipelineConfig
+from diart import SpeakerDiarization, SpeakerDiarizationConfig
 from diart.models import SegmentationModel
 
 benchmark = Benchmark("/wav/dir", "/rttm/dir")
 
 name = "pyannote/segmentation@Interspeech2021"
 segmentation = SegmentationModel.from_pyannote(name)
-config = PipelineConfig(
+config = SpeakerDiarizationConfig(
     # Set the model used in the paper
     segmentation=segmentation,
     step=0.5,
@@ -370,12 +370,12 @@ config = PipelineConfig(
     rho_update=0.422,
     delta_new=1.517
 )
-benchmark(OnlineSpeakerDiarization, config)
+benchmark(SpeakerDiarization, config)
 
 # Run the same benchmark in parallel
 p_benchmark = Parallelize(benchmark, num_workers=4)
 if __name__ == "__main__":  # Needed for multiprocessing
-    p_benchmark(OnlineSpeakerDiarization, config)
+    p_benchmark(SpeakerDiarization, config)
 ```
 
 This pre-calculates model outputs in batches, so it runs a lot faster.

From 540ad0a97be45faaab2936b85cedbd18c9e456cc Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Wed, 19 Apr 2023 21:18:36 +0200
Subject: [PATCH 04/23] Add minor README modifications

---
 README.md | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/README.md b/README.md
index 57ca293a..ae13059f 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,7 @@
     </a>
     <span> | </span>
     <a href="#-custom-models">
-      🤖 Custom models
+      🤖 Add your model
     </a>
     <span> | </span>
     <a href="#-tune-hyper-parameters">
@@ -127,7 +127,7 @@ prediction = inference()
 
 For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).
 
-## 🤖 Custom models
+## 🤖 Add your model
 
 Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):
 

From 8cc9925455a73f20d231560abf9833b17258db96 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Fri, 21 Apr 2023 12:23:02 +0200
Subject: [PATCH 05/23] Initial ASR implementation. Broken stuff

---
 src/diart/blocks/asr.py         | 218 ++++++++++++++++++++++++++++++++
 src/diart/blocks/base.py        |   3 +-
 src/diart/blocks/diarization.py |   3 +-
 src/diart/inference.py          |  39 +++---
 src/diart/models.py             | 136 +++++++++++++++++++-
 src/diart/sinks.py              |   4 +-
 6 files changed, 381 insertions(+), 22 deletions(-)
 create mode 100644 src/diart/blocks/asr.py

diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py
new file mode 100644
index 00000000..641096dd
--- /dev/null
+++ b/src/diart/blocks/asr.py
@@ -0,0 +1,218 @@
+import math
+from pathlib import Path
+from typing import Sequence, Optional, Any, Union, List, Text, Tuple, Dict, Hashable
+
+import numpy as np
+import torch
+from einops import rearrange
+from pyannote.core import SlidingWindowFeature, Annotation, Segment
+from pyannote.metrics.base import BaseMetric
+
+from . import base
+from .. import models as m
+from .. import utils
+from ..blocks.base import HyperParameter
+from ..features import TemporalFeatureFormatter, TemporalFeatures
+
+
+BeamSize = HyperParameter("beam_size", low=1, high=20)
+
+
+class SpeechRecognition:
+    def __init__(self, model: m.SpeechRecognitionModel, device: Optional[torch.device] = None):
+        self.model = model
+        self.model.eval()
+        self.device = device
+        if self.device is None:
+            self.device = torch.device("cpu")
+        self.model.to(self.device)
+        self.formatter = TemporalFeatureFormatter()
+
+    @staticmethod
+    def from_whisper(
+        name: Text,
+        download_path: Optional[Union[Text, Path]] = None,
+        in_memory: bool = False,
+        remember_transcriptions: bool = True,
+        device: Optional[Union[Text, torch.device]] = None,
+    ) -> 'SpeechRecognition':
+        asr_model = m.SpeechRecognitionModel.from_whisper(
+            name, download_path, in_memory, remember_transcriptions
+        )
+        return SpeechRecognition(asr_model, device)
+
+    def __call__(self, waveform: TemporalFeatures) -> List[m.Transcription]:
+        """
+        Compute the transcription of input audio.
+
+        Parameters
+        ----------
+        waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
+            Audio to transcribe
+
+        Returns
+        -------
+        transcriptions: List[Transcription]
+            A list of timestamped transcriptions
+        """
+        with torch.no_grad():
+            wave = rearrange(
+                self.formatter.cast(waveform),
+                "batch sample channel -> batch channel sample"
+            )
+            # output = self.model(wave.to(self.device)).cpu()
+            output = self.model(wave.to(self.device))
+        return output
+
+
+class TranscriptionConfig(base.StreamingConfig):
+    def __init__(
+        self,
+        asr: Optional[m.SpeechRecognitionModel] = None,
+        duration: Optional[float] = None,
+        language: Optional[Text] = None,
+        beam_size: int = 5,
+        device: Optional[torch.device] = None,
+    ):
+        self.device = device
+        if self.device is None:
+            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        # Default ASR model is Whisper small (244M parameters)
+        self.asr = asr
+        if self.asr is None:
+            self.asr = m.SpeechRecognitionModel.from_whisper("small")
+        self.asr.set_language(language)
+
+        self._duration = duration
+        self._sample_rate: Optional[int] = None
+
+        self.beam_size = beam_size
+
+    @property
+    def duration(self) -> float:
+        if self._duration is None:
+            self._duration = self.asr.duration
+        return self._duration
+
+    @property
+    def step(self) -> float:
+        return self.duration
+
+    @property
+    def latency(self) -> float:
+        return self.duration
+
+    @property
+    def sample_rate(self) -> int:
+        if self._sample_rate is None:
+            self._sample_rate = self.asr.sample_rate
+        return self._sample_rate
+
+    @staticmethod
+    def from_dict(data: Any) -> 'TranscriptionConfig':
+        # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
+        device = utils.get(data, "device", None)
+        if device is None:
+            device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+
+        name = utils.get(data, "whisper", "small")
+        asr = m.SpeechRecognitionModel.from_whisper(name)
+
+        return TranscriptionConfig(
+            asr=asr,
+            duration=utils.get(data, "duration", None),
+            language=utils.get(data, "language", None),
+            beam_size=utils.get(data, "beam_size", 5),
+            device=device,
+        )
+
+
+class Transcription(base.StreamingPipeline):
+    def __init__(
+        self,
+        config: Optional[TranscriptionConfig] = None,
+    ):
+        self._config = TranscriptionConfig() if config is None else config
+        self.asr = SpeechRecognition(self.config.asr, self.config.device)
+
+    @staticmethod
+    def get_config_class() -> type:
+        return TranscriptionConfig
+
+    @staticmethod
+    def suggest_metric() -> BaseMetric:
+        # TODO word error rate
+        raise NotImplementedError
+
+    @staticmethod
+    def hyper_parameters() -> Sequence[HyperParameter]:
+        return [BeamSize]
+
+    @property
+    def config(self) -> TranscriptionConfig:
+        return self._config
+
+    def reset(self):
+        # No internal state. Nothing to do
+        pass
+
+    def set_timestamp_shift(self, shift: float):
+        # No timestamped output. Nothing to do
+        pass
+
+    def __call__(
+        self,
+        waveforms: Sequence[SlidingWindowFeature],
+        diarization: Optional[Sequence[Annotation]] = None,
+        **kwargs
+    ) -> Sequence[Tuple[Text, SlidingWindowFeature]]:
+        # TODO implement batched inference
+        too_many_dia = diarization is not None and len(diarization) > 1
+        msg = "Batched inference is not yet supported for 'Transcription'. Please set batch size to 1"
+        if len(waveforms) > 1 or too_many_dia:
+            print(msg)
+            exit(1)
+
+        waveform = waveforms[0]
+
+        # Add fake batch dimension, shape (1, samples, channels)
+        batch = torch.from_numpy(waveform.data).unsqueeze(0)
+
+        expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
+        msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
+        assert batch.shape[1] == expected_num_samples, msg
+
+        # Transcribe batch
+        # TODO only transcribe if there's speech
+        output = self.asr(batch)[0]
+
+        if diarization is None:
+            return [(output.text, waveform)]
+
+        diarization = diarization[0]
+
+        # Align transcription with diarization to determine speakers
+        full_transcription = []
+        buffer_shift = waveform.sliding_window.start
+        for text, timestamp in zip(output.chunks, output.timestamps):
+            target_region = Segment(
+                buffer_shift + timestamp.start,
+                buffer_shift + timestamp.end
+            )
+            dia = diarization.crop(target_region)
+            speakers = dia.labels()
+            num_speakers = len(speakers)
+            if num_speakers == 0:
+                # Include transcription but don't assign a speaker
+                full_transcription.append(text)
+            elif num_speakers == 1:
+                # Typical case, annotate text with the only speaker
+                full_transcription.append(f"[{speakers[0]}]{text}")
+            else:
+                # Multiple speakers for the same text block, choose the most active one
+                # TODO match at the level of words?
+                max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
+                full_transcription.append(f"[{speakers[max_spk]}]{text}")
+
+        return [(" ".join(full_transcription), waveform)]
diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py
index 28f313eb..d1a372c1 100644
--- a/src/diart/blocks/base.py
+++ b/src/diart/blocks/base.py
@@ -87,6 +87,7 @@ def set_timestamp_shift(self, shift: float):
 
     def __call__(
         self,
-        waveforms: Sequence[SlidingWindowFeature]
+        waveforms: Sequence[SlidingWindowFeature],
+        **kwargs,
     ) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
         raise NotImplementedError
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
index f2a25119..03169077 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/blocks/diarization.py
@@ -192,7 +192,8 @@ def reset(self):
 
     def __call__(
         self,
-        waveforms: Sequence[SlidingWindowFeature]
+        waveforms: Sequence[SlidingWindowFeature],
+        **kwargs,
     ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
         batch_size = len(waveforms)
         msg = "Pipeline expected at least 1 input"
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 6afda89e..14ab4736 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -2,7 +2,7 @@
 from multiprocessing import Pool, freeze_support, RLock, current_process
 from pathlib import Path
 from traceback import print_exc
-from typing import Union, Text, Optional, Callable, Tuple, List
+from typing import Union, Text, Optional, Callable, Tuple, List, Any
 
 import numpy as np
 import pandas as pd
@@ -20,7 +20,7 @@
 from . import sources as src
 from . import utils
 from .progress import ProgressBar, RichProgressBar, TQDMProgressBar
-from .sinks import PredictionAccumulator, StreamingPlot, WindowClosedException
+from .sinks import DiarizationAccumulator, StreamingPlot, WindowClosedException
 
 
 class StreamingInference:
@@ -67,9 +67,9 @@ def __init__(
         self.do_profile = do_profile
         self.do_plot = do_plot
         self.show_progress = show_progress
-        self.accumulator = PredictionAccumulator(self.source.uri)
         self.unit = "chunk" if self.batch_size == 1 else "batch"
         self._observers = []
+        self._predictions = []
 
         chunk_duration = self.pipeline.config.duration
         step_duration = self.pipeline.config.step
@@ -123,7 +123,7 @@ def __init__(
 
         self.stream = self.stream.pipe(
             ops.flat_map(lambda results: rx.from_iterable(results)),
-            ops.do(self.accumulator),
+            ops.do_action(lambda pred_wav: self._predictions.append(pred_wav[0])),
         )
 
         if show_progress:
@@ -141,13 +141,13 @@ def _close_chronometer(self):
                 self._chrono.stop(do_count=False)
             self._chrono.report()
 
-    def attach_hooks(self, *hooks: Callable[[Tuple[Annotation, SlidingWindowFeature]], None]):
+    def attach_hooks(self, *hooks: Callable[[Tuple[Any, SlidingWindowFeature]], None]):
         """Attach hooks to the pipeline.
 
         Parameters
         ----------
-        *hooks: (Tuple[Annotation, SlidingWindowFeature]) -> None
-            Hook functions to consume emitted annotations and audio.
+        *hooks: (Tuple[Any, SlidingWindowFeature]) -> None
+            Hook functions to consume emitted predictions and audio.
         """
         self.stream = self.stream.pipe(*[ops.do_action(hook) for hook in hooks])
 
@@ -157,7 +157,7 @@ def attach_observers(self, *observers: Observer):
         Parameters
         ----------
         *observers: Observer
-            Observers to consume emitted annotations and audio.
+            Observers to consume emitted predictions and audio.
         """
         self.stream = self.stream.pipe(*[ops.do(sink) for sink in observers])
         self._observers.extend(observers)
@@ -182,13 +182,13 @@ def _handle_completion(self):
         self._close_pbar()
         self._close_chronometer()
 
-    def __call__(self) -> Annotation:
-        """Stream audio chunks from `source` to `pipeline`.
+    def __call__(self) -> List[Any]:
+        """Stream audio chunks from a source to a pipeline.
 
         Returns
         -------
-        predictions: Annotation
-            Speaker diarization pipeline predictions
+        predictions: List[Any]
+            Streaming pipeline predictions
         """
         if self.show_progress:
             self._pbar.start()
@@ -209,9 +209,9 @@ def __call__(self) -> Annotation:
             on_error=self._handle_error,
             on_completed=self._handle_completion,
         )
-        # FIXME if read() isn't blocking, the prediction returned is empty
+        # FIXME if read() isn't blocking, predictions are empty
         self.source.read()
-        return self.accumulator.get_prediction()
+        return self._predictions
 
 
 class Benchmark:
@@ -329,8 +329,15 @@ def run_single(
             progress_bar=progress_bar,
         )
 
-        pred = inference()
-        pred.uri = source.uri
+        # Accumulate predictions in memory
+        pred_accumulator = DiarizationAccumulator(source.uri)
+        inference.attach_observers(pred_accumulator)
+
+        # Run the pipeline on this file
+        inference()
+
+        # Extract prediction
+        pred = pred_accumulator.get_prediction()
 
         if self.output_path is not None:
             with open(self.output_path / f"{source.uri}.rttm", "w") as out_file:
diff --git a/src/diart/models.py b/src/diart/models.py
index df66e166..7a653837 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -1,7 +1,10 @@
-from typing import Optional, Text, Union, Callable
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Text, Union, Callable, List, Tuple, Dict
 
 import torch
 import torch.nn as nn
+from pyannote.core import Segment
 
 try:
     import pyannote.audio.pipelines.utils as pyannote_loader
@@ -9,6 +12,12 @@
 except ImportError:
     _has_pyannote = False
 
+try:
+    import whisper
+    _has_whisper = True
+except ImportError:
+    _has_whisper = False
+
 
 class PyannoteLoader:
     def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
@@ -20,6 +29,25 @@ def __call__(self) -> nn.Module:
         return pyannote_loader.get_model(self.model_info, self.hf_token)
 
 
+class WhisperLoader:
+    def __init__(
+        self,
+        name: Text,
+        download_path: Optional[Union[Text, Path]] = None,
+        in_memory: bool = False,
+    ):
+        self.name = name
+        self.download_path = download_path
+        self.in_memory = in_memory
+
+    def __call__(self) -> nn.Module:
+        return whisper.load_model(
+            name=self.name,
+            download_root=self.download_path,
+            in_memory=self.in_memory,
+        )
+
+
 class LazyModel(nn.Module):
     def __init__(self, loader: Callable[[], nn.Module]):
         super().__init__()
@@ -163,3 +191,109 @@ def forward(
         weights: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         return self.model(waveform, weights=weights)
+
+
+@dataclass(frozen=True)
+class Transcription:
+    text: Text
+    chunks: List[Text]
+    timestamps: List[Segment]
+
+
+class SpeechRecognitionModel(LazyModel):
+    @staticmethod
+    def from_whisper(
+        name: Text,
+        download_path: Optional[Union[Text, Path]] = None,
+        in_memory: bool = False,
+        remember_transcriptions: bool = True,
+    ) -> 'SpeechRecognitionModel':
+        msg = "No whisper-transcribed installation found. " \
+              "Visit https://github.com/linto-ai/whisper-timestamped#installation to install"
+        assert _has_whisper, msg
+        return WhisperSpeechRecognitionModel(
+            name, download_path, in_memory, remember_transcriptions
+        )
+
+    @property
+    def duration(self) -> float:
+        raise NotImplementedError
+
+    @property
+    def sample_rate(self) -> int:
+        raise NotImplementedError
+
+    def set_language(self, language: Optional[Text] = None):
+        raise NotImplementedError
+
+    def forward(self, waveform: torch.Tensor) -> List[Transcription]:
+        """
+        Forward pass of the speech recognition model.
+
+        Parameters
+        ----------
+        waveform: torch.Tensor, shape (batch, channels, samples)
+            Batch of audio chunks to transcribe
+
+        Returns
+        -------
+        transcriptions: List[Transcription]
+            A list of timestamped transcriptions
+        """
+        raise NotImplementedError
+
+
+class WhisperSpeechRecognitionModel(SpeechRecognitionModel):
+    def __init__(
+        self,
+        name: Text,
+        download_path: Optional[Union[Text, Path]] = None,
+        in_memory: bool = False,
+        remember_transcriptions: bool = True,
+    ):
+        super().__init__(WhisperLoader(name, download_path, in_memory))
+        self.remember_transcriptions = remember_transcriptions
+        self.language = None
+        self._cache = None
+
+    @property
+    def duration(self) -> float:
+        # Whisper's maximum duration per input is 30s
+        return whisper.audio.CHUNK_LENGTH
+
+    @property
+    def sample_rate(self) -> int:
+        return whisper.audio.SAMPLE_RATE
+
+    def set_language(self, language: Optional[Text] = None):
+        self.language = language
+
+    def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]:
+        results = []
+        for waveform in waveform_batch:
+            audio = whisper.pad_or_trim(waveform.type(torch.float32).reshape(-1))
+            transcription = whisper.transcribe(
+                self.model,
+                audio,
+                initial_prompt=self._cache,
+                verbose=None,
+                task="transcribe",
+                language=self.language,
+            )
+
+            # Extract chunks and timestamps
+            chunks, timestamps = [], []
+            for chunk in transcription["segments"]:
+                chunks.append(chunk["text"])
+                timestamps.append(Segment(chunk["start"], chunk["end"]))
+
+            # Create transcription object
+            transcription = Transcription(transcription["text"], chunks, timestamps)
+            results.append(transcription)
+
+            # Update transcription buffer
+            if self.remember_transcriptions:
+                # TODO handle overlapping transcriptions
+                self._cache = transcription.text
+
+        return results
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index 63c170d0..8d9217a1 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -8,8 +8,6 @@
 from rx.core import Observer
 from typing_extensions import Literal
 
-from . import utils
-
 
 class WindowClosedException(Exception):
     pass
@@ -58,7 +56,7 @@ def on_completed(self):
         self.patch()
 
 
-class PredictionAccumulator(Observer):
+class DiarizationAccumulator(Observer):
     def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05):
         super().__init__()
         self.uri = uri

From 1ae4934c11879c293defa9bc172df1e4f7b9763d Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Fri, 21 Apr 2023 16:52:41 +0200
Subject: [PATCH 06/23] First working transcription pipeline. Using diarization
 is possible but a bit quirky

---
 src/diart/blocks/asr.py         | 28 +++++-----
 src/diart/blocks/base.py        | 13 +++--
 src/diart/blocks/diarization.py | 23 ++++++--
 src/diart/blocks/vad.py         | 26 ++++++---
 src/diart/console/tune.py       |  9 ++--
 src/diart/inference.py          | 92 +++++++++++++++++--------------
 src/diart/metrics.py            | 96 +++++++++++++++++++++++++++++++++
 7 files changed, 214 insertions(+), 73 deletions(-)
 create mode 100644 src/diart/metrics.py

diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py
index 641096dd..0e034a3e 100644
--- a/src/diart/blocks/asr.py
+++ b/src/diart/blocks/asr.py
@@ -1,19 +1,17 @@
-import math
 from pathlib import Path
-from typing import Sequence, Optional, Any, Union, List, Text, Tuple, Dict, Hashable
+from typing import Sequence, Optional, Any, Union, List, Text, Tuple
 
 import numpy as np
 import torch
 from einops import rearrange
 from pyannote.core import SlidingWindowFeature, Annotation, Segment
-from pyannote.metrics.base import BaseMetric
 
 from . import base
 from .. import models as m
 from .. import utils
 from ..blocks.base import HyperParameter
 from ..features import TemporalFeatureFormatter, TemporalFeatures
-
+from ..metrics import Metric, WordErrorRate
 
 BeamSize = HyperParameter("beam_size", low=1, high=20)
 
@@ -141,9 +139,8 @@ def get_config_class() -> type:
         return TranscriptionConfig
 
     @staticmethod
-    def suggest_metric() -> BaseMetric:
-        # TODO word error rate
-        raise NotImplementedError
+    def suggest_metric() -> Metric:
+        return WordErrorRate()
 
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
@@ -161,6 +158,13 @@ def set_timestamp_shift(self, shift: float):
         # No timestamped output. Nothing to do
         pass
 
+    def join_predictions(self, predictions: List[Text]) -> Text:
+        return "\n".join(predictions)
+
+    def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]):
+        with open(Path(dir_path) / f"{uri}.txt", "w") as out_file:
+            out_file.write(prediction)
+
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
@@ -168,11 +172,9 @@ def __call__(
         **kwargs
     ) -> Sequence[Tuple[Text, SlidingWindowFeature]]:
         # TODO implement batched inference
-        too_many_dia = diarization is not None and len(diarization) > 1
+        only_one_dia = diarization is None or len(diarization) == 1
         msg = "Batched inference is not yet supported for 'Transcription'. Please set batch size to 1"
-        if len(waveforms) > 1 or too_many_dia:
-            print(msg)
-            exit(1)
+        assert len(waveforms) == 1 and only_one_dia, msg
 
         waveform = waveforms[0]
 
@@ -188,7 +190,7 @@ def __call__(
         output = self.asr(batch)[0]
 
         if diarization is None:
-            return [(output.text, waveform)]
+            return [(output.text.strip(), waveform)]
 
         diarization = diarization[0]
 
@@ -215,4 +217,4 @@ def __call__(
                 max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
                 full_transcription.append(f"[{speakers[max_spk]}]{text}")
 
-        return [(" ".join(full_transcription), waveform)]
+        return [(" ".join(full_transcription).strip(), waveform)]
diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py
index d1a372c1..40d1d22d 100644
--- a/src/diart/blocks/base.py
+++ b/src/diart/blocks/base.py
@@ -1,12 +1,13 @@
-from typing import Any, Tuple, Sequence, Text
 from dataclasses import dataclass
+from typing import Any, Tuple, Sequence, Text, List, Union
+from pathlib import Path
 
 import numpy as np
 from pyannote.core import SlidingWindowFeature
-from pyannote.metrics.base import BaseMetric
 
 from .. import utils
 from ..audio import FilePath, AudioLoader
+from ..metrics import Metric
 
 
 @dataclass
@@ -68,7 +69,7 @@ def get_config_class() -> type:
         raise NotImplementedError
 
     @staticmethod
-    def suggest_metric() -> BaseMetric:
+    def suggest_metric() -> Metric:
         raise NotImplementedError
 
     @staticmethod
@@ -85,6 +86,12 @@ def reset(self):
     def set_timestamp_shift(self, shift: float):
         raise NotImplementedError
 
+    def join_predictions(self, predictions: List[Any]) -> Any:
+        raise NotImplementedError
+
+    def write_prediction(self, uri: Text, prediction: Any, dir_path: Union[Text, Path]):
+        raise NotImplementedError
+
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
index 03169077..2f8de3f5 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/blocks/diarization.py
@@ -1,20 +1,20 @@
-from typing import Optional, Tuple, Sequence, Union, Any
+from pathlib import Path
+from typing import Optional, Tuple, Sequence, Union, Any, Text, List
 
 import numpy as np
 import torch
 from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
-from pyannote.metrics.base import BaseMetric
-from pyannote.metrics.diarization import DiarizationErrorRate
 from typing_extensions import Literal
 
-from .aggregation import DelayedAggregation
 from . import base
+from .aggregation import DelayedAggregation
 from .clustering import OnlineSpeakerClustering
 from .embedding import OverlapAwareSpeakerEmbedding
 from .segmentation import SpeakerSegmentation
 from .utils import Binarize
 from .. import models as m
 from .. import utils
+from ..metrics import Metric, DiarizationErrorRate
 
 
 class SpeakerDiarizationConfig(base.StreamingConfig):
@@ -31,6 +31,7 @@ def __init__(
         gamma: float = 3,
         beta: float = 10,
         max_speakers: int = 20,
+        merge_collar: float = 0.05,
         device: Optional[torch.device] = None,
         **kwargs,
     ):
@@ -61,6 +62,7 @@ def __init__(
         self.gamma = gamma
         self.beta = beta
         self.max_speakers = max_speakers
+        self.merge_collar = merge_collar
 
         self.device = device
         if self.device is None:
@@ -103,6 +105,7 @@ def from_dict(data: Any) -> 'SpeakerDiarizationConfig':
             gamma=utils.get(data, "gamma", 3),
             beta=utils.get(data, "beta", 10),
             max_speakers=utils.get(data, "max_speakers", 20),
+            merge_collar=utils.get(data, "merge_collar", 0.05),
             device=device,
         )
 
@@ -165,7 +168,7 @@ def get_config_class() -> type:
         return SpeakerDiarizationConfig
 
     @staticmethod
-    def suggest_metric() -> BaseMetric:
+    def suggest_metric() -> Metric:
         return DiarizationErrorRate(collar=0, skip_overlap=False)
 
     @staticmethod
@@ -179,6 +182,16 @@ def config(self) -> SpeakerDiarizationConfig:
     def set_timestamp_shift(self, shift: float):
         self.timestamp_shift = shift
 
+    def join_predictions(self, predictions: List[Annotation]) -> Annotation:
+        result = Annotation(uri=predictions[0].uri)
+        for pred in predictions:
+            result.update(pred)
+        return result.support(self.config.merge_collar)
+
+    def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Text, Path]):
+        with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file:
+            prediction.write_rttm(out_file)
+
     def reset(self):
         self.set_timestamp_shift(0)
         self.clustering = OnlineSpeakerClustering(
diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py
index def833b6..47afb4da 100644
--- a/src/diart/blocks/vad.py
+++ b/src/diart/blocks/vad.py
@@ -1,18 +1,18 @@
-from typing import Any, Optional, Union, Sequence, Tuple
+from pathlib import Path
+from typing import Any, Optional, Union, Sequence, Tuple, Text, List
 
 import numpy as np
 import torch
 from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment
-from pyannote.metrics.base import BaseMetric
-from pyannote.metrics.detection import DetectionErrorRate
 from typing_extensions import Literal
 
-from .aggregation import DelayedAggregation
 from . import base
+from .aggregation import DelayedAggregation
 from .segmentation import SpeakerSegmentation
 from .utils import Binarize
 from .. import models as m
 from .. import utils
+from ..metrics import Metric, DetectionErrorRate
 
 
 class VoiceActivityDetectionConfig(base.StreamingConfig):
@@ -23,6 +23,7 @@ def __init__(
         step: float = 0.5,
         latency: Optional[Union[float, Literal["max", "min"]]] = None,
         tau_active: float = 0.6,
+        merge_collar: float = 0.05,
         device: Optional[torch.device] = None,
         **kwargs,
     ):
@@ -43,6 +44,7 @@ def __init__(
             self._latency = self._duration
 
         self.tau_active = tau_active
+        self.merge_collar = merge_collar
         self.device = device
         if self.device is None:
             self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -92,6 +94,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig':
             step=utils.get(data, "step", 0.5),
             latency=utils.get(data, "latency", None),
             tau_active=tau,
+            merge_collar=utils.get(data, "merge_collar", 0.05),
             device=device,
         )
 
@@ -127,7 +130,7 @@ def get_config_class() -> type:
         return VoiceActivityDetectionConfig
 
     @staticmethod
-    def suggest_metric() -> BaseMetric:
+    def suggest_metric() -> Metric:
         return DetectionErrorRate(collar=0, skip_overlap=False)
 
     @staticmethod
@@ -135,7 +138,7 @@ def hyper_parameters() -> Sequence[base.HyperParameter]:
         return [base.TauActive]
 
     @property
-    def config(self) -> base.StreamingConfig:
+    def config(self) -> VoiceActivityDetectionConfig:
         return self._config
 
     def reset(self):
@@ -145,9 +148,20 @@ def reset(self):
     def set_timestamp_shift(self, shift: float):
         self.timestamp_shift = shift
 
+    def join_predictions(self, predictions: List[Annotation]) -> Annotation:
+        result = Annotation(uri=predictions[0].uri)
+        for pred in predictions:
+            result.update(pred)
+        return result.support(self.config.merge_collar)
+
+    def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Text, Path]):
+        with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file:
+            prediction.write_rttm(out_file)
+
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
+        **kwargs,
     ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
         batch_size = len(waveforms)
         msg = "Pipeline expected at least 1 input"
diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py
index a1f1b63a..6affda50 100644
--- a/src/diart/console/tune.py
+++ b/src/diart/console/tune.py
@@ -51,12 +51,9 @@ def run():
     possible_hparams = pipeline_class.hyper_parameters()
     hparams = [HyperParameter.from_name(name) for name in args.hparams]
     hparams = [hp for hp in hparams if hp in possible_hparams]
-    if not hparams:
-        print(
-            f"No hyper-parameters to optimize. "
-            f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}"
-        )
-        exit(1)
+    msg = f"No hyper-parameters to optimize. " \
+          f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}"
+    assert hparams, msg
 
     # Use a custom storage if given
     if args.output is not None:
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 14ab4736..ee22a3cb 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -2,7 +2,7 @@
 from multiprocessing import Pool, freeze_support, RLock, current_process
 from pathlib import Path
 from traceback import print_exc
-from typing import Union, Text, Optional, Callable, Tuple, List, Any
+from typing import Union, Text, Optional, Callable, Tuple, List, Any, Dict
 
 import numpy as np
 import pandas as pd
@@ -10,8 +10,6 @@
 import rx.operators as ops
 import torch
 from pyannote.core import Annotation, SlidingWindowFeature
-from pyannote.database.util import load_rttm
-from pyannote.metrics.base import BaseMetric
 from rx.core import Observer
 from tqdm import tqdm
 
@@ -19,8 +17,9 @@
 from . import operators as dops
 from . import sources as src
 from . import utils
+from .metrics import Metric
 from .progress import ProgressBar, RichProgressBar, TQDMProgressBar
-from .sinks import DiarizationAccumulator, StreamingPlot, WindowClosedException
+from .sinks import StreamingPlot, WindowClosedException
 
 
 class StreamingInference:
@@ -292,7 +291,7 @@ def run_single(
         pipeline: blocks.StreamingPipeline,
         filepath: Path,
         progress_bar: ProgressBar,
-    ) -> Annotation:
+    ) -> Tuple[Text, Any]:
         """Run a given pipeline on a given file.
         Note that this method does NOT reset the
         state of the pipeline before execution.
@@ -329,36 +328,29 @@ def run_single(
             progress_bar=progress_bar,
         )
 
-        # Accumulate predictions in memory
-        pred_accumulator = DiarizationAccumulator(source.uri)
-        inference.attach_observers(pred_accumulator)
-
-        # Run the pipeline on this file
-        inference()
-
-        # Extract prediction
-        pred = pred_accumulator.get_prediction()
+        # Run the pipeline and concatenate predictions
+        pred = pipeline.join_predictions(inference())
 
+        # Write prediction to disk if required
         if self.output_path is not None:
-            with open(self.output_path / f"{source.uri}.rttm", "w") as out_file:
-                pred.write_rttm(out_file)
+            pipeline.write_prediction(source.uri, pred, self.output_path)
 
-        return pred
+        return source.uri, pred
 
     def evaluate(
         self,
-        predictions: List[Annotation],
-        metric: BaseMetric,
-    ) -> Union[pd.DataFrame, List[Annotation]]:
+        predictions: Dict[Text, Any],
+        metric: Metric,
+    ) -> Union[pd.DataFrame, Dict[Text, Any]]:
         """If a reference path was provided,
         compute the diarization error rate of a list of predictions.
 
         Parameters
         ----------
-        predictions: List[Annotation]
+        predictions: List[Any]
             Predictions to evaluate.
-        metric: BaseMetric
-            Evaluation metric from pyannote.metrics.
+        metric: Metric
+            Evaluation metric.
 
         Returns
         -------
@@ -367,23 +359,37 @@ def evaluate(
             reference path was given. Otherwise return the same predictions.
         """
         if self.reference_path is not None:
+            # Initialize progress bar
             progress_bar = TQDMProgressBar(f"Computing {metric.name}", leave=False)
             progress_bar.create(total=len(predictions), unit="file")
             progress_bar.start()
-            for hyp in predictions:
-                ref = load_rttm(self.reference_path / f"{hyp.uri}.rttm").popitem()[1]
-                metric(ref, hyp)
+
+            # Evaluate each prediction
+            uris = []
+            for uri, pred in predictions.items():
+                ref_file = list(self.reference_path.glob(f"{uri}.*"))
+                if ref_file:
+                    ref = metric.load_reference(ref_file[0])
+                    metric(ref, pred)
+                    uris.append(uri)
+                else:
+                    msg = f"Reference file for {uri} not found. Skipping evaluation."
+                    logging.warning(msg)
                 progress_bar.update()
+
+            # Close progress bar safely
             progress_bar.close()
-            return metric.report(display=self.show_report)
+            # Return performance report
+            return metric.report(uris, self.show_report)
+
         return predictions
 
     def __call__(
         self,
         pipeline_class: type,
         config: blocks.StreamingConfig,
-        metric: Optional[BaseMetric] = None,
-    ) -> Union[pd.DataFrame, List[Annotation]]:
+        metric: Optional[Metric] = None,
+    ) -> Union[pd.DataFrame, Dict[Text, Any]]:
         """Run a given pipeline on a set of audio files.
         The internal state of the pipeline is reset before benchmarking.
 
@@ -394,8 +400,8 @@ def __call__(
             A pipeline from this class will be instantiated by each worker.
         config: StreamingConfig
             Streaming pipeline configuration.
-        metric: Optional[BaseMetric]
-            Evaluation metric from pyannote.metrics.
+        metric: Optional[Metric]
+            Evaluation metric.
             Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
 
         Returns
@@ -410,12 +416,13 @@ def __call__(
         num_audio_files = len(audio_file_paths)
         pipeline = pipeline_class(config)
 
-        predictions = []
+        predictions = {}
         for i, filepath in enumerate(audio_file_paths):
             pipeline.reset()
             desc = f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})"
             progress = TQDMProgressBar(desc, leave=False, do_close=True)
-            predictions.append(self.run_single(pipeline, filepath, progress))
+            uri, pred = self.run_single(pipeline, filepath, progress)
+            predictions[uri] = pred
 
         metric = pipeline.suggest_metric() if metric is None else metric
         return self.evaluate(predictions, metric)
@@ -447,7 +454,7 @@ def run_single_job(
         config: blocks.StreamingConfig,
         filepath: Path,
         description: Text,
-    ) -> Annotation:
+    ) -> Tuple[Text, Any]:
         """Build and run a pipeline on a single file.
         Configure execution to show progress alongside parallel runs.
 
@@ -482,8 +489,8 @@ def __call__(
         self,
         pipeline_class: type,
         config: blocks.StreamingConfig,
-        metric: Optional[BaseMetric] = None,
-    ) -> Union[pd.DataFrame, List[Annotation]]:
+        metric: Optional[Metric] = None,
+    ) -> Union[pd.DataFrame, Dict[Text, Any]]:
         """Run a given pipeline on a set of audio files in parallel.
         Each worker will build and run the pipeline on a different file.
 
@@ -494,8 +501,8 @@ def __call__(
             A pipeline from this class will be instantiated by each worker.
         config: StreamingConfig
             Streaming pipeline configuration.
-        metric: Optional[BaseMetric]
-            Evaluation metric from pyannote.metrics.
+        metric: Optional[Metric]
+            Evaluation metric.
             Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
 
         Returns
@@ -529,9 +536,14 @@ def __call__(
         # Submit all jobs
         jobs = [pool.apply_async(self.run_single_job, args=args) for args in arg_list]
 
-        # Wait and collect results
+        # Wait for all jobs to finish
         pool.close()
-        predictions = [job.get() for job in jobs]
+
+        # Collect results
+        predictions = {}
+        for job in jobs:
+            uri, pred = job.get()
+            predictions[uri] = pred
 
         # Evaluate results
         metric = pipeline_class.suggest_metric() if metric is None else metric
diff --git a/src/diart/metrics.py b/src/diart/metrics.py
new file mode 100644
index 00000000..d5ae56e4
--- /dev/null
+++ b/src/diart/metrics.py
@@ -0,0 +1,96 @@
+from pathlib import Path
+from typing import Text, Any, List, Union
+
+import pandas as pd
+from pyannote.core import Annotation
+from pyannote.metrics import diarization as dia, detection as det
+from pyannote.metrics.base import BaseMetric as PyannoteBaseMetric
+from pyannote.database.util import load_rttm
+from torchmetrics import text
+
+
+class Metric:
+    @property
+    def name(self) -> Text:
+        raise NotImplementedError
+
+    def __call__(self, reference: Any, prediction: Any) -> float:
+        raise NotImplementedError
+
+    def report(self, uris: List[Text], display: bool = False) -> pd.DataFrame:
+        raise NotImplementedError
+
+    def load_reference(self, filepath: Union[Text, Path]) -> Any:
+        raise NotImplementedError
+
+
+class PyannoteMetric(Metric):
+    def __init__(self, metric: PyannoteBaseMetric):
+        self._metric = metric
+
+    @property
+    def name(self) -> Text:
+        return self._metric.name
+
+    def __call__(self, reference: Annotation, prediction: Annotation) -> float:
+        return self._metric(reference, prediction)
+
+    def report(self, uris: List[Text], display: bool = False) -> pd.DataFrame:
+        return self._metric.report(display)
+
+    def load_reference(self, filepath: Union[Text, Path]) -> Annotation:
+        return load_rttm(filepath).popitem()[1]
+
+
+class DiarizationErrorRate(PyannoteMetric):
+    def __init__(self, collar: float = 0, skip_overlap: bool = False):
+        super().__init__(dia.DiarizationErrorRate(collar, skip_overlap))
+
+
+class DetectionErrorRate(PyannoteMetric):
+    def __init__(self, collar: float = 0, skip_overlap: bool = False):
+        super().__init__(det.DetectionErrorRate(collar, skip_overlap))
+
+
+class WordErrorRate(Metric):
+    def __init__(self, unify_case: bool = False):
+        self.unify_case = unify_case
+        self._metric = text.WordErrorRate()
+        self._values = []
+
+    @property
+    def name(self) -> Text:
+        return "word error rate"
+
+    def __call__(self, reference: Text, prediction: Text) -> float:
+        if self.unify_case:
+            prediction = prediction.lower()
+            reference = reference.lower()
+        # Torchmetrics requires predictions first, then reference
+        value = self._metric(prediction, reference).item()
+        self._values.append(value)
+        return value
+
+    def report(self, uris: List[Text], display: bool = False) -> pd.DataFrame:
+        num_uris, num_values = len(uris), len(self._values)
+        msg = f"URI list size must match values. Found {num_uris} but expected {num_values}"
+        assert num_uris == num_values, msg
+
+        rows = self._values + [self._metric.compute().item()]
+        index = uris + ["TOTAL"]
+        report = pd.DataFrame(rows, index=index, columns=[self.name])
+
+        if display:
+            print(report.to_string(
+                index=True,
+                sparsify=False,
+                justify="right",
+                float_format=lambda f: "{0:.2f}".format(f),
+            ))
+
+        return report
+
+    def load_reference(self, filepath: Union[Text, Path]) -> Text:
+        with open(filepath, "r") as file:
+            lines = [line.strip() for line in file.readlines()]
+        return " ".join(lines)

From d8d73428fd475b1cc9515862188c2ec634abc6e3 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Fri, 21 Apr 2023 17:20:06 +0200
Subject: [PATCH 07/23] Reduce Whisper VRAM footprint (around 400Mb). Add fp16
 option

---
 src/diart/blocks/asr.py | 4 +++-
 src/diart/models.py     | 7 ++++++-
 2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py
index 0e034a3e..976ddb1d 100644
--- a/src/diart/blocks/asr.py
+++ b/src/diart/blocks/asr.py
@@ -13,6 +13,7 @@
 from ..features import TemporalFeatureFormatter, TemporalFeatures
 from ..metrics import Metric, WordErrorRate
 
+
 BeamSize = HyperParameter("beam_size", low=1, high=20)
 
 
@@ -32,10 +33,11 @@ def from_whisper(
         download_path: Optional[Union[Text, Path]] = None,
         in_memory: bool = False,
         remember_transcriptions: bool = True,
+        fp16: bool = False,
         device: Optional[Union[Text, torch.device]] = None,
     ) -> 'SpeechRecognition':
         asr_model = m.SpeechRecognitionModel.from_whisper(
-            name, download_path, in_memory, remember_transcriptions
+            name, download_path, in_memory, remember_transcriptions, fp16
         )
         return SpeechRecognition(asr_model, device)
 
diff --git a/src/diart/models.py b/src/diart/models.py
index 7a653837..bffebe0d 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -43,6 +43,7 @@ def __init__(
     def __call__(self) -> nn.Module:
         return whisper.load_model(
             name=self.name,
+            device="cpu",
             download_root=self.download_path,
             in_memory=self.in_memory,
         )
@@ -207,12 +208,13 @@ def from_whisper(
         download_path: Optional[Union[Text, Path]] = None,
         in_memory: bool = False,
         remember_transcriptions: bool = True,
+        fp16: bool = False,
     ) -> 'SpeechRecognitionModel':
         msg = "No whisper-transcribed installation found. " \
               "Visit https://github.com/linto-ai/whisper-timestamped#installation to install"
         assert _has_whisper, msg
         return WhisperSpeechRecognitionModel(
-            name, download_path, in_memory, remember_transcriptions
+            name, download_path, in_memory, remember_transcriptions, fp16
         )
 
     @property
@@ -250,9 +252,11 @@ def __init__(
         download_path: Optional[Union[Text, Path]] = None,
         in_memory: bool = False,
         remember_transcriptions: bool = True,
+        fp16: bool = False,
     ):
         super().__init__(WhisperLoader(name, download_path, in_memory))
         self.remember_transcriptions = remember_transcriptions
+        self.fp16 = fp16
         self.language = None
         self._cache = None
 
@@ -279,6 +283,7 @@ def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]:
                 verbose=None,
                 task="transcribe",
                 language=self.language,
+                fp16=self.fp16,
             )
 
             # Extract chunks and timestamps

From 2cfc35da8035f07c27c3336e0d44ee5e499cd983 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Fri, 21 Apr 2023 17:33:06 +0200
Subject: [PATCH 08/23] Change whisper input type based on fp16 parameter

---
 src/diart/models.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/diart/models.py b/src/diart/models.py
index bffebe0d..57921df6 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -275,7 +275,8 @@ def set_language(self, language: Optional[Text] = None):
     def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]:
         results = []
         for waveform in waveform_batch:
-            audio = whisper.pad_or_trim(waveform.type(torch.float32).reshape(-1))
+            dtype = torch.float16 if self.fp16 else torch.float32
+            audio = whisper.pad_or_trim(waveform.type(dtype).reshape(-1))
             transcription = whisper.transcribe(
                 self.model,
                 audio,

From a40112c42543e4d70d37c2fe999530a161e2b7a9 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Sat, 22 Apr 2023 17:40:35 +0200
Subject: [PATCH 09/23] Implement batched inference for whisper. Re-implement
 decoding.

---
 src/diart/blocks/asr.py |  78 +++++++--------
 src/diart/models.py     | 214 +++++++++++++++++++++++++++++++++-------
 2 files changed, 216 insertions(+), 76 deletions(-)

diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py
index 976ddb1d..c485724e 100644
--- a/src/diart/blocks/asr.py
+++ b/src/diart/blocks/asr.py
@@ -32,12 +32,11 @@ def from_whisper(
         name: Text,
         download_path: Optional[Union[Text, Path]] = None,
         in_memory: bool = False,
-        remember_transcriptions: bool = True,
         fp16: bool = False,
         device: Optional[Union[Text, torch.device]] = None,
     ) -> 'SpeechRecognition':
         asr_model = m.SpeechRecognitionModel.from_whisper(
-            name, download_path, in_memory, remember_transcriptions, fp16
+            name, download_path, in_memory, fp16
         )
         return SpeechRecognition(asr_model, device)
 
@@ -173,15 +172,12 @@ def __call__(
         diarization: Optional[Sequence[Annotation]] = None,
         **kwargs
     ) -> Sequence[Tuple[Text, SlidingWindowFeature]]:
-        # TODO implement batched inference
-        only_one_dia = diarization is None or len(diarization) == 1
-        msg = "Batched inference is not yet supported for 'Transcription'. Please set batch size to 1"
-        assert len(waveforms) == 1 and only_one_dia, msg
+        batch_size = len(waveforms)
+        msg = "Pipeline expected at least 1 input"
+        assert batch_size >= 1, msg
 
-        waveform = waveforms[0]
-
-        # Add fake batch dimension, shape (1, samples, channels)
-        batch = torch.from_numpy(waveform.data).unsqueeze(0)
+        # Create batch from chunk sequence, shape (batch, samples, channels)
+        batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
 
         expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
         msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
@@ -189,34 +185,34 @@ def __call__(
 
         # Transcribe batch
         # TODO only transcribe if there's speech
-        output = self.asr(batch)[0]
-
-        if diarization is None:
-            return [(output.text.strip(), waveform)]
-
-        diarization = diarization[0]
-
-        # Align transcription with diarization to determine speakers
-        full_transcription = []
-        buffer_shift = waveform.sliding_window.start
-        for text, timestamp in zip(output.chunks, output.timestamps):
-            target_region = Segment(
-                buffer_shift + timestamp.start,
-                buffer_shift + timestamp.end
-            )
-            dia = diarization.crop(target_region)
-            speakers = dia.labels()
-            num_speakers = len(speakers)
-            if num_speakers == 0:
-                # Include transcription but don't assign a speaker
-                full_transcription.append(text)
-            elif num_speakers == 1:
-                # Typical case, annotate text with the only speaker
-                full_transcription.append(f"[{speakers[0]}]{text}")
-            else:
-                # Multiple speakers for the same text block, choose the most active one
-                # TODO match at the level of words?
-                max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
-                full_transcription.append(f"[{speakers[max_spk]}]{text}")
-
-        return [(" ".join(full_transcription).strip(), waveform)]
+        outputs = self.asr(batch)
+
+        return [(out.text, wav) for out, wav in zip(outputs, waveforms)]
+
+        # TODO align text with speakers if diarization is not None
+
+        # diarization = diarization[0]
+        #
+        # # Align transcription with diarization to determine speakers
+        # full_transcription = []
+        # buffer_shift = waveform.sliding_window.start
+        # for text, timestamp in zip(outputs.chunks, outputs.timestamps):
+        #     target_region = Segment(
+        #         buffer_shift + timestamp.start,
+        #         buffer_shift + timestamp.end
+        #     )
+        #     dia = diarization.crop(target_region)
+        #     speakers = dia.labels()
+        #     num_speakers = len(speakers)
+        #     if num_speakers == 0:
+        #         # Include transcription but don't assign a speaker
+        #         full_transcription.append(text)
+        #     elif num_speakers == 1:
+        #         # Typical case, annotate text with the only speaker
+        #         full_transcription.append(f"[{speakers[0]}]{text}")
+        #     else:
+        #         # Multiple speakers for the same text block, choose the most active one
+        #         max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
+        #         full_transcription.append(f"[{speakers[max_spk]}]{text}")
+        #
+        # return [(" ".join(full_transcription).strip(), waveform)]
diff --git a/src/diart/models.py b/src/diart/models.py
index 57921df6..0879ef9a 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -1,7 +1,9 @@
+import time
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Optional, Text, Union, Callable, List, Tuple, Dict
+from typing import Optional, Text, Union, Callable, List, Any
 
+import numpy as np
 import torch
 import torch.nn as nn
 from pyannote.core import Segment
@@ -14,9 +16,16 @@
 
 try:
     import whisper
+    from whisper.tokenizer import get_tokenizer
     _has_whisper = True
+    DecodingResult = whisper.DecodingResult
+    DecodingOptions = whisper.DecodingOptions
+    Tokenizer = whisper.tokenizer.Tokenizer
 except ImportError:
     _has_whisper = False
+    DecodingResult = Any
+    DecodingOptions = Any
+    Tokenizer = Any
 
 
 class PyannoteLoader:
@@ -207,15 +216,12 @@ def from_whisper(
         name: Text,
         download_path: Optional[Union[Text, Path]] = None,
         in_memory: bool = False,
-        remember_transcriptions: bool = True,
         fp16: bool = False,
     ) -> 'SpeechRecognitionModel':
         msg = "No whisper-transcribed installation found. " \
               "Visit https://github.com/linto-ai/whisper-timestamped#installation to install"
         assert _has_whisper, msg
-        return WhisperSpeechRecognitionModel(
-            name, download_path, in_memory, remember_transcriptions, fp16
-        )
+        return WhisperSpeechRecognitionModel(name, download_path, in_memory, fp16)
 
     @property
     def duration(self) -> float:
@@ -245,20 +251,141 @@ def forward(self, waveform: torch.Tensor) -> List[Transcription]:
         raise NotImplementedError
 
 
+class WhisperDecoder:
+    def __init__(
+        self,
+        compression_ratio_threshold: Optional[float] = 2.4,
+        logprob_threshold: Optional[float] = -1,
+    ):
+        self.compression_ratio_threshold = compression_ratio_threshold
+        self.logprob_threshold = logprob_threshold
+        self.temperatures = (0, 0.2, 0.4, 0.6, 0.8, 1)
+
+    @staticmethod
+    def get_temperature_options(initial: DecodingOptions, t: float) -> DecodingOptions:
+        t_options = {**vars(initial)}
+        if t > 0:
+            t_options.pop("beam_size", None)
+            t_options.pop("patience", None)
+        else:
+            t_options.pop("best_of", None)
+        t_options["temperature"] = t
+        return DecodingOptions(**t_options)
+
+    @staticmethod
+    def decode(
+        model,
+        batch: torch.Tensor,
+        options: DecodingOptions
+    ) -> DecodingResult:
+        return model.decode(batch, options)
+
+    def check_compression(self) -> bool:
+        return self.compression_ratio_threshold is not None
+
+    def check_logprob(self) -> bool:
+        return self.logprob_threshold is not None
+
+    def needs_fallback(self, output: DecodingResult) -> bool:
+        if self.check_compression and output.compression_ratio > self.compression_ratio_threshold:
+            # Transcription is too repetitive
+            return True
+        if self.check_logprob and output.avg_logprob < self.logprob_threshold:
+            # Average log probability is too low
+            return True
+        return False
+
+    def decode_with_fallback(
+        self,
+        model,
+        batch: torch.Tensor,
+        options: DecodingOptions,
+    ) -> DecodingResult:
+        batch_size = batch.shape[0]
+        results = [None] * batch_size
+        retry_idx = torch.ones(batch_size).type(torch.bool)
+
+        for t in self.temperatures:
+            # Transcribe with the given temperature
+            t_options = self.get_temperature_options(options, t)
+            outputs = model.decode(batch[retry_idx], t_options)
+
+            # Determine which outputs need to be transcribed again
+            #  based on quality estimates
+            output_idx = torch.where(retry_idx)[0]
+            for idx, out in zip(output_idx, outputs):
+                results[idx] = out
+                if not self.needs_fallback(out):
+                    retry_idx[idx] = False
+
+            # No output needs fallback, get out of the loop
+            if torch.sum(retry_idx).item() == 0:
+                break
+
+        return results
+
+    @staticmethod
+    def split_with_timestamps(
+        result: DecodingResult,
+        tokenizer: Tokenizer,
+        chunk_duration: float,
+        token_duration: float,
+    ) -> Transcription:
+        tokens = torch.tensor(result.tokens)
+        chunks, timestamps = [], []
+        ts_tokens = tokens.ge(tokenizer.timestamp_begin)
+        single_ts_ending = ts_tokens[-2:].tolist() == [False, True]
+        consecutive = torch.where(ts_tokens[:-1] & ts_tokens[1:])[0] + 1
+        if len(consecutive) > 0:
+            # Output contains two consecutive timestamp tokens
+            slices = consecutive.tolist()
+            if single_ts_ending:
+                slices.append(len(tokens))
+
+            last_slice = 0
+            for current_slice in slices:
+                sliced_tokens = tokens[last_slice:current_slice]
+                start_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
+                end_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
+                text_tokens = [token for token in sliced_tokens if token < tokenizer.eot]
+                text = tokenizer.decode(text_tokens).strip()
+                timestamp = Segment(start_pos * token_duration, end_pos * token_duration)
+                if text and timestamp.start != timestamp.end:
+                    chunks.append(text)
+                    timestamps.append(timestamp)
+                last_slice = current_slice
+        else:
+            duration = chunk_duration
+            ts = tokens[ts_tokens.nonzero().flatten()]
+            if len(ts) > 0 and ts[-1].item() != tokenizer.timestamp_begin:
+                # Use last timestamp as end time for the unique chunk
+                last_ts_pos = ts[-1].item() - tokenizer.timestamp_begin
+                duration = last_ts_pos * token_duration
+            text_tokens = [token for token in tokens if token < tokenizer.eot]
+            text = tokenizer.decode(text_tokens).strip()
+            if text:
+                chunks.append(text)
+                timestamps.append(Segment(0, duration))
+
+        return Transcription(result.text, chunks, timestamps)
+
+
 class WhisperSpeechRecognitionModel(SpeechRecognitionModel):
     def __init__(
         self,
         name: Text,
         download_path: Optional[Union[Text, Path]] = None,
         in_memory: bool = False,
-        remember_transcriptions: bool = True,
         fp16: bool = False,
+        compression_ratio_threshold: Optional[float] = 2.4,
+        logprob_threshold: Optional[float] = -1,
     ):
         super().__init__(WhisperLoader(name, download_path, in_memory))
-        self.remember_transcriptions = remember_transcriptions
         self.fp16 = fp16
+        self.beam_size = None
         self.language = None
-        self._cache = None
+        self._token_duration: Optional[float] = None
+        self.decoder = WhisperDecoder(compression_ratio_threshold, logprob_threshold)
 
     @property
     def duration(self) -> float:
@@ -269,37 +396,54 @@ def duration(self) -> float:
     def sample_rate(self) -> int:
         return whisper.audio.SAMPLE_RATE
 
+    @property
+    def token_duration(self) -> float:
+        if self._token_duration is None:
+            # 2 mel frames per output token
+            input_stride = int(np.rint(whisper.audio.N_FRAMES / self.model.dims.n_audio_ctx))
+            # Output token duration is 0.02 seconds
+            self._token_duration = input_stride * whisper.audio.HOP_LENGTH / self.sample_rate
+        return self._token_duration
+
     def set_language(self, language: Optional[Text] = None):
         self.language = language
 
-    def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]:
-        results = []
-        for waveform in waveform_batch:
-            dtype = torch.float16 if self.fp16 else torch.float32
-            audio = whisper.pad_or_trim(waveform.type(dtype).reshape(-1))
-            transcription = whisper.transcribe(
-                self.model,
-                audio,
-                initial_prompt=self._cache,
-                verbose=None,
-                task="transcribe",
-                language=self.language,
-                fp16=self.fp16,
-            )
+    def set_beam_size(self, size: int):
+        self.beam_size = size
 
-            # Extract chunks and timestamps
-            chunks, timestamps = [], []
-            for chunk in transcription["segments"]:
-                chunks.append(chunk["text"])
-                timestamps.append(Segment(chunk["start"], chunk["end"]))
+    def set_fp16(self, value: bool):
+        self.fp16 = value
 
-            # Create transcription object
-            transcription = Transcription(transcription["text"], chunks, timestamps)
-            results.append(transcription)
+    def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]:
+        # Remove channel dimension
+        batch = waveform_batch.squeeze(1)
+        num_chunk_samples = batch.shape[-1]
+        # Compute log mel spectrogram
+        batch = whisper.log_mel_spectrogram(batch)
+        # Add padding
+        dtype = torch.float16 if self.fp16 else torch.float32
+        batch = whisper.pad_or_trim(batch, whisper.audio.N_FRAMES).to(batch.device).type(dtype)
+
+        # Transcribe batch
+        options = whisper.DecodingOptions(
+            task="transcribe",
+            language=self.language,
+            beam_size=self.beam_size,
+            fp16=self.fp16,
+        )
+        results = self.decoder.decode_with_fallback(self.model, batch, options)
+        tokenizer = get_tokenizer(
+            self.model.is_multilingual,
+            language=options.language,
+            task=options.task,
+        )
 
-            # Update transcription buffer
-            if self.remember_transcriptions:
-                # TODO handle overlapping transcriptions
-                self._cache = transcription.text
+        chunk_duration = int(np.rint(num_chunk_samples / self.sample_rate))
+        transcriptions = [
+            self.decoder.split_with_timestamps(
+                res, tokenizer, chunk_duration, self.token_duration
+            )
+            for res in results
+        ]
 
-        return results
+        return transcriptions

From e8196a7ce8c7cc75f769f706265c735b539f770b Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Sat, 22 Apr 2023 18:31:48 +0200
Subject: [PATCH 10/23] Minor changes in transcription arguments

---
 src/diart/blocks/asr.py         | 17 +++++------------
 src/diart/blocks/base.py        |  1 -
 src/diart/blocks/diarization.py |  1 -
 src/diart/blocks/vad.py         |  1 -
 src/diart/models.py             |  8 ++++----
 5 files changed, 9 insertions(+), 19 deletions(-)

diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py
index c485724e..ecd13b8e 100644
--- a/src/diart/blocks/asr.py
+++ b/src/diart/blocks/asr.py
@@ -4,7 +4,7 @@
 import numpy as np
 import torch
 from einops import rearrange
-from pyannote.core import SlidingWindowFeature, Annotation, Segment
+from pyannote.core import SlidingWindowFeature
 
 from . import base
 from .. import models as m
@@ -13,7 +13,6 @@
 from ..features import TemporalFeatureFormatter, TemporalFeatures
 from ..metrics import Metric, WordErrorRate
 
-
 BeamSize = HyperParameter("beam_size", low=1, high=20)
 
 
@@ -70,7 +69,7 @@ def __init__(
         asr: Optional[m.SpeechRecognitionModel] = None,
         duration: Optional[float] = None,
         language: Optional[Text] = None,
-        beam_size: int = 5,
+        beam_size: int = None,
         device: Optional[torch.device] = None,
     ):
         self.device = device
@@ -82,12 +81,11 @@ def __init__(
         if self.asr is None:
             self.asr = m.SpeechRecognitionModel.from_whisper("small")
         self.asr.set_language(language)
+        self.asr.set_beam_size(beam_size)
 
         self._duration = duration
         self._sample_rate: Optional[int] = None
 
-        self.beam_size = beam_size
-
     @property
     def duration(self) -> float:
         if self._duration is None:
@@ -122,16 +120,13 @@ def from_dict(data: Any) -> 'TranscriptionConfig':
             asr=asr,
             duration=utils.get(data, "duration", None),
             language=utils.get(data, "language", None),
-            beam_size=utils.get(data, "beam_size", 5),
+            beam_size=utils.get(data, "beam_size", None),
             device=device,
         )
 
 
 class Transcription(base.StreamingPipeline):
-    def __init__(
-        self,
-        config: Optional[TranscriptionConfig] = None,
-    ):
+    def __init__(self, config: Optional[TranscriptionConfig] = None):
         self._config = TranscriptionConfig() if config is None else config
         self.asr = SpeechRecognition(self.config.asr, self.config.device)
 
@@ -169,8 +164,6 @@ def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Pa
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
-        diarization: Optional[Sequence[Annotation]] = None,
-        **kwargs
     ) -> Sequence[Tuple[Text, SlidingWindowFeature]]:
         batch_size = len(waveforms)
         msg = "Pipeline expected at least 1 input"
diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py
index 40d1d22d..6494a9bf 100644
--- a/src/diart/blocks/base.py
+++ b/src/diart/blocks/base.py
@@ -95,6 +95,5 @@ def write_prediction(self, uri: Text, prediction: Any, dir_path: Union[Text, Pat
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
-        **kwargs,
     ) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
         raise NotImplementedError
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
index 2f8de3f5..fe3f4c98 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/blocks/diarization.py
@@ -206,7 +206,6 @@ def reset(self):
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
-        **kwargs,
     ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
         batch_size = len(waveforms)
         msg = "Pipeline expected at least 1 input"
diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py
index 47afb4da..42061c86 100644
--- a/src/diart/blocks/vad.py
+++ b/src/diart/blocks/vad.py
@@ -161,7 +161,6 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
-        **kwargs,
     ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
         batch_size = len(waveforms)
         msg = "Pipeline expected at least 1 input"
diff --git a/src/diart/models.py b/src/diart/models.py
index 0879ef9a..274478c8 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -234,6 +234,9 @@ def sample_rate(self) -> int:
     def set_language(self, language: Optional[Text] = None):
         raise NotImplementedError
 
+    def set_beam_size(self, size: Optional[int] = None):
+        raise NotImplementedError
+
     def forward(self, waveform: torch.Tensor) -> List[Transcription]:
         """
         Forward pass of the speech recognition model.
@@ -408,12 +411,9 @@ def token_duration(self) -> float:
     def set_language(self, language: Optional[Text] = None):
         self.language = language
 
-    def set_beam_size(self, size: int):
+    def set_beam_size(self, size: Optional[int] = None):
         self.beam_size = size
 
-    def set_fp16(self, value: bool):
-        self.fp16 = value
-
     def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]:
         # Remove channel dimension
         batch = waveform_batch.squeeze(1)

From 07dd9ae36781cdc25bd379c1f4f1cdaa4d3862da Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Sun, 23 Apr 2023 15:15:54 +0200
Subject: [PATCH 11/23] Greatly improve transcription pipeline by adding
 optional VAD

---
 src/diart/blocks/asr.py |  57 ++++++++++++++++++-----
 src/diart/models.py     | 101 ++++++++++++++++++++++++++++++++++------
 2 files changed, 132 insertions(+), 26 deletions(-)

diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py
index ecd13b8e..fdef7a15 100644
--- a/src/diart/blocks/asr.py
+++ b/src/diart/blocks/asr.py
@@ -9,6 +9,7 @@
 from . import base
 from .. import models as m
 from .. import utils
+from ..blocks import SpeakerSegmentation
 from ..blocks.base import HyperParameter
 from ..features import TemporalFeatureFormatter, TemporalFeatures
 from ..metrics import Metric, WordErrorRate
@@ -32,14 +33,25 @@ def from_whisper(
         download_path: Optional[Union[Text, Path]] = None,
         in_memory: bool = False,
         fp16: bool = False,
+        no_speech_threshold: float = 0.6,
+        compression_ratio_threshold: Optional[float] = 2.4,
+        logprob_threshold: Optional[float] = -1,
+        decode_with_fallback: bool = False,
         device: Optional[Union[Text, torch.device]] = None,
     ) -> 'SpeechRecognition':
         asr_model = m.SpeechRecognitionModel.from_whisper(
-            name, download_path, in_memory, fp16
+            name,
+            download_path,
+            in_memory,
+            fp16,
+            no_speech_threshold,
+            compression_ratio_threshold,
+            logprob_threshold,
+            decode_with_fallback,
         )
         return SpeechRecognition(asr_model, device)
 
-    def __call__(self, waveform: TemporalFeatures) -> List[m.Transcription]:
+    def __call__(self, waveform: TemporalFeatures) -> List[m.TranscriptionResult]:
         """
         Compute the transcription of input audio.
 
@@ -67,6 +79,8 @@ class TranscriptionConfig(base.StreamingConfig):
     def __init__(
         self,
         asr: Optional[m.SpeechRecognitionModel] = None,
+        vad: Optional[m.SegmentationModel] = None,
+        tau_active: float = 0.5,
         duration: Optional[float] = None,
         language: Optional[Text] = None,
         beam_size: int = None,
@@ -83,6 +97,9 @@ def __init__(
         self.asr.set_language(language)
         self.asr.set_beam_size(beam_size)
 
+        self.vad = vad
+        self.tau_active = tau_active
+
         self._duration = duration
         self._sample_rate: Optional[int] = None
 
@@ -112,12 +129,10 @@ def from_dict(data: Any) -> 'TranscriptionConfig':
         device = utils.get(data, "device", None)
         if device is None:
             device = torch.device("cpu") if utils.get(data, "cpu", False) else None
-
-        name = utils.get(data, "whisper", "small")
-        asr = m.SpeechRecognitionModel.from_whisper(name)
-
         return TranscriptionConfig(
-            asr=asr,
+            asr=utils.get(data, "asr", None),
+            vad=utils.get(data, "vad", None),
+            tau_active=utils.get(data, "tau_active", None),
             duration=utils.get(data, "duration", None),
             language=utils.get(data, "language", None),
             beam_size=utils.get(data, "beam_size", None),
@@ -129,6 +144,9 @@ class Transcription(base.StreamingPipeline):
     def __init__(self, config: Optional[TranscriptionConfig] = None):
         self._config = TranscriptionConfig() if config is None else config
         self.asr = SpeechRecognition(self.config.asr, self.config.device)
+        self.segmentation = None
+        if self.config.vad is not None:
+            self.segmentation = SpeakerSegmentation(self.config.vad, self.config.device)
 
     @staticmethod
     def get_config_class() -> type:
@@ -176,11 +194,28 @@ def __call__(
         msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
         assert batch.shape[1] == expected_num_samples, msg
 
-        # Transcribe batch
-        # TODO only transcribe if there's speech
-        outputs = self.asr(batch)
+        # Run voice detection if required
+        if self.segmentation is None:
+            has_voice = torch.arange(0, batch_size)
+        else:
+            segmentations = self.segmentation(batch)  # shape (batch, frames, speakers)
+            has_voice = torch.max(segmentations, dim=-1)[0]  # shape (batch, frames)
+            has_voice = torch.any(has_voice >= self.config.tau_active, dim=-1)  # shape (batch,)
+            has_voice = torch.where(has_voice)[0]
+
+        # Return empty strings if no speech in the entire batch
+        if len(has_voice) == 0:
+            return [("", wav) for wav in waveforms]
 
-        return [(out.text, wav) for out, wav in zip(outputs, waveforms)]
+        # Transcribe batch
+        outputs = self.asr(batch[has_voice])
+        mapping = {i_voice.item(): i_output for i_output, i_voice in enumerate(has_voice)}
+
+        # No-speech outputs are empty strings
+        return [
+            (outputs[mapping[i]].text if i in has_voice else "", waveforms[i])
+            for i in range(batch_size)
+        ]
 
         # TODO align text with speakers if diarization is not None
 
diff --git a/src/diart/models.py b/src/diart/models.py
index 274478c8..5afbec45 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -204,7 +204,7 @@ def forward(
 
 
 @dataclass(frozen=True)
-class Transcription:
+class TranscriptionResult:
     text: Text
     chunks: List[Text]
     timestamps: List[Segment]
@@ -217,11 +217,24 @@ def from_whisper(
         download_path: Optional[Union[Text, Path]] = None,
         in_memory: bool = False,
         fp16: bool = False,
+        no_speech_threshold: float = 0.6,
+        compression_ratio_threshold: Optional[float] = 2.4,
+        logprob_threshold: Optional[float] = -1,
+        decode_with_fallback: bool = False,
     ) -> 'SpeechRecognitionModel':
         msg = "No whisper-transcribed installation found. " \
               "Visit https://github.com/linto-ai/whisper-timestamped#installation to install"
         assert _has_whisper, msg
-        return WhisperSpeechRecognitionModel(name, download_path, in_memory, fp16)
+        return WhisperSpeechRecognitionModel(
+            name,
+            download_path,
+            in_memory,
+            fp16,
+            no_speech_threshold,
+            compression_ratio_threshold,
+            logprob_threshold,
+            decode_with_fallback,
+        )
 
     @property
     def duration(self) -> float:
@@ -237,7 +250,7 @@ def set_language(self, language: Optional[Text] = None):
     def set_beam_size(self, size: Optional[int] = None):
         raise NotImplementedError
 
-    def forward(self, waveform: torch.Tensor) -> List[Transcription]:
+    def forward(self, waveform: torch.Tensor) -> List[TranscriptionResult]:
         """
         Forward pass of the speech recognition model.
 
@@ -248,7 +261,7 @@ def forward(self, waveform: torch.Tensor) -> List[Transcription]:
 
         Returns
         -------
-        transcriptions: List[Transcription]
+        transcriptions: List[TranscriptionResult]
             A list of timestamped transcriptions
         """
         raise NotImplementedError
@@ -257,9 +270,11 @@ def forward(self, waveform: torch.Tensor) -> List[Transcription]:
 class WhisperDecoder:
     def __init__(
         self,
+        no_speech_threshold: float = 0.6,
         compression_ratio_threshold: Optional[float] = 2.4,
         logprob_threshold: Optional[float] = -1,
     ):
+        self.no_speech_threshold = no_speech_threshold
         self.compression_ratio_threshold = compression_ratio_threshold
         self.logprob_threshold = logprob_threshold
         self.temperatures = (0, 0.2, 0.4, 0.6, 0.8, 1)
@@ -303,7 +318,24 @@ def decode_with_fallback(
         model,
         batch: torch.Tensor,
         options: DecodingOptions,
-    ) -> DecodingResult:
+    ) -> List[DecodingResult]:
+        """Transcribe batch and retry with ever-increasing
+        temperatures if the estimated quality of the transcription is not good.
+
+        Parameters
+        ----------
+        model: whisper.Whisper
+            Whisper ASR model (contains 'decode' method).
+        batch: torch.Tensor, shape (batch, channel, samples)
+            Log mel spectrogram batch.
+        options: whisper.DecodingOptions
+            Configuration to decode transcription.
+
+        Returns
+        -------
+        result: List[whisper.DecodingResult]
+            Transcription results for this batch.
+        """
         batch_size = batch.shape[0]
         results = [None] * batch_size
         retry_idx = torch.ones(batch_size).type(torch.bool)
@@ -314,26 +346,51 @@ def decode_with_fallback(
             outputs = model.decode(batch[retry_idx], t_options)
 
             # Determine which outputs need to be transcribed again
-            #  based on quality estimates
             output_idx = torch.where(retry_idx)[0]
             for idx, out in zip(output_idx, outputs):
                 results[idx] = out
                 if not self.needs_fallback(out):
                     retry_idx[idx] = False
 
-            # No output needs fallback, get out of the loop
+            # No output needs fallback, get out of the loop early
             if torch.sum(retry_idx).item() == 0:
                 break
 
         return results
 
-    @staticmethod
     def split_with_timestamps(
+        self,
         result: DecodingResult,
         tokenizer: Tokenizer,
         chunk_duration: float,
         token_duration: float,
-    ) -> Transcription:
+    ) -> TranscriptionResult:
+        """Split a Whisper transcription into segments with their respective timestamps.
+        Replace with empty string if no-speech probability is high.
+
+        Parameters
+        ----------
+        result: whisper.DecodingResult
+            A single transcription output from Whisper.
+        tokenizer: whisper.tokenizer.Tokenizer
+            Tokenizer needed to decode outputs.
+        chunk_duration: float
+            Actual duration of each input chunk.
+        token_duration: float
+            Duration of each output token.
+
+        Returns
+        -------
+        result: TranscriptionResult
+            Transcription with identified segments and timestamps.
+        """
+        # Check if the model detects no speech and do not decode
+        if self.no_speech_threshold is not None:
+            no_speech = result.no_speech_prob > self.no_speech_threshold
+            low_confidence = self.logprob_threshold is None or result.avg_logprob < self.logprob_threshold
+            if no_speech and low_confidence:
+                return TranscriptionResult("", [""], [Segment(0, chunk_duration)])
+
         tokens = torch.tensor(result.tokens)
         chunks, timestamps = [], []
         ts_tokens = tokens.ge(tokenizer.timestamp_begin)
@@ -345,6 +402,7 @@ def split_with_timestamps(
             if single_ts_ending:
                 slices.append(len(tokens))
 
+            # Split into segments based on timestamp tokens
             last_slice = 0
             for current_slice in slices:
                 sliced_tokens = tokens[last_slice:current_slice]
@@ -358,6 +416,7 @@ def split_with_timestamps(
                     timestamps.append(timestamp)
                 last_slice = current_slice
         else:
+            # There is a single segment, identify timestamps
             duration = chunk_duration
             ts = tokens[ts_tokens.nonzero().flatten()]
             if len(ts) > 0 and ts[-1].item() != tokenizer.timestamp_begin:
@@ -370,7 +429,7 @@ def split_with_timestamps(
                 chunks.append(text)
                 timestamps.append(Segment(0, duration))
 
-        return Transcription(result.text, chunks, timestamps)
+        return TranscriptionResult(result.text, chunks, timestamps)
 
 
 class WhisperSpeechRecognitionModel(SpeechRecognitionModel):
@@ -380,15 +439,20 @@ def __init__(
         download_path: Optional[Union[Text, Path]] = None,
         in_memory: bool = False,
         fp16: bool = False,
+        no_speech_threshold: float = 0.6,
         compression_ratio_threshold: Optional[float] = 2.4,
         logprob_threshold: Optional[float] = -1,
+        decode_with_fallback: bool = False,
     ):
         super().__init__(WhisperLoader(name, download_path, in_memory))
         self.fp16 = fp16
         self.beam_size = None
         self.language = None
+        self.decode_with_fallback = decode_with_fallback
+        self.decoder = WhisperDecoder(
+            no_speech_threshold, compression_ratio_threshold, logprob_threshold
+        )
         self._token_duration: Optional[float] = None
-        self.decoder = WhisperDecoder(compression_ratio_threshold, logprob_threshold)
 
     @property
     def duration(self) -> float:
@@ -414,7 +478,7 @@ def set_language(self, language: Optional[Text] = None):
     def set_beam_size(self, size: Optional[int] = None):
         self.beam_size = size
 
-    def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]:
+    def forward(self, waveform_batch: torch.Tensor) -> List[TranscriptionResult]:
         # Remove channel dimension
         batch = waveform_batch.squeeze(1)
         num_chunk_samples = batch.shape[-1]
@@ -424,20 +488,27 @@ def forward(self, waveform_batch: torch.Tensor) -> List[Transcription]:
         dtype = torch.float16 if self.fp16 else torch.float32
         batch = whisper.pad_or_trim(batch, whisper.audio.N_FRAMES).to(batch.device).type(dtype)
 
-        # Transcribe batch
+        # Configure transcription decoding
         options = whisper.DecodingOptions(
             task="transcribe",
             language=self.language,
             beam_size=self.beam_size,
             fp16=self.fp16,
         )
-        results = self.decoder.decode_with_fallback(self.model, batch, options)
+
+        # Transcribe batch with fallback if required
+        if self.decode_with_fallback:
+            decode_fn = self.decoder.decode_with_fallback
+        else:
+            decode_fn = self.decoder.decode
+        results = decode_fn(self.model, batch, options)
+
+        # Split into segments and add timestamps
         tokenizer = get_tokenizer(
             self.model.is_multilingual,
             language=options.language,
             task=options.task,
         )
-
         chunk_duration = int(np.rint(num_chunk_samples / self.sample_rate))
         transcriptions = [
             self.decoder.split_with_timestamps(

From 0bf25228d6b280e5be589fc93d7a1c1f7547d03f Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Sun, 23 Apr 2023 16:30:25 +0200
Subject: [PATCH 12/23] Move pipelines to diart.pipelines. Add torchmetrics as
 a dependency

---
 requirements.txt                              |   1 +
 setup.cfg                                     |   1 +
 src/diart/__init__.py                         |   8 +-
 src/diart/blocks/__init__.py                  |   6 +-
 src/diart/blocks/asr.py                       | 182 +----------------
 src/diart/blocks/clustering.py                |   2 +-
 src/diart/console/tune.py                     |   2 +-
 src/diart/inference.py                        |  11 +-
 src/diart/models.py                           |   1 -
 src/diart/optim.py                            |   7 +-
 src/diart/pipelines/__init__.py               |   4 +
 src/diart/{blocks => pipelines}/base.py       |  26 +--
 .../{blocks => pipelines}/diarization.py      |  23 +--
 src/diart/pipelines/hparams.py                |  24 +++
 src/diart/pipelines/transcription.py          | 184 ++++++++++++++++++
 .../{blocks/vad.py => pipelines/voice.py}     |  17 +-
 src/diart/utils.py                            |   4 +-
 17 files changed, 256 insertions(+), 247 deletions(-)
 create mode 100644 src/diart/pipelines/__init__.py
 rename src/diart/{blocks => pipelines}/base.py (77%)
 rename src/diart/{blocks => pipelines}/diarization.py (93%)
 create mode 100644 src/diart/pipelines/hparams.py
 create mode 100644 src/diart/pipelines/transcription.py
 rename src/diart/{blocks/vad.py => pipelines/voice.py} (94%)

diff --git a/requirements.txt b/requirements.txt
index 50662023..3241ddc4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,6 +9,7 @@ pandas>=1.4.2
 torch>=1.12.1
 torchvision>=0.14.0
 torchaudio>=0.12.1,<1.0
+torchmetrics>=0.11.1
 pyannote.audio>=2.1.1
 pyannote.core>=4.5
 pyannote.database>=4.1.1
diff --git a/setup.cfg b/setup.cfg
index e67e4426..c70eac0b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -31,6 +31,7 @@ install_requires=
     torch>=1.12.1
     torchvision>=0.14.0
     torchaudio>=0.12.1,<1.0
+    torchmetrics>=0.11.1
     pyannote.audio>=2.1.1
     pyannote.core>=4.5
     pyannote.database>=4.1.1
diff --git a/src/diart/__init__.py b/src/diart/__init__.py
index e29287a0..0d67c9c5 100644
--- a/src/diart/__init__.py
+++ b/src/diart/__init__.py
@@ -1,8 +1,10 @@
-from .blocks import (
-    SpeakerDiarization,
+from .pipelines import (
     StreamingPipeline,
-    SpeakerDiarizationConfig,
     StreamingConfig,
+    SpeakerDiarization,
+    SpeakerDiarizationConfig,
     VoiceActivityDetection,
     VoiceActivityDetectionConfig,
+    Transcription,
+    TranscriptionConfig,
 )
diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py
index e6e8c479..96fae0e7 100644
--- a/src/diart/blocks/__init__.py
+++ b/src/diart/blocks/__init__.py
@@ -5,7 +5,7 @@
     FirstOnlyStrategy,
     DelayedAggregation,
 )
-from .clustering import OnlineSpeakerClustering
+from .clustering import IncrementalSpeakerClustering
 from .embedding import (
     SpeakerEmbedding,
     OverlappedSpeechPenalty,
@@ -13,7 +13,5 @@
     OverlapAwareSpeakerEmbedding,
 )
 from .segmentation import SpeakerSegmentation
-from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
-from .base import StreamingConfig, StreamingPipeline
 from .utils import Binarize, Resample, AdjustVolume
-from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig
+from .asr import SpeechRecognition
diff --git a/src/diart/blocks/asr.py b/src/diart/blocks/asr.py
index fdef7a15..83dc0d90 100644
--- a/src/diart/blocks/asr.py
+++ b/src/diart/blocks/asr.py
@@ -1,20 +1,11 @@
 from pathlib import Path
-from typing import Sequence, Optional, Any, Union, List, Text, Tuple
+from typing import Optional, Union, List, Text
 
-import numpy as np
 import torch
 from einops import rearrange
-from pyannote.core import SlidingWindowFeature
 
-from . import base
 from .. import models as m
-from .. import utils
-from ..blocks import SpeakerSegmentation
-from ..blocks.base import HyperParameter
 from ..features import TemporalFeatureFormatter, TemporalFeatures
-from ..metrics import Metric, WordErrorRate
-
-BeamSize = HyperParameter("beam_size", low=1, high=20)
 
 
 class SpeechRecognition:
@@ -73,174 +64,3 @@ def __call__(self, waveform: TemporalFeatures) -> List[m.TranscriptionResult]:
             # output = self.model(wave.to(self.device)).cpu()
             output = self.model(wave.to(self.device))
         return output
-
-
-class TranscriptionConfig(base.StreamingConfig):
-    def __init__(
-        self,
-        asr: Optional[m.SpeechRecognitionModel] = None,
-        vad: Optional[m.SegmentationModel] = None,
-        tau_active: float = 0.5,
-        duration: Optional[float] = None,
-        language: Optional[Text] = None,
-        beam_size: int = None,
-        device: Optional[torch.device] = None,
-    ):
-        self.device = device
-        if self.device is None:
-            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-        # Default ASR model is Whisper small (244M parameters)
-        self.asr = asr
-        if self.asr is None:
-            self.asr = m.SpeechRecognitionModel.from_whisper("small")
-        self.asr.set_language(language)
-        self.asr.set_beam_size(beam_size)
-
-        self.vad = vad
-        self.tau_active = tau_active
-
-        self._duration = duration
-        self._sample_rate: Optional[int] = None
-
-    @property
-    def duration(self) -> float:
-        if self._duration is None:
-            self._duration = self.asr.duration
-        return self._duration
-
-    @property
-    def step(self) -> float:
-        return self.duration
-
-    @property
-    def latency(self) -> float:
-        return self.duration
-
-    @property
-    def sample_rate(self) -> int:
-        if self._sample_rate is None:
-            self._sample_rate = self.asr.sample_rate
-        return self._sample_rate
-
-    @staticmethod
-    def from_dict(data: Any) -> 'TranscriptionConfig':
-        # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
-        device = utils.get(data, "device", None)
-        if device is None:
-            device = torch.device("cpu") if utils.get(data, "cpu", False) else None
-        return TranscriptionConfig(
-            asr=utils.get(data, "asr", None),
-            vad=utils.get(data, "vad", None),
-            tau_active=utils.get(data, "tau_active", None),
-            duration=utils.get(data, "duration", None),
-            language=utils.get(data, "language", None),
-            beam_size=utils.get(data, "beam_size", None),
-            device=device,
-        )
-
-
-class Transcription(base.StreamingPipeline):
-    def __init__(self, config: Optional[TranscriptionConfig] = None):
-        self._config = TranscriptionConfig() if config is None else config
-        self.asr = SpeechRecognition(self.config.asr, self.config.device)
-        self.segmentation = None
-        if self.config.vad is not None:
-            self.segmentation = SpeakerSegmentation(self.config.vad, self.config.device)
-
-    @staticmethod
-    def get_config_class() -> type:
-        return TranscriptionConfig
-
-    @staticmethod
-    def suggest_metric() -> Metric:
-        return WordErrorRate()
-
-    @staticmethod
-    def hyper_parameters() -> Sequence[HyperParameter]:
-        return [BeamSize]
-
-    @property
-    def config(self) -> TranscriptionConfig:
-        return self._config
-
-    def reset(self):
-        # No internal state. Nothing to do
-        pass
-
-    def set_timestamp_shift(self, shift: float):
-        # No timestamped output. Nothing to do
-        pass
-
-    def join_predictions(self, predictions: List[Text]) -> Text:
-        return "\n".join(predictions)
-
-    def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]):
-        with open(Path(dir_path) / f"{uri}.txt", "w") as out_file:
-            out_file.write(prediction)
-
-    def __call__(
-        self,
-        waveforms: Sequence[SlidingWindowFeature],
-    ) -> Sequence[Tuple[Text, SlidingWindowFeature]]:
-        batch_size = len(waveforms)
-        msg = "Pipeline expected at least 1 input"
-        assert batch_size >= 1, msg
-
-        # Create batch from chunk sequence, shape (batch, samples, channels)
-        batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
-
-        expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
-        msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
-        assert batch.shape[1] == expected_num_samples, msg
-
-        # Run voice detection if required
-        if self.segmentation is None:
-            has_voice = torch.arange(0, batch_size)
-        else:
-            segmentations = self.segmentation(batch)  # shape (batch, frames, speakers)
-            has_voice = torch.max(segmentations, dim=-1)[0]  # shape (batch, frames)
-            has_voice = torch.any(has_voice >= self.config.tau_active, dim=-1)  # shape (batch,)
-            has_voice = torch.where(has_voice)[0]
-
-        # Return empty strings if no speech in the entire batch
-        if len(has_voice) == 0:
-            return [("", wav) for wav in waveforms]
-
-        # Transcribe batch
-        outputs = self.asr(batch[has_voice])
-        mapping = {i_voice.item(): i_output for i_output, i_voice in enumerate(has_voice)}
-
-        # No-speech outputs are empty strings
-        return [
-            (outputs[mapping[i]].text if i in has_voice else "", waveforms[i])
-            for i in range(batch_size)
-        ]
-
-        # TODO align text with speakers if diarization is not None
-
-        # diarization = diarization[0]
-        #
-        # # Align transcription with diarization to determine speakers
-        # full_transcription = []
-        # buffer_shift = waveform.sliding_window.start
-        # for text, timestamp in zip(outputs.chunks, outputs.timestamps):
-        #     target_region = Segment(
-        #         buffer_shift + timestamp.start,
-        #         buffer_shift + timestamp.end
-        #     )
-        #     dia = diarization.crop(target_region)
-        #     speakers = dia.labels()
-        #     num_speakers = len(speakers)
-        #     if num_speakers == 0:
-        #         # Include transcription but don't assign a speaker
-        #         full_transcription.append(text)
-        #     elif num_speakers == 1:
-        #         # Typical case, annotate text with the only speaker
-        #         full_transcription.append(f"[{speakers[0]}]{text}")
-        #     else:
-        #         # Multiple speakers for the same text block, choose the most active one
-        #         max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
-        #         full_transcription.append(f"[{speakers[max_spk]}]{text}")
-        #
-        # return [(" ".join(full_transcription).strip(), waveform)]
diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py
index 882001b9..4b737175 100644
--- a/src/diart/blocks/clustering.py
+++ b/src/diart/blocks/clustering.py
@@ -7,7 +7,7 @@
 from ..mapping import SpeakerMap, SpeakerMapBuilder
 
 
-class OnlineSpeakerClustering:
+class IncrementalSpeakerClustering:
     """Implements constrained incremental online clustering of speakers and manages cluster centers.
 
     Parameters
diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py
index 6affda50..111d97c2 100644
--- a/src/diart/console/tune.py
+++ b/src/diart/console/tune.py
@@ -4,7 +4,7 @@
 import optuna
 from diart import argdoc
 from diart import utils
-from diart.blocks.base import HyperParameter
+from diart.pipelines.hparams import HyperParameter
 from diart.optim import Optimizer
 from optuna.samplers import TPESampler
 
diff --git a/src/diart/inference.py b/src/diart/inference.py
index ee22a3cb..258f773a 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -18,6 +18,7 @@
 from . import sources as src
 from . import utils
 from .metrics import Metric
+from .pipelines import StreamingPipeline, StreamingConfig
 from .progress import ProgressBar, RichProgressBar, TQDMProgressBar
 from .sinks import StreamingPlot, WindowClosedException
 
@@ -52,7 +53,7 @@ class StreamingInference:
     """
     def __init__(
         self,
-        pipeline: blocks.StreamingPipeline,
+        pipeline: StreamingPipeline,
         source: src.AudioSource,
         batch_size: int = 1,
         do_profile: bool = True,
@@ -288,7 +289,7 @@ def get_file_paths(self) -> List[Path]:
 
     def run_single(
         self,
-        pipeline: blocks.StreamingPipeline,
+        pipeline: StreamingPipeline,
         filepath: Path,
         progress_bar: ProgressBar,
     ) -> Tuple[Text, Any]:
@@ -387,7 +388,7 @@ def evaluate(
     def __call__(
         self,
         pipeline_class: type,
-        config: blocks.StreamingConfig,
+        config: StreamingConfig,
         metric: Optional[Metric] = None,
     ) -> Union[pd.DataFrame, Dict[Text, Any]]:
         """Run a given pipeline on a set of audio files.
@@ -451,7 +452,7 @@ def __init__(
     def run_single_job(
         self,
         pipeline_class: type,
-        config: blocks.StreamingConfig,
+        config: StreamingConfig,
         filepath: Path,
         description: Text,
     ) -> Tuple[Text, Any]:
@@ -488,7 +489,7 @@ def run_single_job(
     def __call__(
         self,
         pipeline_class: type,
-        config: blocks.StreamingConfig,
+        config: StreamingConfig,
         metric: Optional[Metric] = None,
     ) -> Union[pd.DataFrame, Dict[Text, Any]]:
         """Run a given pipeline on a set of audio files in parallel.
diff --git a/src/diart/models.py b/src/diart/models.py
index 5afbec45..485cb43a 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -1,4 +1,3 @@
-import time
 from dataclasses import dataclass
 from pathlib import Path
 from typing import Optional, Text, Union, Callable, List, Any
diff --git a/src/diart/optim.py b/src/diart/optim.py
index f7a96a6e..0ea27910 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -9,9 +9,10 @@
 from tqdm import trange, tqdm
 from typing_extensions import Literal
 
-from . import blocks
 from .audio import FilePath
 from .inference import Benchmark
+from .pipelines import StreamingConfig
+from .pipelines.hparams import HyperParameter
 
 
 class Optimizer:
@@ -22,8 +23,8 @@ def __init__(
         reference_path: Union[Text, Path],
         study_or_path: Union[FilePath, Study],
         batch_size: int = 32,
-        hparams: Optional[Sequence[blocks.base.HyperParameter]] = None,
-        base_config: Optional[blocks.StreamingConfig] = None,
+        hparams: Optional[Sequence[HyperParameter]] = None,
+        base_config: Optional[StreamingConfig] = None,
         do_kickstart_hparams: bool = True,
         metric: Optional[BaseMetric] = None,
         direction: Literal["minimize", "maximize"] = "minimize",
diff --git a/src/diart/pipelines/__init__.py b/src/diart/pipelines/__init__.py
new file mode 100644
index 00000000..f430e4f7
--- /dev/null
+++ b/src/diart/pipelines/__init__.py
@@ -0,0 +1,4 @@
+from .base import StreamingPipeline, StreamingConfig
+from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
+from .transcription import Transcription, TranscriptionConfig
+from .voice import VoiceActivityDetection, VoiceActivityDetectionConfig
diff --git a/src/diart/blocks/base.py b/src/diart/pipelines/base.py
similarity index 77%
rename from src/diart/blocks/base.py
rename to src/diart/pipelines/base.py
index 6494a9bf..507bfe5d 100644
--- a/src/diart/blocks/base.py
+++ b/src/diart/pipelines/base.py
@@ -1,37 +1,15 @@
-from dataclasses import dataclass
-from typing import Any, Tuple, Sequence, Text, List, Union
 from pathlib import Path
+from typing import Any, Tuple, Sequence, Text, List, Union
 
 import numpy as np
 from pyannote.core import SlidingWindowFeature
 
+from .hparams import HyperParameter
 from .. import utils
 from ..audio import FilePath, AudioLoader
 from ..metrics import Metric
 
 
-@dataclass
-class HyperParameter:
-    name: Text
-    low: float
-    high: float
-
-    @staticmethod
-    def from_name(name: Text) -> 'HyperParameter':
-        if name == "tau_active":
-            return TauActive
-        if name == "rho_update":
-            return RhoUpdate
-        if name == "delta_new":
-            return DeltaNew
-        raise ValueError(f"Hyper-parameter '{name}' not recognized")
-
-
-TauActive = HyperParameter("tau_active", low=0, high=1)
-RhoUpdate = HyperParameter("rho_update", low=0, high=1)
-DeltaNew = HyperParameter("delta_new", low=0, high=2)
-
-
 class StreamingConfig:
     @property
     def duration(self) -> float:
diff --git a/src/diart/blocks/diarization.py b/src/diart/pipelines/diarization.py
similarity index 93%
rename from src/diart/blocks/diarization.py
rename to src/diart/pipelines/diarization.py
index fe3f4c98..7c901799 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/pipelines/diarization.py
@@ -7,11 +7,8 @@
 from typing_extensions import Literal
 
 from . import base
-from .aggregation import DelayedAggregation
-from .clustering import OnlineSpeakerClustering
-from .embedding import OverlapAwareSpeakerEmbedding
-from .segmentation import SpeakerSegmentation
-from .utils import Binarize
+from .hparams import HyperParameter, TauActive, RhoUpdate, DeltaNew
+from .. import blocks
 from .. import models as m
 from .. import utils
 from ..metrics import Metric, DiarizationErrorRate
@@ -139,23 +136,23 @@ def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
         msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
         assert self._config.step <= self._config.latency <= self._config.duration, msg
 
-        self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device)
-        self.embedding = OverlapAwareSpeakerEmbedding(
+        self.segmentation = blocks.SpeakerSegmentation(self._config.segmentation, self._config.device)
+        self.embedding = blocks.OverlapAwareSpeakerEmbedding(
             self._config.embedding, self._config.gamma, self._config.beta, norm=1, device=self._config.device
         )
-        self.pred_aggregation = DelayedAggregation(
+        self.pred_aggregation = blocks.DelayedAggregation(
             self._config.step,
             self._config.latency,
             strategy="hamming",
             cropping_mode="loose",
         )
-        self.audio_aggregation = DelayedAggregation(
+        self.audio_aggregation = blocks.DelayedAggregation(
             self._config.step,
             self._config.latency,
             strategy="first",
             cropping_mode="center",
         )
-        self.binarize = Binarize(self._config.tau_active)
+        self.binarize = blocks.Binarize(self._config.tau_active)
 
         # Internal state, handle with care
         self.timestamp_shift = 0
@@ -172,8 +169,8 @@ def suggest_metric() -> Metric:
         return DiarizationErrorRate(collar=0, skip_overlap=False)
 
     @staticmethod
-    def hyper_parameters() -> Sequence[base.HyperParameter]:
-        return [base.TauActive, base.RhoUpdate, base.DeltaNew]
+    def hyper_parameters() -> Sequence[HyperParameter]:
+        return [TauActive, RhoUpdate, DeltaNew]
 
     @property
     def config(self) -> SpeakerDiarizationConfig:
@@ -194,7 +191,7 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te
 
     def reset(self):
         self.set_timestamp_shift(0)
-        self.clustering = OnlineSpeakerClustering(
+        self.clustering = blocks.IncrementalSpeakerClustering(
             self.config.tau_active,
             self.config.rho_update,
             self.config.delta_new,
diff --git a/src/diart/pipelines/hparams.py b/src/diart/pipelines/hparams.py
new file mode 100644
index 00000000..740a1edf
--- /dev/null
+++ b/src/diart/pipelines/hparams.py
@@ -0,0 +1,24 @@
+from dataclasses import dataclass
+from typing import Text
+
+
+@dataclass
+class HyperParameter:
+    name: Text
+    low: float
+    high: float
+
+    @staticmethod
+    def from_name(name: Text) -> 'HyperParameter':
+        if name == "tau_active":
+            return TauActive
+        if name == "rho_update":
+            return RhoUpdate
+        if name == "delta_new":
+            return DeltaNew
+        raise ValueError(f"Hyper-parameter '{name}' not recognized")
+
+
+TauActive = HyperParameter("tau_active", low=0, high=1)
+RhoUpdate = HyperParameter("rho_update", low=0, high=1)
+DeltaNew = HyperParameter("delta_new", low=0, high=2)
\ No newline at end of file
diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py
new file mode 100644
index 00000000..2bf4a270
--- /dev/null
+++ b/src/diart/pipelines/transcription.py
@@ -0,0 +1,184 @@
+from pathlib import Path
+from typing import Sequence, Optional, Any, Union, List, Text, Tuple
+
+import numpy as np
+import torch
+from pyannote.core import SlidingWindowFeature
+
+from . import base
+from .hparams import HyperParameter, TauActive
+from .. import blocks
+from .. import models as m
+from .. import utils
+from ..metrics import Metric, WordErrorRate
+
+
+class TranscriptionConfig(base.StreamingConfig):
+    def __init__(
+        self,
+        asr: Optional[m.SpeechRecognitionModel] = None,
+        vad: Optional[m.SegmentationModel] = None,
+        tau_active: float = 0.5,
+        duration: Optional[float] = None,
+        language: Optional[Text] = None,
+        beam_size: int = None,
+        device: Optional[torch.device] = None,
+    ):
+        self.device = device
+        if self.device is None:
+            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        # Default ASR model is Whisper small (244M parameters)
+        self.asr = asr
+        if self.asr is None:
+            self.asr = m.SpeechRecognitionModel.from_whisper("small")
+        self.asr.set_language(language)
+        self.asr.set_beam_size(beam_size)
+
+        self.vad = vad
+        self.tau_active = tau_active
+
+        self._duration = duration
+        self._sample_rate: Optional[int] = None
+
+    @property
+    def duration(self) -> float:
+        if self._duration is None:
+            self._duration = self.asr.duration
+        return self._duration
+
+    @property
+    def step(self) -> float:
+        return self.duration
+
+    @property
+    def latency(self) -> float:
+        return self.duration
+
+    @property
+    def sample_rate(self) -> int:
+        if self._sample_rate is None:
+            self._sample_rate = self.asr.sample_rate
+        return self._sample_rate
+
+    @staticmethod
+    def from_dict(data: Any) -> 'TranscriptionConfig':
+        # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
+        device = utils.get(data, "device", None)
+        if device is None:
+            device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+        return TranscriptionConfig(
+            asr=utils.get(data, "asr", None),
+            vad=utils.get(data, "vad", None),
+            tau_active=utils.get(data, "tau_active", None),
+            duration=utils.get(data, "duration", None),
+            language=utils.get(data, "language", None),
+            beam_size=utils.get(data, "beam_size", None),
+            device=device,
+        )
+
+
+class Transcription(base.StreamingPipeline):
+    def __init__(self, config: Optional[TranscriptionConfig] = None):
+        self._config = TranscriptionConfig() if config is None else config
+        self.asr = blocks.SpeechRecognition(self.config.asr, self.config.device)
+        self.segmentation = None
+        if self.config.vad is not None:
+            self.segmentation = blocks.SpeakerSegmentation(self.config.vad, self.config.device)
+
+    @staticmethod
+    def get_config_class() -> type:
+        return TranscriptionConfig
+
+    @staticmethod
+    def suggest_metric() -> Metric:
+        return WordErrorRate()
+
+    @staticmethod
+    def hyper_parameters() -> Sequence[HyperParameter]:
+        return [TauActive]
+
+    @property
+    def config(self) -> TranscriptionConfig:
+        return self._config
+
+    def reset(self):
+        # No internal state. Nothing to do
+        pass
+
+    def set_timestamp_shift(self, shift: float):
+        # No timestamped output. Nothing to do
+        pass
+
+    def join_predictions(self, predictions: List[Text]) -> Text:
+        return "\n".join(predictions)
+
+    def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]):
+        with open(Path(dir_path) / f"{uri}.txt", "w") as out_file:
+            out_file.write(prediction)
+
+    def __call__(
+        self,
+        waveforms: Sequence[SlidingWindowFeature],
+    ) -> Sequence[Tuple[Text, SlidingWindowFeature]]:
+        batch_size = len(waveforms)
+        msg = "Pipeline expected at least 1 input"
+        assert batch_size >= 1, msg
+
+        # Create batch from chunk sequence, shape (batch, samples, channels)
+        batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
+
+        expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
+        msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
+        assert batch.shape[1] == expected_num_samples, msg
+
+        # Run voice detection if required
+        if self.segmentation is None:
+            has_voice = torch.arange(0, batch_size)
+        else:
+            segmentations = self.segmentation(batch)  # shape (batch, frames, speakers)
+            has_voice = torch.max(segmentations, dim=-1)[0]  # shape (batch, frames)
+            has_voice = torch.any(has_voice >= self.config.tau_active, dim=-1)  # shape (batch,)
+            has_voice = torch.where(has_voice)[0]
+
+        # Return empty strings if no speech in the entire batch
+        if len(has_voice) == 0:
+            return [("", wav) for wav in waveforms]
+
+        # Transcribe batch
+        outputs = self.asr(batch[has_voice])
+        mapping = {i_voice.item(): i_output for i_output, i_voice in enumerate(has_voice)}
+
+        # No-speech outputs are empty strings
+        return [
+            (outputs[mapping[i]].text if i in has_voice else "", waveforms[i])
+            for i in range(batch_size)
+        ]
+
+        # TODO align text with speakers if diarization is not None
+
+        # diarization = diarization[0]
+        #
+        # # Align transcription with diarization to determine speakers
+        # full_transcription = []
+        # buffer_shift = waveform.sliding_window.start
+        # for text, timestamp in zip(outputs.chunks, outputs.timestamps):
+        #     target_region = Segment(
+        #         buffer_shift + timestamp.start,
+        #         buffer_shift + timestamp.end
+        #     )
+        #     dia = diarization.crop(target_region)
+        #     speakers = dia.labels()
+        #     num_speakers = len(speakers)
+        #     if num_speakers == 0:
+        #         # Include transcription but don't assign a speaker
+        #         full_transcription.append(text)
+        #     elif num_speakers == 1:
+        #         # Typical case, annotate text with the only speaker
+        #         full_transcription.append(f"[{speakers[0]}]{text}")
+        #     else:
+        #         # Multiple speakers for the same text block, choose the most active one
+        #         max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
+        #         full_transcription.append(f"[{speakers[max_spk]}]{text}")
+        #
+        # return [(" ".join(full_transcription).strip(), waveform)]
\ No newline at end of file
diff --git a/src/diart/blocks/vad.py b/src/diart/pipelines/voice.py
similarity index 94%
rename from src/diart/blocks/vad.py
rename to src/diart/pipelines/voice.py
index 42061c86..f93c28a6 100644
--- a/src/diart/blocks/vad.py
+++ b/src/diart/pipelines/voice.py
@@ -7,9 +7,8 @@
 from typing_extensions import Literal
 
 from . import base
-from .aggregation import DelayedAggregation
-from .segmentation import SpeakerSegmentation
-from .utils import Binarize
+from .hparams import HyperParameter, TauActive
+from .. import blocks
 from .. import models as m
 from .. import utils
 from ..metrics import Metric, DetectionErrorRate
@@ -106,20 +105,20 @@ def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
         msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
         assert self._config.step <= self._config.latency <= self._config.duration, msg
 
-        self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device)
-        self.pred_aggregation = DelayedAggregation(
+        self.segmentation = blocks.SpeakerSegmentation(self._config.segmentation, self._config.device)
+        self.pred_aggregation = blocks.DelayedAggregation(
             self._config.step,
             self._config.latency,
             strategy="hamming",
             cropping_mode="loose",
         )
-        self.audio_aggregation = DelayedAggregation(
+        self.audio_aggregation = blocks.DelayedAggregation(
             self._config.step,
             self._config.latency,
             strategy="first",
             cropping_mode="center",
         )
-        self.binarize = Binarize(self._config.tau_active)
+        self.binarize = blocks.Binarize(self._config.tau_active)
 
         # Internal state, handle with care
         self.timestamp_shift = 0
@@ -134,8 +133,8 @@ def suggest_metric() -> Metric:
         return DetectionErrorRate(collar=0, skip_overlap=False)
 
     @staticmethod
-    def hyper_parameters() -> Sequence[base.HyperParameter]:
-        return [base.TauActive]
+    def hyper_parameters() -> Sequence[HyperParameter]:
+        return [TauActive]
 
     @property
     def config(self) -> VoiceActivityDetectionConfig:
diff --git a/src/diart/utils.py b/src/diart/utils.py
index e825ef29..3714a99d 100644
--- a/src/diart/utils.py
+++ b/src/diart/utils.py
@@ -7,7 +7,7 @@
 from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
 
 from .progress import ProgressBar
-from . import blocks
+from . import pipelines
 
 
 class Chronometer:
@@ -82,7 +82,7 @@ def repeat_label(label: Text):
 
 
 def get_pipeline_class(class_name: Text) -> type:
-    pipeline_class = getattr(blocks, class_name, None)
+    pipeline_class = getattr(pipelines, class_name, None)
     msg = f"Pipeline '{class_name}' doesn't exist"
     assert pipeline_class is not None, msg
     return pipeline_class

From 42fe5f7505e823166a335cd7d4142379f43e2ec7 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Sun, 23 Apr 2023 18:23:39 +0200
Subject: [PATCH 13/23] Add websocket compatibility to transcription pipeline

---
 src/diart/console/serve.py           | 14 ++++++----
 src/diart/pipelines/base.py          |  5 ++++
 src/diart/pipelines/diarization.py   | 10 +++++--
 src/diart/pipelines/transcription.py | 40 +++++++++++++++++++++-------
 src/diart/pipelines/voice.py         | 10 +++++--
 src/diart/sinks.py                   | 23 +++++++++++-----
 src/diart/sources.py                 |  5 ++--
 src/diart/utils.py                   |  6 +++++
 8 files changed, 87 insertions(+), 26 deletions(-)

diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py
index 46bb9328..4e7ca1ce 100644
--- a/src/diart/console/serve.py
+++ b/src/diart/console/serve.py
@@ -5,7 +5,7 @@
 from diart import sources as src
 from diart import utils
 from diart.inference import StreamingInference
-from diart.sinks import RTTMWriter
+from diart.pipelines import StreamingPipeline
 
 
 def run():
@@ -14,6 +14,10 @@ def run():
     parser.add_argument("--port", default=7007, type=int, help="Server port")
     parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
                         help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
+    parser.add_argument("--whisper", default="small", type=str,
+                        help=f"Whisper model for transcription pipeline. Defaults to 'small'")
+    parser.add_argument("--language", default="en", type=str,
+                        help=f"Transcribe in this language. Defaults to 'en' (English)")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -36,7 +40,7 @@ def run():
     # Resolve pipeline
     pipeline_class = utils.get_pipeline_class(args.pipeline)
     config = pipeline_class.get_config_class().from_dict(vars(args))
-    pipeline = pipeline_class(config)
+    pipeline: StreamingPipeline = pipeline_class(config)
 
     # Create websocket audio source
     audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port)
@@ -53,10 +57,10 @@ def run():
 
     # Write to disk if required
     if args.output is not None:
-        inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm"))
+        inference.attach_observers(pipeline.suggest_writer(audio_source.uri, args.output))
 
-    # Send back responses as RTTM text lines
-    inference.attach_hooks(lambda ann_wav: audio_source.send(ann_wav[0].to_rttm()))
+    # Send back responses as text
+    inference.attach_hooks(lambda pred_wav: audio_source.send(utils.serialize_prediction(pred_wav[0])))
 
     # Run server and pipeline
     inference()
diff --git a/src/diart/pipelines/base.py b/src/diart/pipelines/base.py
index 507bfe5d..d47115d5 100644
--- a/src/diart/pipelines/base.py
+++ b/src/diart/pipelines/base.py
@@ -3,6 +3,7 @@
 
 import numpy as np
 from pyannote.core import SlidingWindowFeature
+from rx.core import Observer
 
 from .hparams import HyperParameter
 from .. import utils
@@ -50,6 +51,10 @@ def get_config_class() -> type:
     def suggest_metric() -> Metric:
         raise NotImplementedError
 
+    @staticmethod
+    def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer:
+        raise NotImplementedError
+
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         raise NotImplementedError
diff --git a/src/diart/pipelines/diarization.py b/src/diart/pipelines/diarization.py
index 7c901799..dd9ef070 100644
--- a/src/diart/pipelines/diarization.py
+++ b/src/diart/pipelines/diarization.py
@@ -4,12 +4,14 @@
 import numpy as np
 import torch
 from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
+from rx.core import Observer
 from typing_extensions import Literal
 
 from . import base
 from .hparams import HyperParameter, TauActive, RhoUpdate, DeltaNew
 from .. import blocks
 from .. import models as m
+from .. import sinks
 from .. import utils
 from ..metrics import Metric, DiarizationErrorRate
 
@@ -22,7 +24,7 @@ def __init__(
         duration: Optional[float] = None,
         step: float = 0.5,
         latency: Optional[Union[float, Literal["max", "min"]]] = None,
-        tau_active: float = 0.6,
+        tau_active: float = 0.5,
         rho_update: float = 0.3,
         delta_new: float = 1,
         gamma: float = 3,
@@ -82,7 +84,7 @@ def from_dict(data: Any) -> 'SpeakerDiarizationConfig':
         # Hyper-parameters and their aliases
         tau = utils.get(data, "tau_active", None)
         if tau is None:
-            tau = utils.get(data, "tau", 0.6)
+            tau = utils.get(data, "tau", 0.5)
         rho = utils.get(data, "rho_update", None)
         if rho is None:
             rho = utils.get(data, "rho", 0.3)
@@ -168,6 +170,10 @@ def get_config_class() -> type:
     def suggest_metric() -> Metric:
         return DiarizationErrorRate(collar=0, skip_overlap=False)
 
+    @staticmethod
+    def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer:
+        return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
+
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         return [TauActive, RhoUpdate, DeltaNew]
diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py
index 2bf4a270..6b3249c9 100644
--- a/src/diart/pipelines/transcription.py
+++ b/src/diart/pipelines/transcription.py
@@ -4,11 +4,13 @@
 import numpy as np
 import torch
 from pyannote.core import SlidingWindowFeature
+from rx.core import Observer
 
 from . import base
 from .hparams import HyperParameter, TauActive
 from .. import blocks
 from .. import models as m
+from .. import sinks
 from .. import utils
 from ..metrics import Metric, WordErrorRate
 
@@ -17,9 +19,9 @@ class TranscriptionConfig(base.StreamingConfig):
     def __init__(
         self,
         asr: Optional[m.SpeechRecognitionModel] = None,
-        vad: Optional[m.SegmentationModel] = None,
+        segmentation: Optional[m.SegmentationModel] = None,
         tau_active: float = 0.5,
-        duration: Optional[float] = None,
+        duration: Optional[float] = 3,
         language: Optional[Text] = None,
         beam_size: int = None,
         device: Optional[torch.device] = None,
@@ -35,7 +37,7 @@ def __init__(
         self.asr.set_language(language)
         self.asr.set_beam_size(beam_size)
 
-        self.vad = vad
+        self.segmentation = segmentation
         self.tau_active = tau_active
 
         self._duration = duration
@@ -67,11 +69,27 @@ def from_dict(data: Any) -> 'TranscriptionConfig':
         device = utils.get(data, "device", None)
         if device is None:
             device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+
+        # Default ASR model is Whisper small (244M parameters)
+        whisper_size = utils.get(data, "whisper", "small")
+        asr = m.SpeechRecognitionModel.from_whisper(whisper_size)
+
+        # No VAD segmentation by default
+        segmentation = utils.get(data, "segmentation", None)
+        if segmentation is not None:
+            hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
+            segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
+
+        # Tau hyper-parameter and its alias
+        tau = utils.get(data, "tau_active", None)
+        if tau is None:
+            tau = utils.get(data, "tau", 0.5)
+
         return TranscriptionConfig(
-            asr=utils.get(data, "asr", None),
-            vad=utils.get(data, "vad", None),
-            tau_active=utils.get(data, "tau_active", None),
-            duration=utils.get(data, "duration", None),
+            asr=asr,
+            segmentation=segmentation,
+            tau_active=tau,
+            duration=utils.get(data, "duration", 3),
             language=utils.get(data, "language", None),
             beam_size=utils.get(data, "beam_size", None),
             device=device,
@@ -83,8 +101,8 @@ def __init__(self, config: Optional[TranscriptionConfig] = None):
         self._config = TranscriptionConfig() if config is None else config
         self.asr = blocks.SpeechRecognition(self.config.asr, self.config.device)
         self.segmentation = None
-        if self.config.vad is not None:
-            self.segmentation = blocks.SpeakerSegmentation(self.config.vad, self.config.device)
+        if self.config.segmentation is not None:
+            self.segmentation = blocks.SpeakerSegmentation(self.config.segmentation, self.config.device)
 
     @staticmethod
     def get_config_class() -> type:
@@ -94,6 +112,10 @@ def get_config_class() -> type:
     def suggest_metric() -> Metric:
         return WordErrorRate()
 
+    @staticmethod
+    def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer:
+        return sinks.TextWriter(Path(output_dir) / f"{uri}.txt")
+
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         return [TauActive]
diff --git a/src/diart/pipelines/voice.py b/src/diart/pipelines/voice.py
index f93c28a6..f050a972 100644
--- a/src/diart/pipelines/voice.py
+++ b/src/diart/pipelines/voice.py
@@ -4,12 +4,14 @@
 import numpy as np
 import torch
 from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment
+from rx.core import Observer
 from typing_extensions import Literal
 
 from . import base
 from .hparams import HyperParameter, TauActive
 from .. import blocks
 from .. import models as m
+from .. import sinks
 from .. import utils
 from ..metrics import Metric, DetectionErrorRate
 
@@ -21,7 +23,7 @@ def __init__(
         duration: Optional[float] = None,
         step: float = 0.5,
         latency: Optional[Union[float, Literal["max", "min"]]] = None,
-        tau_active: float = 0.6,
+        tau_active: float = 0.5,
         merge_collar: float = 0.05,
         device: Optional[torch.device] = None,
         **kwargs,
@@ -85,7 +87,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig':
         # Tau active and its alias
         tau = utils.get(data, "tau_active", None)
         if tau is None:
-            tau = utils.get(data, "tau", 0.6)
+            tau = utils.get(data, "tau", 0.5)
 
         return VoiceActivityDetectionConfig(
             segmentation=segmentation,
@@ -132,6 +134,10 @@ def get_config_class() -> type:
     def suggest_metric() -> Metric:
         return DetectionErrorRate(collar=0, skip_overlap=False)
 
+    @staticmethod
+    def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer:
+        return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
+
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         return [TauActive]
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index 8d9217a1..5d77e09e 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -1,5 +1,5 @@
 from pathlib import Path
-from typing import Union, Text, Optional, Tuple
+from typing import Union, Text, Optional, Tuple, Any
 
 import matplotlib.pyplot as plt
 from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
@@ -13,13 +13,10 @@ class WindowClosedException(Exception):
     pass
 
 
-def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation:
+def _extract_prediction(value: Union[Tuple, Any]) -> Any:
     if isinstance(value, tuple):
         return value[0]
-    if isinstance(value, Annotation):
-        return value
-    msg = f"Expected tuple or Annotation, but got {type(value)}"
-    raise ValueError(msg)
+    return value
 
 
 class RTTMWriter(Observer):
@@ -56,6 +53,20 @@ def on_completed(self):
         self.patch()
 
 
+class TextWriter(Observer):
+    def __init__(self, path: Union[Path, Text]):
+        super().__init__()
+        self.path = Path(path).expanduser()
+        if self.path.exists():
+            self.path.unlink()
+
+    def on_next(self, value: Union[Tuple, Text]):
+        # Write transcription to file
+        prediction = _extract_prediction(value)
+        with open(self.path, 'a') as file:
+            file.write(prediction + "\n")
+
+
 class DiarizationAccumulator(Observer):
     def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05):
         super().__init__()
diff --git a/src/diart/sources.py b/src/diart/sources.py
index b34d5cf3..76149bb4 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -247,8 +247,9 @@ def send(self, message: AnyStr):
         message: AnyStr
             Bytes or string to send.
         """
-        if len(message) > 0:
-            self.server.send_message(self.client, message)
+        msg = message.strip()
+        if len(msg) > 0:
+            self.server.send_message(self.client, msg + "\n")
 
 
 class TorchStreamAudioSource(AudioSource):
diff --git a/src/diart/utils.py b/src/diart/utils.py
index 3714a99d..a5f06e21 100644
--- a/src/diart/utils.py
+++ b/src/diart/utils.py
@@ -92,6 +92,12 @@ def get_padding_right(latency: float, step: float) -> float:
     return latency - step
 
 
+def serialize_prediction(value: Union[Annotation, Text]) -> Text:
+    if isinstance(value, Annotation):
+        return value.to_rttm()
+    return value
+
+
 def visualize_feature(duration: Optional[float] = None):
     def apply(feature: SlidingWindowFeature):
         if duration is None:

From 49616e58bfafbf0edba2a92ce8cf3e95d9f2edce Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Sun, 23 Apr 2023 19:49:46 +0200
Subject: [PATCH 14/23] Transcription pipeline is now fully compatible with
 diart.stream

---
 src/diart/console/serve.py           |   1 -
 src/diart/console/stream.py          |  20 ++-
 src/diart/inference.py               |  23 +---
 src/diart/operators.py               | 193 +--------------------------
 src/diart/pipelines/base.py          |  17 +--
 src/diart/pipelines/diarization.py   |  22 +--
 src/diart/pipelines/transcription.py |  17 +--
 src/diart/pipelines/voice.py         |  22 +--
 src/diart/sinks.py                   | 114 +++++++++++++---
 9 files changed, 161 insertions(+), 268 deletions(-)

diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py
index 4e7ca1ce..5c7c68b0 100644
--- a/src/diart/console/serve.py
+++ b/src/diart/console/serve.py
@@ -51,7 +51,6 @@ def run():
         audio_source,
         batch_size=1,
         do_profile=False,
-        do_plot=False,
         show_progress=True,
     )
 
diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py
index e0c670c5..3563dd2a 100644
--- a/src/diart/console/stream.py
+++ b/src/diart/console/stream.py
@@ -5,7 +5,7 @@
 from diart import sources as src
 from diart import utils
 from diart.inference import StreamingInference
-from diart.sinks import RTTMWriter
+from diart.pipelines import StreamingPipeline, StreamingConfig
 
 
 def run():
@@ -13,6 +13,10 @@ def run():
     parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:<DEVICE_ID>'")
     parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
                         help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
+    parser.add_argument("--whisper", default="small", type=str,
+                        help=f"Whisper model for transcription pipeline. Defaults to 'small'")
+    parser.add_argument("--language", default="en", type=str,
+                        help=f"Transcribe in this language. Defaults to 'en' (English)")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -36,8 +40,8 @@ def run():
 
     # Resolve pipeline
     pipeline_class = utils.get_pipeline_class(args.pipeline)
-    config = pipeline_class.get_config_class().from_dict(vars(args))
-    pipeline = pipeline_class(config)
+    config: StreamingConfig = pipeline_class.get_config_class().from_dict(vars(args))
+    pipeline: StreamingPipeline = pipeline_class(config)
 
     # Manage audio source
     block_size = config.optimal_block_size()
@@ -59,10 +63,16 @@ def run():
         audio_source,
         batch_size=1,
         do_profile=True,
-        do_plot=not args.no_plot,
         show_progress=True,
     )
-    inference.attach_observers(RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm"))
+
+    # Attach observers for required side effects
+    observers = [pipeline.suggest_writer(audio_source.uri, args.output)]
+    if not args.no_plot:
+        observers.append(pipeline.suggest_display())
+    inference.attach_observers(*observers)
+
+    # Run pipeline
     inference()
 
 
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 258f773a..b9f2a789 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -20,7 +20,7 @@
 from .metrics import Metric
 from .pipelines import StreamingPipeline, StreamingConfig
 from .progress import ProgressBar, RichProgressBar, TQDMProgressBar
-from .sinks import StreamingPlot, WindowClosedException
+from .sinks import WindowClosedException
 
 
 class StreamingInference:
@@ -40,9 +40,6 @@ class StreamingInference:
     do_profile: bool
         If True, compute and report the processing time of the pipeline.
         Defaults to True.
-    do_plot: bool
-        If True, draw predictions in a moving plot.
-        Defaults to False.
     show_progress: bool
         If True, show a progress bar.
         Defaults to True.
@@ -57,7 +54,6 @@ def __init__(
         source: src.AudioSource,
         batch_size: int = 1,
         do_profile: bool = True,
-        do_plot: bool = False,
         show_progress: bool = True,
         progress_bar: Optional[ProgressBar] = None,
     ):
@@ -65,7 +61,6 @@ def __init__(
         self.source = source
         self.batch_size = batch_size
         self.do_profile = do_profile
-        self.do_plot = do_plot
         self.show_progress = show_progress
         self.unit = "chunk" if self.batch_size == 1 else "batch"
         self._observers = []
@@ -192,20 +187,7 @@ def __call__(self) -> List[Any]:
         """
         if self.show_progress:
             self._pbar.start()
-        config = self.pipeline.config
-        observable = self.stream
-        if self.do_plot:
-            # Buffering is needed for the real-time plot, so we do this at the very end
-            observable = self.stream.pipe(
-                dops.buffer_output(
-                    duration=config.duration,
-                    step=config.step,
-                    latency=config.latency,
-                    sample_rate=config.sample_rate,
-                ),
-                ops.do(StreamingPlot(config.duration, config.latency)),
-            )
-        observable.subscribe(
+        self.stream.subscribe(
             on_error=self._handle_error,
             on_completed=self._handle_completion,
         )
@@ -324,7 +306,6 @@ def run_single(
             source,
             self.batch_size,
             do_profile=False,
-            do_plot=False,
             show_progress=self.show_progress,
             progress_bar=progress_bar,
         )
diff --git a/src/diart/operators.py b/src/diart/operators.py
index 6d73fc9d..c67bf99d 100644
--- a/src/diart/operators.py
+++ b/src/diart/operators.py
@@ -1,9 +1,9 @@
 from dataclasses import dataclass
-from typing import Callable, Optional, List, Any, Tuple
+from typing import Callable, Optional, List, Any
 
 import numpy as np
 import rx
-from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature, Segment
+from pyannote.core import SlidingWindow, SlidingWindowFeature
 from rx import operators as ops
 from rx.core import Observable
 
@@ -97,192 +97,3 @@ def accumulate(state: List[Any], value: Any) -> List[Any]:
             return new_state[1:]
         return new_state
     return rx.pipe(ops.scan(accumulate, []))
-
-
-@dataclass
-class PredictionWithAudio:
-    prediction: Annotation
-    waveform: Optional[SlidingWindowFeature] = None
-
-    @property
-    def has_audio(self) -> bool:
-        return self.waveform is not None
-
-
-@dataclass
-class OutputAccumulationState:
-    annotation: Optional[Annotation]
-    waveform: Optional[SlidingWindowFeature]
-    real_time: float
-    next_sample: Optional[int]
-
-    @staticmethod
-    def initial() -> 'OutputAccumulationState':
-        return OutputAccumulationState(None, None, 0, 0)
-
-    @property
-    def cropped_waveform(self) -> SlidingWindowFeature:
-        return SlidingWindowFeature(
-            self.waveform[:self.next_sample],
-            self.waveform.sliding_window,
-        )
-
-    def to_tuple(self) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]:
-        return self.annotation, self.cropped_waveform, self.real_time
-
-
-def accumulate_output(
-    duration: float,
-    step: float,
-    patch_collar: float = 0.05,
-) -> Operator:
-    """Accumulate predictions and audio to infinity: O(N) space complexity.
-    Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations.
-
-    Parameters
-    ----------
-    duration: float
-        Buffer duration in seconds.
-    step: float
-        Duration of the chunks at each event in seconds.
-        The first chunk may be bigger given the latency.
-    patch_collar: float, optional
-        Collar to merge speaker turns of the same speaker, in seconds.
-        Defaults to 0.05 (i.e. 50ms).
-    Returns
-    -------
-    A reactive x operator implementing this behavior.
-    """
-    def accumulate(
-        state: OutputAccumulationState,
-        value: Tuple[Annotation, Optional[SlidingWindowFeature]]
-    ) -> OutputAccumulationState:
-        value = PredictionWithAudio(*value)
-        annotation, waveform = None, None
-
-        # Determine the real time of the stream
-        real_time = duration if state.annotation is None else state.real_time + step
-
-        # Update total annotation with current predictions
-        if state.annotation is None:
-            annotation = value.prediction
-        else:
-            annotation = state.annotation.update(value.prediction).support(patch_collar)
-
-        # Update total waveform if there's audio in the input
-        new_next_sample = 0
-        if value.has_audio:
-            num_new_samples = value.waveform.data.shape[0]
-            new_next_sample = state.next_sample + num_new_samples
-            sw_holder = state
-            if state.waveform is None:
-                # Initialize the audio buffer with 10 times the size of the first chunk
-                waveform, sw_holder = np.zeros((10 * num_new_samples, 1)), value
-            elif new_next_sample < state.waveform.data.shape[0]:
-                # The buffer still has enough space to accommodate the chunk
-                waveform = state.waveform.data
-            else:
-                # The buffer is full, double its size
-                waveform = np.concatenate(
-                    (state.waveform.data, np.zeros_like(state.waveform.data)), axis=0
-                )
-            # Copy chunk into buffer
-            waveform[state.next_sample:new_next_sample] = value.waveform.data
-            waveform = SlidingWindowFeature(waveform, sw_holder.waveform.sliding_window)
-
-        return OutputAccumulationState(annotation, waveform, real_time, new_next_sample)
-
-    return rx.pipe(
-        ops.scan(accumulate, OutputAccumulationState.initial()),
-        ops.map(OutputAccumulationState.to_tuple),
-    )
-
-
-def buffer_output(
-    duration: float,
-    step: float,
-    latency: float,
-    sample_rate: int,
-    patch_collar: float = 0.05,
-) -> Operator:
-    """Store last predictions and audio inside a fixed buffer.
-    Provides the best time/space complexity trade-off if the past data is not needed.
-
-    Parameters
-    ----------
-    duration: float
-        Buffer duration in seconds.
-    step: float
-        Duration of the chunks at each event in seconds.
-        The first chunk may be bigger given the latency.
-    latency: float
-        Latency of the system in seconds.
-    sample_rate: int
-        Sample rate of the audio source.
-    patch_collar: float, optional
-        Collar to merge speaker turns of the same speaker, in seconds.
-        Defaults to 0.05 (i.e. 50ms).
-
-    Returns
-    -------
-    A reactive x operator implementing this behavior.
-    """
-    # Define some useful constants
-    num_samples = int(round(duration * sample_rate))
-    num_step_samples = int(round(step * sample_rate))
-    resolution = 1 / sample_rate
-
-    def accumulate(
-        state: OutputAccumulationState,
-        value: Tuple[Annotation, Optional[SlidingWindowFeature]]
-    ) -> OutputAccumulationState:
-        value = PredictionWithAudio(*value)
-        annotation, waveform = None, None
-
-        # Determine the real time of the stream and the start time of the buffer
-        real_time = duration if state.annotation is None else state.real_time + step
-        start_time = max(0., real_time - latency - duration)
-
-        # Update annotation and constrain its bounds to the buffer
-        if state.annotation is None:
-            annotation = value.prediction
-        else:
-            annotation = state.annotation.update(value.prediction).support(patch_collar)
-            if start_time > 0:
-                annotation = annotation.extrude(Segment(0, start_time))
-
-        # Update the audio buffer if there's audio in the input
-        new_next_sample = state.next_sample + num_step_samples
-        if value.has_audio:
-            if state.waveform is None:
-                # Determine the size of the first chunk
-                expected_duration = duration + step - latency
-                expected_samples = int(round(expected_duration * sample_rate))
-                # Shift indicator to start copying new audio in the buffer
-                new_next_sample = state.next_sample + expected_samples
-                # Buffer size is duration + step
-                waveform = np.zeros((num_samples + num_step_samples, 1))
-                # Copy first chunk into buffer (slicing because of rounding errors)
-                waveform[:expected_samples] = value.waveform.data[:expected_samples]
-            elif state.next_sample <= num_samples:
-                # The buffer isn't full, copy into next free buffer chunk
-                waveform = state.waveform.data
-                waveform[state.next_sample:new_next_sample] = value.waveform.data
-            else:
-                # The buffer is full, shift values to the left and copy into last buffer chunk
-                waveform = np.roll(state.waveform.data, -num_step_samples, axis=0)
-                # If running on a file, the online prediction may be shorter depending on the latency
-                # The remaining audio at the end is appended, so value.waveform may be longer than num_step_samples
-                # In that case, we simply ignore the appended samples.
-                waveform[-num_step_samples:] = value.waveform.data[:num_step_samples]
-
-            # Wrap waveform in a sliding window feature to include timestamps
-            window = SlidingWindow(start=start_time, duration=resolution, step=resolution)
-            waveform = SlidingWindowFeature(waveform, window)
-
-        return OutputAccumulationState(annotation, waveform, real_time, new_next_sample)
-
-    return rx.pipe(
-        ops.scan(accumulate, OutputAccumulationState.initial()),
-        ops.map(OutputAccumulationState.to_tuple),
-    )
diff --git a/src/diart/pipelines/base.py b/src/diart/pipelines/base.py
index d47115d5..1a220203 100644
--- a/src/diart/pipelines/base.py
+++ b/src/diart/pipelines/base.py
@@ -47,14 +47,6 @@ class StreamingPipeline:
     def get_config_class() -> type:
         raise NotImplementedError
 
-    @staticmethod
-    def suggest_metric() -> Metric:
-        raise NotImplementedError
-
-    @staticmethod
-    def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer:
-        raise NotImplementedError
-
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         raise NotImplementedError
@@ -75,6 +67,15 @@ def join_predictions(self, predictions: List[Any]) -> Any:
     def write_prediction(self, uri: Text, prediction: Any, dir_path: Union[Text, Path]):
         raise NotImplementedError
 
+    def suggest_metric(self) -> Metric:
+        raise NotImplementedError
+
+    def suggest_display(self) -> Observer:
+        raise NotImplementedError
+
+    def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+        raise NotImplementedError
+
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
diff --git a/src/diart/pipelines/diarization.py b/src/diart/pipelines/diarization.py
index dd9ef070..57cb7739 100644
--- a/src/diart/pipelines/diarization.py
+++ b/src/diart/pipelines/diarization.py
@@ -166,14 +166,6 @@ def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
     def get_config_class() -> type:
         return SpeakerDiarizationConfig
 
-    @staticmethod
-    def suggest_metric() -> Metric:
-        return DiarizationErrorRate(collar=0, skip_overlap=False)
-
-    @staticmethod
-    def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer:
-        return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
-
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         return [TauActive, RhoUpdate, DeltaNew]
@@ -195,6 +187,20 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te
         with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file:
             prediction.write_rttm(out_file)
 
+    def suggest_metric(self) -> Metric:
+        return DiarizationErrorRate(collar=0, skip_overlap=False)
+
+    def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+        return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
+
+    def suggest_display(self) -> Observer:
+        return sinks.StreamingPlot(
+            self.config.duration,
+            self.config.step,
+            self.config.latency,
+            self.config.sample_rate
+        )
+
     def reset(self):
         self.set_timestamp_shift(0)
         self.clustering = blocks.IncrementalSpeakerClustering(
diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py
index 6b3249c9..cb5f40e7 100644
--- a/src/diart/pipelines/transcription.py
+++ b/src/diart/pipelines/transcription.py
@@ -108,14 +108,6 @@ def __init__(self, config: Optional[TranscriptionConfig] = None):
     def get_config_class() -> type:
         return TranscriptionConfig
 
-    @staticmethod
-    def suggest_metric() -> Metric:
-        return WordErrorRate()
-
-    @staticmethod
-    def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer:
-        return sinks.TextWriter(Path(output_dir) / f"{uri}.txt")
-
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         return [TauActive]
@@ -139,6 +131,15 @@ def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Pa
         with open(Path(dir_path) / f"{uri}.txt", "w") as out_file:
             out_file.write(prediction)
 
+    def suggest_metric(self) -> Metric:
+        return WordErrorRate()
+
+    def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+        return sinks.TextWriter(Path(output_dir) / f"{uri}.txt")
+
+    def suggest_display(self) -> Observer:
+        return sinks.RichScreen()
+
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
diff --git a/src/diart/pipelines/voice.py b/src/diart/pipelines/voice.py
index f050a972..b22806ab 100644
--- a/src/diart/pipelines/voice.py
+++ b/src/diart/pipelines/voice.py
@@ -130,14 +130,6 @@ def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
     def get_config_class() -> type:
         return VoiceActivityDetectionConfig
 
-    @staticmethod
-    def suggest_metric() -> Metric:
-        return DetectionErrorRate(collar=0, skip_overlap=False)
-
-    @staticmethod
-    def suggest_writer(uri: Text, output_dir: Union[Text, Path]) -> Observer:
-        return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
-
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         return [TauActive]
@@ -163,6 +155,20 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te
         with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file:
             prediction.write_rttm(out_file)
 
+    def suggest_metric(self) -> Metric:
+        return DetectionErrorRate(collar=0, skip_overlap=False)
+
+    def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+        return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
+
+    def suggest_display(self) -> Observer:
+        return sinks.StreamingPlot(
+            self.config.duration,
+            self.config.step,
+            self.config.latency,
+            self.config.sample_rate
+        )
+
     def __call__(
         self,
         waveforms: Sequence[SlidingWindowFeature],
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index 5d77e09e..be461fff 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -1,12 +1,14 @@
+import re
 from pathlib import Path
-from typing import Union, Text, Optional, Tuple, Any
+from typing import Union, Text, Optional, Tuple, Any, List
 
 import matplotlib.pyplot as plt
-from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
+import numpy as np
+import rich
+from pyannote.core import Annotation, Segment, SlidingWindowFeature, SlidingWindow, notebook
 from pyannote.database.util import load_rttm
 from pyannote.metrics.diarization import DiarizationErrorRate
 from rx.core import Observer
-from typing_extensions import Literal
 
 
 class WindowClosedException(Exception):
@@ -99,25 +101,59 @@ def on_completed(self):
         self.patch()
 
 
+class RichScreen(Observer):
+    def __init__(self, speaker_colors: Optional[List[Text]] = None):
+        super().__init__()
+        self.colors = speaker_colors
+        if self.colors is None:
+            self.colors = [
+                "bright_red", "bright_blue", "bright_green", "orange3", "deep_pink1",
+                "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2"
+            ]
+        self.num_colors = len(self.colors)
+
+    def on_next(self, value: Union[Tuple, Text]):
+        prediction = _extract_prediction(value)
+        # Extract speakers
+        speakers = sorted(re.findall(r'\[.*?]', prediction))
+        # Colorize based on speakers
+        colorized = prediction
+        for i, speaker in enumerate(speakers):
+            colorized = colorized.replace(speaker, f"[{self.colors[i % self.num_colors]}]")
+        # Print result
+        rich.print(colorized)
+
+
 class StreamingPlot(Observer):
     def __init__(
         self,
         duration: float,
+        step: float,
         latency: float,
-        visualization: Literal["slide", "accumulate"] = "slide",
+        sample_rate: float,
         reference: Optional[Union[Path, Text]] = None,
+        patch_collar: float = 0.05,
     ):
         super().__init__()
-        assert visualization in ["slide", "accumulate"]
-        self.visualization = visualization
         self.reference = reference
         if self.reference is not None:
             self.reference = list(load_rttm(reference).values())[0]
         self.window_duration = duration
+        self.window_step = step
         self.latency = latency
+        self.sample_rate = sample_rate
+        self.patch_collar = patch_collar
+
+        self.num_window_samples = int(np.rint(self.window_duration * self.sample_rate))
+        self.num_step_samples = int(np.rint(self.window_step * self.sample_rate))
+        self.audio_resolution = 1 / self.sample_rate
+
         self.figure, self.axs, self.num_axs = None, None, -1
         # This flag allows to catch the matplotlib window closed event and make the next call stop iterating
         self.window_closed = False
+        self.real_time = 0
+        self.pred_buffer, self.audio_buffer = None, None
+        self.next_sample = 0
 
     def _on_window_closed(self, event):
         self.window_closed = True
@@ -139,21 +175,63 @@ def _clear_axs(self):
         for i in range(self.num_axs):
             self.axs[i].clear()
 
-    def get_plot_bounds(self, real_time: float) -> Segment:
-        start_time = 0
-        end_time = real_time - self.latency
-        if self.visualization == "slide":
-            start_time = max(0., end_time - self.window_duration)
+    def get_plot_bounds(self) -> Segment:
+        end_time = self.real_time - self.latency
+        start_time = max(0., end_time - self.window_duration)
         return Segment(start_time, end_time)
 
     def on_next(
         self,
-        values: Tuple[Annotation, SlidingWindowFeature, float]
+        values: Tuple[Annotation, SlidingWindowFeature]
     ):
         if self.window_closed:
             raise WindowClosedException
 
-        prediction, waveform, real_time = values
+        prediction, waveform = values
+
+        # TODO break this aggregation code into methods
+
+        # Determine the real time of the stream and the start time of the buffer
+        self.real_time = waveform.extent.end
+        start_time = max(0., self.real_time - self.latency - self.window_duration)
+
+        # Update prediction buffer and constrain its bounds
+        if self.pred_buffer is None:
+            self.pred_buffer = prediction
+        else:
+            self.pred_buffer = self.pred_buffer.update(prediction)
+            self.pred_buffer = self.pred_buffer.support(self.patch_collar)
+            if start_time > 0:
+                self.pred_buffer = self.pred_buffer.extrude(Segment(0, start_time))
+
+        # Update the audio buffer if there's audio in the input
+        new_next_sample = self.next_sample + self.num_step_samples
+        if self.audio_buffer is None:
+            # Determine the size of the first chunk
+            expected_duration = self.window_duration + self.window_step - self.latency
+            expected_samples = int(np.rint(expected_duration * self.sample_rate))
+            # Shift indicator to start copying new audio in the buffer
+            new_next_sample = self.next_sample + expected_samples
+            # Buffer size is duration + step
+            new_buffer = np.zeros((self.num_window_samples + self.num_step_samples, 1))
+            # Copy first chunk into buffer (slicing because of rounding errors)
+            new_buffer[:expected_samples] = waveform.data[:expected_samples]
+        elif self.next_sample <= self.num_window_samples:
+            # The buffer isn't full, copy into next free buffer chunk
+            new_buffer = self.audio_buffer.data
+            new_buffer[self.next_sample:new_next_sample] = waveform.data
+        else:
+            # The buffer is full, shift values to the left and copy into last buffer chunk
+            new_buffer = np.roll(self.audio_buffer.data, -self.num_step_samples, axis=0)
+            # If running on a file, the online prediction may be shorter depending on the latency
+            # The remaining audio at the end is appended, so 'waveform' may be longer than 'num_step_samples'
+            # In that case, we simply ignore the appended samples.
+            new_buffer[-self.num_step_samples:] = waveform.data[:self.num_step_samples]
+
+        # Wrap waveform in a sliding window feature to include timestamps
+        window = SlidingWindow(start=start_time, duration=self.audio_resolution, step=self.audio_resolution)
+        self.audio_buffer = SlidingWindowFeature(new_buffer, window)
+        self.next_sample = new_next_sample
 
         # Initialize figure if first call
         if self.figure is None:
@@ -161,20 +239,20 @@ def on_next(
         # Clear previous plots
         self._clear_axs()
         # Set plot bounds
-        notebook.crop = self.get_plot_bounds(real_time)
+        notebook.crop = self.get_plot_bounds()
 
         # Align prediction and reference if possible
         if self.reference is not None:
             metric = DiarizationErrorRate()
-            mapping = metric.optimal_mapping(self.reference, prediction)
-            prediction.rename_labels(mapping=mapping, copy=False)
+            mapping = metric.optimal_mapping(self.reference, self.pred_buffer)
+            self.pred_buffer.rename_labels(mapping=mapping, copy=False)
 
         # Plot prediction
-        notebook.plot_annotation(prediction, self.axs[0])
+        notebook.plot_annotation(self.pred_buffer, self.axs[0])
         self.axs[0].set_title("Output")
 
         # Plot waveform
-        notebook.plot_feature(waveform, self.axs[1])
+        notebook.plot_feature(self.audio_buffer, self.axs[1])
         self.axs[1].set_title("Audio")
 
         # Plot reference if available

From babf49d9cec81a89ce593b86ea7113fb63a9bbe9 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Mon, 24 Apr 2023 11:22:50 +0200
Subject: [PATCH 15/23] Make transcription pipeline compatible with
 diart.benchmark and diart.tune. Fix major bug in Optimizer

---
 src/diart/__init__.py                |   4 +-
 src/diart/console/benchmark.py       |   6 ++
 src/diart/console/client.py          |   3 +-
 src/diart/console/serve.py           |   6 +-
 src/diart/console/stream.py          |  11 ++-
 src/diart/console/tune.py            |  12 ++-
 src/diart/inference.py               | 114 +++++++++++++--------------
 src/diart/optim.py                   |  11 ++-
 src/diart/pipelines/__init__.py      |   2 +-
 src/diart/pipelines/base.py          |  15 ++--
 src/diart/pipelines/diarization.py   |  11 +--
 src/diart/pipelines/transcription.py |  19 +++--
 src/diart/pipelines/voice.py         |  11 +--
 13 files changed, 125 insertions(+), 100 deletions(-)

diff --git a/src/diart/__init__.py b/src/diart/__init__.py
index 0d67c9c5..842ba267 100644
--- a/src/diart/__init__.py
+++ b/src/diart/__init__.py
@@ -1,6 +1,6 @@
 from .pipelines import (
-    StreamingPipeline,
-    StreamingConfig,
+    Pipeline,
+    PipelineConfig,
     SpeakerDiarization,
     SpeakerDiarizationConfig,
     VoiceActivityDetection,
diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py
index 27d524c5..d8f04183 100644
--- a/src/diart/console/benchmark.py
+++ b/src/diart/console/benchmark.py
@@ -12,12 +12,18 @@ def run():
     parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
     parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
                         help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
+    parser.add_argument("--whisper", default="small", type=str,
+                        help=f"Whisper model for transcription pipeline. Defaults to 'small'")
+    parser.add_argument("--language", default="en", type=str,
+                        help=f"Transcribe in this language. Defaults to 'en' (English)")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
                         help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
     parser.add_argument("--reference", type=Path,
                         help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
+    parser.add_argument("--duration", default=5, type=float,
+                        help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
     parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
     parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
     parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
diff --git a/src/diart/console/client.py b/src/diart/console/client.py
index db4915fa..816c7e0f 100644
--- a/src/diart/console/client.py
+++ b/src/diart/console/client.py
@@ -45,7 +45,8 @@ def run():
     parser.add_argument("--host", required=True, type=str, help="Server host")
     parser.add_argument("--port", required=True, type=int, help="Server port")
     parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
-    parser.add_argument("-sr", "--sample-rate", default=16000, type=int, help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000")
+    parser.add_argument("-sr", "--sample-rate", default=16000, type=int,
+                        help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000")
     parser.add_argument("-o", "--output-file", type=Path, help="Output RTTM file. Defaults to no writing")
     args = parser.parse_args()
 
diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py
index 5c7c68b0..0698ede0 100644
--- a/src/diart/console/serve.py
+++ b/src/diart/console/serve.py
@@ -5,7 +5,7 @@
 from diart import sources as src
 from diart import utils
 from diart.inference import StreamingInference
-from diart.pipelines import StreamingPipeline
+from diart.pipelines import Pipeline
 
 
 def run():
@@ -22,6 +22,8 @@ def run():
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
                         help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
+    parser.add_argument("--duration", type=float,
+                        help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
     parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
     parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
     parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -40,7 +42,7 @@ def run():
     # Resolve pipeline
     pipeline_class = utils.get_pipeline_class(args.pipeline)
     config = pipeline_class.get_config_class().from_dict(vars(args))
-    pipeline: StreamingPipeline = pipeline_class(config)
+    pipeline: Pipeline = pipeline_class(config)
 
     # Create websocket audio source
     audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port)
diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py
index 3563dd2a..1436eb8a 100644
--- a/src/diart/console/stream.py
+++ b/src/diart/console/stream.py
@@ -5,7 +5,7 @@
 from diart import sources as src
 from diart import utils
 from diart.inference import StreamingInference
-from diart.pipelines import StreamingPipeline, StreamingConfig
+from diart.pipelines import Pipeline, PipelineConfig
 
 
 def run():
@@ -21,6 +21,8 @@ def run():
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
                         help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
+    parser.add_argument("--duration", default=5, type=float,
+                        help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
     parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
     parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
     parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -33,15 +35,16 @@ def run():
     parser.add_argument("--cpu", dest="cpu", action="store_true",
                         help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
     parser.add_argument("--output", type=str,
-                        help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file")
+                        help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' "
+                             f"or parent directory if SOURCE is a file")
     parser.add_argument("--hf-token", default="true", type=str,
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
     args = parser.parse_args()
 
     # Resolve pipeline
     pipeline_class = utils.get_pipeline_class(args.pipeline)
-    config: StreamingConfig = pipeline_class.get_config_class().from_dict(vars(args))
-    pipeline: StreamingPipeline = pipeline_class(config)
+    config: PipelineConfig = pipeline_class.get_config_class().from_dict(vars(args))
+    pipeline: Pipeline = pipeline_class(config)
 
     # Manage audio source
     block_size = config.optimal_block_size()
diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py
index 111d97c2..f492c704 100644
--- a/src/diart/console/tune.py
+++ b/src/diart/console/tune.py
@@ -16,10 +16,16 @@ def run():
                         help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
     parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
                         help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
+    parser.add_argument("--whisper", default="small", type=str,
+                        help=f"Whisper model for transcription pipeline. Defaults to 'small'")
+    parser.add_argument("--language", default="en", type=str,
+                        help=f"Transcribe in this language. Defaults to 'en' (English)")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
                         help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
+    parser.add_argument("--duration", default=5, type=float,
+                        help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
     parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
     parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
     parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -32,10 +38,12 @@ def run():
     parser.add_argument("--cpu", dest="cpu", action="store_true",
                         help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
     parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"),
-                        help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new")
+                        help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. "
+                             "Defaults to tau_active, rho_update and delta_new")
     parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials")
     parser.add_argument("--storage", type=str,
-                        help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name")
+                        help="Optuna storage string. If provided, continue a previous study instead of creating one. "
+                             "The database name must match the study name")
     parser.add_argument("--output", type=str, help="Working directory")
     parser.add_argument("--hf-token", default="true", type=str,
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
diff --git a/src/diart/inference.py b/src/diart/inference.py
index b9f2a789..061065fb 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -18,22 +18,21 @@
 from . import sources as src
 from . import utils
 from .metrics import Metric
-from .pipelines import StreamingPipeline, StreamingConfig
+from .pipelines import Pipeline, PipelineConfig
 from .progress import ProgressBar, RichProgressBar, TQDMProgressBar
 from .sinks import WindowClosedException
 
 
 class StreamingInference:
-    """Performs inference in real time given a pipeline and an audio source.
-    Streams an audio source to an online speaker diarization pipeline.
-    It allows users to attach a chain of operations in the form of hooks.
+    """Performs streaming inference given a pipeline and an audio source.
+    Side-effect hooks and observers can also be attached for customized behavior.
 
     Parameters
     ----------
-    pipeline: StreamingPipeline
-        Configured speaker diarization pipeline.
+    pipeline: Pipeline
+        A pipeline.
     source: AudioSource
-        Audio source to be read and streamed.
+        Audio source to read and stream.
     batch_size: int
         Number of inputs to send to the pipeline at once.
         Defaults to 1.
@@ -50,7 +49,7 @@ class StreamingInference:
     """
     def __init__(
         self,
-        pipeline: StreamingPipeline,
+        pipeline: Pipeline,
         source: src.AudioSource,
         batch_size: int = 1,
         do_profile: bool = True,
@@ -62,7 +61,6 @@ def __init__(
         self.batch_size = batch_size
         self.do_profile = do_profile
         self.show_progress = show_progress
-        self.unit = "chunk" if self.batch_size == 1 else "batch"
         self._observers = []
         self._predictions = []
 
@@ -84,11 +82,11 @@ def __init__(
             self._pbar.create(
                 total=self.num_chunks,
                 description=f"Streaming {self.source.uri}",
-                unit=self.unit
+                unit="chunk"
             )
 
         # Initialize chronometer for profiling
-        self._chrono = utils.Chronometer(self.unit, self._pbar)
+        self._chrono = utils.Chronometer("batch", self._pbar)
 
         self.stream = self.source.stream
 
@@ -198,21 +196,21 @@ def __call__(self) -> List[Any]:
 
 class Benchmark:
     """
-    Run an online speaker diarization pipeline on a set of audio files in batches.
+    Run a pipeline on a set of audio files in batches.
     Write predictions to a given output directory.
 
-    If the reference is given, calculate the average diarization error rate.
+    If the reference is given, compute the average performance metric.
 
     Parameters
     ----------
     speech_path: Text or Path
         Directory with audio files.
     reference_path: Text, Path or None
-        Directory with reference RTTM files (same names as audio files).
-        If None, performance will not be calculated.
+        Directory with reference files (same names as audio files with different extension).
+        If None, performance will not be computed.
         Defaults to None.
-    output_path: Text, Path or None
-        Output directory to store predictions in RTTM format.
+    output_path: Optional[Text | Path]
+        Output directory to store predictions.
         If None, predictions will not be written to disk.
         Defaults to None.
     show_progress: bool
@@ -222,12 +220,7 @@ class Benchmark:
         Whether to print a performance report to stdout.
         Defaults to True.
     batch_size: int
-        Inference batch size.
-        If < 2, then it will run in real time.
-        If >= 2, then it will pre-calculate segmentation and
-        embeddings, running the rest in real time.
-        The performance between this two modes does not differ.
-        Defaults to 32.
+        Inference batch size. Defaults to 32.
     """
     def __init__(
         self,
@@ -252,7 +245,7 @@ def __init__(
 
         self.output_path = output_path
         if self.output_path is not None:
-            self.output_path = Path(output_path).expanduser()
+            self.output_path: Path = Path(output_path).expanduser()
             self.output_path.mkdir(parents=True, exist_ok=True)
 
         self.show_progress = show_progress
@@ -271,27 +264,29 @@ def get_file_paths(self) -> List[Path]:
 
     def run_single(
         self,
-        pipeline: StreamingPipeline,
+        pipeline: Pipeline,
         filepath: Path,
         progress_bar: ProgressBar,
     ) -> Tuple[Text, Any]:
         """Run a given pipeline on a given file.
-        Note that this method does NOT reset the
+        This method does NOT reset the
         state of the pipeline before execution.
 
         Parameters
         ----------
-        pipeline: StreamingPipeline
-            Speaker diarization pipeline to run.
+        pipeline: Pipeline
+            A pipeline.
         filepath: Path
             Path to the target file.
         progress_bar: diart.progress.ProgressBar
-            An object to manage the progress of this run.
+            Object to display the progress of this run.
 
         Returns
         -------
-        prediction: Annotation
-            Pipeline prediction for the given file.
+        uri: Text
+            File URI.
+        prediction: Any
+            Aggregated pipeline prediction for the given file.
         """
         padding = pipeline.config.get_file_padding(filepath)
         source = src.FileAudioSource(
@@ -325,7 +320,7 @@ def evaluate(
         metric: Metric,
     ) -> Union[pd.DataFrame, Dict[Text, Any]]:
         """If a reference path was provided,
-        compute the diarization error rate of a list of predictions.
+        compute the performance of a list of predictions.
 
         Parameters
         ----------
@@ -369,30 +364,29 @@ def evaluate(
     def __call__(
         self,
         pipeline_class: type,
-        config: StreamingConfig,
+        config: PipelineConfig,
         metric: Optional[Metric] = None,
     ) -> Union[pd.DataFrame, Dict[Text, Any]]:
-        """Run a given pipeline on a set of audio files.
+        """Run a pipeline on a set of audio files.
         The internal state of the pipeline is reset before benchmarking.
 
         Parameters
         ----------
         pipeline_class: class
-            Class from the StreamingPipeline hierarchy.
+            Class from the `Pipeline` hierarchy.
             A pipeline from this class will be instantiated by each worker.
-        config: StreamingConfig
-            Streaming pipeline configuration.
+        config: PipelineConfig
+            Pipeline configuration.
         metric: Optional[Metric]
             Evaluation metric.
             Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
 
         Returns
         -------
-        performance: pandas.DataFrame or List[Annotation]
+        performance: pandas.DataFrame or Dict[Text, Any]
             If reference annotations are given, a DataFrame with detailed
             performance on each file as well as average performance.
-
-            If no reference annotations, a list of predictions.
+            If no reference annotations, a dict of uris with predictions.
         """
         audio_file_paths = self.get_file_paths()
         num_audio_files = len(audio_file_paths)
@@ -412,15 +406,14 @@ def __call__(
 
 class Parallelize:
     """Wrapper to parallelize the execution of a `Benchmark` instance.
-    Note that models will be copied in each worker instead of being reused.
+    Models will be copied in each worker instead of being reused.
 
     Parameters
     ----------
     benchmark: Benchmark
-        Benchmark instance to execute in parallel.
+        Benchmark instance to run in parallel.
     num_workers: int
-        Number of parallel workers.
-        Defaults to 0 (no parallelism).
+        Number of parallel workers. Defaults to 4.
     """
     def __init__(
         self,
@@ -433,20 +426,20 @@ def __init__(
     def run_single_job(
         self,
         pipeline_class: type,
-        config: StreamingConfig,
+        config: PipelineConfig,
         filepath: Path,
         description: Text,
     ) -> Tuple[Text, Any]:
-        """Build and run a pipeline on a single file.
+        """Instantiate and run a pipeline on a given file.
         Configure execution to show progress alongside parallel runs.
 
         Parameters
         ----------
         pipeline_class: class
-            Class from the StreamingPipeline hierarchy.
+            Class from the Pipeline hierarchy.
             A pipeline from this class will be instantiated.
-        config: StreamingConfig
-            Streaming pipeline configuration.
+        config: PipelineConfig
+            Pipeline configuration.
         filepath: Path
             Path to the target file.
         description: Text
@@ -454,8 +447,10 @@ def run_single_job(
 
         Returns
         -------
-        prediction: Annotation
-            Pipeline prediction for the given file.
+        uri: Text
+            File URI.
+        prediction: Any
+            Aggregated pipeline prediction for the given file.
         """
         # The process ID inside the pool determines the position of the progress bar
         idx_process = int(current_process().name.split('-')[1]) - 1
@@ -470,30 +465,29 @@ def run_single_job(
     def __call__(
         self,
         pipeline_class: type,
-        config: StreamingConfig,
+        config: PipelineConfig,
         metric: Optional[Metric] = None,
     ) -> Union[pd.DataFrame, Dict[Text, Any]]:
-        """Run a given pipeline on a set of audio files in parallel.
-        Each worker will build and run the pipeline on a different file.
+        """Run a pipeline on a set of audio files in parallel.
+        Each worker instantiates and runs the pipeline on a different file.
 
         Parameters
         ----------
         pipeline_class: class
-            Class from the StreamingPipeline hierarchy.
+            Class from the Pipeline hierarchy.
             A pipeline from this class will be instantiated by each worker.
-        config: StreamingConfig
-            Streaming pipeline configuration.
+        config: PipelineConfig
+            Pipeline configuration.
         metric: Optional[Metric]
             Evaluation metric.
             Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
 
         Returns
         -------
-        performance: pandas.DataFrame or List[Annotation]
+        performance: pandas.DataFrame or Dict[Text, Any]
             If reference annotations are given, a DataFrame with detailed
             performance on each file as well as average performance.
-
-            If no reference annotations, a list of predictions.
+            If no reference annotations, a dict of uris with predictions.
         """
         audio_file_paths = self.benchmark.get_file_paths()
         num_audio_files = len(audio_file_paths)
diff --git a/src/diart/optim.py b/src/diart/optim.py
index 0ea27910..be371a71 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -11,7 +11,7 @@
 
 from .audio import FilePath
 from .inference import Benchmark
-from .pipelines import StreamingConfig
+from .pipelines import PipelineConfig
 from .pipelines.hparams import HyperParameter
 
 
@@ -24,7 +24,7 @@ def __init__(
         study_or_path: Union[FilePath, Study],
         batch_size: int = 32,
         hparams: Optional[Sequence[HyperParameter]] = None,
-        base_config: Optional[StreamingConfig] = None,
+        base_config: Optional[PipelineConfig] = None,
         do_kickstart_hparams: bool = True,
         metric: Optional[BaseMetric] = None,
         direction: Literal["minimize", "maximize"] = "minimize",
@@ -97,6 +97,9 @@ def _callback(self, study: Study, trial: FrozenTrial):
     def objective(self, trial: Trial) -> float:
         # Set suggested values for optimized hyper-parameters
         trial_config = vars(self.base_config)
+        trial_config["duration"] = self.base_config.duration
+        trial_config["step"] = self.base_config.step
+        trial_config["latency"] = self.base_config.latency
         for hparam in self.hparams:
             trial_config[hparam.name] = trial.suggest_uniform(
                 hparam.name, hparam.low, hparam.high
@@ -118,12 +121,12 @@ def objective(self, trial: Trial) -> float:
         report = self.benchmark(self.pipeline_class, config, metric)
 
         # Extract target metric from report
-        return report.loc["TOTAL", metric.name]["%"]
+        return report.loc["TOTAL", metric.name].item()
 
     def __call__(self, num_iter: int, show_progress: bool = True):
         self._progress = None
         if show_progress:
-            self._progress = trange(num_iter)
+            self._progress = trange(num_iter, unit="trial")
             last_trial = -1
             if self.study.trials:
                 last_trial = self.study.trials[-1].number
diff --git a/src/diart/pipelines/__init__.py b/src/diart/pipelines/__init__.py
index f430e4f7..11fe5e82 100644
--- a/src/diart/pipelines/__init__.py
+++ b/src/diart/pipelines/__init__.py
@@ -1,4 +1,4 @@
-from .base import StreamingPipeline, StreamingConfig
+from .base import Pipeline, PipelineConfig
 from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
 from .transcription import Transcription, TranscriptionConfig
 from .voice import VoiceActivityDetection, VoiceActivityDetectionConfig
diff --git a/src/diart/pipelines/base.py b/src/diart/pipelines/base.py
index 1a220203..de0582a7 100644
--- a/src/diart/pipelines/base.py
+++ b/src/diart/pipelines/base.py
@@ -11,7 +11,7 @@
 from ..metrics import Metric
 
 
-class StreamingConfig:
+class PipelineConfig:
     @property
     def duration(self) -> float:
         raise NotImplementedError
@@ -29,7 +29,7 @@ def sample_rate(self) -> int:
         raise NotImplementedError
 
     @staticmethod
-    def from_dict(data: Any) -> 'StreamingConfig':
+    def from_dict(data: Any) -> 'PipelineConfig':
         raise NotImplementedError
 
     def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
@@ -42,17 +42,21 @@ def optimal_block_size(self) -> int:
         return int(np.rint(self.step * self.sample_rate))
 
 
-class StreamingPipeline:
+class Pipeline:
     @staticmethod
     def get_config_class() -> type:
         raise NotImplementedError
 
+    @staticmethod
+    def suggest_metric() -> Metric:
+        raise NotImplementedError
+
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         raise NotImplementedError
 
     @property
-    def config(self) -> StreamingConfig:
+    def config(self) -> PipelineConfig:
         raise NotImplementedError
 
     def reset(self):
@@ -67,9 +71,6 @@ def join_predictions(self, predictions: List[Any]) -> Any:
     def write_prediction(self, uri: Text, prediction: Any, dir_path: Union[Text, Path]):
         raise NotImplementedError
 
-    def suggest_metric(self) -> Metric:
-        raise NotImplementedError
-
     def suggest_display(self) -> Observer:
         raise NotImplementedError
 
diff --git a/src/diart/pipelines/diarization.py b/src/diart/pipelines/diarization.py
index 57cb7739..114a4223 100644
--- a/src/diart/pipelines/diarization.py
+++ b/src/diart/pipelines/diarization.py
@@ -16,7 +16,7 @@
 from ..metrics import Metric, DiarizationErrorRate
 
 
-class SpeakerDiarizationConfig(base.StreamingConfig):
+class SpeakerDiarizationConfig(base.PipelineConfig):
     def __init__(
         self,
         segmentation: Optional[m.SegmentationModel] = None,
@@ -131,7 +131,7 @@ def sample_rate(self) -> int:
         return self._sample_rate
 
 
-class SpeakerDiarization(base.StreamingPipeline):
+class SpeakerDiarization(base.Pipeline):
     def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
         self._config = SpeakerDiarizationConfig() if config is None else config
 
@@ -166,6 +166,10 @@ def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
     def get_config_class() -> type:
         return SpeakerDiarizationConfig
 
+    @staticmethod
+    def suggest_metric() -> Metric:
+        return DiarizationErrorRate(collar=0, skip_overlap=False)
+
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         return [TauActive, RhoUpdate, DeltaNew]
@@ -187,9 +191,6 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te
         with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file:
             prediction.write_rttm(out_file)
 
-    def suggest_metric(self) -> Metric:
-        return DiarizationErrorRate(collar=0, skip_overlap=False)
-
     def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
         return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
 
diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py
index cb5f40e7..3616222e 100644
--- a/src/diart/pipelines/transcription.py
+++ b/src/diart/pipelines/transcription.py
@@ -15,7 +15,7 @@
 from ..metrics import Metric, WordErrorRate
 
 
-class TranscriptionConfig(base.StreamingConfig):
+class TranscriptionConfig(base.PipelineConfig):
     def __init__(
         self,
         asr: Optional[m.SpeechRecognitionModel] = None,
@@ -25,7 +25,11 @@ def __init__(
         language: Optional[Text] = None,
         beam_size: int = None,
         device: Optional[torch.device] = None,
+        **kwargs,
     ):
+        self.language = language
+        self.beam_size = beam_size
+
         self.device = device
         if self.device is None:
             self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -34,8 +38,8 @@ def __init__(
         self.asr = asr
         if self.asr is None:
             self.asr = m.SpeechRecognitionModel.from_whisper("small")
-        self.asr.set_language(language)
-        self.asr.set_beam_size(beam_size)
+        self.asr.set_language(self.language)
+        self.asr.set_beam_size(self.beam_size)
 
         self.segmentation = segmentation
         self.tau_active = tau_active
@@ -96,7 +100,7 @@ def from_dict(data: Any) -> 'TranscriptionConfig':
         )
 
 
-class Transcription(base.StreamingPipeline):
+class Transcription(base.Pipeline):
     def __init__(self, config: Optional[TranscriptionConfig] = None):
         self._config = TranscriptionConfig() if config is None else config
         self.asr = blocks.SpeechRecognition(self.config.asr, self.config.device)
@@ -108,6 +112,10 @@ def __init__(self, config: Optional[TranscriptionConfig] = None):
     def get_config_class() -> type:
         return TranscriptionConfig
 
+    @staticmethod
+    def suggest_metric() -> Metric:
+        return WordErrorRate()
+
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         return [TauActive]
@@ -131,9 +139,6 @@ def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Pa
         with open(Path(dir_path) / f"{uri}.txt", "w") as out_file:
             out_file.write(prediction)
 
-    def suggest_metric(self) -> Metric:
-        return WordErrorRate()
-
     def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
         return sinks.TextWriter(Path(output_dir) / f"{uri}.txt")
 
diff --git a/src/diart/pipelines/voice.py b/src/diart/pipelines/voice.py
index b22806ab..05eaa216 100644
--- a/src/diart/pipelines/voice.py
+++ b/src/diart/pipelines/voice.py
@@ -16,7 +16,7 @@
 from ..metrics import Metric, DetectionErrorRate
 
 
-class VoiceActivityDetectionConfig(base.StreamingConfig):
+class VoiceActivityDetectionConfig(base.PipelineConfig):
     def __init__(
         self,
         segmentation: Optional[m.SegmentationModel] = None,
@@ -100,7 +100,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig':
         )
 
 
-class VoiceActivityDetection(base.StreamingPipeline):
+class VoiceActivityDetection(base.Pipeline):
     def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
         self._config = VoiceActivityDetectionConfig() if config is None else config
 
@@ -130,6 +130,10 @@ def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
     def get_config_class() -> type:
         return VoiceActivityDetectionConfig
 
+    @staticmethod
+    def suggest_metric() -> Metric:
+        return DetectionErrorRate(collar=0, skip_overlap=False)
+
     @staticmethod
     def hyper_parameters() -> Sequence[HyperParameter]:
         return [TauActive]
@@ -155,9 +159,6 @@ def write_prediction(self, uri: Text, prediction: Annotation, dir_path: Union[Te
         with open(Path(dir_path) / f"{uri}.rttm", "w") as out_file:
             prediction.write_rttm(out_file)
 
-    def suggest_metric(self) -> Metric:
-        return DetectionErrorRate(collar=0, skip_overlap=False)
-
     def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
         return sinks.RTTMWriter(uri, Path(output_dir) / f"{uri}.rttm")
 

From 6609e3ca2cb2f92e5839ab9cf1ec4b647211f7ca Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Mon, 24 Apr 2023 11:25:51 +0200
Subject: [PATCH 16/23] Rename base pipeline and config objects

---
 src/diart/__init__.py           |  4 ++--
 src/diart/blocks/__init__.py    |  2 +-
 src/diart/blocks/base.py        |  8 ++++----
 src/diart/blocks/diarization.py |  4 ++--
 src/diart/blocks/vad.py         |  6 +++---
 src/diart/inference.py          | 10 +++++-----
 src/diart/optim.py              |  2 +-
 7 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/src/diart/__init__.py b/src/diart/__init__.py
index e29287a0..4bd51327 100644
--- a/src/diart/__init__.py
+++ b/src/diart/__init__.py
@@ -1,8 +1,8 @@
 from .blocks import (
     SpeakerDiarization,
-    StreamingPipeline,
+    Pipeline,
     SpeakerDiarizationConfig,
-    StreamingConfig,
+    PipelineConfig,
     VoiceActivityDetection,
     VoiceActivityDetectionConfig,
 )
diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py
index e6e8c479..15cf81d9 100644
--- a/src/diart/blocks/__init__.py
+++ b/src/diart/blocks/__init__.py
@@ -14,6 +14,6 @@
 )
 from .segmentation import SpeakerSegmentation
 from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
-from .base import StreamingConfig, StreamingPipeline
+from .base import PipelineConfig, Pipeline
 from .utils import Binarize, Resample, AdjustVolume
 from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig
diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py
index 28f313eb..11ef961d 100644
--- a/src/diart/blocks/base.py
+++ b/src/diart/blocks/base.py
@@ -31,7 +31,7 @@ def from_name(name: Text) -> 'HyperParameter':
 DeltaNew = HyperParameter("delta_new", low=0, high=2)
 
 
-class StreamingConfig:
+class PipelineConfig:
     @property
     def duration(self) -> float:
         raise NotImplementedError
@@ -49,7 +49,7 @@ def sample_rate(self) -> int:
         raise NotImplementedError
 
     @staticmethod
-    def from_dict(data: Any) -> 'StreamingConfig':
+    def from_dict(data: Any) -> 'PipelineConfig':
         raise NotImplementedError
 
     def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
@@ -62,7 +62,7 @@ def optimal_block_size(self) -> int:
         return int(np.rint(self.step * self.sample_rate))
 
 
-class StreamingPipeline:
+class Pipeline:
     @staticmethod
     def get_config_class() -> type:
         raise NotImplementedError
@@ -76,7 +76,7 @@ def hyper_parameters() -> Sequence[HyperParameter]:
         raise NotImplementedError
 
     @property
-    def config(self) -> StreamingConfig:
+    def config(self) -> PipelineConfig:
         raise NotImplementedError
 
     def reset(self):
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
index f2a25119..06658cfc 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/blocks/diarization.py
@@ -17,7 +17,7 @@
 from .. import utils
 
 
-class SpeakerDiarizationConfig(base.StreamingConfig):
+class SpeakerDiarizationConfig(base.PipelineConfig):
     def __init__(
         self,
         segmentation: Optional[m.SegmentationModel] = None,
@@ -129,7 +129,7 @@ def sample_rate(self) -> int:
         return self._sample_rate
 
 
-class SpeakerDiarization(base.StreamingPipeline):
+class SpeakerDiarization(base.Pipeline):
     def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
         self._config = SpeakerDiarizationConfig() if config is None else config
 
diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py
index def833b6..e519a9cf 100644
--- a/src/diart/blocks/vad.py
+++ b/src/diart/blocks/vad.py
@@ -15,7 +15,7 @@
 from .. import utils
 
 
-class VoiceActivityDetectionConfig(base.StreamingConfig):
+class VoiceActivityDetectionConfig(base.PipelineConfig):
     def __init__(
         self,
         segmentation: Optional[m.SegmentationModel] = None,
@@ -96,7 +96,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig':
         )
 
 
-class VoiceActivityDetection(base.StreamingPipeline):
+class VoiceActivityDetection(base.Pipeline):
     def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
         self._config = VoiceActivityDetectionConfig() if config is None else config
 
@@ -135,7 +135,7 @@ def hyper_parameters() -> Sequence[base.HyperParameter]:
         return [base.TauActive]
 
     @property
-    def config(self) -> base.StreamingConfig:
+    def config(self) -> base.PipelineConfig:
         return self._config
 
     def reset(self):
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 6afda89e..f562fdd9 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -53,7 +53,7 @@ class StreamingInference:
     """
     def __init__(
         self,
-        pipeline: blocks.StreamingPipeline,
+        pipeline: blocks.Pipeline,
         source: src.AudioSource,
         batch_size: int = 1,
         do_profile: bool = True,
@@ -289,7 +289,7 @@ def get_file_paths(self) -> List[Path]:
 
     def run_single(
         self,
-        pipeline: blocks.StreamingPipeline,
+        pipeline: blocks.Pipeline,
         filepath: Path,
         progress_bar: ProgressBar,
     ) -> Annotation:
@@ -374,7 +374,7 @@ def evaluate(
     def __call__(
         self,
         pipeline_class: type,
-        config: blocks.StreamingConfig,
+        config: blocks.PipelineConfig,
         metric: Optional[BaseMetric] = None,
     ) -> Union[pd.DataFrame, List[Annotation]]:
         """Run a given pipeline on a set of audio files.
@@ -437,7 +437,7 @@ def __init__(
     def run_single_job(
         self,
         pipeline_class: type,
-        config: blocks.StreamingConfig,
+        config: blocks.PipelineConfig,
         filepath: Path,
         description: Text,
     ) -> Annotation:
@@ -474,7 +474,7 @@ def run_single_job(
     def __call__(
         self,
         pipeline_class: type,
-        config: blocks.StreamingConfig,
+        config: blocks.PipelineConfig,
         metric: Optional[BaseMetric] = None,
     ) -> Union[pd.DataFrame, List[Annotation]]:
         """Run a given pipeline on a set of audio files in parallel.
diff --git a/src/diart/optim.py b/src/diart/optim.py
index f7a96a6e..86492627 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -23,7 +23,7 @@ def __init__(
         study_or_path: Union[FilePath, Study],
         batch_size: int = 32,
         hparams: Optional[Sequence[blocks.base.HyperParameter]] = None,
-        base_config: Optional[blocks.StreamingConfig] = None,
+        base_config: Optional[blocks.PipelineConfig] = None,
         do_kickstart_hparams: bool = True,
         metric: Optional[BaseMetric] = None,
         direction: Literal["minimize", "maximize"] = "minimize",

From d19b04464355c5b0b282aee38327a540d8d2163d Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Wed, 19 Apr 2023 17:41:41 +0200
Subject: [PATCH 17/23] New feature: streaming voice activity detection.
 Pipeline name changes

---
 src/diart/__init__.py           |  10 +-
 src/diart/blocks/__init__.py    |   5 +-
 src/diart/blocks/base.py        |  92 ++++++++++++++
 src/diart/blocks/config.py      | 153 -----------------------
 src/diart/blocks/diarization.py | 145 ++++++++++++++++++----
 src/diart/blocks/vad.py         | 208 ++++++++++++++++++++++++++++++++
 src/diart/console/benchmark.py  |  12 +-
 src/diart/console/client.py     |   6 +-
 src/diart/console/serve.py      |  19 +--
 src/diart/console/stream.py     |  19 +--
 src/diart/console/tune.py       |  26 +++-
 src/diart/inference.py          |  86 +++++++------
 src/diart/optim.py              |  56 ++++-----
 src/diart/sinks.py              |  47 +++++---
 src/diart/sources.py            |   2 +-
 src/diart/utils.py              |  16 ++-
 16 files changed, 605 insertions(+), 297 deletions(-)
 create mode 100644 src/diart/blocks/base.py
 delete mode 100644 src/diart/blocks/config.py
 create mode 100644 src/diart/blocks/vad.py

diff --git a/src/diart/__init__.py b/src/diart/__init__.py
index c9692638..e29287a0 100644
--- a/src/diart/__init__.py
+++ b/src/diart/__init__.py
@@ -1,6 +1,8 @@
 from .blocks import (
-    OnlineSpeakerDiarization,
-    BasePipeline,
-    PipelineConfig,
-    BasePipelineConfig,
+    SpeakerDiarization,
+    StreamingPipeline,
+    SpeakerDiarizationConfig,
+    StreamingConfig,
+    VoiceActivityDetection,
+    VoiceActivityDetectionConfig,
 )
diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py
index 59a6ef36..e6e8c479 100644
--- a/src/diart/blocks/__init__.py
+++ b/src/diart/blocks/__init__.py
@@ -13,6 +13,7 @@
     OverlapAwareSpeakerEmbedding,
 )
 from .segmentation import SpeakerSegmentation
-from .diarization import OnlineSpeakerDiarization, BasePipeline
-from .config import BasePipelineConfig, PipelineConfig
+from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
+from .base import StreamingConfig, StreamingPipeline
 from .utils import Binarize, Resample, AdjustVolume
+from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig
diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py
new file mode 100644
index 00000000..28f313eb
--- /dev/null
+++ b/src/diart/blocks/base.py
@@ -0,0 +1,92 @@
+from typing import Any, Tuple, Sequence, Text
+from dataclasses import dataclass
+
+import numpy as np
+from pyannote.core import SlidingWindowFeature
+from pyannote.metrics.base import BaseMetric
+
+from .. import utils
+from ..audio import FilePath, AudioLoader
+
+
+@dataclass
+class HyperParameter:
+    name: Text
+    low: float
+    high: float
+
+    @staticmethod
+    def from_name(name: Text) -> 'HyperParameter':
+        if name == "tau_active":
+            return TauActive
+        if name == "rho_update":
+            return RhoUpdate
+        if name == "delta_new":
+            return DeltaNew
+        raise ValueError(f"Hyper-parameter '{name}' not recognized")
+
+
+TauActive = HyperParameter("tau_active", low=0, high=1)
+RhoUpdate = HyperParameter("rho_update", low=0, high=1)
+DeltaNew = HyperParameter("delta_new", low=0, high=2)
+
+
+class StreamingConfig:
+    @property
+    def duration(self) -> float:
+        raise NotImplementedError
+
+    @property
+    def step(self) -> float:
+        raise NotImplementedError
+
+    @property
+    def latency(self) -> float:
+        raise NotImplementedError
+
+    @property
+    def sample_rate(self) -> int:
+        raise NotImplementedError
+
+    @staticmethod
+    def from_dict(data: Any) -> 'StreamingConfig':
+        raise NotImplementedError
+
+    def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
+        file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
+        right = utils.get_padding_right(self.latency, self.step)
+        left = utils.get_padding_left(file_duration + right, self.duration)
+        return left, right
+
+    def optimal_block_size(self) -> int:
+        return int(np.rint(self.step * self.sample_rate))
+
+
+class StreamingPipeline:
+    @staticmethod
+    def get_config_class() -> type:
+        raise NotImplementedError
+
+    @staticmethod
+    def suggest_metric() -> BaseMetric:
+        raise NotImplementedError
+
+    @staticmethod
+    def hyper_parameters() -> Sequence[HyperParameter]:
+        raise NotImplementedError
+
+    @property
+    def config(self) -> StreamingConfig:
+        raise NotImplementedError
+
+    def reset(self):
+        raise NotImplementedError
+
+    def set_timestamp_shift(self, shift: float):
+        raise NotImplementedError
+
+    def __call__(
+        self,
+        waveforms: Sequence[SlidingWindowFeature]
+    ) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
+        raise NotImplementedError
diff --git a/src/diart/blocks/config.py b/src/diart/blocks/config.py
deleted file mode 100644
index d8e2a656..00000000
--- a/src/diart/blocks/config.py
+++ /dev/null
@@ -1,153 +0,0 @@
-from typing import Any, Optional, Union, Tuple
-
-import numpy as np
-import torch
-from typing_extensions import Literal
-
-from .. import models as m
-from .. import utils
-from ..audio import FilePath, AudioLoader
-
-
-class BasePipelineConfig:
-    @property
-    def duration(self) -> float:
-        raise NotImplementedError
-
-    @property
-    def step(self) -> float:
-        raise NotImplementedError
-
-    @property
-    def latency(self) -> float:
-        raise NotImplementedError
-
-    @property
-    def sample_rate(self) -> int:
-        raise NotImplementedError
-
-    @staticmethod
-    def from_dict(data: Any) -> 'BasePipelineConfig':
-        raise NotImplementedError
-
-    def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
-        file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
-        right = utils.get_padding_right(self.latency, self.step)
-        left = utils.get_padding_left(file_duration + right, self.duration)
-        return left, right
-
-    def optimal_block_size(self) -> int:
-        return int(np.rint(self.step * self.sample_rate))
-
-
-class PipelineConfig(BasePipelineConfig):
-    def __init__(
-        self,
-        segmentation: Optional[m.SegmentationModel] = None,
-        embedding: Optional[m.EmbeddingModel] = None,
-        duration: Optional[float] = None,
-        step: float = 0.5,
-        latency: Optional[Union[float, Literal["max", "min"]]] = None,
-        tau_active: float = 0.6,
-        rho_update: float = 0.3,
-        delta_new: float = 1,
-        gamma: float = 3,
-        beta: float = 10,
-        max_speakers: int = 20,
-        device: Optional[torch.device] = None,
-        **kwargs,
-    ):
-        # Default segmentation model is pyannote/segmentation
-        self.segmentation = segmentation
-        if self.segmentation is None:
-            self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
-
-        # Default duration is the one given by the segmentation model
-        self._duration = duration
-
-        # Expected sample rate is given by the segmentation model
-        self._sample_rate: Optional[int] = None
-
-        # Default embedding model is pyannote/embedding
-        self.embedding = embedding
-        if self.embedding is None:
-            self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
-
-        # Latency defaults to the step duration
-        self._step = step
-        self._latency = latency
-        if self._latency is None or self._latency == "min":
-            self._latency = self._step
-        elif self._latency == "max":
-            self._latency = self._duration
-
-        self.tau_active = tau_active
-        self.rho_update = rho_update
-        self.delta_new = delta_new
-        self.gamma = gamma
-        self.beta = beta
-        self.max_speakers = max_speakers
-
-        self.device = device
-        if self.device is None:
-            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-    @staticmethod
-    def from_dict(data: Any) -> 'PipelineConfig':
-        # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
-        device = utils.get(data, "device", None)
-        if device is None:
-            device = torch.device("cpu") if utils.get(data, "cpu", False) else None
-
-        # Instantiate models
-        hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
-        segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
-        segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
-        embedding = utils.get(data, "embedding", "pyannote/embedding")
-        embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)
-
-        # Hyper-parameters and their aliases
-        tau = utils.get(data, "tau_active", None)
-        if tau is None:
-            tau = utils.get(data, "tau", 0.6)
-        rho = utils.get(data, "rho_update", None)
-        if rho is None:
-            rho = utils.get(data, "rho", 0.3)
-        delta = utils.get(data, "delta_new", None)
-        if delta is None:
-            delta = utils.get(data, "delta", 1)
-
-        return PipelineConfig(
-            segmentation=segmentation,
-            embedding=embedding,
-            duration=utils.get(data, "duration", None),
-            step=utils.get(data, "step", 0.5),
-            latency=utils.get(data, "latency", None),
-            tau_active=tau,
-            rho_update=rho,
-            delta_new=delta,
-            gamma=utils.get(data, "gamma", 3),
-            beta=utils.get(data, "beta", 10),
-            max_speakers=utils.get(data, "max_speakers", 20),
-            device=device,
-        )
-
-    @property
-    def duration(self) -> float:
-        if self._duration is None:
-            self._duration = self.segmentation.duration
-        return self._duration
-
-    @property
-    def step(self) -> float:
-        return self._step
-
-    @property
-    def latency(self) -> float:
-        return self._latency
-
-    @property
-    def sample_rate(self) -> int:
-        if self._sample_rate is None:
-            self._sample_rate = self.segmentation.sample_rate
-        return self._sample_rate
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
index 7f0e162c..f2a25119 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/blocks/diarization.py
@@ -1,42 +1,137 @@
-from typing import Optional, Tuple, Sequence
+from typing import Optional, Tuple, Sequence, Union, Any
 
 import numpy as np
 import torch
 from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
+from pyannote.metrics.base import BaseMetric
+from pyannote.metrics.diarization import DiarizationErrorRate
+from typing_extensions import Literal
 
 from .aggregation import DelayedAggregation
+from . import base
 from .clustering import OnlineSpeakerClustering
 from .embedding import OverlapAwareSpeakerEmbedding
 from .segmentation import SpeakerSegmentation
 from .utils import Binarize
-from .config import BasePipelineConfig, PipelineConfig
+from .. import models as m
+from .. import utils
 
 
-class BasePipeline:
+class SpeakerDiarizationConfig(base.StreamingConfig):
+    def __init__(
+        self,
+        segmentation: Optional[m.SegmentationModel] = None,
+        embedding: Optional[m.EmbeddingModel] = None,
+        duration: Optional[float] = None,
+        step: float = 0.5,
+        latency: Optional[Union[float, Literal["max", "min"]]] = None,
+        tau_active: float = 0.6,
+        rho_update: float = 0.3,
+        delta_new: float = 1,
+        gamma: float = 3,
+        beta: float = 10,
+        max_speakers: int = 20,
+        device: Optional[torch.device] = None,
+        **kwargs,
+    ):
+        # Default segmentation model is pyannote/segmentation
+        self.segmentation = segmentation
+        if self.segmentation is None:
+            self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
+
+        self._duration = duration
+        self._sample_rate: Optional[int] = None
+
+        # Default embedding model is pyannote/embedding
+        self.embedding = embedding
+        if self.embedding is None:
+            self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
+
+        # Latency defaults to the step duration
+        self._step = step
+        self._latency = latency
+        if self._latency is None or self._latency == "min":
+            self._latency = self._step
+        elif self._latency == "max":
+            self._latency = self._duration
+
+        self.tau_active = tau_active
+        self.rho_update = rho_update
+        self.delta_new = delta_new
+        self.gamma = gamma
+        self.beta = beta
+        self.max_speakers = max_speakers
+
+        self.device = device
+        if self.device is None:
+            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
     @staticmethod
-    def get_config_class() -> type:
-        raise NotImplementedError
+    def from_dict(data: Any) -> 'SpeakerDiarizationConfig':
+        # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
+        device = utils.get(data, "device", None)
+        if device is None:
+            device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+
+        # Instantiate models
+        hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
+        segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
+        segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
+        embedding = utils.get(data, "embedding", "pyannote/embedding")
+        embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)
+
+        # Hyper-parameters and their aliases
+        tau = utils.get(data, "tau_active", None)
+        if tau is None:
+            tau = utils.get(data, "tau", 0.6)
+        rho = utils.get(data, "rho_update", None)
+        if rho is None:
+            rho = utils.get(data, "rho", 0.3)
+        delta = utils.get(data, "delta_new", None)
+        if delta is None:
+            delta = utils.get(data, "delta", 1)
+
+        return SpeakerDiarizationConfig(
+            segmentation=segmentation,
+            embedding=embedding,
+            duration=utils.get(data, "duration", None),
+            step=utils.get(data, "step", 0.5),
+            latency=utils.get(data, "latency", None),
+            tau_active=tau,
+            rho_update=rho,
+            delta_new=delta,
+            gamma=utils.get(data, "gamma", 3),
+            beta=utils.get(data, "beta", 10),
+            max_speakers=utils.get(data, "max_speakers", 20),
+            device=device,
+        )
 
     @property
-    def config(self) -> BasePipelineConfig:
-        raise NotImplementedError
+    def duration(self) -> float:
+        # Default duration is the one given by the segmentation model
+        if self._duration is None:
+            self._duration = self.segmentation.duration
+        return self._duration
 
-    def reset(self):
-        raise NotImplementedError
+    @property
+    def step(self) -> float:
+        return self._step
 
-    def set_timestamp_shift(self, shift: float):
-        raise NotImplementedError
+    @property
+    def latency(self) -> float:
+        return self._latency
 
-    def __call__(
-        self,
-        waveforms: Sequence[SlidingWindowFeature]
-    ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
-        raise NotImplementedError
+    @property
+    def sample_rate(self) -> int:
+        # Expected sample rate is given by the segmentation model
+        if self._sample_rate is None:
+            self._sample_rate = self.segmentation.sample_rate
+        return self._sample_rate
 
 
-class OnlineSpeakerDiarization(BasePipeline):
-    def __init__(self, config: Optional[PipelineConfig] = None):
-        self._config = PipelineConfig() if config is None else config
+class SpeakerDiarization(base.StreamingPipeline):
+    def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
+        self._config = SpeakerDiarizationConfig() if config is None else config
 
         msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
         assert self._config.step <= self._config.latency <= self._config.duration, msg
@@ -67,10 +162,18 @@ def __init__(self, config: Optional[PipelineConfig] = None):
 
     @staticmethod
     def get_config_class() -> type:
-        return PipelineConfig
+        return SpeakerDiarizationConfig
+
+    @staticmethod
+    def suggest_metric() -> BaseMetric:
+        return DiarizationErrorRate(collar=0, skip_overlap=False)
+
+    @staticmethod
+    def hyper_parameters() -> Sequence[base.HyperParameter]:
+        return [base.TauActive, base.RhoUpdate, base.DeltaNew]
 
     @property
-    def config(self) -> PipelineConfig:
+    def config(self) -> SpeakerDiarizationConfig:
         return self._config
 
     def set_timestamp_shift(self, shift: float):
diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py
new file mode 100644
index 00000000..def833b6
--- /dev/null
+++ b/src/diart/blocks/vad.py
@@ -0,0 +1,208 @@
+from typing import Any, Optional, Union, Sequence, Tuple
+
+import numpy as np
+import torch
+from pyannote.core import Annotation, Timeline, SlidingWindowFeature, SlidingWindow, Segment
+from pyannote.metrics.base import BaseMetric
+from pyannote.metrics.detection import DetectionErrorRate
+from typing_extensions import Literal
+
+from .aggregation import DelayedAggregation
+from . import base
+from .segmentation import SpeakerSegmentation
+from .utils import Binarize
+from .. import models as m
+from .. import utils
+
+
+class VoiceActivityDetectionConfig(base.StreamingConfig):
+    def __init__(
+        self,
+        segmentation: Optional[m.SegmentationModel] = None,
+        duration: Optional[float] = None,
+        step: float = 0.5,
+        latency: Optional[Union[float, Literal["max", "min"]]] = None,
+        tau_active: float = 0.6,
+        device: Optional[torch.device] = None,
+        **kwargs,
+    ):
+        # Default segmentation model is pyannote/segmentation
+        self.segmentation = segmentation
+        if self.segmentation is None:
+            self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
+
+        self._duration = duration
+        self._step = step
+        self._sample_rate: Optional[int] = None
+
+        # Latency defaults to the step duration
+        self._latency = latency
+        if self._latency is None or self._latency == "min":
+            self._latency = self._step
+        elif self._latency == "max":
+            self._latency = self._duration
+
+        self.tau_active = tau_active
+        self.device = device
+        if self.device is None:
+            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    @property
+    def duration(self) -> float:
+        # Default duration is the one given by the segmentation model
+        if self._duration is None:
+            self._duration = self.segmentation.duration
+        return self._duration
+
+    @property
+    def step(self) -> float:
+        return self._step
+
+    @property
+    def latency(self) -> float:
+        return self._latency
+
+    @property
+    def sample_rate(self) -> int:
+        # Expected sample rate is given by the segmentation model
+        if self._sample_rate is None:
+            self._sample_rate = self.segmentation.sample_rate
+        return self._sample_rate
+
+    @staticmethod
+    def from_dict(data: Any) -> 'VoiceActivityDetectionConfig':
+        # Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
+        device = utils.get(data, "device", None)
+        if device is None:
+            device = torch.device("cpu") if utils.get(data, "cpu", False) else None
+
+        # Instantiate segmentation model
+        hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
+        segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
+        segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
+
+        # Tau active and its alias
+        tau = utils.get(data, "tau_active", None)
+        if tau is None:
+            tau = utils.get(data, "tau", 0.6)
+
+        return VoiceActivityDetectionConfig(
+            segmentation=segmentation,
+            duration=utils.get(data, "duration", None),
+            step=utils.get(data, "step", 0.5),
+            latency=utils.get(data, "latency", None),
+            tau_active=tau,
+            device=device,
+        )
+
+
+class VoiceActivityDetection(base.StreamingPipeline):
+    def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
+        self._config = VoiceActivityDetectionConfig() if config is None else config
+
+        msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
+        assert self._config.step <= self._config.latency <= self._config.duration, msg
+
+        self.segmentation = SpeakerSegmentation(self._config.segmentation, self._config.device)
+        self.pred_aggregation = DelayedAggregation(
+            self._config.step,
+            self._config.latency,
+            strategy="hamming",
+            cropping_mode="loose",
+        )
+        self.audio_aggregation = DelayedAggregation(
+            self._config.step,
+            self._config.latency,
+            strategy="first",
+            cropping_mode="center",
+        )
+        self.binarize = Binarize(self._config.tau_active)
+
+        # Internal state, handle with care
+        self.timestamp_shift = 0
+        self.chunk_buffer, self.pred_buffer = [], []
+
+    @staticmethod
+    def get_config_class() -> type:
+        return VoiceActivityDetectionConfig
+
+    @staticmethod
+    def suggest_metric() -> BaseMetric:
+        return DetectionErrorRate(collar=0, skip_overlap=False)
+
+    @staticmethod
+    def hyper_parameters() -> Sequence[base.HyperParameter]:
+        return [base.TauActive]
+
+    @property
+    def config(self) -> base.StreamingConfig:
+        return self._config
+
+    def reset(self):
+        self.set_timestamp_shift(0)
+        self.chunk_buffer, self.pred_buffer = [], []
+
+    def set_timestamp_shift(self, shift: float):
+        self.timestamp_shift = shift
+
+    def __call__(
+        self,
+        waveforms: Sequence[SlidingWindowFeature],
+    ) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
+        batch_size = len(waveforms)
+        msg = "Pipeline expected at least 1 input"
+        assert batch_size >= 1, msg
+
+        # Create batch from chunk sequence, shape (batch, samples, channels)
+        batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
+
+        expected_num_samples = int(np.rint(self.config.duration * self.config.sample_rate))
+        msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
+        assert batch.shape[1] == expected_num_samples, msg
+
+        # Extract segmentation
+        segmentations = self.segmentation(batch)  # shape (batch, frames, speakers)
+        voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[0]  # shape (batch, frames, 1)
+
+        seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
+
+        outputs = []
+        for wav, vad in zip(waveforms, voice_detection):
+            # Add timestamps to segmentation
+            sw = SlidingWindow(
+                start=wav.extent.start,
+                duration=seg_resolution,
+                step=seg_resolution,
+            )
+            vad = SlidingWindowFeature(vad.cpu().numpy(), sw)
+
+            # Update sliding buffer
+            self.chunk_buffer.append(wav)
+            self.pred_buffer.append(vad)
+
+            # Aggregate buffer outputs for this time step
+            agg_waveform = self.audio_aggregation(self.chunk_buffer)
+            agg_prediction = self.pred_aggregation(self.pred_buffer)
+            agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False)
+
+            # Shift prediction timestamps if required
+            if self.timestamp_shift != 0:
+                shifted_agg_prediction = Timeline(uri=agg_prediction.uri)
+                for segment in agg_prediction:
+                    new_segment = Segment(
+                        segment.start + self.timestamp_shift,
+                        segment.end + self.timestamp_shift,
+                    )
+                    shifted_agg_prediction.add(new_segment)
+                agg_prediction = shifted_agg_prediction
+
+            # Convert timeline into annotation with single speaker "speech"
+            agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech"))
+            outputs.append((agg_prediction, agg_waveform))
+
+            # Make place for new chunks in buffer if required
+            if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows:
+                self.chunk_buffer = self.chunk_buffer[1:]
+                self.pred_buffer = self.pred_buffer[1:]
+
+        return outputs
diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py
index b6a3f9ff..27d524c5 100644
--- a/src/diart/console/benchmark.py
+++ b/src/diart/console/benchmark.py
@@ -1,15 +1,17 @@
 import argparse
 from pathlib import Path
 
-import diart.argdoc as argdoc
 import pandas as pd
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
+from diart import argdoc
+from diart import utils
 from diart.inference import Benchmark, Parallelize
 
 
 def run():
     parser = argparse.ArgumentParser()
     parser.add_argument("root", type=Path, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
+    parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+                        help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -34,6 +36,8 @@ def run():
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
     args = parser.parse_args()
 
+    pipeline_class = utils.get_pipeline_class(args.pipeline)
+
     benchmark = Benchmark(
         args.root,
         args.reference,
@@ -43,11 +47,11 @@ def run():
         batch_size=args.batch_size,
     )
 
-    config = PipelineConfig.from_dict(vars(args))
+    config = pipeline_class.get_config_class().from_dict(vars(args))
     if args.num_workers > 0:
         benchmark = Parallelize(benchmark, args.num_workers)
 
-    report = benchmark(OnlineSpeakerDiarization, config)
+    report = benchmark(pipeline_class, config)
 
     if args.output is not None and isinstance(report, pd.DataFrame):
         report.to_csv(args.output / "benchmark_report.csv")
diff --git a/src/diart/console/client.py b/src/diart/console/client.py
index 084dbc13..db4915fa 100644
--- a/src/diart/console/client.py
+++ b/src/diart/console/client.py
@@ -3,11 +3,11 @@
 from threading import Thread
 from typing import Text, Optional
 
-import diart.argdoc as argdoc
-import diart.sources as src
-import diart.utils as utils
 import numpy as np
 import rx.operators as ops
+from diart import argdoc
+from diart import sources as src
+from diart import utils
 from websocket import WebSocket
 
 
diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py
index 2f632d57..46bb9328 100644
--- a/src/diart/console/serve.py
+++ b/src/diart/console/serve.py
@@ -1,10 +1,10 @@
 import argparse
 from pathlib import Path
 
-import diart.argdoc as argdoc
-import diart.sources as src
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
-from diart.inference import RealTimeInference
+from diart import argdoc
+from diart import sources as src
+from diart import utils
+from diart.inference import StreamingInference
 from diart.sinks import RTTMWriter
 
 
@@ -12,6 +12,8 @@ def run():
     parser = argparse.ArgumentParser()
     parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host")
     parser.add_argument("--port", default=7007, type=int, help="Server port")
+    parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+                        help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -31,15 +33,16 @@ def run():
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
     args = parser.parse_args()
 
-    # Define online speaker diarization pipeline
-    config = PipelineConfig.from_dict(vars(args))
-    pipeline = OnlineSpeakerDiarization(config)
+    # Resolve pipeline
+    pipeline_class = utils.get_pipeline_class(args.pipeline)
+    config = pipeline_class.get_config_class().from_dict(vars(args))
+    pipeline = pipeline_class(config)
 
     # Create websocket audio source
     audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port)
 
     # Run online inference
-    inference = RealTimeInference(
+    inference = StreamingInference(
         pipeline,
         audio_source,
         batch_size=1,
diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py
index d7218f07..e0c670c5 100644
--- a/src/diart/console/stream.py
+++ b/src/diart/console/stream.py
@@ -1,16 +1,18 @@
 import argparse
 from pathlib import Path
 
-import diart.argdoc as argdoc
-import diart.sources as src
-from diart.blocks import OnlineSpeakerDiarization, PipelineConfig
-from diart.inference import RealTimeInference
+from diart import argdoc
+from diart import sources as src
+from diart import utils
+from diart.inference import StreamingInference
 from diart.sinks import RTTMWriter
 
 
 def run():
     parser = argparse.ArgumentParser()
     parser.add_argument("source", type=str, help="Path to an audio file | 'microphone' | 'microphone:<DEVICE_ID>'")
+    parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+                        help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -32,9 +34,10 @@ def run():
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
     args = parser.parse_args()
 
-    # Define online speaker diarization pipeline
-    config = PipelineConfig.from_dict(vars(args))
-    pipeline = OnlineSpeakerDiarization(config)
+    # Resolve pipeline
+    pipeline_class = utils.get_pipeline_class(args.pipeline)
+    config = pipeline_class.get_config_class().from_dict(vars(args))
+    pipeline = pipeline_class(config)
 
     # Manage audio source
     block_size = config.optimal_block_size()
@@ -51,7 +54,7 @@ def run():
         audio_source = src.MicrophoneAudioSource(config.sample_rate, block_size, device)
 
     # Run online inference
-    inference = RealTimeInference(
+    inference = StreamingInference(
         pipeline,
         audio_source,
         batch_size=1,
diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py
index 4ad8852a..a1f1b63a 100644
--- a/src/diart/console/tune.py
+++ b/src/diart/console/tune.py
@@ -1,10 +1,11 @@
 import argparse
 from pathlib import Path
 
-import diart.argdoc as argdoc
 import optuna
-from diart.blocks import PipelineConfig, OnlineSpeakerDiarization
-from diart.optim import Optimizer, HyperParameter
+from diart import argdoc
+from diart import utils
+from diart.blocks.base import HyperParameter
+from diart.optim import Optimizer
 from optuna.samplers import TPESampler
 
 
@@ -13,6 +14,8 @@ def run():
     parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
     parser.add_argument("--reference", required=True, type=str,
                         help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
+    parser.add_argument("--pipeline", default="SpeakerDiarization", type=str,
+                        help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'")
     parser.add_argument("--segmentation", default="pyannote/segmentation", type=str,
                         help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation")
     parser.add_argument("--embedding", default="pyannote/embedding", type=str,
@@ -38,17 +41,28 @@ def run():
                         help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)")
     args = parser.parse_args()
 
+    # Retrieve pipeline class
+    pipeline_class = utils.get_pipeline_class(args.pipeline)
+
     # Create the base configuration for each trial
-    base_config = PipelineConfig.from_dict(vars(args))
+    base_config = pipeline_class.get_config_class().from_dict(vars(args))
 
     # Create hyper-parameters to optimize
+    possible_hparams = pipeline_class.hyper_parameters()
     hparams = [HyperParameter.from_name(name) for name in args.hparams]
+    hparams = [hp for hp in hparams if hp in possible_hparams]
+    if not hparams:
+        print(
+            f"No hyper-parameters to optimize. "
+            f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}"
+        )
+        exit(1)
 
     # Use a custom storage if given
     if args.output is not None:
         msg = "Both `output` and `storage` were set, but only one was expected"
         assert args.storage is None, msg
-        args.output = Path(args.output)
+        args.output = Path(args.output).expanduser()
         args.output.mkdir(parents=True, exist_ok=True)
         study_or_path = args.output
     elif args.storage is not None:
@@ -60,11 +74,11 @@ def run():
 
     # Run optimization
     Optimizer(
+        pipeline_class=pipeline_class,
         speech_path=args.root,
         reference_path=args.reference,
         study_or_path=study_or_path,
         batch_size=args.batch_size,
-        pipeline_class=OnlineSpeakerDiarization,
         hparams=hparams,
         base_config=base_config,
     )(num_iter=args.num_iter, show_progress=True)
diff --git a/src/diart/inference.py b/src/diart/inference.py
index f4b65f5f..6afda89e 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -4,32 +4,33 @@
 from traceback import print_exc
 from typing import Union, Text, Optional, Callable, Tuple, List
 
-import diart.operators as dops
-import diart.sources as src
 import numpy as np
 import pandas as pd
 import rx
 import rx.operators as ops
 import torch
-from diart import utils
-from diart.blocks import BasePipeline, Resample, BasePipelineConfig
-from diart.progress import ProgressBar, RichProgressBar, TQDMProgressBar
-from diart.sinks import DiarizationPredictionAccumulator, RealTimePlot, WindowClosedException
 from pyannote.core import Annotation, SlidingWindowFeature
 from pyannote.database.util import load_rttm
-from pyannote.metrics.diarization import DiarizationErrorRate
+from pyannote.metrics.base import BaseMetric
 from rx.core import Observer
 from tqdm import tqdm
 
+from . import blocks
+from . import operators as dops
+from . import sources as src
+from . import utils
+from .progress import ProgressBar, RichProgressBar, TQDMProgressBar
+from .sinks import PredictionAccumulator, StreamingPlot, WindowClosedException
 
-class RealTimeInference:
+
+class StreamingInference:
     """Performs inference in real time given a pipeline and an audio source.
     Streams an audio source to an online speaker diarization pipeline.
     It allows users to attach a chain of operations in the form of hooks.
 
     Parameters
     ----------
-    pipeline: BasePipeline
+    pipeline: StreamingPipeline
         Configured speaker diarization pipeline.
     source: AudioSource
         Audio source to be read and streamed.
@@ -52,7 +53,7 @@ class RealTimeInference:
     """
     def __init__(
         self,
-        pipeline: BasePipeline,
+        pipeline: blocks.StreamingPipeline,
         source: src.AudioSource,
         batch_size: int = 1,
         do_profile: bool = True,
@@ -66,7 +67,7 @@ def __init__(
         self.do_profile = do_profile
         self.do_plot = do_plot
         self.show_progress = show_progress
-        self.accumulator = DiarizationPredictionAccumulator(self.source.uri)
+        self.accumulator = PredictionAccumulator(self.source.uri)
         self.unit = "chunk" if self.batch_size == 1 else "batch"
         self._observers = []
 
@@ -102,7 +103,7 @@ def __init__(
                   f"but pipeline's is {sample_rate}. Will resample."
             logging.warning(msg)
             self.stream = self.stream.pipe(
-                ops.map(Resample(self.source.sample_rate, sample_rate))
+                ops.map(blocks.Resample(self.source.sample_rate, sample_rate))
             )
 
         # Add rx operators to manage the inputs and outputs of the pipeline
@@ -202,7 +203,7 @@ def __call__(self) -> Annotation:
                     latency=config.latency,
                     sample_rate=config.sample_rate,
                 ),
-                ops.do(RealTimePlot(config.duration, config.latency)),
+                ops.do(StreamingPlot(config.duration, config.latency)),
             )
         observable.subscribe(
             on_error=self._handle_error,
@@ -288,7 +289,7 @@ def get_file_paths(self) -> List[Path]:
 
     def run_single(
         self,
-        pipeline: BasePipeline,
+        pipeline: blocks.StreamingPipeline,
         filepath: Path,
         progress_bar: ProgressBar,
     ) -> Annotation:
@@ -298,7 +299,7 @@ def run_single(
 
         Parameters
         ----------
-        pipeline: BasePipeline
+        pipeline: StreamingPipeline
             Speaker diarization pipeline to run.
         filepath: Path
             Path to the target file.
@@ -318,7 +319,7 @@ def run_single(
             pipeline.config.optimal_block_size(),
         )
         pipeline.set_timestamp_shift(-padding[0])
-        inference = RealTimeInference(
+        inference = StreamingInference(
             pipeline,
             source,
             self.batch_size,
@@ -337,7 +338,11 @@ def run_single(
 
         return pred
 
-    def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[Annotation]]:
+    def evaluate(
+        self,
+        predictions: List[Annotation],
+        metric: BaseMetric,
+    ) -> Union[pd.DataFrame, List[Annotation]]:
         """If a reference path was provided,
         compute the diarization error rate of a list of predictions.
 
@@ -345,6 +350,8 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An
         ----------
         predictions: List[Annotation]
             Predictions to evaluate.
+        metric: BaseMetric
+            Evaluation metric from pyannote.metrics.
 
         Returns
         -------
@@ -353,8 +360,7 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An
             reference path was given. Otherwise return the same predictions.
         """
         if self.reference_path is not None:
-            metric = DiarizationErrorRate(collar=0, skip_overlap=False)
-            progress_bar = TQDMProgressBar("Computing DER", leave=False)
+            progress_bar = TQDMProgressBar(f"Computing {metric.name}", leave=False)
             progress_bar.create(total=len(predictions), unit="file")
             progress_bar.start()
             for hyp in predictions:
@@ -368,18 +374,22 @@ def evaluate(self, predictions: List[Annotation]) -> Union[pd.DataFrame, List[An
     def __call__(
         self,
         pipeline_class: type,
-        config: BasePipelineConfig,
+        config: blocks.StreamingConfig,
+        metric: Optional[BaseMetric] = None,
     ) -> Union[pd.DataFrame, List[Annotation]]:
         """Run a given pipeline on a set of audio files.
-        Notice that the internal state of the pipeline is reset before benchmarking.
+        The internal state of the pipeline is reset before benchmarking.
 
         Parameters
         ----------
         pipeline_class: class
-            Class from the BasePipeline hierarchy.
+            Class from the StreamingPipeline hierarchy.
             A pipeline from this class will be instantiated by each worker.
-        config: BasePipelineConfig
-            Diarization pipeline configuration.
+        config: StreamingConfig
+            Streaming pipeline configuration.
+        metric: Optional[BaseMetric]
+            Evaluation metric from pyannote.metrics.
+            Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
 
         Returns
         -------
@@ -400,7 +410,8 @@ def __call__(
             progress = TQDMProgressBar(desc, leave=False, do_close=True)
             predictions.append(self.run_single(pipeline, filepath, progress))
 
-        return self.evaluate(predictions)
+        metric = pipeline.suggest_metric() if metric is None else metric
+        return self.evaluate(predictions, metric)
 
 
 class Parallelize:
@@ -426,20 +437,20 @@ def __init__(
     def run_single_job(
         self,
         pipeline_class: type,
-        config: BasePipelineConfig,
+        config: blocks.StreamingConfig,
         filepath: Path,
         description: Text,
-    ):
+    ) -> Annotation:
         """Build and run a pipeline on a single file.
         Configure execution to show progress alongside parallel runs.
 
         Parameters
         ----------
         pipeline_class: class
-            Class from the BasePipeline hierarchy.
+            Class from the StreamingPipeline hierarchy.
             A pipeline from this class will be instantiated.
-        config: BasePipelineConfig
-            Diarization pipeline configuration.
+        config: StreamingConfig
+            Streaming pipeline configuration.
         filepath: Path
             Path to the target file.
         description: Text
@@ -463,7 +474,8 @@ def run_single_job(
     def __call__(
         self,
         pipeline_class: type,
-        config: BasePipelineConfig,
+        config: blocks.StreamingConfig,
+        metric: Optional[BaseMetric] = None,
     ) -> Union[pd.DataFrame, List[Annotation]]:
         """Run a given pipeline on a set of audio files in parallel.
         Each worker will build and run the pipeline on a different file.
@@ -471,10 +483,13 @@ def __call__(
         Parameters
         ----------
         pipeline_class: class
-            Class from the BasePipeline hierarchy.
+            Class from the StreamingPipeline hierarchy.
             A pipeline from this class will be instantiated by each worker.
-        config: BasePipelineConfig
-            Diarization pipeline configuration.
+        config: StreamingConfig
+            Streaming pipeline configuration.
+        metric: Optional[BaseMetric]
+            Evaluation metric from pyannote.metrics.
+            Defaults to the pipeline's suggested metric (see `StreamingPipeline.suggest_metric()`)
 
         Returns
         -------
@@ -512,4 +527,5 @@ def __call__(
         predictions = [job.get() for job in jobs]
 
         # Evaluate results
-        return self.benchmark.evaluate(predictions)
+        metric = pipeline_class.suggest_metric() if metric is None else metric
+        return self.benchmark.evaluate(predictions, metric)
diff --git a/src/diart/optim.py b/src/diart/optim.py
index 05800a05..f7a96a6e 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -1,51 +1,32 @@
 from collections import OrderedDict
-from dataclasses import dataclass
 from pathlib import Path
 from typing import Sequence, Text, Optional, Union
 
 from optuna import TrialPruned, Study, create_study
 from optuna.samplers import TPESampler
 from optuna.trial import Trial, FrozenTrial
+from pyannote.metrics.base import BaseMetric
 from tqdm import trange, tqdm
+from typing_extensions import Literal
 
+from . import blocks
 from .audio import FilePath
-from .blocks import BasePipelineConfig, PipelineConfig, OnlineSpeakerDiarization
 from .inference import Benchmark
 
 
-@dataclass
-class HyperParameter:
-    name: Text
-    low: float
-    high: float
-
-    @staticmethod
-    def from_name(name: Text) -> 'HyperParameter':
-        if name == "tau_active":
-            return TauActive
-        if name == "rho_update":
-            return RhoUpdate
-        if name == "delta_new":
-            return DeltaNew
-        raise ValueError(f"Hyper-parameter '{name}' not recognized")
-
-
-TauActive = HyperParameter("tau_active", low=0, high=1)
-RhoUpdate = HyperParameter("rho_update", low=0, high=1)
-DeltaNew = HyperParameter("delta_new", low=0, high=2)
-
-
 class Optimizer:
     def __init__(
         self,
+        pipeline_class: type,
         speech_path: Union[Text, Path],
         reference_path: Union[Text, Path],
         study_or_path: Union[FilePath, Study],
         batch_size: int = 32,
-        pipeline_class: type = OnlineSpeakerDiarization,
-        hparams: Optional[Sequence[HyperParameter]] = None,
-        base_config: Optional[BasePipelineConfig] = None,
+        hparams: Optional[Sequence[blocks.base.HyperParameter]] = None,
+        base_config: Optional[blocks.StreamingConfig] = None,
         do_kickstart_hparams: bool = True,
+        metric: Optional[BaseMetric] = None,
+        direction: Literal["minimize", "maximize"] = "minimize",
     ):
         self.pipeline_class = pipeline_class
         # FIXME can we run this benchmark in parallel?
@@ -58,15 +39,17 @@ def __init__(
             batch_size=batch_size,
         )
 
+        self.metric = metric
+        self.direction = direction
         self.base_config = base_config
         self.do_kickstart_hparams = do_kickstart_hparams
         if self.base_config is None:
-            self.base_config = PipelineConfig()
+            self.base_config = self.pipeline_class.get_config_class()()
             self.do_kickstart_hparams = False
 
         self.hparams = hparams
         if self.hparams is None:
-            self.hparams = [TauActive, RhoUpdate, DeltaNew]
+            self.hparams = self.pipeline_class.hyper_parameters()
 
         # Make sure hyper-parameters exist in the configuration class given
         possible_hparams = vars(self.base_config)
@@ -85,7 +68,7 @@ def __init__(
                 storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"),
                 sampler=TPESampler(),
                 study_name=study_or_path.stem,
-                direction="minimize",
+                direction=self.direction,
                 load_if_exists=True,
             )
         else:
@@ -105,7 +88,7 @@ def _callback(self, study: Study, trial: FrozenTrial):
             return
         self._progress.update(1)
         self._progress.set_description(f"Trial {trial.number + 1}")
-        values = {"best_der": study.best_value}
+        values = {"best_perf": study.best_value}
         for name, value in study.best_params.items():
             values[f"best_{name}"] = value
         self._progress.set_postfix(OrderedDict(values))
@@ -125,11 +108,16 @@ def objective(self, trial: Trial) -> float:
         # Instantiate the new configuration for the trial
         config = self.base_config.__class__(**trial_config)
 
+        # Determine the evaluation metric
+        metric = self.metric
+        if metric is None:
+            metric = self.pipeline_class.suggest_metric()
+
         # Run pipeline over the dataset
-        report = self.benchmark(self.pipeline_class, config)
+        report = self.benchmark(self.pipeline_class, config, metric)
 
-        # Extract DER from report
-        return report.loc["TOTAL", "diarization error rate"]["%"]
+        # Extract target metric from report
+        return report.loc["TOTAL", metric.name]["%"]
 
     def __call__(self, num_iter: int, show_progress: bool = True):
         self._progress = None
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index cf480bed..63c170d0 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -8,12 +8,14 @@
 from rx.core import Observer
 from typing_extensions import Literal
 
+from . import utils
+
 
 class WindowClosedException(Exception):
     pass
 
 
-def _extract_annotation(value: Union[Tuple, Annotation]) -> Annotation:
+def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation:
     if isinstance(value, tuple):
         return value[0]
     if isinstance(value, Annotation):
@@ -43,10 +45,11 @@ def patch(self):
                 annotation.support(self.patch_collar).write_rttm(file)
 
     def on_next(self, value: Union[Tuple, Annotation]):
-        annotation = _extract_annotation(value)
-        annotation.uri = self.uri
+        prediction = _extract_prediction(value)
+        # Write prediction in RTTM format
+        prediction.uri = self.uri
         with open(self.path, 'a') as file:
-            annotation.write_rttm(file)
+            prediction.write_rttm(file)
 
     def on_error(self, error: Exception):
         self.patch()
@@ -55,30 +58,30 @@ def on_completed(self):
         self.patch()
 
 
-class DiarizationPredictionAccumulator(Observer):
+class PredictionAccumulator(Observer):
     def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05):
         super().__init__()
         self.uri = uri
         self.patch_collar = patch_collar
-        self._annotation = None
+        self._prediction: Optional[Annotation] = None
 
     def patch(self):
         """Stitch same-speaker turns that are close to each other"""
-        if self._annotation is not None:
-            self._annotation = self._annotation.support(self.patch_collar)
+        if self._prediction is not None:
+            self._prediction = self._prediction.support(self.patch_collar)
 
     def get_prediction(self) -> Annotation:
         # Patch again in case this is called before on_completed
         self.patch()
-        return self._annotation
+        return self._prediction
 
     def on_next(self, value: Union[Tuple, Annotation]):
-        annotation = _extract_annotation(value)
-        annotation.uri = self.uri
-        if self._annotation is None:
-            self._annotation = annotation
+        prediction = _extract_prediction(value)
+        prediction.uri = self.uri
+        if self._prediction is None:
+            self._prediction = prediction
         else:
-            self._annotation.update(annotation)
+            self._prediction.update(prediction)
 
     def on_error(self, error: Exception):
         self.patch()
@@ -87,7 +90,7 @@ def on_completed(self):
         self.patch()
 
 
-class RealTimePlot(Observer):
+class StreamingPlot(Observer):
     def __init__(
         self,
         duration: float,
@@ -134,11 +137,15 @@ def get_plot_bounds(self, real_time: float) -> Segment:
             start_time = max(0., end_time - self.window_duration)
         return Segment(start_time, end_time)
 
-    def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
+    def on_next(
+        self,
+        values: Tuple[Annotation, SlidingWindowFeature, float]
+    ):
         if self.window_closed:
             raise WindowClosedException
 
         prediction, waveform, real_time = values
+
         # Initialize figure if first call
         if self.figure is None:
             self._init_figure()
@@ -147,15 +154,21 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
         # Set plot bounds
         notebook.crop = self.get_plot_bounds(real_time)
 
-        # Plot current values
+        # Align prediction and reference if possible
         if self.reference is not None:
             metric = DiarizationErrorRate()
             mapping = metric.optimal_mapping(self.reference, prediction)
             prediction.rename_labels(mapping=mapping, copy=False)
+
+        # Plot prediction
         notebook.plot_annotation(prediction, self.axs[0])
         self.axs[0].set_title("Output")
+
+        # Plot waveform
         notebook.plot_feature(waveform, self.axs[1])
         self.axs[1].set_title("Audio")
+
+        # Plot reference if available
         if self.num_axs == 3:
             notebook.plot_annotation(self.reference, self.axs[2])
             self.axs[2].set_title("Reference")
diff --git a/src/diart/sources.py b/src/diart/sources.py
index 0f5dedf7..b34d5cf3 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -5,12 +5,12 @@
 import numpy as np
 import sounddevice as sd
 import torch
-from diart import utils
 from einops import rearrange
 from rx.subject import Subject
 from torchaudio.io import StreamReader
 from websocket_server import WebsocketServer
 
+from . import utils
 from .audio import FilePath, AudioLoader
 
 
diff --git a/src/diart/utils.py b/src/diart/utils.py
index e90861c7..e825ef29 100644
--- a/src/diart/utils.py
+++ b/src/diart/utils.py
@@ -4,9 +4,11 @@
 
 import matplotlib.pyplot as plt
 import numpy as np
-from diart.progress import ProgressBar
 from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
 
+from .progress import ProgressBar
+from . import blocks
+
 
 class Chronometer:
     def __init__(self, unit: Text, progress_bar: Optional[ProgressBar] = None):
@@ -74,6 +76,18 @@ def get_padding_left(stream_duration: float, chunk_duration: float) -> float:
     return 0
 
 
+def repeat_label(label: Text):
+    while True:
+        yield label
+
+
+def get_pipeline_class(class_name: Text) -> type:
+    pipeline_class = getattr(blocks, class_name, None)
+    msg = f"Pipeline '{class_name}' doesn't exist"
+    assert pipeline_class is not None, msg
+    return pipeline_class
+
+
 def get_padding_right(latency: float, step: float) -> float:
     return latency - step
 

From 6caa4a4ab9b2e8c2ab7bc5049b1112f325a08d5b Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Wed, 19 Apr 2023 17:43:51 +0200
Subject: [PATCH 18/23] Update link in setup.cfg

---
 setup.cfg | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index 594c876e..e67e4426 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -2,11 +2,11 @@
 name=diart
 version=0.7.0
 author=Juan Manuel Coria
-description=Speaker diarization in real time
+description=Streaming speaker diarization in real-time
 long_description=file: README.md
 long_description_content_type=text/markdown
 keywords=speaker diarization, streaming, online, real time, rxpy
-url=https://github.com/juanmc2005/StreamingSpeakerDiarization
+url=https://github.com/juanmc2005/diart
 license=MIT
 classifiers=
     Development Status :: 4 - Beta

From 0993fe85411d09f3b0bb5db709b7b02a7a56f0be Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Wed, 19 Apr 2023 17:51:41 +0200
Subject: [PATCH 19/23] Update code snippets in README

---
 README.md | 42 +++++++++++++++++++++---------------------
 1 file changed, 21 insertions(+), 21 deletions(-)

diff --git a/README.md b/README.md
index ef533946..57ca293a 100644
--- a/README.md
+++ b/README.md
@@ -110,17 +110,17 @@ See `diart.stream -h` for more options.
 
 ### From python
 
-Use `RealTimeInference` to easily run a pipeline on an audio source and write the results to disk:
+Use `StreamingInference` to run a pipeline on an audio source and write the results to disk:
 
 ```python
-from diart import OnlineSpeakerDiarization
+from diart import SpeakerDiarization
 from diart.sources import MicrophoneAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
 from diart.sinks import RTTMWriter
 
-pipeline = OnlineSpeakerDiarization()
+pipeline = SpeakerDiarization()
 mic = MicrophoneAudioSource(pipeline.config.sample_rate)
-inference = RealTimeInference(pipeline, mic, do_plot=True)
+inference = StreamingInference(pipeline, mic, do_plot=True)
 inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
 prediction = inference()
 ```
@@ -129,13 +129,13 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n
 
 ## 🤖 Custom models
 
-Third-party models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel` (which are PyTorch `Module` subclasses):
+Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):
 
 ```python
-from diart import OnlineSpeakerDiarization, PipelineConfig
+from diart import SpeakerDiarization, SpeakerDiarizationConfig
 from diart.models import EmbeddingModel, SegmentationModel
 from diart.sources import MicrophoneAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
 
 
 def model_loader():
@@ -168,19 +168,19 @@ class MyEmbeddingModel(EmbeddingModel):
         return self.model(waveform, weights)
 
     
-config = PipelineConfig(
+config = SpeakerDiarizationConfig(
     segmentation=MySegmentationModel(),
     embedding=MyEmbeddingModel()
 )
-pipeline = OnlineSpeakerDiarization(config)
+pipeline = SpeakerDiarization(config)
 mic = MicrophoneAudioSource(config.sample_rate)
-inference = RealTimeInference(pipeline, mic)
+inference = StreamingInference(pipeline, mic)
 prediction = inference()
 ```
 
 ## 📈 Tune hyper-parameters
 
-Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.
+Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs.
 
 ### From the command line
 
@@ -281,7 +281,7 @@ diart.serve --host 0.0.0.0 --port 7007
 diart.client microphone --host <server-address> --port 7007
 ```
 
-**Note:** please make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
+**Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
 
 See `-h` for more options.
 
@@ -290,13 +290,13 @@ See `-h` for more options.
 For customized solutions, a server can also be created in python using the `WebSocketAudioSource`:
 
 ```python
-from diart import OnlineSpeakerDiarization
+from diart import SpeakerDiarization
 from diart.sources import WebSocketAudioSource
-from diart.inference import RealTimeInference
+from diart.inference import StreamingInference
 
-pipeline = OnlineSpeakerDiarization()
+pipeline = SpeakerDiarization()
 source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
-inference = RealTimeInference(pipeline, source)
+inference = StreamingInference(pipeline, source)
 inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
 prediction = inference()
 ```
@@ -354,14 +354,14 @@ or using the inference API:
 
 ```python
 from diart.inference import Benchmark, Parallelize
-from diart import OnlineSpeakerDiarization, PipelineConfig
+from diart import SpeakerDiarization, SpeakerDiarizationConfig
 from diart.models import SegmentationModel
 
 benchmark = Benchmark("/wav/dir", "/rttm/dir")
 
 name = "pyannote/segmentation@Interspeech2021"
 segmentation = SegmentationModel.from_pyannote(name)
-config = PipelineConfig(
+config = SpeakerDiarizationConfig(
     # Set the model used in the paper
     segmentation=segmentation,
     step=0.5,
@@ -370,12 +370,12 @@ config = PipelineConfig(
     rho_update=0.422,
     delta_new=1.517
 )
-benchmark(OnlineSpeakerDiarization, config)
+benchmark(SpeakerDiarization, config)
 
 # Run the same benchmark in parallel
 p_benchmark = Parallelize(benchmark, num_workers=4)
 if __name__ == "__main__":  # Needed for multiprocessing
-    p_benchmark(OnlineSpeakerDiarization, config)
+    p_benchmark(SpeakerDiarization, config)
 ```
 
 This pre-calculates model outputs in batches, so it runs a lot faster.

From 95d4fae66dea06e1cbb12ac591e5f323687cd02f Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Wed, 19 Apr 2023 21:18:36 +0200
Subject: [PATCH 20/23] Add minor README modifications

---
 README.md | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/README.md b/README.md
index 57ca293a..ae13059f 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,7 @@
     </a>
     <span> | </span>
     <a href="#-custom-models">
-      🤖 Custom models
+      🤖 Add your model
     </a>
     <span> | </span>
     <a href="#-tune-hyper-parameters">
@@ -127,7 +127,7 @@ prediction = inference()
 
 For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).
 
-## 🤖 Custom models
+## 🤖 Add your model
 
 Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):
 

From 569c68fa5648c9c940dae215ed557582f17a513f Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Mon, 24 Apr 2023 11:25:51 +0200
Subject: [PATCH 21/23] Rename base pipeline and config objects

---
 src/diart/__init__.py           |  4 ++--
 src/diart/blocks/__init__.py    |  2 +-
 src/diart/blocks/base.py        |  8 ++++----
 src/diart/blocks/diarization.py |  4 ++--
 src/diart/blocks/vad.py         |  6 +++---
 src/diart/inference.py          | 10 +++++-----
 src/diart/optim.py              |  2 +-
 7 files changed, 18 insertions(+), 18 deletions(-)

diff --git a/src/diart/__init__.py b/src/diart/__init__.py
index e29287a0..4bd51327 100644
--- a/src/diart/__init__.py
+++ b/src/diart/__init__.py
@@ -1,8 +1,8 @@
 from .blocks import (
     SpeakerDiarization,
-    StreamingPipeline,
+    Pipeline,
     SpeakerDiarizationConfig,
-    StreamingConfig,
+    PipelineConfig,
     VoiceActivityDetection,
     VoiceActivityDetectionConfig,
 )
diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py
index e6e8c479..15cf81d9 100644
--- a/src/diart/blocks/__init__.py
+++ b/src/diart/blocks/__init__.py
@@ -14,6 +14,6 @@
 )
 from .segmentation import SpeakerSegmentation
 from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
-from .base import StreamingConfig, StreamingPipeline
+from .base import PipelineConfig, Pipeline
 from .utils import Binarize, Resample, AdjustVolume
 from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig
diff --git a/src/diart/blocks/base.py b/src/diart/blocks/base.py
index 28f313eb..11ef961d 100644
--- a/src/diart/blocks/base.py
+++ b/src/diart/blocks/base.py
@@ -31,7 +31,7 @@ def from_name(name: Text) -> 'HyperParameter':
 DeltaNew = HyperParameter("delta_new", low=0, high=2)
 
 
-class StreamingConfig:
+class PipelineConfig:
     @property
     def duration(self) -> float:
         raise NotImplementedError
@@ -49,7 +49,7 @@ def sample_rate(self) -> int:
         raise NotImplementedError
 
     @staticmethod
-    def from_dict(data: Any) -> 'StreamingConfig':
+    def from_dict(data: Any) -> 'PipelineConfig':
         raise NotImplementedError
 
     def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
@@ -62,7 +62,7 @@ def optimal_block_size(self) -> int:
         return int(np.rint(self.step * self.sample_rate))
 
 
-class StreamingPipeline:
+class Pipeline:
     @staticmethod
     def get_config_class() -> type:
         raise NotImplementedError
@@ -76,7 +76,7 @@ def hyper_parameters() -> Sequence[HyperParameter]:
         raise NotImplementedError
 
     @property
-    def config(self) -> StreamingConfig:
+    def config(self) -> PipelineConfig:
         raise NotImplementedError
 
     def reset(self):
diff --git a/src/diart/blocks/diarization.py b/src/diart/blocks/diarization.py
index f2a25119..06658cfc 100644
--- a/src/diart/blocks/diarization.py
+++ b/src/diart/blocks/diarization.py
@@ -17,7 +17,7 @@
 from .. import utils
 
 
-class SpeakerDiarizationConfig(base.StreamingConfig):
+class SpeakerDiarizationConfig(base.PipelineConfig):
     def __init__(
         self,
         segmentation: Optional[m.SegmentationModel] = None,
@@ -129,7 +129,7 @@ def sample_rate(self) -> int:
         return self._sample_rate
 
 
-class SpeakerDiarization(base.StreamingPipeline):
+class SpeakerDiarization(base.Pipeline):
     def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
         self._config = SpeakerDiarizationConfig() if config is None else config
 
diff --git a/src/diart/blocks/vad.py b/src/diart/blocks/vad.py
index def833b6..e519a9cf 100644
--- a/src/diart/blocks/vad.py
+++ b/src/diart/blocks/vad.py
@@ -15,7 +15,7 @@
 from .. import utils
 
 
-class VoiceActivityDetectionConfig(base.StreamingConfig):
+class VoiceActivityDetectionConfig(base.PipelineConfig):
     def __init__(
         self,
         segmentation: Optional[m.SegmentationModel] = None,
@@ -96,7 +96,7 @@ def from_dict(data: Any) -> 'VoiceActivityDetectionConfig':
         )
 
 
-class VoiceActivityDetection(base.StreamingPipeline):
+class VoiceActivityDetection(base.Pipeline):
     def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
         self._config = VoiceActivityDetectionConfig() if config is None else config
 
@@ -135,7 +135,7 @@ def hyper_parameters() -> Sequence[base.HyperParameter]:
         return [base.TauActive]
 
     @property
-    def config(self) -> base.StreamingConfig:
+    def config(self) -> base.PipelineConfig:
         return self._config
 
     def reset(self):
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 6afda89e..f562fdd9 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -53,7 +53,7 @@ class StreamingInference:
     """
     def __init__(
         self,
-        pipeline: blocks.StreamingPipeline,
+        pipeline: blocks.Pipeline,
         source: src.AudioSource,
         batch_size: int = 1,
         do_profile: bool = True,
@@ -289,7 +289,7 @@ def get_file_paths(self) -> List[Path]:
 
     def run_single(
         self,
-        pipeline: blocks.StreamingPipeline,
+        pipeline: blocks.Pipeline,
         filepath: Path,
         progress_bar: ProgressBar,
     ) -> Annotation:
@@ -374,7 +374,7 @@ def evaluate(
     def __call__(
         self,
         pipeline_class: type,
-        config: blocks.StreamingConfig,
+        config: blocks.PipelineConfig,
         metric: Optional[BaseMetric] = None,
     ) -> Union[pd.DataFrame, List[Annotation]]:
         """Run a given pipeline on a set of audio files.
@@ -437,7 +437,7 @@ def __init__(
     def run_single_job(
         self,
         pipeline_class: type,
-        config: blocks.StreamingConfig,
+        config: blocks.PipelineConfig,
         filepath: Path,
         description: Text,
     ) -> Annotation:
@@ -474,7 +474,7 @@ def run_single_job(
     def __call__(
         self,
         pipeline_class: type,
-        config: blocks.StreamingConfig,
+        config: blocks.PipelineConfig,
         metric: Optional[BaseMetric] = None,
     ) -> Union[pd.DataFrame, List[Annotation]]:
         """Run a given pipeline on a set of audio files in parallel.
diff --git a/src/diart/optim.py b/src/diart/optim.py
index f7a96a6e..86492627 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -23,7 +23,7 @@ def __init__(
         study_or_path: Union[FilePath, Study],
         batch_size: int = 32,
         hparams: Optional[Sequence[blocks.base.HyperParameter]] = None,
-        base_config: Optional[blocks.StreamingConfig] = None,
+        base_config: Optional[blocks.PipelineConfig] = None,
         do_kickstart_hparams: bool = True,
         metric: Optional[BaseMetric] = None,
         direction: Literal["minimize", "maximize"] = "minimize",

From a16bb5c40c0c0850f3ab0bf197ea0a86ef214484 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Mon, 24 Apr 2023 18:30:33 +0200
Subject: [PATCH 22/23] Add initial implementation of SpeakerAwareTranscription

---
 src/diart/inference.py                       |   2 +-
 src/diart/pipelines/__init__.py              |   1 +
 src/diart/pipelines/speaker_transcription.py | 316 +++++++++++++++++++
 src/diart/pipelines/transcription.py         |  28 --
 src/diart/sinks.py                           |  21 +-
 5 files changed, 337 insertions(+), 31 deletions(-)
 create mode 100644 src/diart/pipelines/speaker_transcription.py

diff --git a/src/diart/inference.py b/src/diart/inference.py
index eb2de85e..1589feed 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -116,7 +116,7 @@ def __init__(
 
         self.stream = self.stream.pipe(
             ops.flat_map(lambda results: rx.from_iterable(results)),
-            ops.do_action(lambda pred_wav: self._predictions.append(pred_wav[0])),
+            ops.do_action(lambda res: self._predictions.append(res[0] if isinstance(res, tuple) else res)),
         )
 
         if show_progress:
diff --git a/src/diart/pipelines/__init__.py b/src/diart/pipelines/__init__.py
index 11fe5e82..55676f66 100644
--- a/src/diart/pipelines/__init__.py
+++ b/src/diart/pipelines/__init__.py
@@ -1,4 +1,5 @@
 from .base import Pipeline, PipelineConfig
 from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
+from .speaker_transcription import SpeakerAwareTranscription, SpeakerAwareTranscriptionConfig
 from .transcription import Transcription, TranscriptionConfig
 from .voice import VoiceActivityDetection, VoiceActivityDetectionConfig
diff --git a/src/diart/pipelines/speaker_transcription.py b/src/diart/pipelines/speaker_transcription.py
new file mode 100644
index 00000000..87d7d192
--- /dev/null
+++ b/src/diart/pipelines/speaker_transcription.py
@@ -0,0 +1,316 @@
+from pathlib import Path
+from typing import Any, Optional, Union, Sequence, Tuple, Text, List
+
+import numpy as np
+import torch
+from diart.metrics import Metric
+from pyannote.core import SlidingWindowFeature, SlidingWindow, Annotation, Segment
+from rx.core import Observer
+from typing_extensions import Literal
+
+from .base import Pipeline, PipelineConfig
+from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
+from .hparams import HyperParameter, TauActive, RhoUpdate, DeltaNew
+from .. import models as m
+from .. import sinks
+from .. import blocks
+from .. import utils
+from ..metrics import WordErrorRate
+
+
+class SpeakerAwareTranscriptionConfig(PipelineConfig):
+    def __init__(
+        self,
+        asr: Optional[m.SpeechRecognitionModel] = None,
+        segmentation: Optional[m.SegmentationModel] = None,
+        embedding: Optional[m.EmbeddingModel] = None,
+        duration: Optional[float] = None,
+        asr_duration: float = 3,
+        step: float = 0.5,
+        latency: Optional[Union[float, Literal["max", "min"]]] = None,
+        tau_active: float = 0.5,
+        rho_update: float = 0.3,
+        delta_new: float = 1,
+        language: Optional[Text] = None,
+        beam_size: Optional[int] = None,
+        gamma: float = 3,
+        beta: float = 10,
+        max_speakers: int = 20,
+        merge_collar: float = 0.05,
+        diarization_device: Optional[torch.device] = None,
+        asr_device: Optional[torch.device] = None,
+        **kwargs,
+    ):
+        # Default segmentation model is pyannote/segmentation
+        self.segmentation = segmentation
+        if self.segmentation is None:
+            self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
+
+        self._duration = duration
+        self._sample_rate: Optional[int] = None
+
+        # Default embedding model is pyannote/embedding
+        self.embedding = embedding
+        if self.embedding is None:
+            self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
+
+        # Latency defaults to the step duration
+        self._step = step
+        self._latency = latency
+        if self._latency is None or self._latency == "min":
+            self._latency = self._step
+        elif self._latency == "max":
+            self._latency = self.duration
+
+        self.tau_active = tau_active
+        self.rho_update = rho_update
+        self.delta_new = delta_new
+        self.gamma = gamma
+        self.beta = beta
+        self.max_speakers = max_speakers
+        self.merge_collar = merge_collar
+        self.asr_duration = asr_duration
+
+        self.diarization_device = diarization_device
+        if self.diarization_device is None:
+            self.diarization_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        self.language = language
+        self.beam_size = beam_size
+
+        self.asr_device = asr_device
+        if self.asr_device is None:
+            self.asr_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        # Default ASR model is Whisper small (244M parameters)
+        self.asr = asr
+        if self.asr is None:
+            self.asr = m.SpeechRecognitionModel.from_whisper("small")
+        self.asr.set_language(self.language)
+        self.asr.set_beam_size(self.beam_size)
+
+    def to_diarization_config(self) -> SpeakerDiarizationConfig:
+        return SpeakerDiarizationConfig(
+            segmentation=self.segmentation,
+            embedding=self.embedding,
+            duration=self.duration,
+            step=self.step,
+            latency=self.latency,
+            tau_active=self.tau_active,
+            rho_update=self.rho_update,
+            delta_new=self.delta_new,
+            gamma=self.gamma,
+            beta=self.beta,
+            max_speakers=self.max_speakers,
+            merge_collar=self.merge_collar,
+            device=self.diarization_device,
+        )
+
+    @property
+    def duration(self) -> float:
+        # Default duration is the one given by the segmentation model
+        if self._duration is None:
+            self._duration = self.segmentation.duration
+        return self._duration
+
+    @property
+    def step(self) -> float:
+        return self._step
+
+    @property
+    def latency(self) -> float:
+        return self._latency
+
+    @property
+    def sample_rate(self) -> int:
+        if self._sample_rate is None:
+            dia_sample_rate = self.segmentation.sample_rate
+            asr_sample_rate = self.asr.sample_rate
+            msg = "Sample rates for speech recognition and speaker segmentation models must match"
+            assert dia_sample_rate == asr_sample_rate, msg
+            self._sample_rate = dia_sample_rate
+        return self._sample_rate
+
+    @staticmethod
+    def from_dict(data: Any) -> 'SpeakerAwareTranscriptionConfig':
+        # Resolve arguments exactly like diarization
+        dia_config = SpeakerDiarizationConfig.from_dict(data)
+
+        # Default ASR model is Whisper small (244M parameters)
+        whisper_size = utils.get(data, "whisper", "small")
+        asr = m.SpeechRecognitionModel.from_whisper(whisper_size)
+
+        return SpeakerAwareTranscriptionConfig(
+            asr=asr,
+            segmentation=dia_config.segmentation,
+            embedding=dia_config.embedding,
+            duration=dia_config.duration,
+            asr_duration=utils.get(data, "asr_duration", 3),
+            step=dia_config.step,
+            latency=dia_config.latency,
+            tau_active=dia_config.tau_active,
+            rho_update=dia_config.rho_update,
+            delta_new=dia_config.delta_new,
+            language=utils.get(data, "language", None),
+            beam_size=utils.get(data, "beam_size", None),
+            gamma=dia_config.gamma,
+            beta=dia_config.beta,
+            max_speakers=dia_config.max_speakers,
+            merge_collar=dia_config.merge_collar,
+            diarization_device=dia_config.device,
+            # TODO handle different devices
+            asr_device=dia_config.device,
+        )
+
+
+class SpeakerAwareTranscription(Pipeline):
+    def __init__(self, config: Optional[SpeakerAwareTranscriptionConfig] = None):
+        self._config = SpeakerAwareTranscriptionConfig() if config is None else config
+        self.diarization = SpeakerDiarization(self.config.to_diarization_config())
+        self.asr = blocks.SpeechRecognition(self.config.asr, self.config.asr_device)
+
+        # Internal state, handle with care
+        self.audio_buffer, self.dia_buffer = None, None
+
+    @staticmethod
+    def get_config_class() -> type:
+        return SpeakerAwareTranscriptionConfig
+
+    @staticmethod
+    def suggest_metric() -> Metric:
+        # TODO per-speaker WER?
+        return WordErrorRate()
+
+    @staticmethod
+    def hyper_parameters() -> Sequence[HyperParameter]:
+        return [TauActive, RhoUpdate, DeltaNew]
+
+    @property
+    def config(self) -> SpeakerAwareTranscriptionConfig:
+        return self._config
+
+    def reset(self):
+        self.diarization.reset()
+        self.audio_buffer, self.dia_buffer = None, None
+
+    def set_timestamp_shift(self, shift: float):
+        self.diarization.set_timestamp_shift(shift)
+
+    def join_predictions(self, predictions: List[Text]) -> Text:
+        return "\n".join(predictions)
+
+    def write_prediction(self, uri: Text, prediction: Text, dir_path: Union[Text, Path]):
+        with open(Path(dir_path) / f"{uri}.txt", "w") as out_file:
+            out_file.write(prediction)
+
+    def suggest_display(self) -> Observer:
+        return sinks.RichScreen()
+
+    def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
+        return sinks.TextWriter(Path(output_dir) / f"{uri}.txt")
+
+    def __call__(
+        self,
+        waveforms: Sequence[SlidingWindowFeature],
+    ) -> Sequence[Text]:
+        # Compute diarization output
+        diarization_output = self.diarization(waveforms)
+
+        first_chunk = diarization_output[0][1]
+        output_start = first_chunk.extent.start
+        resolution = first_chunk.sliding_window.duration
+        diarization, chunk_data = Annotation(), []
+        for dia, chunk in diarization_output:
+            diarization = diarization.update(dia)
+            chunk_data.append(chunk.data)
+
+        # Update diarization output buffer
+        if self.dia_buffer is None:
+            self.dia_buffer = diarization
+        else:
+            self.dia_buffer = self.dia_buffer.update(diarization)
+        self.dia_buffer = self.dia_buffer.support(self.config.merge_collar)
+
+        # Update audio buffer
+        if self.audio_buffer is None:
+            window = SlidingWindow(resolution, resolution, output_start)
+            self.audio_buffer = SlidingWindowFeature(np.concatenate(chunk_data, axis=0), window)
+        else:
+            chunk_data.insert(0, self.audio_buffer.data)
+            self.audio_buffer = SlidingWindowFeature(
+                np.concatenate(chunk_data, axis=0),
+                self.audio_buffer.sliding_window
+            )
+
+        # Extract audio to transcribe from the buffer
+        asr_duration = self.config.asr_duration
+        buffer_duration = self.audio_buffer.extent.duration
+        asr_batch_size = int(buffer_duration / asr_duration)
+
+        if asr_batch_size == 0:
+            return ["" for _ in waveforms]
+
+        buffer_start = self.audio_buffer.extent.start
+        asr_inputs, input_dia, last_end_time = [], [], None
+        for i in range(asr_batch_size):
+            start = buffer_start + i * asr_duration
+            last_end_time = start + asr_duration
+            region = Segment(start, last_end_time)
+            chunk = self.audio_buffer.crop(region, fixed=asr_duration)
+            window = SlidingWindow(resolution, resolution, start)
+            asr_inputs.append(SlidingWindowFeature(chunk, window))
+            input_dia.append(self.dia_buffer.crop(region))
+
+        # Create ASR batch, shape (batch, samples, channels)
+        batch = torch.stack([torch.from_numpy(w.data) for w in asr_inputs])
+
+        # Remove transcribed chunks from buffer
+        new_buffer_bounds = Segment(last_end_time, self.audio_buffer.extent.end)
+        new_buffer = self.audio_buffer.crop(new_buffer_bounds, fixed=new_buffer_bounds.duration)
+        window = SlidingWindow(resolution, resolution, last_end_time)
+        self.audio_buffer = SlidingWindowFeature(new_buffer, window)
+        self.dia_buffer = self.dia_buffer.extrude(Segment(0, last_end_time))
+
+        # Filter out non-speech chunks
+        has_voice = torch.tensor([dia.get_timeline().duration() > 0 for dia in input_dia])
+        has_voice = torch.where(has_voice)[0]
+
+        # Return empty list if no speech in the entire batch
+        if len(has_voice) == 0:
+            return ["" for _ in waveforms]
+
+        # Transcribe batch
+        outputs = self.asr(batch[has_voice])
+
+        # Align transcription with diarization to determine speakers
+        full_transcription = []
+        for i, waveform in enumerate(asr_inputs):
+            if i not in has_voice:
+                continue
+            buffer_shift = waveform.sliding_window.start
+            for text, timestamp in zip(outputs[i].chunks, outputs[i].timestamps):
+                if not text.strip():
+                    continue
+                target_region = Segment(
+                    buffer_shift + timestamp.start,
+                    buffer_shift + timestamp.end,
+                )
+                dia = input_dia[i].crop(target_region)
+                speakers = dia.labels()
+                num_speakers = len(speakers)
+                if num_speakers == 0:
+                    # Include transcription but don't assign a speaker
+                    full_transcription.append(text)
+                elif num_speakers == 1:
+                    # Typical case, annotate text with the only speaker
+                    full_transcription.append(f"[{speakers[0]}]{text}")
+                else:
+                    # Multiple speakers for the same text block, choose the most active one
+                    max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
+                    full_transcription.append(f"[{speakers[max_spk]}]{text}")
+
+        batch_size = len(waveforms)
+        output = [" ".join(full_transcription).strip()]
+        if batch_size > 1:
+            output += [""] * (batch_size - 1)
+        return output
diff --git a/src/diart/pipelines/transcription.py b/src/diart/pipelines/transcription.py
index 3616222e..49b60175 100644
--- a/src/diart/pipelines/transcription.py
+++ b/src/diart/pipelines/transcription.py
@@ -182,31 +182,3 @@ def __call__(
             (outputs[mapping[i]].text if i in has_voice else "", waveforms[i])
             for i in range(batch_size)
         ]
-
-        # TODO align text with speakers if diarization is not None
-
-        # diarization = diarization[0]
-        #
-        # # Align transcription with diarization to determine speakers
-        # full_transcription = []
-        # buffer_shift = waveform.sliding_window.start
-        # for text, timestamp in zip(outputs.chunks, outputs.timestamps):
-        #     target_region = Segment(
-        #         buffer_shift + timestamp.start,
-        #         buffer_shift + timestamp.end
-        #     )
-        #     dia = diarization.crop(target_region)
-        #     speakers = dia.labels()
-        #     num_speakers = len(speakers)
-        #     if num_speakers == 0:
-        #         # Include transcription but don't assign a speaker
-        #         full_transcription.append(text)
-        #     elif num_speakers == 1:
-        #         # Typical case, annotate text with the only speaker
-        #         full_transcription.append(f"[{speakers[0]}]{text}")
-        #     else:
-        #         # Multiple speakers for the same text block, choose the most active one
-        #         max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
-        #         full_transcription.append(f"[{speakers[max_spk]}]{text}")
-        #
-        # return [(" ".join(full_transcription).strip(), waveform)]
\ No newline at end of file
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index be461fff..fbc46904 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -62,6 +62,9 @@ def __init__(self, path: Union[Path, Text]):
         if self.path.exists():
             self.path.unlink()
 
+    def on_error(self, error: Exception):
+        pass
+
     def on_next(self, value: Union[Tuple, Text]):
         # Write transcription to file
         prediction = _extract_prediction(value)
@@ -111,15 +114,25 @@ def __init__(self, speaker_colors: Optional[List[Text]] = None):
                 "yellow2", "magenta", "cyan", "bright_magenta", "dodger_blue2"
             ]
         self.num_colors = len(self.colors)
+        self._speaker_to_color = {}
+
+    def on_error(self, error: Exception):
+        pass
 
     def on_next(self, value: Union[Tuple, Text]):
         prediction = _extract_prediction(value)
+        if not prediction.strip():
+            return
         # Extract speakers
         speakers = sorted(re.findall(r'\[.*?]', prediction))
         # Colorize based on speakers
         colorized = prediction
-        for i, speaker in enumerate(speakers):
-            colorized = colorized.replace(speaker, f"[{self.colors[i % self.num_colors]}]")
+        for spk in speakers:
+            name = spk[1:-1]
+            if name not in self._speaker_to_color:
+                next_color_idx = len(self._speaker_to_color) % self.num_colors
+                self._speaker_to_color[name] = self.colors[next_color_idx]
+            colorized = colorized.replace(spk, f"[{self._speaker_to_color[name]}]")
         # Print result
         rich.print(colorized)
 
@@ -180,6 +193,10 @@ def get_plot_bounds(self) -> Segment:
         start_time = max(0., end_time - self.window_duration)
         return Segment(start_time, end_time)
 
+    def on_error(self, error: Exception):
+        # Do nothing on error
+        pass
+
     def on_next(
         self,
         values: Tuple[Annotation, SlidingWindowFeature]

From c7bbcc43aadd65358da6d95065f7a250a1f3ad81 Mon Sep 17 00:00:00 2001
From: juanmc2005 <juanmc2005@hotmail.com>
Date: Wed, 26 Apr 2023 14:37:34 +0200
Subject: [PATCH 23/23] Refactor SpeakerAwareTranscription

---
 src/diart/console/benchmark.py               |   2 +
 src/diart/console/serve.py                   |   4 +-
 src/diart/console/stream.py                  |   2 +
 src/diart/console/tune.py                    |   2 +
 src/diart/pipelines/speaker_transcription.py | 120 +++++++++++--------
 src/diart/utils.py                           |   6 +-
 6 files changed, 86 insertions(+), 50 deletions(-)

diff --git a/src/diart/console/benchmark.py b/src/diart/console/benchmark.py
index d8f04183..3c87edab 100644
--- a/src/diart/console/benchmark.py
+++ b/src/diart/console/benchmark.py
@@ -24,6 +24,8 @@ def run():
                         help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
     parser.add_argument("--duration", default=5, type=float,
                         help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
+    parser.add_argument("--asr-duration", default=3, type=float,
+                        help=f"Duration of the transcription window (in seconds). Defaults to 3")
     parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
     parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
     parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py
index 0698ede0..fe668c5b 100644
--- a/src/diart/console/serve.py
+++ b/src/diart/console/serve.py
@@ -24,6 +24,8 @@ def run():
                         help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
     parser.add_argument("--duration", type=float,
                         help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
+    parser.add_argument("--asr-duration", default=3, type=float,
+                        help=f"Duration of the transcription window (in seconds). Defaults to 3")
     parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
     parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
     parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -61,7 +63,7 @@ def run():
         inference.attach_observers(pipeline.suggest_writer(audio_source.uri, args.output))
 
     # Send back responses as text
-    inference.attach_hooks(lambda pred_wav: audio_source.send(utils.serialize_prediction(pred_wav[0])))
+    inference.attach_hooks(lambda result: audio_source.send(utils.serialize_prediction(result)))
 
     # Run server and pipeline
     inference()
diff --git a/src/diart/console/stream.py b/src/diart/console/stream.py
index 1436eb8a..af8e2cf8 100644
--- a/src/diart/console/stream.py
+++ b/src/diart/console/stream.py
@@ -23,6 +23,8 @@ def run():
                         help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
     parser.add_argument("--duration", default=5, type=float,
                         help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
+    parser.add_argument("--asr-duration", default=3, type=float,
+                        help=f"Duration of the transcription window (in seconds). Defaults to 3")
     parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
     parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
     parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
diff --git a/src/diart/console/tune.py b/src/diart/console/tune.py
index f492c704..ea34ed97 100644
--- a/src/diart/console/tune.py
+++ b/src/diart/console/tune.py
@@ -26,6 +26,8 @@ def run():
                         help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding")
     parser.add_argument("--duration", default=5, type=float,
                         help=f"Duration of the sliding window (in seconds). Default value depends on the pipeline")
+    parser.add_argument("--asr-duration", default=3, type=float,
+                        help=f"Duration of the transcription window (in seconds). Defaults to 3")
     parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
     parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
     parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
diff --git a/src/diart/pipelines/speaker_transcription.py b/src/diart/pipelines/speaker_transcription.py
index 87d7d192..ad2303fa 100644
--- a/src/diart/pipelines/speaker_transcription.py
+++ b/src/diart/pipelines/speaker_transcription.py
@@ -209,13 +209,8 @@ def suggest_display(self) -> Observer:
     def suggest_writer(self, uri: Text, output_dir: Union[Text, Path]) -> Observer:
         return sinks.TextWriter(Path(output_dir) / f"{uri}.txt")
 
-    def __call__(
-        self,
-        waveforms: Sequence[SlidingWindowFeature],
-    ) -> Sequence[Text]:
-        # Compute diarization output
-        diarization_output = self.diarization(waveforms)
-
+    def _update_buffers(self, diarization_output: Sequence[Tuple[Annotation, SlidingWindowFeature]]):
+        # Separate diarization and aligned audio chunks
         first_chunk = diarization_output[0][1]
         output_start = first_chunk.extent.start
         resolution = first_chunk.sliding_window.duration
@@ -239,78 +234,109 @@ def __call__(
             chunk_data.insert(0, self.audio_buffer.data)
             self.audio_buffer = SlidingWindowFeature(
                 np.concatenate(chunk_data, axis=0),
-                self.audio_buffer.sliding_window
+                self.audio_buffer.sliding_window,
             )
 
-        # Extract audio to transcribe from the buffer
-        asr_duration = self.config.asr_duration
+    def _extract_asr_inputs(self) -> Tuple[List[SlidingWindowFeature], List[Annotation]]:
+        chunk_duration = self.config.asr_duration
         buffer_duration = self.audio_buffer.extent.duration
-        asr_batch_size = int(buffer_duration / asr_duration)
-
-        if asr_batch_size == 0:
-            return ["" for _ in waveforms]
-
+        batch_size = int(buffer_duration / chunk_duration)
         buffer_start = self.audio_buffer.extent.start
+        resolution = self.audio_buffer.sliding_window.duration
+
+        # Extract audio chunks with their diarization
         asr_inputs, input_dia, last_end_time = [], [], None
-        for i in range(asr_batch_size):
-            start = buffer_start + i * asr_duration
-            last_end_time = start + asr_duration
+        for i in range(batch_size):
+            start = buffer_start + i * chunk_duration
+            last_end_time = start + chunk_duration
             region = Segment(start, last_end_time)
-            chunk = self.audio_buffer.crop(region, fixed=asr_duration)
+            chunk = self.audio_buffer.crop(region, fixed=chunk_duration)
             window = SlidingWindow(resolution, resolution, start)
             asr_inputs.append(SlidingWindowFeature(chunk, window))
             input_dia.append(self.dia_buffer.crop(region))
 
-        # Create ASR batch, shape (batch, samples, channels)
-        batch = torch.stack([torch.from_numpy(w.data) for w in asr_inputs])
-
-        # Remove transcribed chunks from buffer
-        new_buffer_bounds = Segment(last_end_time, self.audio_buffer.extent.end)
-        new_buffer = self.audio_buffer.crop(new_buffer_bounds, fixed=new_buffer_bounds.duration)
-        window = SlidingWindow(resolution, resolution, last_end_time)
-        self.audio_buffer = SlidingWindowFeature(new_buffer, window)
-        self.dia_buffer = self.dia_buffer.extrude(Segment(0, last_end_time))
+        # Remove extracted chunks from buffers
+        if asr_inputs:
+            new_buffer_bounds = Segment(last_end_time, self.audio_buffer.extent.end)
+            new_buffer = self.audio_buffer.crop(new_buffer_bounds, fixed=new_buffer_bounds.duration)
+            window = SlidingWindow(resolution, resolution, last_end_time)
+            self.audio_buffer = SlidingWindowFeature(new_buffer, window)
+            self.dia_buffer = self.dia_buffer.extrude(Segment(0, last_end_time))
 
-        # Filter out non-speech chunks
-        has_voice = torch.tensor([dia.get_timeline().duration() > 0 for dia in input_dia])
-        has_voice = torch.where(has_voice)[0]
+        return asr_inputs, input_dia
 
-        # Return empty list if no speech in the entire batch
-        if len(has_voice) == 0:
-            return ["" for _ in waveforms]
-
-        # Transcribe batch
-        outputs = self.asr(batch[has_voice])
-
-        # Align transcription with diarization to determine speakers
-        full_transcription = []
+    def _get_speaker_transcriptions(
+        self,
+        input_diarization: List[Annotation],
+        asr_inputs: List[SlidingWindowFeature],
+        asr_outputs: List[m.TranscriptionResult],
+    ) -> Text:
+        transcriptions = []
         for i, waveform in enumerate(asr_inputs):
-            if i not in has_voice:
+            if waveform is None:
                 continue
             buffer_shift = waveform.sliding_window.start
-            for text, timestamp in zip(outputs[i].chunks, outputs[i].timestamps):
+            for text, timestamp in zip(asr_outputs[i].chunks, asr_outputs[i].timestamps):
                 if not text.strip():
                     continue
                 target_region = Segment(
                     buffer_shift + timestamp.start,
                     buffer_shift + timestamp.end,
                 )
-                dia = input_dia[i].crop(target_region)
+                dia = input_diarization[i].crop(target_region)
                 speakers = dia.labels()
                 num_speakers = len(speakers)
                 if num_speakers == 0:
                     # Include transcription but don't assign a speaker
-                    full_transcription.append(text)
+                    transcriptions.append(text)
                 elif num_speakers == 1:
                     # Typical case, annotate text with the only speaker
-                    full_transcription.append(f"[{speakers[0]}]{text}")
+                    transcriptions.append(f"[{speakers[0]}]{text}")
                 else:
                     # Multiple speakers for the same text block, choose the most active one
                     max_spk = np.argmax([dia.label_duration(spk) for spk in speakers])
-                    full_transcription.append(f"[{speakers[max_spk]}]{text}")
+                    transcriptions.append(f"[{speakers[max_spk]}]{text}")
+        return " ".join(transcriptions).strip()
 
+    def __call__(
+        self,
+        waveforms: Sequence[SlidingWindowFeature],
+    ) -> Sequence[Text]:
+        # Compute diarization output
+        diarization_output = self.diarization(waveforms)
+        self._update_buffers(diarization_output)
+
+        # Extract audio to transcribe from the buffer
+        asr_inputs, asr_input_dia = self._extract_asr_inputs()
+        if not asr_inputs:
+            return ["" for _ in waveforms]
+
+        # Detect non-speech chunks
+        has_voice = torch.tensor([dia.get_timeline().duration() > 0 for dia in asr_input_dia])
+        has_voice = torch.where(has_voice)[0]
+        # Return empty strings if no speech in the entire batch
+        if len(has_voice) == 0:
+            return ["" for _ in waveforms]
+
+        # Create ASR batch, shape (batch, samples, channels)
+        batch = torch.stack([torch.from_numpy(w.data) for w in asr_inputs])
+
+        # Transcribe batch
+        asr_outputs = self.asr(batch[has_voice])
+        asr_outputs = [
+            asr_outputs[i] if i in has_voice else None
+            for i in range(batch.shape[0])
+        ]
+
+        # Attach speaker labels to ASR output and concatenate
+        transcription = self._get_speaker_transcriptions(
+            asr_input_dia, asr_inputs, asr_outputs
+        )
+
+        # Fill output sequence with empty strings
         batch_size = len(waveforms)
-        output = [" ".join(full_transcription).strip()]
+        output = [transcription]
         if batch_size > 1:
             output += [""] * (batch_size - 1)
+
         return output
diff --git a/src/diart/utils.py b/src/diart/utils.py
index 018bc02c..725d9bf4 100644
--- a/src/diart/utils.py
+++ b/src/diart/utils.py
@@ -1,6 +1,6 @@
 import base64
 import time
-from typing import Optional, Text, Union, Any, Dict
+from typing import Optional, Text, Union, Any, Dict, Tuple
 
 import matplotlib.pyplot as plt
 import numpy as np
@@ -92,7 +92,9 @@ def get_padding_right(latency: float, step: float) -> float:
     return latency - step
 
 
-def serialize_prediction(value: Union[Annotation, Text]) -> Text:
+def serialize_prediction(value: Union[Tuple, Annotation, Text]) -> Text:
+    if isinstance(value, tuple):
+        value = value[0]
     if isinstance(value, Annotation):
         return value.to_rttm()
     return value