diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..b00b1eed --- /dev/null +++ b/.dockerignore @@ -0,0 +1,19 @@ +# Development +.git/ +.github/ +.idea/ +__pycache__/ + +# Data and examples +assets/ +example/ +expected_outputs/ +tests/ + +# Documentation +docs/ + +# Build artifacts +*.egg-info/ +dist/ +build/ diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml new file mode 100644 index 00000000..067b7907 --- /dev/null +++ b/.github/workflows/pytest.yml @@ -0,0 +1,35 @@ +name: Pytest + +on: + pull_request: + branches: + - main + - develop + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.10' + + - name: Install apt dependencies + run: | + sudo add-apt-repository ppa:savoury1/ffmpeg4 + sudo apt-get update + sudo apt-get -y install ffmpeg libportaudio2=19.6.0-1.1 + + - name: Install pip dependencies + run: | + python -m pip install --upgrade pip + pip install .[tests] + + - name: Run tests + run: | + pytest diff --git a/.github/workflows/quick-runs.yml b/.github/workflows/quick-runs.yml index 0540897a..24b7f387 100644 --- a/.github/workflows/quick-runs.yml +++ b/.github/workflows/quick-runs.yml @@ -38,6 +38,7 @@ jobs: run: | python -m pip install --upgrade pip pip install . + pip install onnxruntime==1.18.0 - name: Crop audio and rttm run: | sox audio/ES2002a_long.wav audio/ES2002a.wav trim 00:40 00:30 @@ -50,10 +51,10 @@ jobs: rm rttms/ES2002b_long.rttm - name: Run stream run: | - diart.stream audio/ES2002a.wav --output trash --no-plot --hf-token ${{ secrets.HUGGINGFACE }} + diart.stream audio/ES2002a.wav --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx --output trash --no-plot - name: Run benchmark run: | - diart.benchmark audio --reference rttms --batch-size 4 --hf-token ${{ secrets.HUGGINGFACE }} + diart.benchmark audio --reference rttms --batch-size 4 --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx - name: Run tuning run: | - diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --hf-token ${{ secrets.HUGGINGFACE }} + diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..4e839b10 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,75 @@ +# Use NVIDIA CUDA base image +FROM docker.io/nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 + +# Install sudo, git, wget, gcc, g++, and other essential build tools +RUN apt-get update && \ + apt-get install -y sudo git wget build-essential && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +# Install Miniconda +ENV CONDA_DIR=/opt/conda +ENV PATH=$CONDA_DIR/bin:$PATH +RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \ + bash /tmp/miniconda.sh -b -p $CONDA_DIR && \ + rm /tmp/miniconda.sh + +# Install Python 3.10 using Conda +RUN conda install python=3.10 + +# Upgrade pip and setuptools to avoid deprecation warnings +RUN pip install --upgrade pip setuptools + +# Set Python 3.11 as default by creating a symbolic link +RUN ln -sf /opt/conda/bin/python3.10 /opt/conda/bin/python && \ + ln -sf /opt/conda/bin/python3.10 /usr/bin/python + +# Verify installations +RUN python --version && \ + gcc --version && \ + g++ --version && \ + pip --version && \ + conda --version + +# Create app directory and copy files +WORKDIR /diart +COPY . . + +# Install diart dependencies +RUN conda install portaudio pysoundfile ffmpeg -c conda-forge +RUN pip install -e . + +# Expose the port the app runs on +EXPOSE 7007 + +# Define environment variable to prevent Python from buffering stdout/stderr +# and writing byte code to file +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +# Define custom options as env variables with defaults +ENV HOST=0.0.0.0 +ENV PORT=7007 +ENV SEGMENTATION=pyannote/segmentation-3.0 +ENV EMBEDDING=speechbrain/spkrec-resnet-voxceleb +ENV TAU_ACTIVE=0.45 +ENV RHO_UPDATE=0.25 +ENV DELTA_NEW=0.6 +ENV LATENCY=5 +ENV MAX_SPEAKERS=3 + +CMD ["sh", "-c", "python -m diart.console.serve --host ${HOST} --port ${PORT} --segmentation ${SEGMENTATION} --embedding ${EMBEDDING} --tau-active ${TAU_ACTIVE} --rho-update ${RHO_UPDATE} --delta-new ${DELTA_NEW} --latency ${LATENCY} --max-speakers ${MAX_SPEAKERS}"] + +# Example run command with environment variables: +# docker run -p 7007:7007 --restart unless-stopped --gpus all \ +# -e HF_TOKEN= \ +# -e HOST=0.0.0.0 \ +# -e PORT=7007 \ +# -e SEGMENTATION=pyannote/segmentation-3.0 \ +# -e EMBEDDING=speechbrain/spkrec-resnet-voxceleb \ +# -e TAU_ACTIVE=0.45 \ +# -e RHO_UPDATE=0.25 \ +# -e DELTA_NEW=0.6 \ +# -e LATENCY=5 \ +# -e MAX_SPEAKERS=3 \ +# diart-image \ No newline at end of file diff --git a/README.md b/README.md index f9a89dbc..cc425a43 100644 --- a/README.md +++ b/README.md @@ -202,6 +202,7 @@ def embedding_loader(): segmentation = SegmentationModel(segmentation_loader) embedding = EmbeddingModel(embedding_loader) config = SpeakerDiarizationConfig( + # Set the segmentation model used in the paper segmentation=segmentation, embedding=embedding, ) @@ -284,21 +285,27 @@ Obtain overlap-aware speaker embeddings from a microphone stream: ```python import rx.operators as ops import diart.operators as dops -from diart.sources import MicrophoneAudioSource +from diart.sources import MicrophoneAudioSource, FileAudioSource from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding segmentation = SpeakerSegmentation.from_pretrained("pyannote/segmentation") embedding = OverlapAwareSpeakerEmbedding.from_pretrained("pyannote/embedding") -mic = MicrophoneAudioSource() + +source = MicrophoneAudioSource() +# To take input from file: +# source = FileAudioSource("", sample_rate=16000) + +# Make sure the models have been trained with this sample rate +print(source.sample_rate) stream = mic.stream.pipe( # Reformat stream to 5s duration and 500ms shift - dops.rearrange_audio_stream(sample_rate=segmentation.model.sample_rate), + dops.rearrange_audio_stream(sample_rate=source.sample_rate), ops.map(lambda wav: (wav, segmentation(wav))), ops.starmap(embedding) ).subscribe(on_next=lambda emb: print(emb.shape)) -mic.read() +source.read() ``` Output: @@ -326,20 +333,57 @@ diart.client microphone --host --port 7007 See `-h` for more options. +### From the Dockerfile + +You can also run the server in a Docker container. First, build the image: +```shell +docker build -t diart -f Dockerfile . +``` + +Run the server with default configuration: +```shell +docker run -p 7007:7007 --gpus all -e HF_TOKEN= diart +``` + +Run with custom configuration: +```shell +docker run -p 7007:7007 --restart unless-stopped --gpus all \ + -e HF_TOKEN= \ + -e HOST=0.0.0.0 \ + -e PORT=7007 \ + -e SEGMENTATION=pyannote/segmentation-3.0 \ + -e EMBEDDING=speechbrain/spkrec-resnet-voxceleb \ + -e TAU_ACTIVE=0.45 \ + -e RHO_UPDATE=0.25 \ + -e DELTA_NEW=0.6 \ + -e LATENCY=5 \ + -e MAX_SPEAKERS=3 \ + diart +``` + +The server can be configured using these environment variables, at runtime: +- `HOST`: Server host (default: 0.0.0.0) +- `PORT`: Server port (default: 7007) +- `SEGMENTATION`: Segmentation model (default: pyannote/segmentation) +- `EMBEDDING`: Embedding model (default: pyannote/embedding) +- `TAU_ACTIVE`: Activity threshold (default: 0.5) +- `RHO_UPDATE`: Update threshold (default: 0.3) +- `DELTA_NEW`: New speaker threshold (default: 1.0) +- `LATENCY`: Processing latency in seconds (default: 0.5) +- `MAX_SPEAKERS`: Maximum number of speakers (default: 20) + ### From python -For customized solutions, a server can also be created in python using the `WebSocketAudioSource`: +For customized solutions, a server can also be created in python using `WebSocketStreamingServer`: ```python -from diart import SpeakerDiarization -from diart.sources import WebSocketAudioSource -from diart.inference import StreamingInference +from diart import SpeakerDiarization, SpeakerDiarizationConfig +from diart.websockets import WebSocketStreamingServer -pipeline = SpeakerDiarization() -source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007) -inference = StreamingInference(pipeline, source) -inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm())) -prediction = inference() +pipeline_class = SpeakerDiarization +pipeline_config = SpeakerDiarizationConfig(step=0.5, sample_rate=16000) +server = WebSocketStreamingServer(pipeline_class, pipeline_config, host="localhost", port=7007) +server.run() ``` ## 🔬 Powered by research diff --git a/assets/models/embedding_uint8.onnx b/assets/models/embedding_uint8.onnx new file mode 100644 index 00000000..ac5ab44d Binary files /dev/null and b/assets/models/embedding_uint8.onnx differ diff --git a/assets/models/segmentation_uint8.onnx b/assets/models/segmentation_uint8.onnx new file mode 100644 index 00000000..8daa3751 Binary files /dev/null and b/assets/models/segmentation_uint8.onnx differ diff --git a/requirements.txt b/requirements.txt index 2d3e4611..03e18829 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -numpy>=1.20.2 -matplotlib>=3.3.3 +numpy>=1.20.2,<2.0.0 +matplotlib>=3.3.3,<3.6.0 rx>=3.2.0 scipy>=1.6.0 sounddevice>=0.4.2 diff --git a/setup.cfg b/setup.cfg index 66b255f2..9f1ad091 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name=diart -version=0.9.0 +version=0.9.1 author=Juan Manuel Coria description=A python framework to build AI for real-time speech long_description=file: README.md @@ -20,8 +20,8 @@ package_dir= =src packages=find: install_requires= - numpy>=1.20.2 - matplotlib>=3.3.3 + numpy>=1.20.2,<2.0.0 + matplotlib>=3.3.3,<3.6.0 rx>=3.2.0 scipy>=1.6.0 sounddevice>=0.4.2 @@ -41,6 +41,11 @@ install_requires= websocket-client>=0.58.0 rich>=12.5.1 +[options.extras_require] +tests= + pytest>=7.4.0,<8.0.0 + onnxruntime==1.18.0 + [options.packages.find] where=src diff --git a/src/diart/console/client.py b/src/diart/console/client.py index b3de36db..39dd4824 100644 --- a/src/diart/console/client.py +++ b/src/diart/console/client.py @@ -1,39 +1,60 @@ import argparse from pathlib import Path -from threading import Thread -from typing import Text, Optional +from threading import Event, Thread +from typing import Optional, Text import rx.operators as ops -from websocket import WebSocket +from websocket import WebSocket, WebSocketException from diart import argdoc from diart import sources as src from diart import utils -def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int): - # Create audio source - source_components = source.split(":") - if source_components[0] != "microphone": - audio_source = src.FileAudioSource(source, sample_rate, block_duration=step) - else: - device = int(source_components[1]) if len(source_components) > 1 else None - audio_source = src.MicrophoneAudioSource(step, device) +def send_audio( + ws: WebSocket, source: Text, step: float, sample_rate: int, stop_event: Event +): + try: + # Create audio source + source_components = source.split(":") + if source_components[0] != "microphone": + audio_source = src.FileAudioSource(source, sample_rate, block_duration=step) + else: + device = int(source_components[1]) if len(source_components) > 1 else None + audio_source = src.MicrophoneAudioSource(step, device) - # Encode audio, then send through websocket - audio_source.stream.pipe(ops.map(utils.encode_audio)).subscribe_(ws.send) + # Encode audio, then send through websocket + def on_next(data): + if not stop_event.is_set(): + try: + ws.send(utils.encode_audio(data)) + except WebSocketException: + stop_event.set() - # Start reading audio - audio_source.read() + audio_source.stream.subscribe_(on_next) + # Start reading audio + audio_source.read() + except Exception as e: + print(f"Error in send_audio: {e}") + stop_event.set() -def receive_audio(ws: WebSocket, output: Optional[Path]): - while True: - message = ws.recv() - print(f"Received: {message}", end="") - if output is not None: - with open(output, "a") as file: - file.write(message) + +def receive_audio(ws: WebSocket, output: Optional[Path], stop_event: Event): + try: + while not stop_event.is_set(): + try: + message = ws.recv() + print(f"Received: {message}", end="") + if output is not None: + with open(output, "a") as file: + file.write(message) + except WebSocketException: + break + except Exception as e: + print(f"Error in receive_audio: {e}") + finally: + stop_event.set() def run(): @@ -43,8 +64,12 @@ def run(): type=str, help="Path to an audio file | 'microphone' | 'microphone:'", ) - parser.add_argument("--host", required=True, type=str, help="Server host") - parser.add_argument("--port", required=True, type=int, help="Server port") + parser.add_argument( + "--host", default="0.0.0.0", type=str, help="Server host. Defaults to 0.0.0.0" + ) + parser.add_argument( + "--port", default=7007, type=int, help="Server port. Defaults to 7007" + ) parser.add_argument( "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" ) @@ -65,13 +90,53 @@ def run(): # Run websocket client ws = WebSocket() - ws.connect(f"ws://{args.host}:{args.port}") - sender = Thread( - target=send_audio, args=[ws, args.source, args.step, args.sample_rate] - ) - receiver = Thread(target=receive_audio, args=[ws, args.output_file]) - sender.start() - receiver.start() + stop_event = Event() + + try: + ws.connect(f"ws://{args.host}:{args.port}") + + # Wait for READY signal from server + print("Waiting for server to be ready...", end="", flush=True) + while not stop_event.is_set(): + try: + message = ws.recv() + if message.strip() == "READY": + print(" OK") + break + print(f"\nUnexpected message while waiting for READY: {message}") + except WebSocketException as e: + print(f"\nWebSocket error while waiting for server: {e}") + return + + # Start threads for sending and receiving audio + sender = Thread( + target=send_audio, + args=[ws, args.source, args.step, args.sample_rate, stop_event], + ) + receiver = Thread(target=receive_audio, args=[ws, args.output_file, stop_event]) + + sender.start() + receiver.start() + + try: + # Wait for threads to complete or for keyboard interrupt + sender.join() + receiver.join() + except KeyboardInterrupt: + print("\nShutting down...") + stop_event.set() + + except Exception as e: + print(f"Error: {e}") + + finally: + stop_event.set() + try: + ws.close() + except WebSocketException: + print("Error closing WebSocket") + except Exception as e: + print(f"Unexpected error closing WebSocket: {e}") if __name__ == "__main__": diff --git a/src/diart/console/serve.py b/src/diart/console/serve.py index e52980dd..94b33e56 100644 --- a/src/diart/console/serve.py +++ b/src/diart/console/serve.py @@ -5,10 +5,8 @@ from diart import argdoc from diart import models as m -from diart import sources as src from diart import utils -from diart.inference import StreamingInference -from diart.sinks import RTTMWriter +from diart.websockets import WebSocketStreamingServer def run(): @@ -72,9 +70,6 @@ def run(): action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", ) - parser.add_argument( - "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing" - ) parser.add_argument( "--hf-token", default="true", @@ -96,35 +91,19 @@ def run(): args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token) args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token) - # Resolve pipeline + # Resolve pipeline configuration pipeline_class = utils.get_pipeline_class(args.pipeline) - config = pipeline_class.get_config_class()(**vars(args)) - pipeline = pipeline_class(config) - - # Create websocket audio source - audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port) + pipeline_config = pipeline_class.get_config_class()(**vars(args)) - # Run online inference - inference = StreamingInference( - pipeline, - audio_source, - batch_size=1, - do_profile=False, - do_plot=False, - show_progress=True, + # Initialize Websocket server + server = WebSocketStreamingServer( + pipeline_class=pipeline_class, + pipeline_config=pipeline_config, + host=args.host, + port=args.port, ) - # Write to disk if required - if args.output is not None: - inference.attach_observers( - RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm") - ) - - # Send back responses as RTTM text lines - inference.attach_hooks(lambda ann_wav: audio_source.send(ann_wav[0].to_rttm())) - - # Run server and pipeline - inference() + server.run() if __name__ == "__main__": diff --git a/src/diart/sources.py b/src/diart/sources.py index 82051b2e..79fc5834 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from pathlib import Path from queue import SimpleQueue -from typing import Text, Optional, AnyStr, Dict, Any, Union, Tuple +from typing import Any, AnyStr, Dict, Optional, Text, Tuple, Union import numpy as np import sounddevice as sd @@ -12,7 +12,7 @@ from websocket_server import WebsocketServer from . import utils -from .audio import FilePath, AudioLoader +from .audio import AudioLoader, FilePath class AudioSource(ABC): @@ -201,76 +201,6 @@ def close(self): self._mic_stream.close() -class WebSocketAudioSource(AudioSource): - """Represents a source of audio coming from the network using the WebSocket protocol. - - Parameters - ---------- - sample_rate: int - Sample rate of the chunks emitted. - host: Text - The host to run the websocket server. - Defaults to 127.0.0.1. - port: int - The port to run the websocket server. - Defaults to 7007. - key: Text | Path | None - Path to a key if using SSL. - Defaults to no key. - certificate: Text | Path | None - Path to a certificate if using SSL. - Defaults to no certificate. - """ - - def __init__( - self, - sample_rate: int, - host: Text = "127.0.0.1", - port: int = 7007, - key: Optional[Union[Text, Path]] = None, - certificate: Optional[Union[Text, Path]] = None, - ): - # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities. - # I would prefer the client to send a JSON with data and sample rate, then resample if needed - super().__init__(f"{host}:{port}", sample_rate) - self.client: Optional[Dict[Text, Any]] = None - self.server = WebsocketServer(host, port, key=key, cert=certificate) - self.server.set_fn_message_received(self._on_message_received) - - def _on_message_received( - self, - client: Dict[Text, Any], - server: WebsocketServer, - message: AnyStr, - ): - # Only one client at a time is allowed - if self.client is None or self.client["id"] != client["id"]: - self.client = client - # Send decoded audio to pipeline - self.stream.on_next(utils.decode_audio(message)) - - def read(self): - """Starts running the websocket server and listening for audio chunks""" - self.server.run_forever() - - def close(self): - """Close the websocket server""" - if self.server is not None: - self.stream.on_completed() - self.server.shutdown_gracefully() - - def send(self, message: AnyStr): - """Send a message through the current websocket. - - Parameters - ---------- - message: AnyStr - Bytes or string to send. - """ - if len(message) > 0: - self.server.send_message(self.client, message) - - class TorchStreamAudioSource(AudioSource): def __init__( self, diff --git a/src/diart/websockets.py b/src/diart/websockets.py new file mode 100644 index 00000000..e0040355 --- /dev/null +++ b/src/diart/websockets.py @@ -0,0 +1,370 @@ +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, AnyStr, Callable, Dict, Optional, Text, Union + +import numpy as np +from websocket_server import WebsocketServer + +from . import blocks +from . import sources as src +from . import utils +from .inference import StreamingInference + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +class ProxyAudioSource(src.AudioSource): + """A proxy audio source that forwards decoded audio chunks to a processing pipeline. + + Parameters + ---------- + uri : str + Unique identifier for this audio source + sample_rate : int + Expected sample rate of the audio chunks + """ + + def __init__( + self, uri: str, sample_rate: int, + ): + # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities. + # I would prefer the client to send a JSON with data and sample rate, then resample if needed + super().__init__(uri, sample_rate) + + def process_message(self, message: np.ndarray): + """Process an incoming audio message.""" + # Send audio to pipeline + self.stream.on_next(message) + + def read(self): + """Starts running the websocket server and listening for audio chunks""" + pass + + def close(self): + """Complete the audio stream for this client.""" + self.stream.on_completed() + + +@dataclass +class ClientState: + """Represents the state of a connected client.""" + + audio_source: ProxyAudioSource + inference: StreamingInference + + +class WebSocketStreamingServer: + """Handles real-time speaker diarization inference for multiple audio sources over WebSocket. + + This handler manages WebSocket connections from multiple clients, processing + audio streams and performing speaker diarization in real-time. + + Parameters + ---------- + pipeline_class : type + Pipeline class + pipeline_config : blocks.PipelineConfig + Pipeline configuration + host : str, optional + WebSocket server host, by default "127.0.0.1" + port : int, optional + WebSocket server port, by default 7007 + key : Union[str, Path], optional + SSL key file path for secure WebSocket + certificate : Union[str, Path], optional + SSL certificate file path for secure WebSocket + """ + + def __init__( + self, + pipeline_class: type, + pipeline_config: blocks.PipelineConfig, + host: Text = "127.0.0.1", + port: int = 7007, + key: Optional[Union[Text, Path]] = None, + certificate: Optional[Union[Text, Path]] = None, + ): + # Pipeline configuration + self.pipeline_class = pipeline_class + self.pipeline_config = pipeline_config + + # Server configuration + self.host = host + self.port = port + self.uri = f"{host}:{port}" + self._clients: Dict[Text, ClientState] = {} + + # Initialize WebSocket server + self.server = WebsocketServer(host, port, key=key, cert=certificate) + self._setup_server_handlers() + + def _setup_server_handlers(self) -> None: + """Configure WebSocket server event handlers.""" + self.server.set_fn_new_client(self._on_connect) + self.server.set_fn_client_left(self._on_disconnect) + self.server.set_fn_message_received(self._on_message_received) + + def _create_client_state(self, client_id: Text) -> ClientState: + """Create and initialize state for a new client. + + Parameters + ---------- + client_id : Text + Unique client identifier + + Returns + ------- + ClientState + Initialized client state object + """ + # Create a new pipeline instance with the same config + # This ensures each client has its own state while sharing model weights + pipeline = self.pipeline_class(self.pipeline_config) + + audio_source = ProxyAudioSource( + uri=f"{self.uri}:{client_id}", sample_rate=self.pipeline_config.sample_rate, + ) + + inference = StreamingInference( + pipeline=pipeline, + source=audio_source, + # The following variables are fixed for a client + batch_size=1, + do_profile=False, # for minimal latency + do_plot=False, + show_progress=False, + progress_bar=None, + ) + + return ClientState(audio_source=audio_source, inference=inference) + + def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None: + """Handle new client connection. + + Parameters + ---------- + client : Dict[Text, Any] + Client information dictionary + server : WebsocketServer + WebSocket server instance + """ + client_id = client["id"] + logger.info(f"New client connected: {client_id}") + + if client_id in self._clients: + return + + try: + self._clients[client_id] = self._create_client_state(client_id) + + # Setup RTTM response hook + self._clients[client_id].inference.attach_hooks( + lambda ann_wav: self.send(client_id, ann_wav[0].to_rttm()) + ) + + # Start inference + self._clients[client_id].inference() + logger.info(f"Started inference for client: {client_id}") + + # Send ready notification to client + self.send(client_id, "READY") + except OSError as e: + logger.warning(f"Client {client_id} connection failed: {e}") + # Just cleanup since client is already disconnected + self.close(client_id) + except Exception as e: + logger.error(f"Failed to initialize client {client_id}: {e}") + # Close audio source and remove client + self.close(client_id) + # Send close notification to client + self.send(client_id, "CLOSE") + + def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> None: + """Cleanup client state when a connection is closed. + + Parameters + ---------- + client : Dict[Text, Any] + Client metadata + server : WebsocketServer + Server instance + """ + client_id = client["id"] + logger.info(f"Client disconnected: {client_id}") + # Just cleanup resources, no need to send CLOSE as client is already disconnected + self.close(client_id) + + def _on_message_received( + self, client: Dict[Text, Any], server: WebsocketServer, message: AnyStr + ) -> None: + """Process incoming client messages. + + Parameters + ---------- + client : Dict[Text, Any] + Client information dictionary + server : WebsocketServer + WebSocket server instance + message : AnyStr + Received message data + + Raises + ------ + Warning + If client not found in self._clients. Common cases: + - Message received before client initialization complete + - Message received after client cleanup started + - Race condition in client connection/disconnection + Error + If message processing fails. Error will be logged. + """ + client_id = client["id"] + + if client_id not in self._clients: + logger.warning(f"Received message from unregistered client {client_id}") + return + + try: + # decode message to audio + decoded_audio = utils.decode_audio(message) + self._clients[client_id].audio_source.process_message(decoded_audio) + except OSError as e: + logger.warning(f"Client {client_id} disconnected: {e}") + # Just cleanup since client is already disconnected + self.close(client_id) + except Exception as e: + # Don't close the connection for non-connection related errors + # This allows the client to retry sending the message + logger.error(f"Error processing message from client {client_id}: {e}") + + def send(self, client_id: Text, message: AnyStr) -> None: + """Send a message to a specific client. + + Parameters + ---------- + client_id : Text + Target client identifier + message : AnyStr + Message to send + + Raises + ------ + Warning + If client is not found in server.clients. Common cases: + - Client disconnected but cleanup is not complete + - Network issues caused client to drop + - Race condition between disconnect and message send + Error + If message sending fails. Error will be logged and re-raised. + """ + if not message: + return + + client = next((c for c in self.server.clients if c["id"] == client_id), None) + if client is None: + logger.warning( + f"Failed to send message to client {client_id}: client not found" + ) + return + + try: + self.server.send_message(client, message) + except Exception as e: + logger.error(f"Failed to send message to client {client_id}: {e}") + raise + + def close(self, client_id: Text) -> None: + """Close and cleanup resources for a specific client. + + Parameters + ---------- + client_id : Text + Client identifier to close + + Raises + ------ + Warning + If client not found in self._clients. Common cases: + - Already cleaned up by another error handler + - Multiple error handlers trying to cleanup same client + - Cleanup triggered for client that never fully connected + Error + If cleanup fails. Error will be logged and client will be force-removed. + """ + if client_id not in self._clients: + logger.warning(f"Attempted to close non-existent client {client_id}") + return + + try: + # Clean up pipeline state using built-in reset method + client_state = self._clients[client_id] + client_state.inference.pipeline.reset() + + # Close audio source and remove client + client_state.audio_source.close() + del self._clients[client_id] + + logger.info(f"Cleaned up resources for client: {client_id}") + except Exception as e: + logger.error(f"Error cleaning up resources for client {client_id}: {e}") + # Ensure client is removed even if cleanup fails + self._clients.pop(client_id, None) + + def close_all(self) -> None: + """Shutdown the server and cleanup all client connections.""" + logger.info("Shutting down server...") + try: + for client_id in self._clients.keys(): + # Close audio source and remove client + self.close(client_id) + # Send close notification to client + self.send(client_id, "CLOSE") + + if self.server is not None: + self.server.shutdown_gracefully() + + logger.info("Server shutdown complete") + except Exception as e: + logger.error(f"Error during shutdown: {e}") + + def run(self) -> None: + """Start the WebSocket server. + + The server will attempt to restart on connection errors with exponential backoff. + After max retries are exhausted, it will shut down gracefully. + """ + logger.info(f"Starting WebSocket server on {self.uri}") + max_retries = 3 + retry_count = 0 + base_delay = 1 # in seconds + + while retry_count < max_retries: + try: + self.server.run_forever() + break # If server exits normally, break the retry loop + except OSError as e: + logger.warning(f"WebSocket server connection error: {e}") + retry_count += 1 + if retry_count < max_retries: + delay = base_delay * (2 ** (retry_count - 1)) # Exponential backoff + logger.info( + f"Retrying in {delay} seconds... " + f"(attempt {retry_count}/{max_retries})" + ) + time.sleep(delay) + else: + logger.error( + f"WebSocket server failed to start after {max_retries} attempts. " + f"Last error: {e}" + ) + except Exception as e: + logger.error(f"Fatal server error: {e}") + break + finally: + self.close_all() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..3c5a2915 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,48 @@ +import random + +import pytest +import torch + +from diart.models import SegmentationModel, EmbeddingModel + + +class DummySegmentationModel: + def to(self, device): + pass + + def __call__(self, waveform: torch.Tensor) -> torch.Tensor: + assert waveform.ndim == 3 + + batch_size, num_channels, num_samples = waveform.shape + num_frames = random.randint(250, 500) + num_speakers = random.randint(3, 5) + + return torch.rand(batch_size, num_frames, num_speakers) + + +class DummyEmbeddingModel: + def to(self, device): + pass + + def __call__(self, waveform: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + assert waveform.ndim == 3 + assert weights.ndim == 2 + + batch_size, num_channels, num_samples = waveform.shape + batch_size_weights, num_frames = weights.shape + + assert batch_size == batch_size_weights + + embedding_dim = random.randint(128, 512) + + return torch.randn(batch_size, embedding_dim) + + +@pytest.fixture(scope="session") +def segmentation_model() -> SegmentationModel: + return SegmentationModel(DummySegmentationModel) + + +@pytest.fixture(scope="session") +def embedding_model() -> EmbeddingModel: + return EmbeddingModel(DummyEmbeddingModel) diff --git a/tests/data/audio/sample.wav b/tests/data/audio/sample.wav new file mode 100644 index 00000000..150d49a6 Binary files /dev/null and b/tests/data/audio/sample.wav differ diff --git a/tests/data/rttm/latency_0.5.rttm b/tests/data/rttm/latency_0.5.rttm new file mode 100644 index 00000000..058ed2e2 --- /dev/null +++ b/tests/data/rttm/latency_0.5.rttm @@ -0,0 +1,13 @@ +SPEAKER sample 1 6.675 0.533 speaker0 +SPEAKER sample 1 7.625 1.883 speaker0 +SPEAKER sample 1 9.508 1.000 speaker1 +SPEAKER sample 1 10.508 0.567 speaker0 +SPEAKER sample 1 10.625 4.133 speaker1 +SPEAKER sample 1 14.325 3.733 speaker0 +SPEAKER sample 1 18.058 3.450 speaker1 +SPEAKER sample 1 18.325 0.183 speaker0 +SPEAKER sample 1 21.508 0.017 speaker0 +SPEAKER sample 1 21.775 0.233 speaker1 +SPEAKER sample 1 22.008 6.633 speaker0 +SPEAKER sample 1 28.508 1.500 speaker1 +SPEAKER sample 1 29.958 0.050 speaker0 diff --git a/tests/data/rttm/latency_1.rttm b/tests/data/rttm/latency_1.rttm new file mode 100644 index 00000000..40c591e8 --- /dev/null +++ b/tests/data/rttm/latency_1.rttm @@ -0,0 +1,13 @@ +SPEAKER sample 1 6.708 0.450 speaker0 +SPEAKER sample 1 7.625 1.383 speaker0 +SPEAKER sample 1 9.008 1.500 speaker1 +SPEAKER sample 1 10.008 1.067 speaker0 +SPEAKER sample 1 10.592 4.200 speaker1 +SPEAKER sample 1 14.308 3.700 speaker0 +SPEAKER sample 1 18.042 3.250 speaker1 +SPEAKER sample 1 18.508 0.033 speaker0 +SPEAKER sample 1 21.108 0.383 speaker0 +SPEAKER sample 1 21.508 0.033 speaker1 +SPEAKER sample 1 21.775 6.817 speaker0 +SPEAKER sample 1 28.008 2.000 speaker1 +SPEAKER sample 1 29.975 0.033 speaker0 diff --git a/tests/data/rttm/latency_2.rttm b/tests/data/rttm/latency_2.rttm new file mode 100644 index 00000000..dacd8453 --- /dev/null +++ b/tests/data/rttm/latency_2.rttm @@ -0,0 +1,10 @@ +SPEAKER sample 1 6.725 0.433 speaker0 +SPEAKER sample 1 7.592 0.817 speaker0 +SPEAKER sample 1 8.475 1.617 speaker1 +SPEAKER sample 1 9.892 1.150 speaker0 +SPEAKER sample 1 10.625 4.133 speaker1 +SPEAKER sample 1 14.292 3.667 speaker0 +SPEAKER sample 1 18.008 3.533 speaker1 +SPEAKER sample 1 18.225 0.283 speaker0 +SPEAKER sample 1 21.758 6.867 speaker0 +SPEAKER sample 1 27.875 2.133 speaker1 diff --git a/tests/data/rttm/latency_3.rttm b/tests/data/rttm/latency_3.rttm new file mode 100644 index 00000000..95d432dc --- /dev/null +++ b/tests/data/rttm/latency_3.rttm @@ -0,0 +1,10 @@ +SPEAKER sample 1 6.725 0.433 speaker0 +SPEAKER sample 1 7.625 0.467 speaker0 +SPEAKER sample 1 8.008 2.050 speaker1 +SPEAKER sample 1 9.875 1.167 speaker0 +SPEAKER sample 1 10.592 4.167 speaker1 +SPEAKER sample 1 14.292 3.667 speaker0 +SPEAKER sample 1 17.992 3.550 speaker1 +SPEAKER sample 1 18.192 0.367 speaker0 +SPEAKER sample 1 21.758 6.833 speaker0 +SPEAKER sample 1 27.825 2.183 speaker1 diff --git a/tests/data/rttm/latency_4.rttm b/tests/data/rttm/latency_4.rttm new file mode 100644 index 00000000..2a73c427 --- /dev/null +++ b/tests/data/rttm/latency_4.rttm @@ -0,0 +1,10 @@ +SPEAKER sample 1 6.742 0.400 speaker0 +SPEAKER sample 1 7.625 0.650 speaker0 +SPEAKER sample 1 8.092 1.950 speaker1 +SPEAKER sample 1 9.875 1.167 speaker0 +SPEAKER sample 1 10.575 4.183 speaker1 +SPEAKER sample 1 14.308 3.667 speaker0 +SPEAKER sample 1 17.992 3.550 speaker1 +SPEAKER sample 1 18.208 0.333 speaker0 +SPEAKER sample 1 21.758 6.817 speaker0 +SPEAKER sample 1 27.808 2.200 speaker1 diff --git a/tests/data/rttm/latency_5.rttm b/tests/data/rttm/latency_5.rttm new file mode 100644 index 00000000..78b1f1e1 --- /dev/null +++ b/tests/data/rttm/latency_5.rttm @@ -0,0 +1,10 @@ +SPEAKER sample 1 6.742 0.383 speaker0 +SPEAKER sample 1 7.625 0.667 speaker0 +SPEAKER sample 1 8.092 1.967 speaker1 +SPEAKER sample 1 9.875 1.167 speaker0 +SPEAKER sample 1 10.558 4.200 speaker1 +SPEAKER sample 1 14.308 3.667 speaker0 +SPEAKER sample 1 17.992 3.550 speaker1 +SPEAKER sample 1 18.208 0.317 speaker0 +SPEAKER sample 1 21.758 6.817 speaker0 +SPEAKER sample 1 27.808 2.200 speaker1 diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py new file mode 100644 index 00000000..21d40322 --- /dev/null +++ b/tests/test_aggregation.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest +from pyannote.core import SlidingWindow, SlidingWindowFeature + +from diart.blocks.aggregation import ( + AggregationStrategy, + HammingWeightedAverageStrategy, + FirstOnlyStrategy, + AverageStrategy, + DelayedAggregation, +) + + +def test_strategy_build(): + strategy = AggregationStrategy.build("mean") + assert isinstance(strategy, AverageStrategy) + + strategy = AggregationStrategy.build("hamming") + assert isinstance(strategy, HammingWeightedAverageStrategy) + + strategy = AggregationStrategy.build("first") + assert isinstance(strategy, FirstOnlyStrategy) + + with pytest.raises(Exception): + AggregationStrategy.build("invalid") + + +def test_aggregation(): + duration = 5 + frames = 500 + step = 0.5 + speakers = 2 + start_time = 10 + resolution = duration / frames + + dagg1 = DelayedAggregation(step=step, latency=2, strategy="mean") + dagg2 = DelayedAggregation(step=step, latency=2, strategy="hamming") + dagg3 = DelayedAggregation(step=step, latency=2, strategy="first") + + for dagg in [dagg1, dagg2, dagg3]: + assert dagg.num_overlapping_windows == 4 + + buffers = [ + SlidingWindowFeature( + np.random.rand(frames, speakers), + SlidingWindow( + start=(i + start_time) * step, duration=resolution, step=resolution + ), + ) + for i in range(dagg1.num_overlapping_windows) + ] + + for dagg in [dagg1, dagg2, dagg3]: + assert dagg(buffers).data.shape == (51, 2) diff --git a/tests/test_diarization.py b/tests/test_diarization.py new file mode 100644 index 00000000..1895c26a --- /dev/null +++ b/tests/test_diarization.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +import random + +import pytest + +from diart import SpeakerDiarizationConfig, SpeakerDiarization +from utils import build_waveform_swf + + +@pytest.fixture +def random_diarization_config( + segmentation_model, embedding_model +) -> SpeakerDiarizationConfig: + duration = round(random.uniform(1, 10), 1) + step = round(random.uniform(0.1, duration), 1) + latency = round(random.uniform(step, duration), 1) + return SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=duration, + step=step, + latency=latency, + ) + + +@pytest.fixture(scope="session") +def min_latency_config(segmentation_model, embedding_model) -> SpeakerDiarizationConfig: + return SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=5, + step=0.5, + latency="min", + ) + + +@pytest.fixture(scope="session") +def max_latency_config(segmentation_model, embedding_model) -> SpeakerDiarizationConfig: + return SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=5, + step=0.5, + latency="max", + ) + + +def test_config( + segmentation_model, embedding_model, min_latency_config, max_latency_config +): + duration = round(random.uniform(1, 10), 1) + step = round(random.uniform(0.1, duration), 1) + latency = round(random.uniform(step, duration), 1) + config = SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=duration, + step=step, + latency=latency, + ) + + assert config.duration == duration + assert config.step == step + assert config.latency == latency + assert min_latency_config.latency == min_latency_config.step + assert max_latency_config.latency == max_latency_config.duration + + +def test_bad_latency(segmentation_model, embedding_model): + duration = round(random.uniform(1, 10), 1) + step = round(random.uniform(0.5, duration - 0.2), 1) + latency_too_low = round(random.uniform(0, step - 0.1), 1) + latency_too_high = round(random.uniform(duration + 0.1, 100), 1) + + config1 = SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=duration, + step=step, + latency=latency_too_low, + ) + config2 = SpeakerDiarizationConfig( + segmentation=segmentation_model, + embedding=embedding_model, + duration=duration, + step=step, + latency=latency_too_high, + ) + + with pytest.raises(AssertionError): + SpeakerDiarization(config1) + + with pytest.raises(AssertionError): + SpeakerDiarization(config2) + + +def test_pipeline_build(random_diarization_config): + pipeline = SpeakerDiarization(random_diarization_config) + + assert pipeline.get_config_class() == SpeakerDiarizationConfig + + hparams = pipeline.hyper_parameters() + hp_names = [hp.name for hp in hparams] + assert len(set(hp_names)) == 3 + + for hparam in hparams: + assert hparam.low == 0 + if hparam.name in ["tau_active", "rho_update"]: + assert hparam.high == 1 + elif hparam.name == "delta_new": + assert hparam.high == 2 + else: + assert False + + assert pipeline.config == random_diarization_config + + +def test_timestamp_shift(random_diarization_config): + pipeline = SpeakerDiarization(random_diarization_config) + + assert pipeline.timestamp_shift == 0 + + new_shift = round(random.uniform(-10, 10), 1) + pipeline.set_timestamp_shift(new_shift) + assert pipeline.timestamp_shift == new_shift + + waveform = build_waveform_swf( + random_diarization_config.duration, + random_diarization_config.sample_rate, + ) + prediction, _ = pipeline([waveform])[0] + + for segment, _, label in prediction.itertracks(yield_label=True): + assert segment.start >= new_shift + assert segment.end >= new_shift + + pipeline.reset() + assert pipeline.timestamp_shift == 0 + + +def test_call_min_latency(min_latency_config): + pipeline = SpeakerDiarization(min_latency_config) + waveform1 = build_waveform_swf( + min_latency_config.duration, + min_latency_config.sample_rate, + start_time=0, + ) + waveform2 = build_waveform_swf( + min_latency_config.duration, + min_latency_config.sample_rate, + min_latency_config.step, + ) + + batch = [waveform1, waveform2] + output = pipeline(batch) + + pred1, wave1 = output[0] + pred2, wave2 = output[1] + + assert waveform1.data.shape[0] == wave1.data.shape[0] + assert wave1.data.shape[0] > wave2.data.shape[0] + + pred1_timeline = pred1.get_timeline() + pred2_timeline = pred2.get_timeline() + pred1_duration = round(pred1_timeline[-1].end - pred1_timeline[0].start, 3) + pred2_duration = round(pred2_timeline[-1].end - pred2_timeline[0].start, 3) + + expected_duration = round(min_latency_config.duration, 3) + expected_step = round(min_latency_config.step, 3) + assert not pred1_timeline or pred1_duration <= expected_duration + assert not pred2_timeline or pred2_duration <= expected_step + + +def test_call_max_latency(max_latency_config): + pipeline = SpeakerDiarization(max_latency_config) + waveform1 = build_waveform_swf( + max_latency_config.duration, + max_latency_config.sample_rate, + start_time=0, + ) + waveform2 = build_waveform_swf( + max_latency_config.duration, + max_latency_config.sample_rate, + max_latency_config.step, + ) + + batch = [waveform1, waveform2] + output = pipeline(batch) + + pred1, wave1 = output[0] + pred2, wave2 = output[1] + + assert waveform1.data.shape[0] > wave1.data.shape[0] + assert wave1.data.shape[0] == wave2.data.shape[0] + + pred1_timeline = pred1.get_timeline() + pred2_timeline = pred2.get_timeline() + pred1_duration = pred1_timeline[-1].end - pred1_timeline[0].start + pred2_duration = pred2_timeline[-1].end - pred2_timeline[0].start + + expected_step = round(max_latency_config.step, 3) + assert not pred1_timeline or round(pred1_duration, 3) <= expected_step + assert not pred2_timeline or round(pred2_duration, 3) <= expected_step diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py new file mode 100644 index 00000000..de5ee8b5 --- /dev/null +++ b/tests/test_end_to_end.py @@ -0,0 +1,78 @@ +import math +from pathlib import Path + +import pytest +from pyannote.database.util import load_rttm + +from diart import SpeakerDiarization, SpeakerDiarizationConfig +from diart.inference import StreamingInference +from diart.models import SegmentationModel, EmbeddingModel +from diart.sources import FileAudioSource + +MODEL_DIR = Path(__file__).parent.parent / "assets" / "models" +DATA_DIR = Path(__file__).parent / "data" + + +@pytest.fixture(scope="session") +def segmentation(): + model_path = MODEL_DIR / "segmentation_uint8.onnx" + return SegmentationModel.from_pretrained(model_path) + + +@pytest.fixture(scope="session") +def embedding(): + model_path = MODEL_DIR / "embedding_uint8.onnx" + return EmbeddingModel.from_pretrained(model_path) + + +@pytest.fixture(scope="session") +def make_config(segmentation, embedding): + def _config(latency): + return SpeakerDiarizationConfig( + segmentation=segmentation, + embedding=embedding, + step=0.5, + latency=latency, + tau_active=0.507, + rho_update=0.006, + delta_new=1.057 + ) + return _config + + +@pytest.mark.parametrize("source_file", [DATA_DIR / "audio" / "sample.wav"]) +@pytest.mark.parametrize("latency", [0.5, 1, 2, 3, 4, 5]) +def test_benchmark(make_config, source_file, latency): + config = make_config(latency) + pipeline = SpeakerDiarization(config) + + padding = pipeline.config.get_file_padding(source_file) + source = FileAudioSource( + source_file, + pipeline.config.sample_rate, + padding, + pipeline.config.step, + ) + + pipeline.set_timestamp_shift(-padding[0]) + inference = StreamingInference( + pipeline, + source, + do_profile=False, + do_plot=False, + show_progress=False + ) + + pred = inference() + + expected_file = (DATA_DIR / "rttm" / f"latency_{latency}.rttm") + expected = load_rttm(expected_file).popitem()[1] + + assert len(pred) == len(expected) + for track1, track2 in zip(pred.itertracks(yield_label=True), expected.itertracks(yield_label=True)): + pred_segment, _, pred_spk = track1 + expected_segment, _, expected_spk = track2 + # We can tolerate a difference of up to 50ms + assert math.isclose(pred_segment.start, expected_segment.start, abs_tol=0.05) + assert math.isclose(pred_segment.end, expected_segment.end, abs_tol=0.05) + assert pred_spk == expected_spk diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..e8ae41c2 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations +import random +import numpy as np +from pyannote.core import SlidingWindowFeature, SlidingWindow + + +def build_waveform_swf( + duration: float, sample_rate: int, start_time: float | None = None +) -> SlidingWindowFeature: + start_time = round(random.uniform(0, 600), 1) if start_time is None else start_time + chunk_size = int(duration * sample_rate) + resolution = duration / chunk_size + samples = np.random.randn(chunk_size, 1) + sliding_window = SlidingWindow( + start=start_time, step=resolution, duration=resolution + ) + return SlidingWindowFeature(samples, sliding_window)