1
1
from abc import ABC , abstractmethod
2
2
from typing import Optional , Text , Union , Callable
3
3
4
+ import numpy as np
4
5
import torch
5
6
import torch .nn as nn
7
+ from requests import HTTPError
6
8
7
9
try :
8
- import pyannote .audio .pipelines .utils as pyannote_loader
10
+ from pyannote .audio import Inference , Model
11
+ from pyannote .audio .pipelines .speaker_verification import (
12
+ PretrainedSpeakerEmbedding ,
13
+ )
9
14
10
15
_has_pyannote = True
11
16
except ImportError :
@@ -18,15 +23,20 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
18
23
self .model_info = model_info
19
24
self .hf_token = hf_token
20
25
21
- def __call__ (self ) -> nn .Module :
22
- return pyannote_loader .get_model (self .model_info , self .hf_token )
26
+ def __call__ (self ) -> Callable :
27
+ try :
28
+ return Model .from_pretrained (self .model_info , use_auth_token = self .hf_token )
29
+ except HTTPError :
30
+ return PretrainedSpeakerEmbedding (
31
+ self .model_info , use_auth_token = self .hf_token
32
+ )
23
33
24
34
25
- class LazyModel (nn . Module , ABC ):
26
- def __init__ (self , loader : Callable [[], nn . Module ]):
35
+ class LazyModel (ABC ):
36
+ def __init__ (self , loader : Callable [[], Callable ]):
27
37
super ().__init__ ()
28
38
self .get_model = loader
29
- self .model : Optional [nn . Module ] = None
39
+ self .model : Optional [Callable ] = None
30
40
31
41
def is_in_memory (self ) -> bool :
32
42
"""Return whether the model has been loaded into memory"""
@@ -36,13 +46,20 @@ def load(self):
36
46
if not self .is_in_memory ():
37
47
self .model = self .get_model ()
38
48
39
- def to (self , * args , ** kwargs ) -> nn . Module :
49
+ def to (self , device : torch . device ) -> "LazyModel" :
40
50
self .load ()
41
- return super ().to (* args , ** kwargs )
51
+ self .model = self .model .to (device )
52
+ return self
42
53
43
54
def __call__ (self , * args , ** kwargs ):
44
55
self .load ()
45
- return super ().__call__ (* args , ** kwargs )
56
+ return self .model (* args , ** kwargs )
57
+
58
+ def eval (self ) -> "LazyModel" :
59
+ self .load ()
60
+ if isinstance (self .model , nn .Module ):
61
+ self .model .eval ()
62
+ return self
46
63
47
64
48
65
class SegmentationModel (LazyModel ):
@@ -83,20 +100,17 @@ def sample_rate(self) -> int:
83
100
def duration (self ) -> float :
84
101
pass
85
102
86
- @abstractmethod
87
- def forward (self , waveform : torch .Tensor ) -> torch .Tensor :
103
+ def __call__ (self , waveform : torch .Tensor ) -> torch .Tensor :
88
104
"""
89
- Forward pass of the segmentation model.
90
-
105
+ Call the forward pass of the segmentation model.
91
106
Parameters
92
107
----------
93
108
waveform: torch.Tensor, shape (batch, channels, samples)
94
-
95
109
Returns
96
110
-------
97
111
speaker_segmentation: torch.Tensor, shape (batch, frames, speakers)
98
112
"""
99
- pass
113
+ return super (). __call__ ( waveform )
100
114
101
115
102
116
class PyannoteSegmentationModel (SegmentationModel ):
@@ -113,9 +127,6 @@ def duration(self) -> float:
113
127
self .load ()
114
128
return self .model .specifications .duration
115
129
116
- def forward (self , waveform : torch .Tensor ) -> torch .Tensor :
117
- return self .model (waveform )
118
-
119
130
120
131
class EmbeddingModel (LazyModel ):
121
132
"""Minimal interface for an embedding model."""
@@ -143,33 +154,26 @@ def from_pyannote(
143
154
assert _has_pyannote , "No pyannote.audio installation found"
144
155
return PyannoteEmbeddingModel (model , use_hf_token )
145
156
146
- @abstractmethod
147
- def forward (
157
+ def __call__ (
148
158
self , waveform : torch .Tensor , weights : Optional [torch .Tensor ] = None
149
159
) -> torch .Tensor :
150
160
"""
151
- Forward pass of an embedding model with optional weights.
152
-
161
+ Call the forward pass of an embedding model with optional weights.
153
162
Parameters
154
163
----------
155
164
waveform: torch.Tensor, shape (batch, channels, samples)
156
165
weights: Optional[torch.Tensor], shape (batch, frames)
157
166
Temporal weights for each sample in the batch. Defaults to no weights.
158
-
159
167
Returns
160
168
-------
161
169
speaker_embeddings: torch.Tensor, shape (batch, embedding_dim)
162
170
"""
163
- pass
171
+ embeddings = super ().__call__ (waveform , weights )
172
+ if isinstance (embeddings , np .ndarray ):
173
+ embeddings = torch .from_numpy (embeddings )
174
+ return embeddings
164
175
165
176
166
177
class PyannoteEmbeddingModel (EmbeddingModel ):
167
178
def __init__ (self , model_info , hf_token : Union [Text , bool , None ] = True ):
168
179
super ().__init__ (PyannoteLoader (model_info , hf_token ))
169
-
170
- def forward (
171
- self ,
172
- waveform : torch .Tensor ,
173
- weights : Optional [torch .Tensor ] = None ,
174
- ) -> torch .Tensor :
175
- return self .model (waveform , weights = weights )
0 commit comments