Skip to content

Commit 9695074

Browse files
committed
Revert "mv AsyncLLMEngine init to CosyVoice2"
This reverts commit 9b3f351.
1 parent 9b3f351 commit 9695074

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

cosyvoice/cli/cosyvoice.py

-22
Original file line numberDiff line numberDiff line change
@@ -166,29 +166,7 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vl
166166
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
167167
if use_vllm:
168168
try:
169-
os.environ["VLLM_USE_V1"] = '1'
170-
from vllm import AsyncLLMEngine
171-
from vllm.engine.arg_utils import AsyncEngineArgs
172-
# EngineArgs
173-
ENGINE_ARGS = {
174-
"block_size": 16,
175-
"swap_space": 0,
176-
# "enforce_eager": True,
177-
"gpu_memory_utilization": 0.4,
178-
"max_num_batched_tokens": 1024,
179-
"max_model_len": 1024,
180-
"max_num_seqs": 256,
181-
"disable_log_requests": True,
182-
"disable_log_stats": True,
183-
"dtype": "bfloat16"
184-
}
185169
self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16)
186-
engine_args = AsyncEngineArgs(
187-
model=model_dir,
188-
**ENGINE_ARGS,
189-
)
190-
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
191-
self.model.llm_engine = self.llm_engine
192170
except Exception as e:
193171
logging.warning(f'use vllm inference failed. \n{e}')
194172
raise e

cosyvoice/llm/llm_vllm.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@
3131
from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
3232
ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
3333

34+
# EngineArgs
35+
ENGINE_ARGS = {
36+
"block_size": 16,
37+
"swap_space": 0,
38+
# "enforce_eager": True,
39+
"gpu_memory_utilization": 0.4,
40+
"max_num_batched_tokens": 1024,
41+
"max_model_len": 1024,
42+
"max_num_seqs": 256,
43+
"disable_log_requests": True,
44+
"disable_log_stats": True,
45+
"dtype": "float16"
46+
}
47+
3448
from vllm.sampling_params import RequestOutputKind
3549
# SamplingParams
3650
SAMPLING_PARAMS = {
@@ -58,7 +72,13 @@ def __init__(
5872
self.fp16 = False
5973
self.half = lambda: None
6074
self.mix_ratio = mix_ratio
61-
self.llm_engine = None
75+
# ---------------------------------------------
76+
# vllm engine 的参数配置
77+
engine_args = AsyncEngineArgs(
78+
model=model_dir,
79+
**ENGINE_ARGS,
80+
)
81+
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
6282

6383
self.speech_token_size = 6564 # 6561 + 3
6484
self.llm_token_size = 151936 # llm vocab_size

0 commit comments

Comments
 (0)