Skip to content

Commit 763de25

Browse files
Add compatibility with pyannote 3.0 embedding wrappers (#188)
* bump pyannote to 3.0 * add wespeaker inference * add weights normalization, cpu for numpy conversion * unify api * remove try catch * always normalize * use PretrainedSpeakerEmbedding in Loader * Fix min-max normalization equation * fix: remove imports * Change embedding model return type to Callable Co-authored-by: Simon <80467011+sorgfresser@users.noreply.github.com> * fix: remove type checking * remove from active if NaN embeddings * Fix wrong typing of model in `LazyModel` * add docstrings * Simplify EmbeddingModel.__call__() * Add numpy import * add normalize boolean * Update requirements.txt * Update setup.cfg * Apply suggestions from code review * Fix wrong kwarg name * add abstract __call__ * move __call__ to parent class --------- Co-authored-by: Juan Coria <juanmc2005@hotmail.com>
1 parent 2118110 commit 763de25

11 files changed

+81
-35
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ torch>=1.12.1
1010
torchvision>=0.14.0
1111
torchaudio>=2.0.2
1212
pyannote.audio>=2.1.1
13+
requests>=2.31.0
1314
pyannote.core>=4.5
1415
pyannote.database>=4.1.1
1516
pyannote.metrics>=3.2

setup.cfg

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ install_requires=
3232
torchvision>=0.14.0
3333
torchaudio>=2.0.2
3434
pyannote.audio>=2.1.1
35+
requests>=2.31.0
3536
pyannote.core>=4.5
3637
pyannote.database>=4.1.1
3738
pyannote.metrics>=3.2

src/diart/argdoc.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
OUTPUT = "Directory to store the system's output in RTTM format"
1616
HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | <token>). If 'true', it will use the token from huggingface-cli login"
1717
SAMPLE_RATE = "Sample rate of the audio stream"
18+
NORMALIZE_EMBEDDING_WEIGHTS = "Rescale embedding weights (min-max normalization) to be in the range [0, 1]. This is useful in some models without weighted statistics pooling that rely on masking, like WeSpeaker or ECAPA-TDNN"

src/diart/blocks/clustering.py

+4
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ def identify(
140140
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[
141141
0
142142
]
143+
# Remove speakers that have NaN embeddings
144+
no_nan_embeddings = np.where(~np.isnan(embeddings).any(axis=1))[0]
145+
active_speakers = np.intersect1d(active_speakers, no_nan_embeddings)
146+
143147
num_local_speakers = segmentation.data.shape[1]
144148

145149
if self.centers is None:

src/diart/blocks/diarization.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
gamma: float = 3,
3333
beta: float = 10,
3434
max_speakers: int = 20,
35+
normalize_embedding_weights: bool = False,
3536
device: torch.device | None = None,
3637
**kwargs,
3738
):
@@ -62,7 +63,7 @@ def __init__(
6263
self.gamma = gamma
6364
self.beta = beta
6465
self.max_speakers = max_speakers
65-
66+
self.normalize_embedding_weights = normalize_embedding_weights
6667
self.device = device or torch.device(
6768
"cuda" if torch.cuda.is_available() else "cpu"
6869
)
@@ -105,6 +106,7 @@ def __init__(self, config: SpeakerDiarizationConfig | None = None):
105106
self._config.gamma,
106107
self._config.beta,
107108
norm=1,
109+
normalize_weights=self._config.normalize_embedding_weights,
108110
device=self._config.device,
109111
)
110112
self.pred_aggregation = DelayedAggregation(

src/diart/blocks/embedding.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -77,19 +77,28 @@ class OverlappedSpeechPenalty:
7777
beta: float, optional
7878
Temperature parameter (actually 1/beta) to lower joint speaker activations.
7979
Defaults to 10.
80+
normalize: bool, optional
81+
Whether to min-max normalize weights to be in the range [0, 1].
82+
Defaults to False.
8083
"""
8184

82-
def __init__(self, gamma: float = 3, beta: float = 10):
85+
def __init__(self, gamma: float = 3, beta: float = 10, normalize: bool = False):
8386
self.gamma = gamma
8487
self.beta = beta
8588
self.formatter = TemporalFeatureFormatter()
89+
self.normalize = normalize
8690

8791
def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures:
8892
weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers)
8993
with torch.no_grad():
9094
probs = torch.softmax(self.beta * weights, dim=-1)
9195
weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma)
9296
weights[weights < 1e-8] = 1e-8
97+
if self.normalize:
98+
min_values = weights.min(dim=1, keepdim=True).values
99+
max_values = weights.max(dim=1, keepdim=True).values
100+
weights = (weights - min_values) / (max_values - min_values)
101+
weights.nan_to_num_(1e-8)
93102
return self.formatter.restore_type(weights)
94103

95104

@@ -134,6 +143,8 @@ class OverlapAwareSpeakerEmbedding:
134143
norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional
135144
The target norm for the embeddings. It can be different for each speaker.
136145
Defaults to 1.
146+
normalize_weights: bool, optional
147+
Whether to min-max normalize embedding weights to be in the range [0, 1].
137148
device: Optional[torch.device]
138149
The device on which to run the embedding model.
139150
Defaults to GPU if available or CPU if not.
@@ -145,10 +156,11 @@ def __init__(
145156
gamma: float = 3,
146157
beta: float = 10,
147158
norm: Union[float, torch.Tensor] = 1,
159+
normalize_weights: bool = False,
148160
device: Optional[torch.device] = None,
149161
):
150162
self.embedding = SpeakerEmbedding(model, device)
151-
self.osp = OverlappedSpeechPenalty(gamma, beta)
163+
self.osp = OverlappedSpeechPenalty(gamma, beta, normalize_weights)
152164
self.normalize = EmbeddingNormalization(norm)
153165

154166
@staticmethod
@@ -158,10 +170,11 @@ def from_pyannote(
158170
beta: float = 10,
159171
norm: Union[float, torch.Tensor] = 1,
160172
use_hf_token: Union[Text, bool, None] = True,
173+
normalize_weights: bool = False,
161174
device: Optional[torch.device] = None,
162175
):
163176
model = EmbeddingModel.from_pyannote(model, use_hf_token)
164-
return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device)
177+
return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, normalize_weights, device)
165178

166179
def __call__(
167180
self, waveform: TemporalFeatures, segmentation: TemporalFeatures

src/diart/console/benchmark.py

+5
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def run():
9999
type=str,
100100
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
101101
)
102+
parser.add_argument(
103+
"--normalize-embedding-weights",
104+
action="store_true",
105+
help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
106+
)
102107
args = parser.parse_args()
103108

104109
# Resolve device

src/diart/console/serve.py

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def run():
8080
type=str,
8181
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
8282
)
83+
parser.add_argument(
84+
"--normalize-embedding-weights",
85+
action="store_true",
86+
help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
87+
)
8388
args = parser.parse_args()
8489

8590
# Resolve device

src/diart/console/stream.py

+5
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ def run():
9191
type=str,
9292
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
9393
)
94+
parser.add_argument(
95+
"--normalize-embedding-weights",
96+
action="store_true",
97+
help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
98+
)
9499
args = parser.parse_args()
95100

96101
# Resolve device

src/diart/console/tune.py

+5
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def run():
108108
type=str,
109109
help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
110110
)
111+
parser.add_argument(
112+
"--normalize-embedding-weights",
113+
action="store_true",
114+
help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
115+
)
111116
args = parser.parse_args()
112117

113118
# Resolve device

src/diart/models.py

+35-31
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from abc import ABC, abstractmethod
22
from typing import Optional, Text, Union, Callable
33

4+
import numpy as np
45
import torch
56
import torch.nn as nn
7+
from requests import HTTPError
68

79
try:
8-
import pyannote.audio.pipelines.utils as pyannote_loader
10+
from pyannote.audio import Inference, Model
11+
from pyannote.audio.pipelines.speaker_verification import (
12+
PretrainedSpeakerEmbedding,
13+
)
914

1015
_has_pyannote = True
1116
except ImportError:
@@ -18,15 +23,20 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
1823
self.model_info = model_info
1924
self.hf_token = hf_token
2025

21-
def __call__(self) -> nn.Module:
22-
return pyannote_loader.get_model(self.model_info, self.hf_token)
26+
def __call__(self) -> Callable:
27+
try:
28+
return Model.from_pretrained(self.model_info, use_auth_token=self.hf_token)
29+
except HTTPError:
30+
return PretrainedSpeakerEmbedding(
31+
self.model_info, use_auth_token=self.hf_token
32+
)
2333

2434

25-
class LazyModel(nn.Module, ABC):
26-
def __init__(self, loader: Callable[[], nn.Module]):
35+
class LazyModel(ABC):
36+
def __init__(self, loader: Callable[[], Callable]):
2737
super().__init__()
2838
self.get_model = loader
29-
self.model: Optional[nn.Module] = None
39+
self.model: Optional[Callable] = None
3040

3141
def is_in_memory(self) -> bool:
3242
"""Return whether the model has been loaded into memory"""
@@ -36,13 +46,20 @@ def load(self):
3646
if not self.is_in_memory():
3747
self.model = self.get_model()
3848

39-
def to(self, *args, **kwargs) -> nn.Module:
49+
def to(self, device: torch.device) -> "LazyModel":
4050
self.load()
41-
return super().to(*args, **kwargs)
51+
self.model = self.model.to(device)
52+
return self
4253

4354
def __call__(self, *args, **kwargs):
4455
self.load()
45-
return super().__call__(*args, **kwargs)
56+
return self.model(*args, **kwargs)
57+
58+
def eval(self) -> "LazyModel":
59+
self.load()
60+
if isinstance(self.model, nn.Module):
61+
self.model.eval()
62+
return self
4663

4764

4865
class SegmentationModel(LazyModel):
@@ -83,20 +100,17 @@ def sample_rate(self) -> int:
83100
def duration(self) -> float:
84101
pass
85102

86-
@abstractmethod
87-
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
103+
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
88104
"""
89-
Forward pass of the segmentation model.
90-
105+
Call the forward pass of the segmentation model.
91106
Parameters
92107
----------
93108
waveform: torch.Tensor, shape (batch, channels, samples)
94-
95109
Returns
96110
-------
97111
speaker_segmentation: torch.Tensor, shape (batch, frames, speakers)
98112
"""
99-
pass
113+
return super().__call__(waveform)
100114

101115

102116
class PyannoteSegmentationModel(SegmentationModel):
@@ -113,9 +127,6 @@ def duration(self) -> float:
113127
self.load()
114128
return self.model.specifications.duration
115129

116-
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
117-
return self.model(waveform)
118-
119130

120131
class EmbeddingModel(LazyModel):
121132
"""Minimal interface for an embedding model."""
@@ -143,33 +154,26 @@ def from_pyannote(
143154
assert _has_pyannote, "No pyannote.audio installation found"
144155
return PyannoteEmbeddingModel(model, use_hf_token)
145156

146-
@abstractmethod
147-
def forward(
157+
def __call__(
148158
self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None
149159
) -> torch.Tensor:
150160
"""
151-
Forward pass of an embedding model with optional weights.
152-
161+
Call the forward pass of an embedding model with optional weights.
153162
Parameters
154163
----------
155164
waveform: torch.Tensor, shape (batch, channels, samples)
156165
weights: Optional[torch.Tensor], shape (batch, frames)
157166
Temporal weights for each sample in the batch. Defaults to no weights.
158-
159167
Returns
160168
-------
161169
speaker_embeddings: torch.Tensor, shape (batch, embedding_dim)
162170
"""
163-
pass
171+
embeddings = super().__call__(waveform, weights)
172+
if isinstance(embeddings, np.ndarray):
173+
embeddings = torch.from_numpy(embeddings)
174+
return embeddings
164175

165176

166177
class PyannoteEmbeddingModel(EmbeddingModel):
167178
def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
168179
super().__init__(PyannoteLoader(model_info, hf_token))
169-
170-
def forward(
171-
self,
172-
waveform: torch.Tensor,
173-
weights: Optional[torch.Tensor] = None,
174-
) -> torch.Tensor:
175-
return self.model(waveform, weights=weights)

0 commit comments

Comments
 (0)