Skip to content

Commit 8e9f74c

Browse files
authored
Make ONNX runtime optional (#215)
1 parent 8cad376 commit 8e9f74c

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,3 @@ optuna>=2.10
1818
websocket-server>=0.6.4
1919
websocket-client>=0.58.0
2020
rich>=12.5.1
21-
onnxruntime-gpu>=1.16.1

setup.cfg

-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ install_requires=
4040
websocket-server>=0.6.4
4141
websocket-client>=0.58.0
4242
rich>=12.5.1
43-
onnxruntime-gpu>=1.16.1
4443

4544
[options.packages.find]
4645
where=src

src/diart/models.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Optional, Text, Union, Callable, List
66

77
import numpy as np
8-
import onnxruntime
98
import torch
109
import torch.nn as nn
1110
from requests import HTTPError
@@ -15,9 +14,16 @@
1514
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
1615
from pyannote.audio.utils.powerset import Powerset
1716

18-
_has_pyannote = True
17+
IS_PYANNOTE_AVAILABLE = True
1918
except ImportError:
20-
_has_pyannote = False
19+
IS_PYANNOTE_AVAILABLE = False
20+
21+
try:
22+
import onnxruntime as ort
23+
24+
IS_ONNX_AVAILABLE = True
25+
except ImportError:
26+
IS_ONNX_AVAILABLE = False
2127

2228

2329
class PowersetAdapter(nn.Module):
@@ -88,11 +94,9 @@ def execution_provider(self) -> str:
8894
return f"{device}ExecutionProvider"
8995

9096
def recreate_session(self):
91-
options = onnxruntime.SessionOptions()
92-
options.graph_optimization_level = (
93-
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
94-
)
95-
self.session = onnxruntime.InferenceSession(
97+
options = ort.SessionOptions()
98+
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
99+
self.session = ort.InferenceSession(
96100
self.path,
97101
sess_options=options,
98102
providers=[self.execution_provider],
@@ -168,7 +172,7 @@ def from_pyannote(
168172
-------
169173
wrapper: SegmentationModel
170174
"""
171-
assert _has_pyannote, "No pyannote.audio installation found"
175+
assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found"
172176
return SegmentationModel(PyannoteLoader(model, use_hf_token))
173177

174178
@staticmethod
@@ -177,6 +181,7 @@ def from_onnx(
177181
input_name: str = "waveform",
178182
output_name: str = "segmentation",
179183
) -> "SegmentationModel":
184+
assert IS_ONNX_AVAILABLE, "No ONNX installation found"
180185
return SegmentationModel(ONNXLoader(model_path, [input_name], output_name))
181186

182187
@staticmethod
@@ -224,7 +229,7 @@ def from_pyannote(
224229
-------
225230
wrapper: EmbeddingModel
226231
"""
227-
assert _has_pyannote, "No pyannote.audio installation found"
232+
assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found"
228233
loader = PyannoteLoader(model, use_hf_token)
229234
return EmbeddingModel(loader)
230235

@@ -234,6 +239,7 @@ def from_onnx(
234239
input_names: List[str] | None = None,
235240
output_name: str = "embedding",
236241
) -> "EmbeddingModel":
242+
assert IS_ONNX_AVAILABLE, "No ONNX installation found"
237243
input_names = input_names or ["waveform", "weights"]
238244
loader = ONNXLoader(model_path, input_names, output_name)
239245
return EmbeddingModel(loader)

0 commit comments

Comments
 (0)