5
5
from typing import Optional , Text , Union , Callable , List
6
6
7
7
import numpy as np
8
- import onnxruntime
9
8
import torch
10
9
import torch .nn as nn
11
10
from requests import HTTPError
15
14
from pyannote .audio .pipelines .speaker_verification import PretrainedSpeakerEmbedding
16
15
from pyannote .audio .utils .powerset import Powerset
17
16
18
- _has_pyannote = True
17
+ IS_PYANNOTE_AVAILABLE = True
19
18
except ImportError :
20
- _has_pyannote = False
19
+ IS_PYANNOTE_AVAILABLE = False
20
+
21
+ try :
22
+ import onnxruntime as ort
23
+
24
+ IS_ONNX_AVAILABLE = True
25
+ except ImportError :
26
+ IS_ONNX_AVAILABLE = False
21
27
22
28
23
29
class PowersetAdapter (nn .Module ):
@@ -88,11 +94,9 @@ def execution_provider(self) -> str:
88
94
return f"{ device } ExecutionProvider"
89
95
90
96
def recreate_session (self ):
91
- options = onnxruntime .SessionOptions ()
92
- options .graph_optimization_level = (
93
- onnxruntime .GraphOptimizationLevel .ORT_ENABLE_ALL
94
- )
95
- self .session = onnxruntime .InferenceSession (
97
+ options = ort .SessionOptions ()
98
+ options .graph_optimization_level = ort .GraphOptimizationLevel .ORT_ENABLE_ALL
99
+ self .session = ort .InferenceSession (
96
100
self .path ,
97
101
sess_options = options ,
98
102
providers = [self .execution_provider ],
@@ -168,7 +172,7 @@ def from_pyannote(
168
172
-------
169
173
wrapper: SegmentationModel
170
174
"""
171
- assert _has_pyannote , "No pyannote.audio installation found"
175
+ assert IS_PYANNOTE_AVAILABLE , "No pyannote.audio installation found"
172
176
return SegmentationModel (PyannoteLoader (model , use_hf_token ))
173
177
174
178
@staticmethod
@@ -177,6 +181,7 @@ def from_onnx(
177
181
input_name : str = "waveform" ,
178
182
output_name : str = "segmentation" ,
179
183
) -> "SegmentationModel" :
184
+ assert IS_ONNX_AVAILABLE , "No ONNX installation found"
180
185
return SegmentationModel (ONNXLoader (model_path , [input_name ], output_name ))
181
186
182
187
@staticmethod
@@ -224,7 +229,7 @@ def from_pyannote(
224
229
-------
225
230
wrapper: EmbeddingModel
226
231
"""
227
- assert _has_pyannote , "No pyannote.audio installation found"
232
+ assert IS_PYANNOTE_AVAILABLE , "No pyannote.audio installation found"
228
233
loader = PyannoteLoader (model , use_hf_token )
229
234
return EmbeddingModel (loader )
230
235
@@ -234,6 +239,7 @@ def from_onnx(
234
239
input_names : List [str ] | None = None ,
235
240
output_name : str = "embedding" ,
236
241
) -> "EmbeddingModel" :
242
+ assert IS_ONNX_AVAILABLE , "No ONNX installation found"
237
243
input_names = input_names or ["waveform" , "weights" ]
238
244
loader = ONNXLoader (model_path , input_names , output_name )
239
245
return EmbeddingModel (loader )
0 commit comments