@@ -49,6 +49,7 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False):
49
49
self .model .load ('{}/llm.pt' .format (model_dir ),
50
50
'{}/flow.pt' .format (model_dir ),
51
51
'{}/hift.pt' .format (model_dir ))
52
+ self .vllm_codec_engine = None
52
53
if load_jit :
53
54
self .model .load_jit ('{}/llm.text_encoder.{}.zip' .format (model_dir , 'fp16' if self .fp16 is True else 'fp32' ),
54
55
'{}/llm.llm.{}.zip' .format (model_dir , 'fp16' if self .fp16 is True else 'fp32' ),
@@ -149,8 +150,16 @@ def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, use_vl
149
150
self .model .load ('{}/llm.pt' .format (model_dir ),
150
151
'{}/flow.pt' .format (model_dir ),
151
152
'{}/hift.pt' .format (model_dir ))
153
+ self .vllm_codec_engine = None
152
154
if use_vllm :
155
+ from vllm import EngineArgs , LLMEngine
153
156
self .model .export_codec_vllm ('' .join ([model_dir , '/codec_vllm_model' ]))
157
+ engine_args = EngineArgs (model = '' .join ([model_dir , '/codec_vllm_model' ]),
158
+ skip_tokenizer_init = True ,
159
+ gpu_memory_utilization = 0.1 )
160
+ self .vllm_codec_engine = LLMEngine .from_engine_args (engine_args )
161
+ self .model .llm .vllm_codec_engine = self .vllm_codec_engine
162
+
154
163
if load_jit :
155
164
self .model .load_jit ('{}/flow.encoder.{}.zip' .format (model_dir , 'fp16' if self .fp16 is True else 'fp32' ))
156
165
if load_trt :
0 commit comments