Skip to content

Commit 21ddaec

Browse files
authored
Merge pull request #495 from FunAudioLLM/dev/lyuxiang.lx
fix bug
2 parents de76577 + 7e6d60c commit 21ddaec

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

cosyvoice/bin/inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def main():
9999
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
100100
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
101101
tts_speeches = []
102-
for model_output in model.inference(**model_input):
102+
for model_output in model.tts(**model_input):
103103
tts_speeches.append(model_output['tts_speech'])
104104
tts_speeches = torch.concat(tts_speeches, dim=1)
105105
tts_key = '{}_{}'.format(utts[0], tts_index[0])

cosyvoice/cli/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ def __init__(self,
5656
self.hift_cache_dict = {}
5757

5858
def load(self, llm_model, flow_model, hift_model):
59-
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
59+
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
6060
self.llm.to(self.device).eval()
6161
if self.fp16 is True:
6262
self.llm.half()
63-
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
63+
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
6464
self.flow.to(self.device).eval()
6565
# in case hift_model is a hifigan model
66-
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device)}
66+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
6767
self.hift.load_state_dict(hift_state_dict, strict=False)
6868
self.hift.to(self.device).eval()
6969

0 commit comments

Comments
 (0)