Skip to content

Commit c1077a4

Browse files
hbredinjuanmc2005
andcommitted
Add support for powerset segmentation models (#198)
* feat: add support for powerset segmentation models * wip: trying this PowersetAdapter thing * fix: initialize nn.Module before setting attribute * Fix unresolved duration and sample rate * Apply suggestions from code review * fix: remove Inference import * fix: black embedding.py ... though it has nothing to do with this PR... --------- Co-authored-by: Juan Coria <juanmc2005@hotmail.com>
1 parent 763de25 commit c1077a4

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

src/diart/blocks/embedding.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ def from_pyannote(
174174
device: Optional[torch.device] = None,
175175
):
176176
model = EmbeddingModel.from_pyannote(model, use_hf_token)
177-
return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, normalize_weights, device)
177+
return OverlapAwareSpeakerEmbedding(
178+
model, gamma, beta, norm, normalize_weights, device
179+
)
178180

179181
def __call__(
180182
self, waveform: TemporalFeatures, segmentation: TemporalFeatures

src/diart/models.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,38 @@
77
from requests import HTTPError
88

99
try:
10-
from pyannote.audio import Inference, Model
10+
from pyannote.audio import Model
1111
from pyannote.audio.pipelines.speaker_verification import (
1212
PretrainedSpeakerEmbedding,
1313
)
14+
from pyannote.audio.utils.powerset import Powerset
1415

1516
_has_pyannote = True
1617
except ImportError:
1718
_has_pyannote = False
1819

1920

21+
class PowersetAdapter(nn.Module):
22+
def __init__(self, segmentation_model: nn.Module):
23+
super().__init__()
24+
self.model = segmentation_model
25+
specs = self.model.specifications
26+
max_speakers_per_frame = specs.powerset_max_classes
27+
max_speakers_per_chunk = len(specs.classes)
28+
self.powerset = Powerset(max_speakers_per_chunk, max_speakers_per_frame)
29+
30+
@property
31+
def specifications(self):
32+
return self.model.specifications
33+
34+
@property
35+
def audio(self):
36+
return self.model.audio
37+
38+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
39+
return self.powerset.to_multilabel(self.model(waveform), soft=False)
40+
41+
2042
class PyannoteLoader:
2143
def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
2244
super().__init__()
@@ -25,7 +47,11 @@ def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
2547

2648
def __call__(self) -> Callable:
2749
try:
28-
return Model.from_pretrained(self.model_info, use_auth_token=self.hf_token)
50+
model = Model.from_pretrained(self.model_info, use_auth_token=self.hf_token)
51+
specs = getattr(model, "specifications", None)
52+
if specs is not None and specs.powerset:
53+
model = PowersetAdapter(model)
54+
return model
2955
except HTTPError:
3056
return PretrainedSpeakerEmbedding(
3157
self.model_info, use_auth_token=self.hf_token

0 commit comments

Comments
 (0)