11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import queue
14
15
import threading
15
16
import torch
16
17
import torch .nn .functional as F
17
18
from matcha .models .components .flow_matching import BASECFM
18
- import queue
19
+
19
20
20
21
class EstimatorWrapper :
21
- def __init__ (self , estimator_engine , estimator_count = 2 ,):
22
+ def __init__ (self , estimator_engine , estimator_count = 2 , ):
22
23
self .estimators = queue .Queue ()
23
24
self .estimator_engine = estimator_engine
24
25
for _ in range (estimator_count ):
@@ -36,6 +37,7 @@ def release_estimator(self, estimator):
36
37
self .estimators .put (estimator )
37
38
return
38
39
40
+
39
41
class ConditionalCFM (BASECFM ):
40
42
def __init__ (self , in_channels , cfm_params , n_spks = 1 , spk_emb_dim = 64 , estimator : torch .nn .Module = None ):
41
43
super ().__init__ (
@@ -53,7 +55,8 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
53
55
self .lock = threading .Lock ()
54
56
55
57
@torch .inference_mode ()
56
- def forward (self , mu , mask , n_timesteps , temperature = 1.0 , spks = None , cond = None , prompt_len = 0 , flow_cache = torch .zeros (1 , 80 , 0 , 2 )):
58
+ def forward (self , mu , mask , n_timesteps , temperature = 1.0 , spks = None , cond = None , prompt_len = 0 ,
59
+ flow_cache = torch .zeros (1 , 80 , 0 , 2 )):
57
60
"""Forward diffusion
58
61
59
62
Args:
@@ -155,12 +158,12 @@ def forward_estimator(self, x, mask, mu, t, spks, cond):
155
158
estimator .set_input_shape ('cond' , (2 , 80 , x .size (2 )))
156
159
157
160
data_ptrs = [x .contiguous ().data_ptr (),
158
- mask .contiguous ().data_ptr (),
159
- mu .contiguous ().data_ptr (),
160
- t .contiguous ().data_ptr (),
161
- spks .contiguous ().data_ptr (),
162
- cond .contiguous ().data_ptr (),
163
- x .data_ptr ()]
161
+ mask .contiguous ().data_ptr (),
162
+ mu .contiguous ().data_ptr (),
163
+ t .contiguous ().data_ptr (),
164
+ spks .contiguous ().data_ptr (),
165
+ cond .contiguous ().data_ptr (),
166
+ x .data_ptr ()]
164
167
165
168
for idx , data_ptr in enumerate (data_ptrs ):
166
169
estimator .set_tensor_address (engine .get_tensor_name (idx ), data_ptr )
@@ -181,12 +184,12 @@ def forward_estimator(self, x, mask, mu, t, spks, cond):
181
184
self .estimator .set_input_shape ('cond' , (2 , 80 , x .size (2 )))
182
185
# run trt engine
183
186
self .estimator .execute_v2 ([x .contiguous ().data_ptr (),
184
- mask .contiguous ().data_ptr (),
185
- mu .contiguous ().data_ptr (),
186
- t .contiguous ().data_ptr (),
187
- spks .contiguous ().data_ptr (),
188
- cond .contiguous ().data_ptr (),
189
- x .data_ptr ()])
187
+ mask .contiguous ().data_ptr (),
188
+ mu .contiguous ().data_ptr (),
189
+ t .contiguous ().data_ptr (),
190
+ spks .contiguous ().data_ptr (),
191
+ cond .contiguous ().data_ptr (),
192
+ x .data_ptr ()])
190
193
return x
191
194
192
195
def compute_loss (self , x1 , mask , mu , spks = None , cond = None ):
0 commit comments