Skip to content

Commit 7921181

Browse files
committed
Fix inconsistent indentation in flow_matching.py
Aligned code indentation for improved readability and maintainability. Adjusted spacing issues in method signatures and data pointer lists to ensure uniform formatting across the file.
1 parent ab02a5a commit 7921181

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

Diff for: cosyvoice/flow/flow_matching.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import queue
1415
import threading
1516
import torch
1617
import torch.nn.functional as F
1718
from matcha.models.components.flow_matching import BASECFM
18-
import queue
19+
1920

2021
class EstimatorWrapper:
21-
def __init__(self, estimator_engine, estimator_count=2,):
22+
def __init__(self, estimator_engine, estimator_count=2, ):
2223
self.estimators = queue.Queue()
2324
self.estimator_engine = estimator_engine
2425
for _ in range(estimator_count):
@@ -36,6 +37,7 @@ def release_estimator(self, estimator):
3637
self.estimators.put(estimator)
3738
return
3839

40+
3941
class ConditionalCFM(BASECFM):
4042
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
4143
super().__init__(
@@ -53,7 +55,8 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
5355
self.lock = threading.Lock()
5456

5557
@torch.inference_mode()
56-
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
58+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0,
59+
flow_cache=torch.zeros(1, 80, 0, 2)):
5760
"""Forward diffusion
5861
5962
Args:
@@ -155,12 +158,12 @@ def forward_estimator(self, x, mask, mu, t, spks, cond):
155158
estimator.set_input_shape('cond', (2, 80, x.size(2)))
156159

157160
data_ptrs = [x.contiguous().data_ptr(),
158-
mask.contiguous().data_ptr(),
159-
mu.contiguous().data_ptr(),
160-
t.contiguous().data_ptr(),
161-
spks.contiguous().data_ptr(),
162-
cond.contiguous().data_ptr(),
163-
x.data_ptr()]
161+
mask.contiguous().data_ptr(),
162+
mu.contiguous().data_ptr(),
163+
t.contiguous().data_ptr(),
164+
spks.contiguous().data_ptr(),
165+
cond.contiguous().data_ptr(),
166+
x.data_ptr()]
164167

165168
for idx, data_ptr in enumerate(data_ptrs):
166169
estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr)
@@ -181,12 +184,12 @@ def forward_estimator(self, x, mask, mu, t, spks, cond):
181184
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
182185
# run trt engine
183186
self.estimator.execute_v2([x.contiguous().data_ptr(),
184-
mask.contiguous().data_ptr(),
185-
mu.contiguous().data_ptr(),
186-
t.contiguous().data_ptr(),
187-
spks.contiguous().data_ptr(),
188-
cond.contiguous().data_ptr(),
189-
x.data_ptr()])
187+
mask.contiguous().data_ptr(),
188+
mu.contiguous().data_ptr(),
189+
t.contiguous().data_ptr(),
190+
spks.contiguous().data_ptr(),
191+
cond.contiguous().data_ptr(),
192+
x.data_ptr()])
190193
return x
191194

192195
def compute_loss(self, x1, mask, mu, spks=None, cond=None):

0 commit comments

Comments
 (0)