Skip to content

Commit 9b3f351

Browse files
committed
mv AsyncLLMEngine init to CosyVoice2
1 parent 00b454c commit 9b3f351

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

cosyvoice/cli/cosyvoice.py

+22
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,29 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vl
166166
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
167167
if use_vllm:
168168
try:
169+
os.environ["VLLM_USE_V1"] = '1'
170+
from vllm import AsyncLLMEngine
171+
from vllm.engine.arg_utils import AsyncEngineArgs
172+
# EngineArgs
173+
ENGINE_ARGS = {
174+
"block_size": 16,
175+
"swap_space": 0,
176+
# "enforce_eager": True,
177+
"gpu_memory_utilization": 0.4,
178+
"max_num_batched_tokens": 1024,
179+
"max_model_len": 1024,
180+
"max_num_seqs": 256,
181+
"disable_log_requests": True,
182+
"disable_log_stats": True,
183+
"dtype": "bfloat16"
184+
}
169185
self.model = VllmCosyVoice2Model(model_dir, configs['flow'], configs['hift'], fp16)
186+
engine_args = AsyncEngineArgs(
187+
model=model_dir,
188+
**ENGINE_ARGS,
189+
)
190+
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
191+
self.model.llm_engine = self.llm_engine
170192
except Exception as e:
171193
logging.warning(f'use vllm inference failed. \n{e}')
172194
raise e

cosyvoice/llm/llm_vllm.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,6 @@
3131
from cosyvoice.llm.vllm_use_cosyvoice2_model import CosyVoice2Model as CosyVoice2LLM
3232
ModelRegistry.register_model("CosyVoice2Model", CosyVoice2LLM)
3333

34-
# EngineArgs
35-
ENGINE_ARGS = {
36-
"block_size": 16,
37-
"swap_space": 0,
38-
# "enforce_eager": True,
39-
"gpu_memory_utilization": 0.4,
40-
"max_num_batched_tokens": 1024,
41-
"max_model_len": 1024,
42-
"max_num_seqs": 256,
43-
"disable_log_requests": True,
44-
"disable_log_stats": True,
45-
"dtype": "float16"
46-
}
47-
4834
from vllm.sampling_params import RequestOutputKind
4935
# SamplingParams
5036
SAMPLING_PARAMS = {
@@ -72,13 +58,7 @@ def __init__(
7258
self.fp16 = False
7359
self.half = lambda: None
7460
self.mix_ratio = mix_ratio
75-
# ---------------------------------------------
76-
# vllm engine 的参数配置
77-
engine_args = AsyncEngineArgs(
78-
model=model_dir,
79-
**ENGINE_ARGS,
80-
)
81-
self.llm_engine: AsyncLLMEngine = AsyncLLMEngine.from_engine_args(engine_args)
61+
self.llm_engine = None
8262

8363
self.speech_token_size = 6564 # 6561 + 3
8464
self.llm_token_size = 151936 # llm vocab_size

0 commit comments

Comments
 (0)