Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: add pseudo speaker diarization pipeline based on segmentation stitching #201

Draft
wants to merge 19 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
3ecb58c
feat: add support for powerset segmentation models
hbredin Nov 8, 2023
72b0e60
wip: trying this PowersetAdapter thing
hbredin Nov 9, 2023
da517ac
fix: initialize nn.Module before setting attribute
hbredin Nov 9, 2023
c3fda2d
wip: add pseudo speaker diarization based on stitching only
hbredin Nov 9, 2023
14910e1
Add compatibility with pyannote 3.0 embedding wrappers (#188)
sorgfresser Nov 9, 2023
d5cbf72
Add support for powerset segmentation models (#198)
hbredin Nov 10, 2023
51ca233
Merge branch 'develop' into feat/pseudo-speaker-diarization
hbredin Nov 10, 2023
763de25
Add compatibility with pyannote 3.0 embedding wrappers (#188)
sorgfresser Nov 9, 2023
c1077a4
Add support for powerset segmentation models (#198)
hbredin Nov 10, 2023
aee7bce
Add ONNX compatibility (#204)
juanmc2005 Nov 11, 2023
c15e395
hotfix: Catch ModuleNotFoundError when loading a model with pyannote
juanmc2005 Nov 11, 2023
797d7b9
Merge branch 'develop' into feat/pseudo-speaker-diarization
hbredin Nov 13, 2023
6041c77
Add documentation page (#209)
juanmc2005 Nov 13, 2023
8cad376
README improvements (#207)
juanmc2005 Nov 16, 2023
8e9f74c
Make ONNX runtime optional (#215)
juanmc2005 Nov 18, 2023
c5a1bae
Add reproducibility warning in README (#216)
juanmc2005 Nov 18, 2023
7a031d2
Rename from_pyannote to from_pretrained in SpeakerEmbedding
juanmc2005 Nov 18, 2023
65a3fd9
Clean unused methods and incorrect arg doc
juanmc2005 Nov 18, 2023
658af40
Merge branch 'develop' into feat/pseudo-speaker-diarization
juanmc2005 Dec 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
version: 2

build:
os: "ubuntu-22.04"
tools:
python: "3.10"

python:
install:
- requirements: docs/requirements.txt
# Install diart before building the docs
- method: pip
path: .

sphinx:
configuration: docs/conf.py
218 changes: 139 additions & 79 deletions README.md

Large diffs are not rendered by default.

Binary file modified demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 20 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
Binary file added docs/_static/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 65 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

project = "diart"
copyright = "2023, Juan Manuel Coria"
author = "Juan Manuel Coria"
release = "v0.9"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
"autoapi.extension",
"sphinx.ext.coverage",
"sphinx.ext.napoleon",
"sphinx_mdinclude",
]

autoapi_dirs = ["../src/diart"]
autoapi_options = [
"members",
"undoc-members",
"show-inheritance",
"show-module-summary",
"special-members",
"imported-members",
]

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

# -- Options for autodoc ----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration

# Automatically extract typehints when specified and place them in
# descriptions of the relevant function/method.
autodoc_typehints = "description"

# Don't show class signature with the class' name.
autodoc_class_signature = "separated"

# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

html_theme = "furo"
html_static_path = ["_static"]
html_logo = "_static/logo.png"
html_title = "diart documentation"


def skip_submodules(app, what, name, obj, skip, options):
return (
name.endswith("__init__")
or name.startswith("diart.console")
or name.startswith("diart.argdoc")
)


def setup(sphinx):
sphinx.connect("autoapi-skip-member", skip_submodules)
11 changes: 11 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Get started with diart
======================

.. mdinclude:: ../README.md


Useful Links
============

.. toctree::
:maxdepth: 1
35 changes: 35 additions & 0 deletions docs/make.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@ECHO OFF

pushd %~dp0

REM Command file for Sphinx documentation

if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build

%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)

if "%1" == "" goto help

%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end

:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%

:end
popd
4 changes: 4 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
sphinx==6.2.1
sphinx-autoapi==3.0.0
sphinx-mdinclude==0.5.3
furo==2023.9.10
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ channels:
- conda-forge
- defaults
dependencies:
- python=3.8
- python=3.10
- portaudio=19.6.*
- pysoundfile=0.12.*
- ffmpeg[version='<4.4']
Expand Down
Binary file modified logo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pipeline.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ torch>=1.12.1
torchvision>=0.14.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
requests>=2.31.0
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[metadata]
name=diart
version=0.8.0
version=0.9.0
author=Juan Manuel Coria
description=Streaming speaker diarization in real-time
description=A python framework to build AI for real-time speech
long_description=file: README.md
long_description_content_type=text/markdown
keywords=speaker diarization, streaming, online, real time, rxpy
Expand Down Expand Up @@ -32,6 +32,7 @@ install_requires=
torchvision>=0.14.0
torchaudio>=2.0.2
pyannote.audio>=2.1.1
requests>=2.31.0
pyannote.core>=4.5
pyannote.database>=4.1.1
pyannote.metrics>=3.2
Expand Down
1 change: 1 addition & 0 deletions src/diart/argdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
OUTPUT = "Directory to store the system's output in RTTM format"
HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | <token>). If 'true', it will use the token from huggingface-cli login"
SAMPLE_RATE = "Sample rate of the audio stream"
NORMALIZE_EMBEDDING_WEIGHTS = "Rescale embedding weights (min-max normalization) to be in the range [0, 1]. This is useful in some models without weighted statistics pooling that rely on masking, like Nvidia's NeMo or ECAPA-TDNN"
1 change: 1 addition & 0 deletions src/diart/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from .segmentation import SpeakerSegmentation
from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
from .pseudo_diarization import PseudoSpeakerDiarization, PseudoSpeakerDiarizationConfig
from .base import PipelineConfig, Pipeline
from .utils import Binarize, Resample, AdjustVolume
from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig
42 changes: 42 additions & 0 deletions src/diart/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,28 @@

@dataclass
class HyperParameter:
"""Represents a pipeline hyper-parameter that can be tuned by diart"""

name: Text
"""Name of the hyper-parameter (e.g. tau_active)"""
low: float
"""Lowest value that this parameter can take"""
high: float
"""Highest value that this parameter can take"""

@staticmethod
def from_name(name: Text) -> "HyperParameter":
"""Create a HyperParameter object given its name.

Parameters
----------
name: str
Name of the hyper-parameter

Returns
-------
HyperParameter
"""
if name == "tau_active":
return TauActive
if name == "rho_update":
Expand All @@ -32,24 +48,34 @@ def from_name(name: Text) -> "HyperParameter":


class PipelineConfig(ABC):
"""Configuration containing the required
parameters to build and run a pipeline"""

@property
@abstractmethod
def duration(self) -> float:
"""The duration of an input audio chunk (in seconds)"""
pass

@property
@abstractmethod
def step(self) -> float:
"""The step between two consecutive input audio chunks (in seconds)"""
pass

@property
@abstractmethod
def latency(self) -> float:
"""The algorithmic latency of the pipeline (in seconds).
At time `t` of the audio stream, the pipeline will
output predictions for time `t - latency`.
"""
pass

@property
@abstractmethod
def sample_rate(self) -> int:
"""The sample rate of the input audio stream"""
pass

def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
Expand All @@ -60,6 +86,8 @@ def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:


class Pipeline(ABC):
"""Represents a streaming audio pipeline"""

@staticmethod
@abstractmethod
def get_config_class() -> type:
Expand Down Expand Up @@ -92,4 +120,18 @@ def set_timestamp_shift(self, shift: float):
def __call__(
self, waveforms: Sequence[SlidingWindowFeature]
) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
"""Runs the next steps of the pipeline
given a list of consecutive audio chunks.

Parameters
----------
waveforms: Sequence[SlidingWindowFeature]
Consecutive chunk waveforms for the pipeline to ingest

Returns
-------
Sequence[Tuple[Any, SlidingWindowFeature]]
For each input waveform, a tuple containing
the pipeline output and its respective audio
"""
pass
4 changes: 4 additions & 0 deletions src/diart/blocks/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def identify(
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[
0
]
# Remove speakers that have NaN embeddings
no_nan_embeddings = np.where(~np.isnan(embeddings).any(axis=1))[0]
active_speakers = np.intersect1d(active_speakers, no_nan_embeddings)

num_local_speakers = segmentation.data.shape[1]

if self.centers is None:
Expand Down
Loading