Skip to content

Commit b51c1c1

Browse files
authored
fix bugs (#2207)
1 parent 2849418 commit b51c1c1

File tree

5 files changed

+32
-55
lines changed

5 files changed

+32
-55
lines changed

swift/llm/deploy.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
275275
request_id = request_info['request_id']
276276

277277
kwargs = {'max_tokens': request.max_tokens}
278-
for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
278+
for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty']:
279279
kwargs[key] = getattr(request, key)
280280
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
281281
new_value = getattr(request, key)
@@ -292,9 +292,6 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
292292
kwargs['logprobs'] = max(1, request.top_logprobs)
293293

294294
generation_config = VllmGenerationConfig(**kwargs)
295-
if generation_config.use_beam_search and request.stream:
296-
error_msg = 'Streaming generation does not support beam search.'
297-
raise ValueError(error_msg)
298295
tokenizer = template.tokenizer
299296
if tokenizer.eos_token is not None and tokenizer.eos_token not in generation_config.stop:
300297
generation_config.stop.append(tokenizer.eos_token)

swift/llm/export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,18 +255,18 @@ def llm_export(args: ExportArguments) -> None:
255255
if args.quant_method == 'awq':
256256
from awq import AutoAWQForCausalLM
257257
model, template = prepare_model_template(
258-
args, device_map=args.quant_device_map, verbose=False, automodel_class=AutoAWQForCausalLM)
258+
args, device_map=args.quant_device_map, task='export', automodel_class=AutoAWQForCausalLM)
259259
awq_model_quantize(model, template.tokenizer, args.quant_batch_size)
260260
model.save_quantized(args.quant_output_dir)
261261
elif args.quant_method == 'gptq':
262-
model, template = prepare_model_template(args, device_map=args.quant_device_map, verbose=False)
262+
model, template = prepare_model_template(args, device_map=args.quant_device_map, task='export')
263263
gptq_quantizer = gptq_model_quantize(model, template.tokenizer, args.quant_batch_size)
264264
model.config.quantization_config.pop('dataset', None)
265265
gptq_quantizer.save(model, args.quant_output_dir)
266266
elif args.quant_method == 'bnb':
267267
args.quantization_bit = args.quant_bits
268268
args.bnb_4bit_compute_dtype, args.load_in_4bit, args.load_in_8bit = args.select_bnb()
269-
model, template = prepare_model_template(args, device_map=args.quant_device_map, verbose=False)
269+
model, template = prepare_model_template(args, device_map=args.quant_device_map, task='export')
270270
model.save_pretrained(args.quant_output_dir)
271271
else:
272272
raise ValueError(f'args.quant_method: {args.quant_method}')

swift/llm/infer.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def merge_lora(args: InferArguments,
109109
if device_map is None:
110110
device_map = args.merge_device_map
111111
logger.info(f'merge_device_map: {device_map}')
112-
model, template = prepare_model_template(args, device_map=device_map, verbose=False)
112+
model, template = prepare_model_template(args, device_map=device_map, task='export')
113113
logger.info('Merge LoRA...')
114114
Swift.merge_and_unload(model)
115115
model = model.model
@@ -133,7 +133,7 @@ def merge_lora(args: InferArguments,
133133
def prepare_model_template(args: InferArguments,
134134
*,
135135
device_map: Optional[str] = None,
136-
verbose: bool = True,
136+
task: Literal['infer', 'export'] = 'infer',
137137
automodel_class=None) -> Tuple[PreTrainedModel, Template]:
138138
from .sft import get_default_device_map
139139
if is_torch_npu_available():
@@ -188,25 +188,7 @@ def prepare_model_template(args: InferArguments,
188188
revision=args.model_revision,
189189
quant_method=args.quant_method,
190190
**kwargs)
191-
if verbose:
192-
logger.info(f'model_config: {model.config}')
193-
194-
generation_config = GenerationConfig(
195-
max_new_tokens=args.max_new_tokens,
196-
temperature=args.temperature,
197-
top_k=args.top_k,
198-
top_p=args.top_p,
199-
do_sample=args.do_sample,
200-
repetition_penalty=args.repetition_penalty,
201-
num_beams=args.num_beams,
202-
pad_token_id=tokenizer.pad_token_id,
203-
eos_token_id=tokenizer.eos_token_id)
204-
set_generation_config(model, generation_config)
205-
logger.info(f'model.generation_config: {model.generation_config}')
206191

207-
if model.generation_config.num_beams != 1:
208-
args.stream = False
209-
logger.info('Setting args.stream: False')
210192
if model.max_model_len is None:
211193
model.max_model_len = args.max_model_len
212194
elif args.max_model_len is not None:
@@ -215,6 +197,26 @@ def prepare_model_template(args: InferArguments,
215197
else:
216198
raise ValueError('args.max_model_len exceeds the maximum max_model_len supported by the model.'
217199
f'args.max_model_len: {args.max_model_len}, model.max_model_len: {model.max_model_len}')
200+
if task == 'infer':
201+
logger.info(f'model_config: {model.config}')
202+
generation_config = GenerationConfig(
203+
max_new_tokens=args.max_new_tokens,
204+
temperature=args.temperature,
205+
top_k=args.top_k,
206+
top_p=args.top_p,
207+
do_sample=args.do_sample,
208+
repetition_penalty=args.repetition_penalty,
209+
num_beams=args.num_beams,
210+
pad_token_id=tokenizer.pad_token_id,
211+
eos_token_id=tokenizer.eos_token_id)
212+
model._generation_config_origin = model.generation_config
213+
set_generation_config(model, generation_config)
214+
logger.info(f'model.generation_config: {model.generation_config}')
215+
216+
if model.generation_config.num_beams != 1:
217+
args.stream = False
218+
logger.info('Setting args.stream: False')
219+
218220
# Preparing LoRA
219221
if is_adapter(args.sft_type) and args.ckpt_dir is not None:
220222
if isinstance(args, DeployArguments) and args.lora_request_list is not None:
@@ -227,7 +229,7 @@ def prepare_model_template(args: InferArguments,
227229
model = model.to(model.dtype)
228230
model.requires_grad_(False)
229231

230-
if verbose:
232+
if task == 'infer':
231233
show_layers(model)
232234
logger.info(model)
233235
logger.info(get_model_info(model))

swift/llm/utils/template.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2028,6 +2028,10 @@ def _post_encode(self, model, data: Any) -> Dict[str, Any]:
20282028
res['labels'] = labels[0]
20292029
return res
20302030

2031+
@staticmethod
2032+
def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
2033+
return generate_ids
2034+
20312035

20322036
register_template(TemplateType.llama3_1_omni, Llama3_1OmniTemplate(), lazy_tokenize=True)
20332037

@@ -2642,7 +2646,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
26422646
videos_path = example.get('videos') or []
26432647
if len(videos_path) > 0:
26442648
video_processor = self.tokenizer.processor.video_processor
2645-
video_inputs = video_processor(videos, return_tensors='pt').to(self.model.dtype)
2649+
video_inputs = video_processor(videos_path, return_tensors='pt').to(self.model.dtype)
26462650
inputs['pixel_values_videos'] = video_inputs['pixel_values_videos']
26472651
if len(images) > 0:
26482652
image_processor = self.tokenizer.processor.image_processor

swift/llm/utils/vllm_utils.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def __init__(
204204
top_k: int = 50, # -1: all
205205
top_p: float = 1.,
206206
repetition_penalty: float = 1.,
207-
num_beams: int = 1,
208207
*,
209208
n: int = 1,
210209
logprobs: Optional[int] = None,
@@ -218,12 +217,6 @@ def __init__(
218217
max_new_tokens = kwargs.pop('max_new_tokens', None)
219218
if max_new_tokens is not None:
220219
max_tokens = max_new_tokens
221-
if num_beams > 1:
222-
top_k = -1
223-
top_p = 1
224-
temperature = 0
225-
logger.warning('The output of num_beams in vllm may not be consistent with '
226-
'the output of num_beams in transformers.')
227220
if top_k == 0:
228221
top_k = -1
229222
if stop is None:
@@ -233,11 +226,6 @@ def __init__(
233226
kwargs['top_k'] = top_k
234227
kwargs['top_p'] = top_p
235228
kwargs['repetition_penalty'] = repetition_penalty
236-
if num_beams > 1:
237-
best_of = kwargs.get('best_of')
238-
assert 'use_beam_search' not in kwargs and best_of is None
239-
kwargs['use_beam_search'] = True
240-
kwargs['best_of'] = num_beams
241229
kwargs['n'] = n
242230
kwargs['logprobs'] = logprobs
243231
kwargs['seed'] = seed
@@ -260,7 +248,6 @@ class VllmGenerationConfig(_VllmGenerationConfigMixin, SamplingParams):
260248
top_k: int = 50 # -1: all
261249
top_p: float = 1.
262250
repetition_penalty: float = 1.
263-
num_beams: int = 1
264251
n: int = 1
265252
logprobs: Optional[int] = None
266253
seed: Optional[int] = None
@@ -269,15 +256,6 @@ class VllmGenerationConfig(_VllmGenerationConfigMixin, SamplingParams):
269256
skip_special_tokens: bool = False
270257

271258
def __post_init__(self):
272-
if self.num_beams > 1:
273-
self.top_k = -1
274-
self.top_p = 1
275-
self.temperature = 0
276-
logger.warning('The output of num_beams in vllm may not be consistent with '
277-
'the output of num_beams in transformers.')
278-
assert self.best_of is None
279-
self.use_beam_search = True
280-
self.best_of = self.num_beams
281259
if self.top_k == 0:
282260
self.top_k = -1
283261
if self.stop is None:
@@ -435,10 +413,6 @@ def inference_stream_vllm(
435413
use_tqdm=use_tqdm,
436414
**kwargs)
437415

438-
if generation_config.use_beam_search:
439-
error_msg = 'Streaming generation does not support beam search.'
440-
raise ValueError(error_msg)
441-
442416
n_finished = 0
443417
n_steps = 0
444418
if flush_steps is None:

0 commit comments

Comments
 (0)