Skip to content

Commit 0113ab2

Browse files
authored
Remove PipelineConfig.from_dict() (#189)
* Unify hparam naming. Clean some typing annotations * Remove config.from_dict(). Add --duration argument to CLI * Update README.md accordingly
1 parent 45f8ad9 commit 0113ab2

File tree

10 files changed

+124
-140
lines changed

10 files changed

+124
-140
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ To obtain the best results, make sure to use the following hyper-parameters:
340340
`diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration:
341341

342342
```shell
343-
diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --segmentation pyannote/segmentation@Interspeech2021
343+
diart.benchmark /wav/dir --reference /rttm/dir --tau-active=0.555 --rho-update=0.422 --delta-new=1.517 --segmentation pyannote/segmentation@Interspeech2021
344344
```
345345

346346
or using the inference API:

src/diart/blocks/base.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from typing import Any, Tuple, Sequence, Text
2-
from dataclasses import dataclass
31
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
from typing import Any, Tuple, Sequence, Text
44

5-
import numpy as np
65
from pyannote.core import SlidingWindowFeature
76
from pyannote.metrics.base import BaseMetric
87

@@ -53,11 +52,6 @@ def latency(self) -> float:
5352
def sample_rate(self) -> int:
5453
pass
5554

56-
@staticmethod
57-
@abstractmethod
58-
def from_dict(data: Any) -> "PipelineConfig":
59-
pass
60-
6155
def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
6256
file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
6357
right = utils.get_padding_right(self.latency, self.step)

src/diart/blocks/diarization.py

+21-62
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Optional, Tuple, Sequence, Union, Any
1+
from __future__ import annotations
2+
3+
from typing import Sequence
24

35
import numpy as np
46
import torch
@@ -14,40 +16,37 @@
1416
from .segmentation import SpeakerSegmentation
1517
from .utils import Binarize
1618
from .. import models as m
17-
from .. import utils
1819

1920

2021
class SpeakerDiarizationConfig(base.PipelineConfig):
2122
def __init__(
2223
self,
23-
segmentation: Optional[m.SegmentationModel] = None,
24-
embedding: Optional[m.EmbeddingModel] = None,
25-
duration: Optional[float] = None,
24+
segmentation: m.SegmentationModel | None = None,
25+
embedding: m.EmbeddingModel | None = None,
26+
duration: float | None = None,
2627
step: float = 0.5,
27-
latency: Optional[Union[float, Literal["max", "min"]]] = None,
28+
latency: float | Literal["max", "min"] | None = None,
2829
tau_active: float = 0.6,
2930
rho_update: float = 0.3,
3031
delta_new: float = 1,
3132
gamma: float = 3,
3233
beta: float = 10,
3334
max_speakers: int = 20,
34-
device: Optional[torch.device] = None,
35+
device: torch.device | None = None,
3536
**kwargs,
3637
):
3738
# Default segmentation model is pyannote/segmentation
38-
self.segmentation = segmentation
39-
if self.segmentation is None:
40-
self.segmentation = m.SegmentationModel.from_pyannote(
41-
"pyannote/segmentation"
42-
)
43-
44-
self._duration = duration
45-
self._sample_rate: Optional[int] = None
39+
self.segmentation = segmentation or m.SegmentationModel.from_pyannote(
40+
"pyannote/segmentation"
41+
)
4642

4743
# Default embedding model is pyannote/embedding
48-
self.embedding = embedding
49-
if self.embedding is None:
50-
self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
44+
self.embedding = embedding or m.EmbeddingModel.from_pyannote(
45+
"pyannote/embedding"
46+
)
47+
48+
self._duration = duration
49+
self._sample_rate: int | None = None
5150

5251
# Latency defaults to the step duration
5352
self._step = step
@@ -64,48 +63,8 @@ def __init__(
6463
self.beta = beta
6564
self.max_speakers = max_speakers
6665

67-
self.device = device
68-
if self.device is None:
69-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70-
71-
@staticmethod
72-
def from_dict(data: Any) -> "SpeakerDiarizationConfig":
73-
# Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
74-
device = utils.get(data, "device", None)
75-
if device is None:
76-
device = torch.device("cpu") if utils.get(data, "cpu", False) else None
77-
78-
# Instantiate models
79-
hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
80-
segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
81-
segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
82-
embedding = utils.get(data, "embedding", "pyannote/embedding")
83-
embedding = m.EmbeddingModel.from_pyannote(embedding, hf_token)
84-
85-
# Hyper-parameters and their aliases
86-
tau = utils.get(data, "tau_active", None)
87-
if tau is None:
88-
tau = utils.get(data, "tau", 0.6)
89-
rho = utils.get(data, "rho_update", None)
90-
if rho is None:
91-
rho = utils.get(data, "rho", 0.3)
92-
delta = utils.get(data, "delta_new", None)
93-
if delta is None:
94-
delta = utils.get(data, "delta", 1)
95-
96-
return SpeakerDiarizationConfig(
97-
segmentation=segmentation,
98-
embedding=embedding,
99-
duration=utils.get(data, "duration", None),
100-
step=utils.get(data, "step", 0.5),
101-
latency=utils.get(data, "latency", None),
102-
tau_active=tau,
103-
rho_update=rho,
104-
delta_new=delta,
105-
gamma=utils.get(data, "gamma", 3),
106-
beta=utils.get(data, "beta", 10),
107-
max_speakers=utils.get(data, "max_speakers", 20),
108-
device=device,
66+
self.device = device or torch.device(
67+
"cuda" if torch.cuda.is_available() else "cpu"
10968
)
11069

11170
@property
@@ -132,7 +91,7 @@ def sample_rate(self) -> int:
13291

13392

13493
class SpeakerDiarization(base.Pipeline):
135-
def __init__(self, config: Optional[SpeakerDiarizationConfig] = None):
94+
def __init__(self, config: SpeakerDiarizationConfig | None = None):
13695
self._config = SpeakerDiarizationConfig() if config is None else config
13796

13897
msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
@@ -200,7 +159,7 @@ def reset(self):
200159

201160
def __call__(
202161
self, waveforms: Sequence[SlidingWindowFeature]
203-
) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
162+
) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
204163
batch_size = len(waveforms)
205164
msg = "Pipeline expected at least 1 input"
206165
assert batch_size >= 1, msg

src/diart/blocks/vad.py

+17-43
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Optional, Union, Sequence, Tuple
1+
from __future__ import annotations
2+
3+
from typing import Sequence
24

35
import numpy as np
46
import torch
@@ -13,8 +15,8 @@
1315
from pyannote.metrics.detection import DetectionErrorRate
1416
from typing_extensions import Literal
1517

16-
from .aggregation import DelayedAggregation
1718
from . import base
19+
from .aggregation import DelayedAggregation
1820
from .segmentation import SpeakerSegmentation
1921
from .utils import Binarize
2022
from .. import models as m
@@ -24,24 +26,22 @@
2426
class VoiceActivityDetectionConfig(base.PipelineConfig):
2527
def __init__(
2628
self,
27-
segmentation: Optional[m.SegmentationModel] = None,
28-
duration: Optional[float] = None,
29+
segmentation: m.SegmentationModel | None = None,
30+
duration: float | None = None,
2931
step: float = 0.5,
30-
latency: Optional[Union[float, Literal["max", "min"]]] = None,
32+
latency: float | Literal["max", "min"] | None = None,
3133
tau_active: float = 0.6,
32-
device: Optional[torch.device] = None,
34+
device: torch.device | None = None,
3335
**kwargs,
3436
):
3537
# Default segmentation model is pyannote/segmentation
36-
self.segmentation = segmentation
37-
if self.segmentation is None:
38-
self.segmentation = m.SegmentationModel.from_pyannote(
39-
"pyannote/segmentation"
40-
)
38+
self.segmentation = segmentation or m.SegmentationModel.from_pyannote(
39+
"pyannote/segmentation"
40+
)
4141

4242
self._duration = duration
4343
self._step = step
44-
self._sample_rate: Optional[int] = None
44+
self._sample_rate: int | None = None
4545

4646
# Latency defaults to the step duration
4747
self._latency = latency
@@ -51,9 +51,9 @@ def __init__(
5151
self._latency = self._duration
5252

5353
self.tau_active = tau_active
54-
self.device = device
55-
if self.device is None:
56-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54+
self.device = device or torch.device(
55+
"cuda" if torch.cuda.is_available() else "cpu"
56+
)
5757

5858
@property
5959
def duration(self) -> float:
@@ -77,35 +77,9 @@ def sample_rate(self) -> int:
7777
self._sample_rate = self.segmentation.sample_rate
7878
return self._sample_rate
7979

80-
@staticmethod
81-
def from_dict(data: Any) -> "VoiceActivityDetectionConfig":
82-
# Check for explicit device, otherwise check for 'cpu' bool, otherwise pass None
83-
device = utils.get(data, "device", None)
84-
if device is None:
85-
device = torch.device("cpu") if utils.get(data, "cpu", False) else None
86-
87-
# Instantiate segmentation model
88-
hf_token = utils.parse_hf_token_arg(utils.get(data, "hf_token", True))
89-
segmentation = utils.get(data, "segmentation", "pyannote/segmentation")
90-
segmentation = m.SegmentationModel.from_pyannote(segmentation, hf_token)
91-
92-
# Tau active and its alias
93-
tau = utils.get(data, "tau_active", None)
94-
if tau is None:
95-
tau = utils.get(data, "tau", 0.6)
96-
97-
return VoiceActivityDetectionConfig(
98-
segmentation=segmentation,
99-
duration=utils.get(data, "duration", None),
100-
step=utils.get(data, "step", 0.5),
101-
latency=utils.get(data, "latency", None),
102-
tau_active=tau,
103-
device=device,
104-
)
105-
10680

10781
class VoiceActivityDetection(base.Pipeline):
108-
def __init__(self, config: Optional[VoiceActivityDetectionConfig] = None):
82+
def __init__(self, config: VoiceActivityDetectionConfig | None = None):
10983
self._config = VoiceActivityDetectionConfig() if config is None else config
11084

11185
msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
@@ -158,7 +132,7 @@ def set_timestamp_shift(self, shift: float):
158132
def __call__(
159133
self,
160134
waveforms: Sequence[SlidingWindowFeature],
161-
) -> Sequence[Tuple[Annotation, SlidingWindowFeature]]:
135+
) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
162136
batch_size = len(waveforms)
163137
msg = "Pipeline expected at least 1 input"
164138
assert batch_size >= 1, msg

src/diart/console/benchmark.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
from pathlib import Path
33

44
import pandas as pd
5+
import torch
6+
57
from diart import argdoc
8+
from diart import models as m
69
from diart import utils
710
from diart.inference import Benchmark, Parallelize
811

@@ -37,20 +40,25 @@ def run():
3740
type=Path,
3841
help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files",
3942
)
43+
parser.add_argument(
44+
"--duration",
45+
type=float,
46+
help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
47+
)
4048
parser.add_argument(
4149
"--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
4250
)
4351
parser.add_argument(
4452
"--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
4553
)
4654
parser.add_argument(
47-
"--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
55+
"--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
4856
)
4957
parser.add_argument(
50-
"--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
58+
"--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
5159
)
5260
parser.add_argument(
53-
"--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
61+
"--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
5462
)
5563
parser.add_argument(
5664
"--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
@@ -93,6 +101,14 @@ def run():
93101
)
94102
args = parser.parse_args()
95103

104+
# Resolve device
105+
args.device = torch.device("cpu") if args.cpu else None
106+
107+
# Resolve models
108+
hf_token = utils.parse_hf_token_arg(args.hf_token)
109+
args.segmentation = m.SegmentationModel.from_pyannote(args.segmentation, hf_token)
110+
args.embedding = m.EmbeddingModel.from_pyannote(args.embedding, hf_token)
111+
96112
pipeline_class = utils.get_pipeline_class(args.pipeline)
97113

98114
benchmark = Benchmark(
@@ -104,7 +120,7 @@ def run():
104120
batch_size=args.batch_size,
105121
)
106122

107-
config = pipeline_class.get_config_class().from_dict(vars(args))
123+
config = pipeline_class.get_config_class()(**vars(args))
108124
if args.num_workers > 0:
109125
benchmark = Parallelize(benchmark, args.num_workers)
110126

src/diart/console/client.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from threading import Thread
44
from typing import Text, Optional
55

6-
import numpy as np
76
import rx.operators as ops
7+
from websocket import WebSocket
8+
89
from diart import argdoc
910
from diart import sources as src
1011
from diart import utils
11-
from websocket import WebSocket
1212

1313

1414
def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int):

0 commit comments

Comments
 (0)