Skip to content

Commit 467997d

Browse files
committed
Add Test Suite (#237)
* Add testing configuration and diarization tests * Add aggregation tests * Add end-to-end test for a sample wav file and several latencies * Fix rounding error in min latency unit test * Improve CI workflows and add pytest. Fix matplotlib colormap error * Install missing dependencies in CI * Add onnxruntime as a test dependency * Update expected timestamp tolerance to up to 50ms
1 parent 9e6c2e9 commit 467997d

18 files changed

+513
-5
lines changed

.github/workflows/pytest.yml

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Pytest
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
- develop
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- name: Checkout code
15+
uses: actions/checkout@v3
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v3
19+
with:
20+
python-version: '3.10'
21+
22+
- name: Install apt dependencies
23+
run: |
24+
sudo add-apt-repository ppa:savoury1/ffmpeg4
25+
sudo apt-get update
26+
sudo apt-get -y install ffmpeg libportaudio2=19.6.0-1.1
27+
28+
- name: Install pip dependencies
29+
run: |
30+
python -m pip install --upgrade pip
31+
pip install .[tests]
32+
33+
- name: Run tests
34+
run: |
35+
pytest

.github/workflows/quick-runs.yml

+4-3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ jobs:
3838
run: |
3939
python -m pip install --upgrade pip
4040
pip install .
41+
pip install onnxruntime==1.18.0
4142
- name: Crop audio and rttm
4243
run: |
4344
sox audio/ES2002a_long.wav audio/ES2002a.wav trim 00:40 00:30
@@ -50,10 +51,10 @@ jobs:
5051
rm rttms/ES2002b_long.rttm
5152
- name: Run stream
5253
run: |
53-
diart.stream audio/ES2002a.wav --output trash --no-plot --hf-token ${{ secrets.HUGGINGFACE }}
54+
diart.stream audio/ES2002a.wav --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx --output trash --no-plot
5455
- name: Run benchmark
5556
run: |
56-
diart.benchmark audio --reference rttms --batch-size 4 --hf-token ${{ secrets.HUGGINGFACE }}
57+
diart.benchmark audio --reference rttms --batch-size 4 --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx
5758
- name: Run tuning
5859
run: |
59-
diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --hf-token ${{ secrets.HUGGINGFACE }}
60+
diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx

assets/models/embedding_uint8.onnx

4.31 MB
Binary file not shown.

assets/models/segmentation_uint8.onnx

1.53 MB
Binary file not shown.

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
numpy>=1.20.2
2-
matplotlib>=3.3.3
2+
matplotlib>=3.3.3,<3.6.0
33
rx>=3.2.0
44
scipy>=1.6.0
55
sounddevice>=0.4.2

setup.cfg

+6-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package_dir=
2121
packages=find:
2222
install_requires=
2323
numpy>=1.20.2
24-
matplotlib>=3.3.3
24+
matplotlib>=3.3.3,<3.6.0
2525
rx>=3.2.0
2626
scipy>=1.6.0
2727
sounddevice>=0.4.2
@@ -41,6 +41,11 @@ install_requires=
4141
websocket-client>=0.58.0
4242
rich>=12.5.1
4343

44+
[options.extras_require]
45+
tests=
46+
pytest>=7.4.0,<8.0.0
47+
onnxruntime==1.18.0
48+
4449
[options.packages.find]
4550
where=src
4651

tests/conftest.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import random
2+
3+
import pytest
4+
import torch
5+
6+
from diart.models import SegmentationModel, EmbeddingModel
7+
8+
9+
class DummySegmentationModel:
10+
def to(self, device):
11+
pass
12+
13+
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
14+
assert waveform.ndim == 3
15+
16+
batch_size, num_channels, num_samples = waveform.shape
17+
num_frames = random.randint(250, 500)
18+
num_speakers = random.randint(3, 5)
19+
20+
return torch.rand(batch_size, num_frames, num_speakers)
21+
22+
23+
class DummyEmbeddingModel:
24+
def to(self, device):
25+
pass
26+
27+
def __call__(self, waveform: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
28+
assert waveform.ndim == 3
29+
assert weights.ndim == 2
30+
31+
batch_size, num_channels, num_samples = waveform.shape
32+
batch_size_weights, num_frames = weights.shape
33+
34+
assert batch_size == batch_size_weights
35+
36+
embedding_dim = random.randint(128, 512)
37+
38+
return torch.randn(batch_size, embedding_dim)
39+
40+
41+
@pytest.fixture(scope="session")
42+
def segmentation_model() -> SegmentationModel:
43+
return SegmentationModel(DummySegmentationModel)
44+
45+
46+
@pytest.fixture(scope="session")
47+
def embedding_model() -> EmbeddingModel:
48+
return EmbeddingModel(DummyEmbeddingModel)

tests/data/audio/sample.wav

938 KB
Binary file not shown.

tests/data/rttm/latency_0.5.rttm

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
SPEAKER sample 1 6.675 0.533 <NA> <NA> speaker0 <NA> <NA>
2+
SPEAKER sample 1 7.625 1.883 <NA> <NA> speaker0 <NA> <NA>
3+
SPEAKER sample 1 9.508 1.000 <NA> <NA> speaker1 <NA> <NA>
4+
SPEAKER sample 1 10.508 0.567 <NA> <NA> speaker0 <NA> <NA>
5+
SPEAKER sample 1 10.625 4.133 <NA> <NA> speaker1 <NA> <NA>
6+
SPEAKER sample 1 14.325 3.733 <NA> <NA> speaker0 <NA> <NA>
7+
SPEAKER sample 1 18.058 3.450 <NA> <NA> speaker1 <NA> <NA>
8+
SPEAKER sample 1 18.325 0.183 <NA> <NA> speaker0 <NA> <NA>
9+
SPEAKER sample 1 21.508 0.017 <NA> <NA> speaker0 <NA> <NA>
10+
SPEAKER sample 1 21.775 0.233 <NA> <NA> speaker1 <NA> <NA>
11+
SPEAKER sample 1 22.008 6.633 <NA> <NA> speaker0 <NA> <NA>
12+
SPEAKER sample 1 28.508 1.500 <NA> <NA> speaker1 <NA> <NA>
13+
SPEAKER sample 1 29.958 0.050 <NA> <NA> speaker0 <NA> <NA>

tests/data/rttm/latency_1.rttm

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
SPEAKER sample 1 6.708 0.450 <NA> <NA> speaker0 <NA> <NA>
2+
SPEAKER sample 1 7.625 1.383 <NA> <NA> speaker0 <NA> <NA>
3+
SPEAKER sample 1 9.008 1.500 <NA> <NA> speaker1 <NA> <NA>
4+
SPEAKER sample 1 10.008 1.067 <NA> <NA> speaker0 <NA> <NA>
5+
SPEAKER sample 1 10.592 4.200 <NA> <NA> speaker1 <NA> <NA>
6+
SPEAKER sample 1 14.308 3.700 <NA> <NA> speaker0 <NA> <NA>
7+
SPEAKER sample 1 18.042 3.250 <NA> <NA> speaker1 <NA> <NA>
8+
SPEAKER sample 1 18.508 0.033 <NA> <NA> speaker0 <NA> <NA>
9+
SPEAKER sample 1 21.108 0.383 <NA> <NA> speaker0 <NA> <NA>
10+
SPEAKER sample 1 21.508 0.033 <NA> <NA> speaker1 <NA> <NA>
11+
SPEAKER sample 1 21.775 6.817 <NA> <NA> speaker0 <NA> <NA>
12+
SPEAKER sample 1 28.008 2.000 <NA> <NA> speaker1 <NA> <NA>
13+
SPEAKER sample 1 29.975 0.033 <NA> <NA> speaker0 <NA> <NA>

tests/data/rttm/latency_2.rttm

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
SPEAKER sample 1 6.725 0.433 <NA> <NA> speaker0 <NA> <NA>
2+
SPEAKER sample 1 7.592 0.817 <NA> <NA> speaker0 <NA> <NA>
3+
SPEAKER sample 1 8.475 1.617 <NA> <NA> speaker1 <NA> <NA>
4+
SPEAKER sample 1 9.892 1.150 <NA> <NA> speaker0 <NA> <NA>
5+
SPEAKER sample 1 10.625 4.133 <NA> <NA> speaker1 <NA> <NA>
6+
SPEAKER sample 1 14.292 3.667 <NA> <NA> speaker0 <NA> <NA>
7+
SPEAKER sample 1 18.008 3.533 <NA> <NA> speaker1 <NA> <NA>
8+
SPEAKER sample 1 18.225 0.283 <NA> <NA> speaker0 <NA> <NA>
9+
SPEAKER sample 1 21.758 6.867 <NA> <NA> speaker0 <NA> <NA>
10+
SPEAKER sample 1 27.875 2.133 <NA> <NA> speaker1 <NA> <NA>

tests/data/rttm/latency_3.rttm

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
SPEAKER sample 1 6.725 0.433 <NA> <NA> speaker0 <NA> <NA>
2+
SPEAKER sample 1 7.625 0.467 <NA> <NA> speaker0 <NA> <NA>
3+
SPEAKER sample 1 8.008 2.050 <NA> <NA> speaker1 <NA> <NA>
4+
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA>
5+
SPEAKER sample 1 10.592 4.167 <NA> <NA> speaker1 <NA> <NA>
6+
SPEAKER sample 1 14.292 3.667 <NA> <NA> speaker0 <NA> <NA>
7+
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA>
8+
SPEAKER sample 1 18.192 0.367 <NA> <NA> speaker0 <NA> <NA>
9+
SPEAKER sample 1 21.758 6.833 <NA> <NA> speaker0 <NA> <NA>
10+
SPEAKER sample 1 27.825 2.183 <NA> <NA> speaker1 <NA> <NA>

tests/data/rttm/latency_4.rttm

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
SPEAKER sample 1 6.742 0.400 <NA> <NA> speaker0 <NA> <NA>
2+
SPEAKER sample 1 7.625 0.650 <NA> <NA> speaker0 <NA> <NA>
3+
SPEAKER sample 1 8.092 1.950 <NA> <NA> speaker1 <NA> <NA>
4+
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA>
5+
SPEAKER sample 1 10.575 4.183 <NA> <NA> speaker1 <NA> <NA>
6+
SPEAKER sample 1 14.308 3.667 <NA> <NA> speaker0 <NA> <NA>
7+
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA>
8+
SPEAKER sample 1 18.208 0.333 <NA> <NA> speaker0 <NA> <NA>
9+
SPEAKER sample 1 21.758 6.817 <NA> <NA> speaker0 <NA> <NA>
10+
SPEAKER sample 1 27.808 2.200 <NA> <NA> speaker1 <NA> <NA>

tests/data/rttm/latency_5.rttm

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
SPEAKER sample 1 6.742 0.383 <NA> <NA> speaker0 <NA> <NA>
2+
SPEAKER sample 1 7.625 0.667 <NA> <NA> speaker0 <NA> <NA>
3+
SPEAKER sample 1 8.092 1.967 <NA> <NA> speaker1 <NA> <NA>
4+
SPEAKER sample 1 9.875 1.167 <NA> <NA> speaker0 <NA> <NA>
5+
SPEAKER sample 1 10.558 4.200 <NA> <NA> speaker1 <NA> <NA>
6+
SPEAKER sample 1 14.308 3.667 <NA> <NA> speaker0 <NA> <NA>
7+
SPEAKER sample 1 17.992 3.550 <NA> <NA> speaker1 <NA> <NA>
8+
SPEAKER sample 1 18.208 0.317 <NA> <NA> speaker0 <NA> <NA>
9+
SPEAKER sample 1 21.758 6.817 <NA> <NA> speaker0 <NA> <NA>
10+
SPEAKER sample 1 27.808 2.200 <NA> <NA> speaker1 <NA> <NA>

tests/test_aggregation.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
import pytest
3+
from pyannote.core import SlidingWindow, SlidingWindowFeature
4+
5+
from diart.blocks.aggregation import (
6+
AggregationStrategy,
7+
HammingWeightedAverageStrategy,
8+
FirstOnlyStrategy,
9+
AverageStrategy,
10+
DelayedAggregation,
11+
)
12+
13+
14+
def test_strategy_build():
15+
strategy = AggregationStrategy.build("mean")
16+
assert isinstance(strategy, AverageStrategy)
17+
18+
strategy = AggregationStrategy.build("hamming")
19+
assert isinstance(strategy, HammingWeightedAverageStrategy)
20+
21+
strategy = AggregationStrategy.build("first")
22+
assert isinstance(strategy, FirstOnlyStrategy)
23+
24+
with pytest.raises(Exception):
25+
AggregationStrategy.build("invalid")
26+
27+
28+
def test_aggregation():
29+
duration = 5
30+
frames = 500
31+
step = 0.5
32+
speakers = 2
33+
start_time = 10
34+
resolution = duration / frames
35+
36+
dagg1 = DelayedAggregation(step=step, latency=2, strategy="mean")
37+
dagg2 = DelayedAggregation(step=step, latency=2, strategy="hamming")
38+
dagg3 = DelayedAggregation(step=step, latency=2, strategy="first")
39+
40+
for dagg in [dagg1, dagg2, dagg3]:
41+
assert dagg.num_overlapping_windows == 4
42+
43+
buffers = [
44+
SlidingWindowFeature(
45+
np.random.rand(frames, speakers),
46+
SlidingWindow(
47+
start=(i + start_time) * step, duration=resolution, step=resolution
48+
),
49+
)
50+
for i in range(dagg1.num_overlapping_windows)
51+
]
52+
53+
for dagg in [dagg1, dagg2, dagg3]:
54+
assert dagg(buffers).data.shape == (51, 2)

0 commit comments

Comments
 (0)