Skip to content

fix bugs #2207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions swift/llm/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
request_id = request_info['request_id']

kwargs = {'max_tokens': request.max_tokens}
for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty', 'num_beams']:
for key in ['n', 'best_of', 'frequency_penalty', 'length_penalty', 'presence_penalty']:
kwargs[key] = getattr(request, key)
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
new_value = getattr(request, key)
Expand All @@ -292,9 +292,6 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest, CompletionR
kwargs['logprobs'] = max(1, request.top_logprobs)

generation_config = VllmGenerationConfig(**kwargs)
if generation_config.use_beam_search and request.stream:
error_msg = 'Streaming generation does not support beam search.'
raise ValueError(error_msg)
tokenizer = template.tokenizer
if tokenizer.eos_token is not None and tokenizer.eos_token not in generation_config.stop:
generation_config.stop.append(tokenizer.eos_token)
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,18 +255,18 @@ def llm_export(args: ExportArguments) -> None:
if args.quant_method == 'awq':
from awq import AutoAWQForCausalLM
model, template = prepare_model_template(
args, device_map=args.quant_device_map, verbose=False, automodel_class=AutoAWQForCausalLM)
args, device_map=args.quant_device_map, task='export', automodel_class=AutoAWQForCausalLM)
awq_model_quantize(model, template.tokenizer, args.quant_batch_size)
model.save_quantized(args.quant_output_dir)
elif args.quant_method == 'gptq':
model, template = prepare_model_template(args, device_map=args.quant_device_map, verbose=False)
model, template = prepare_model_template(args, device_map=args.quant_device_map, task='export')
gptq_quantizer = gptq_model_quantize(model, template.tokenizer, args.quant_batch_size)
model.config.quantization_config.pop('dataset', None)
gptq_quantizer.save(model, args.quant_output_dir)
elif args.quant_method == 'bnb':
args.quantization_bit = args.quant_bits
args.bnb_4bit_compute_dtype, args.load_in_4bit, args.load_in_8bit = args.select_bnb()
model, template = prepare_model_template(args, device_map=args.quant_device_map, verbose=False)
model, template = prepare_model_template(args, device_map=args.quant_device_map, task='export')
model.save_pretrained(args.quant_output_dir)
else:
raise ValueError(f'args.quant_method: {args.quant_method}')
Expand Down
53 changes: 28 additions & 25 deletions swift/llm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def merge_lora(args: InferArguments,
if device_map is None:
device_map = args.merge_device_map
logger.info(f'merge_device_map: {device_map}')
model, template = prepare_model_template(args, device_map=device_map, verbose=False)
model, template = prepare_model_template(args, device_map=device_map, task='export')
logger.info('Merge LoRA...')
Swift.merge_and_unload(model)
model = model.model
Expand All @@ -130,11 +130,12 @@ def merge_lora(args: InferArguments,
return merged_lora_path


def prepare_model_template(args: InferArguments,
*,
device_map: Optional[str] = None,
verbose: bool = True,
automodel_class=None) -> Tuple[PreTrainedModel, Template]:
def prepare_model_template(
args: InferArguments,
*,
device_map: Optional[str] = None,
task: Literal['infer', 'export'] = 'infer', # for inference or export
automodel_class=None) -> Tuple[PreTrainedModel, Template]:
from .sft import get_default_device_map
if is_torch_npu_available():
print(f'device_count: {torch.npu.device_count()}')
Expand Down Expand Up @@ -188,25 +189,7 @@ def prepare_model_template(args: InferArguments,
revision=args.model_revision,
quant_method=args.quant_method,
**kwargs)
if verbose:
logger.info(f'model_config: {model.config}')

generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
do_sample=args.do_sample,
repetition_penalty=args.repetition_penalty,
num_beams=args.num_beams,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id)
set_generation_config(model, generation_config)
logger.info(f'model.generation_config: {model.generation_config}')

if model.generation_config.num_beams != 1:
args.stream = False
logger.info('Setting args.stream: False')
if model.max_model_len is None:
model.max_model_len = args.max_model_len
elif args.max_model_len is not None:
Expand All @@ -215,6 +198,26 @@ def prepare_model_template(args: InferArguments,
else:
raise ValueError('args.max_model_len exceeds the maximum max_model_len supported by the model.'
f'args.max_model_len: {args.max_model_len}, model.max_model_len: {model.max_model_len}')
if task == 'infer':
logger.info(f'model_config: {model.config}')
generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
do_sample=args.do_sample,
repetition_penalty=args.repetition_penalty,
num_beams=args.num_beams,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id)
model._generation_config_origin = model.generation_config
set_generation_config(model, generation_config)
logger.info(f'model.generation_config: {model.generation_config}')

if model.generation_config.num_beams != 1:
args.stream = False
logger.info('Setting args.stream: False')

# Preparing LoRA
if is_adapter(args.sft_type) and args.ckpt_dir is not None:
if isinstance(args, DeployArguments) and args.lora_request_list is not None:
Expand All @@ -227,7 +230,7 @@ def prepare_model_template(args: InferArguments,
model = model.to(model.dtype)
model.requires_grad_(False)

if verbose:
if task == 'infer':
show_layers(model)
logger.info(model)
logger.info(get_model_info(model))
Expand Down
4 changes: 4 additions & 0 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,6 +2028,10 @@ def _post_encode(self, model, data: Any) -> Dict[str, Any]:
res['labels'] = labels[0]
return res

@staticmethod
def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
return generate_ids


register_template(TemplateType.llama3_1_omni, Llama3_1OmniTemplate(), lazy_tokenize=True)

Expand Down
26 changes: 0 additions & 26 deletions swift/llm/utils/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def __init__(
top_k: int = 50, # -1: all
top_p: float = 1.,
repetition_penalty: float = 1.,
num_beams: int = 1,
*,
n: int = 1,
logprobs: Optional[int] = None,
Expand All @@ -218,12 +217,6 @@ def __init__(
max_new_tokens = kwargs.pop('max_new_tokens', None)
if max_new_tokens is not None:
max_tokens = max_new_tokens
if num_beams > 1:
top_k = -1
top_p = 1
temperature = 0
logger.warning('The output of num_beams in vllm may not be consistent with '
'the output of num_beams in transformers.')
if top_k == 0:
top_k = -1
if stop is None:
Expand All @@ -233,11 +226,6 @@ def __init__(
kwargs['top_k'] = top_k
kwargs['top_p'] = top_p
kwargs['repetition_penalty'] = repetition_penalty
if num_beams > 1:
best_of = kwargs.get('best_of')
assert 'use_beam_search' not in kwargs and best_of is None
kwargs['use_beam_search'] = True
kwargs['best_of'] = num_beams
kwargs['n'] = n
kwargs['logprobs'] = logprobs
kwargs['seed'] = seed
Expand All @@ -260,7 +248,6 @@ class VllmGenerationConfig(_VllmGenerationConfigMixin, SamplingParams):
top_k: int = 50 # -1: all
top_p: float = 1.
repetition_penalty: float = 1.
num_beams: int = 1
n: int = 1
logprobs: Optional[int] = None
seed: Optional[int] = None
Expand All @@ -269,15 +256,6 @@ class VllmGenerationConfig(_VllmGenerationConfigMixin, SamplingParams):
skip_special_tokens: bool = False

def __post_init__(self):
if self.num_beams > 1:
self.top_k = -1
self.top_p = 1
self.temperature = 0
logger.warning('The output of num_beams in vllm may not be consistent with '
'the output of num_beams in transformers.')
assert self.best_of is None
self.use_beam_search = True
self.best_of = self.num_beams
if self.top_k == 0:
self.top_k = -1
if self.stop is None:
Expand Down Expand Up @@ -435,10 +413,6 @@ def inference_stream_vllm(
use_tqdm=use_tqdm,
**kwargs)

if generation_config.use_beam_search:
error_msg = 'Streaming generation does not support beam search.'
raise ValueError(error_msg)

n_finished = 0
n_steps = 0
if flush_steps is None:
Expand Down
Loading