Skip to content

Commit aee7bce

Browse files
committed
Add ONNX compatibility (#204)
* Add ONNX segmentation and embedding models * Minor readability improvements * Replace onnxruntime with onnxruntime-gpu * ONNX can have only one output * Clean up useless embedding model subclasses * Remove duration and sample_rate properties from SegmentationModel. Clean up code * Update README
1 parent c1077a4 commit aee7bce

14 files changed

+180
-116
lines changed

README.md

+27-34
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ For inference and evaluation on a dataset we recommend to use `Benchmark` (see n
123123

124124
## 🤖 Add your model
125125

126-
Third-party models can be integrated by subclassing `SegmentationModel` and `EmbeddingModel` (both PyTorch `nn.Module`):
126+
Third-party models can be integrated by providing a loader function:
127127

128128
```python
129129
from diart import SpeakerDiarization, SpeakerDiarizationConfig
@@ -132,46 +132,39 @@ from diart.sources import MicrophoneAudioSource
132132
from diart.inference import StreamingInference
133133

134134

135-
def model_loader():
135+
def segmentation_loader():
136+
# It should take a waveform and return a segmentation tensor
136137
return load_pretrained_model("my_model.ckpt")
137138

139+
def embedding_loader():
140+
# It should take (waveform, weights) and return per-speaker embeddings
141+
return load_pretrained_model("my_other_model.ckpt")
138142

139-
class MySegmentationModel(SegmentationModel):
140-
def __init__(self):
141-
super().__init__(model_loader)
142-
143-
@property
144-
def sample_rate(self) -> int:
145-
return 16000
146-
147-
@property
148-
def duration(self) -> float:
149-
return 2 # seconds
150-
151-
def forward(self, waveform):
152-
# self.model is created lazily
153-
return self.model(waveform)
154-
155-
156-
class MyEmbeddingModel(EmbeddingModel):
157-
def __init__(self):
158-
super().__init__(model_loader)
159-
160-
def forward(self, waveform, weights):
161-
# self.model is created lazily
162-
return self.model(waveform, weights)
163-
164-
143+
144+
segmentation = SegmentationModel(segmentation_loader)
145+
embedding = EmbeddingModel(embedding_loader)
165146
config = SpeakerDiarizationConfig(
166-
segmentation=MySegmentationModel(),
167-
embedding=MyEmbeddingModel()
147+
segmentation=segmentation,
148+
embedding=embedding,
168149
)
169150
pipeline = SpeakerDiarization(config)
170151
mic = MicrophoneAudioSource()
171152
inference = StreamingInference(pipeline, mic)
172153
prediction = inference()
173154
```
174155

156+
If you have an ONNX model, you can use `from_onnx()`:
157+
158+
```python
159+
from diart.models import EmbeddingModel
160+
161+
embedding = EmbeddingModel.from_onnx(
162+
model_path="my_model.ckpt",
163+
input_names=["x", "w"], # defaults to ["waveform", "weights"]
164+
output_name="output", # defaults to "embedding"
165+
)
166+
```
167+
175168
## 📈 Tune hyper-parameters
176169

177170
Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs.
@@ -352,11 +345,11 @@ from diart.models import SegmentationModel
352345

353346
benchmark = Benchmark("/wav/dir", "/rttm/dir")
354347

355-
name = "pyannote/segmentation@Interspeech2021"
356-
segmentation = SegmentationModel.from_pyannote(name)
348+
model_name = "pyannote/segmentation@Interspeech2021"
349+
model = SegmentationModel.from_pretrained(model_name)
357350
config = SpeakerDiarizationConfig(
358-
# Set the model used in the paper
359-
segmentation=segmentation,
351+
# Set the segmentation model used in the paper
352+
segmentation=model,
360353
step=0.5,
361354
latency=0.5,
362355
tau_active=0.555,

environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- python=3.8
6+
- python=3.10
77
- portaudio=19.6.*
88
- pysoundfile=0.12.*
99
- ffmpeg[version='<4.4']

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ optuna>=2.10
1818
websocket-server>=0.6.4
1919
websocket-client>=0.58.0
2020
rich>=12.5.1
21+
onnxruntime-gpu>=1.16.1

setup.cfg

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ install_requires=
4040
websocket-server>=0.6.4
4141
websocket-client>=0.58.0
4242
rich>=12.5.1
43+
onnxruntime-gpu>=1.16.1
4344

4445
[options.packages.find]
4546
where=src

src/diart/blocks/diarization.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(
2323
self,
2424
segmentation: m.SegmentationModel | None = None,
2525
embedding: m.EmbeddingModel | None = None,
26-
duration: float | None = None,
26+
duration: float = 5,
2727
step: float = 0.5,
2828
latency: float | Literal["max", "min"] | None = None,
2929
tau_active: float = 0.6,
@@ -34,6 +34,7 @@ def __init__(
3434
max_speakers: int = 20,
3535
normalize_embedding_weights: bool = False,
3636
device: torch.device | None = None,
37+
sample_rate: int = 16000,
3738
**kwargs,
3839
):
3940
# Default segmentation model is pyannote/segmentation
@@ -47,7 +48,7 @@ def __init__(
4748
)
4849

4950
self._duration = duration
50-
self._sample_rate: int | None = None
51+
self._sample_rate = sample_rate
5152

5253
# Latency defaults to the step duration
5354
self._step = step
@@ -70,9 +71,6 @@ def __init__(
7071

7172
@property
7273
def duration(self) -> float:
73-
# Default duration is the one given by the segmentation model
74-
if self._duration is None:
75-
self._duration = self.segmentation.duration
7674
return self._duration
7775

7876
@property
@@ -85,9 +83,6 @@ def latency(self) -> float:
8583

8684
@property
8785
def sample_rate(self) -> int:
88-
# Expected sample rate is given by the segmentation model
89-
if self._sample_rate is None:
90-
self._sample_rate = self.segmentation.sample_rate
9186
return self._sample_rate
9287

9388

@@ -177,9 +172,8 @@ def __call__(
177172

178173
# Extract segmentation and embeddings
179174
segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
180-
embeddings = self.embedding(
181-
batch, segmentations
182-
) # shape (batch, speakers, emb_dim)
175+
# embeddings has shape (batch, speakers, emb_dim)
176+
embeddings = self.embedding(batch, segmentations)
183177

184178
seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
185179

src/diart/blocks/embedding.py

+5-17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from einops import rearrange
55

6+
from .. import functional as F
67
from ..features import TemporalFeatures, TemporalFeatureFormatter
78
from ..models import EmbeddingModel
89

@@ -90,10 +91,8 @@ def __init__(self, gamma: float = 3, beta: float = 10, normalize: bool = False):
9091

9192
def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures:
9293
weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers)
93-
with torch.no_grad():
94-
probs = torch.softmax(self.beta * weights, dim=-1)
95-
weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma)
96-
weights[weights < 1e-8] = 1e-8
94+
with torch.inference_mode():
95+
weights = F.overlapped_speech_penalty(weights, self.gamma, self.beta)
9796
if self.normalize:
9897
min_values = weights.min(dim=1, keepdim=True).values
9998
max_values = weights.max(dim=1, keepdim=True).values
@@ -110,19 +109,8 @@ def __init__(self, norm: Union[float, torch.Tensor] = 1):
110109
self.norm = self.norm.unsqueeze(0)
111110

112111
def __call__(self, embeddings: torch.Tensor) -> torch.Tensor:
113-
# Add batch dimension if missing
114-
if embeddings.ndim == 2:
115-
embeddings = embeddings.unsqueeze(0)
116-
if isinstance(self.norm, torch.Tensor):
117-
batch_size1, num_speakers1, _ = self.norm.shape
118-
batch_size2, num_speakers2, _ = embeddings.shape
119-
assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2
120-
with torch.no_grad():
121-
norm_embs = (
122-
self.norm
123-
* embeddings
124-
/ torch.norm(embeddings, p=2, dim=-1, keepdim=True)
125-
)
112+
with torch.inference_mode():
113+
norm_embs = F.normalize_embeddings(embeddings, self.norm)
126114
return norm_embs
127115

128116

src/diart/blocks/vad.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ class VoiceActivityDetectionConfig(base.PipelineConfig):
2727
def __init__(
2828
self,
2929
segmentation: m.SegmentationModel | None = None,
30-
duration: float | None = None,
30+
duration: float = 5,
3131
step: float = 0.5,
3232
latency: float | Literal["max", "min"] | None = None,
3333
tau_active: float = 0.6,
3434
device: torch.device | None = None,
35+
sample_rate: int = 16000,
3536
**kwargs,
3637
):
3738
# Default segmentation model is pyannote/segmentation
@@ -41,7 +42,7 @@ def __init__(
4142

4243
self._duration = duration
4344
self._step = step
44-
self._sample_rate: int | None = None
45+
self._sample_rate = sample_rate
4546

4647
# Latency defaults to the step duration
4748
self._latency = latency
@@ -57,9 +58,6 @@ def __init__(
5758

5859
@property
5960
def duration(self) -> float:
60-
# Default duration is the one given by the segmentation model
61-
if self._duration is None:
62-
self._duration = self.segmentation.duration
6361
return self._duration
6462

6563
@property
@@ -72,9 +70,6 @@ def latency(self) -> float:
7270

7371
@property
7472
def sample_rate(self) -> int:
75-
# Expected sample rate is given by the segmentation model
76-
if self._sample_rate is None:
77-
self._sample_rate = self.segmentation.sample_rate
7873
return self._sample_rate
7974

8075

src/diart/console/benchmark.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def run():
4343
parser.add_argument(
4444
"--duration",
4545
type=float,
46+
default=5,
4647
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
4748
)
4849
parser.add_argument(
@@ -111,8 +112,8 @@ def run():
111112

112113
# Resolve models
113114
hf_token = utils.parse_hf_token_arg(args.hf_token)
114-
args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
115-
args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
115+
args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
116+
args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
116117

117118
pipeline_class = utils.get_pipeline_class(args.pipeline)
118119

src/diart/console/serve.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def run():
3636
parser.add_argument(
3737
"--duration",
3838
type=float,
39+
default=5,
3940
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
4041
)
4142
parser.add_argument(
@@ -92,8 +93,8 @@ def run():
9293

9394
# Resolve models
9495
hf_token = utils.parse_hf_token_arg(args.hf_token)
95-
args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
96-
args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
96+
args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
97+
args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
9798

9899
# Resolve pipeline
99100
pipeline_class = utils.get_pipeline_class(args.pipeline)

src/diart/console/stream.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def run():
3939
parser.add_argument(
4040
"--duration",
4141
type=float,
42+
default=5,
4243
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
4344
)
4445
parser.add_argument(
@@ -103,8 +104,8 @@ def run():
103104

104105
# Resolve models
105106
hf_token = utils.parse_hf_token_arg(args.hf_token)
106-
args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
107-
args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
107+
args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
108+
args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
108109

109110
# Resolve pipeline
110111
pipeline_class = utils.get_pipeline_class(args.pipeline)

src/diart/console/tune.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def run():
4646
parser.add_argument(
4747
"--duration",
4848
type=float,
49+
default=5,
4950
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
5051
)
5152
parser.add_argument(
@@ -120,8 +121,8 @@ def run():
120121

121122
# Resolve models
122123
hf_token = utils.parse_hf_token_arg(args.hf_token)
123-
args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
124-
args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
124+
args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
125+
args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
125126

126127
# Retrieve pipeline class
127128
pipeline_class = utils.get_pipeline_class(args.pipeline)

src/diart/functional.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
6+
def overlapped_speech_penalty(
7+
segmentation: torch.Tensor, gamma: float = 3, beta: float = 10
8+
):
9+
# segmentation has shape (batch, frames, speakers)
10+
probs = torch.softmax(beta * segmentation, dim=-1)
11+
weights = torch.pow(segmentation, gamma) * torch.pow(probs, gamma)
12+
weights[weights < 1e-8] = 1e-8
13+
return weights
14+
15+
16+
def normalize_embeddings(
17+
embeddings: torch.Tensor, norm: float | torch.Tensor = 1
18+
) -> torch.Tensor:
19+
# embeddings has shape (batch, speakers, feat) or (speakers, feat)
20+
if embeddings.ndim == 2:
21+
embeddings = embeddings.unsqueeze(0)
22+
if isinstance(norm, torch.Tensor):
23+
batch_size1, num_speakers1, _ = norm.shape
24+
batch_size2, num_speakers2, _ = embeddings.shape
25+
assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2
26+
emb_norm = torch.norm(embeddings, p=2, dim=-1, keepdim=True)
27+
return norm * embeddings / emb_norm

0 commit comments

Comments
 (0)