Skip to content

Commit 9989ef9

Browse files
authored
Support quant mllm (#2177)
1 parent 45354f7 commit 9989ef9

File tree

5 files changed

+107
-53
lines changed

5 files changed

+107
-53
lines changed

swift/llm/export.py

Lines changed: 94 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import os
3-
from typing import List, Optional
3+
from contextlib import contextmanager
4+
from types import MethodType
5+
from typing import Dict, List, Optional
46

57
import json
68
import torch
9+
import torch.nn as nn
710

811
from swift.llm import get_model_tokenizer, get_template
912
from swift.utils import (check_json_format, get_logger, get_main, get_model_info, push_to_ms_hub, seed_everything,
1013
show_layers)
1114
from .infer import merge_lora, prepare_model_template, save_checkpoint
12-
from .utils import ExportArguments, Template, get_dataset, swift_to_peft_format
15+
from .utils import ExportArguments, Template, deep_getattr, get_dataset, get_mllm_arch, swift_to_peft_format
1316

1417
logger = get_logger()
1518

1619
_args: Optional[ExportArguments] = None
1720
template: Optional[Template] = None
1821

1922

23+
def _prepare_dataset(examples: List[Dict[str, torch.LongTensor]], batch_size: int = 1, *args, **kwargs):
24+
global _args, template
25+
assert template is not None
26+
examples = [
27+
template.data_collator(examples[start:start + batch_size]) for start in range(0, len(examples), batch_size)
28+
]
29+
return examples
30+
31+
2032
def _get_dataset(*args, **kwargs):
2133
global _args, template
2234
assert _args is not None
@@ -39,27 +51,31 @@ def _get_dataset(*args, **kwargs):
3951
samples = []
4052
n_run = 0
4153
for data in dataset:
42-
input_ids = template.encode(data)[0].get('input_ids')
54+
inputs = template.encode(data)[0]
55+
input_ids = inputs['input_ids']
4356
if input_ids is None or len(input_ids) == 0:
4457
continue
45-
sample = torch.tensor(input_ids)
46-
samples.append(sample)
58+
if _args.is_multimodal and _args.quant_method == 'gptq':
59+
inputs.pop('labels', None)
60+
samples.append(inputs)
61+
else:
62+
samples += input_ids
4763
n_run += 1
4864
if n_run == n_samples:
4965
break
66+
if _args.is_multimodal and _args.quant_method == 'gptq':
67+
return samples
5068
# now concatenate all samples and split according to block size
51-
cat_samples = torch.cat(samples, dim=0) # shape: [X]
52-
n_split = cat_samples.shape[0] // block_size
69+
n_split = len(samples) // block_size
5370
logger.info(f'Split into {n_split} blocks')
54-
if _args.quant_method == 'awq':
55-
return [cat_samples[None, i * block_size:(i + 1) * block_size] for i in range(n_split)]
56-
else: # gptq
57-
res = []
58-
for i in range(n_split):
59-
input_ids = cat_samples[None, i * block_size:(i + 1) * block_size]
60-
attention_mask = torch.ones_like(input_ids)
61-
res.append({'input_ids': input_ids, 'attention_mask': attention_mask})
62-
return res
71+
res = []
72+
for i in range(n_split):
73+
input_ids = samples[i * block_size:(i + 1) * block_size]
74+
if _args.quant_method == 'awq':
75+
res.append(torch.tensor(input_ids)[None])
76+
else:
77+
res.append({'input_ids': input_ids})
78+
return res
6379

6480

6581
def awq_model_quantize(awq_model, tokenizer, batch_size) -> None:
@@ -80,22 +96,74 @@ def awq_model_quantize(awq_model, tokenizer, batch_size) -> None:
8096
bits=_args.quant_bits, group_size=group_size, zero_point=True, version='GEMM')
8197

8298

99+
@contextmanager
100+
def _patch_gptq():
101+
from optimum.gptq import quantizer
102+
_get_dataset_origin = quantizer.get_dataset
103+
_prepare_dataset_origin = quantizer.prepare_dataset
104+
quantizer.get_dataset = _get_dataset
105+
quantizer.prepare_dataset = _prepare_dataset
106+
yield
107+
quantizer.get_dataset = _get_dataset_origin
108+
quantizer.prepare_dataset = _prepare_dataset_origin
109+
110+
111+
def _patch_model_forward(module_list):
112+
113+
def _new_forward(self, *args, **kwargs):
114+
if 'use_cache' in kwargs:
115+
kwargs['use_cache'] = False
116+
layer_ret = self.__old_forward(*args, **kwargs)
117+
return layer_ret + args[len(layer_ret):]
118+
119+
for module in module_list:
120+
if hasattr(module, '_old_forward'): # device_map
121+
__old_forward = module._old_forward
122+
module._old_forward = MethodType(_new_forward, module)
123+
else:
124+
__old_forward = module.forward
125+
module.forward = MethodType(_new_forward, module)
126+
module.__old_forward = __old_forward
127+
128+
129+
def get_block_name_to_quantize(model: nn.Module, model_type: str) -> Optional[str]:
130+
mllm_arch = get_mllm_arch(model_type)
131+
prefix = ''
132+
if mllm_arch is not None:
133+
assert len(mllm_arch.language_model) == 1, f'mllm_arch.language_model: {mllm_arch.language_model}'
134+
prefix = mllm_arch.language_model[0]
135+
model = deep_getattr(model, prefix)
136+
137+
module_lists = []
138+
for n, m in model.named_modules():
139+
if isinstance(m, nn.ModuleList) and len(m) >= 10:
140+
module_lists.append((n, m))
141+
if module_lists:
142+
module_list = max(module_lists, key=lambda x: len(x[1]))
143+
_patch_model_forward(module_list[1])
144+
return f'{prefix}.{module_list[0]}'
145+
146+
83147
def gptq_model_quantize(model, tokenizer, batch_size):
84-
from optimum.gptq import GPTQQuantizer, quantizer
148+
from optimum.gptq import GPTQQuantizer
85149
global _args
86150
logger.info(f'Quantization dataset: {_args.dataset}')
87-
gptq_quantizer = GPTQQuantizer(bits=_args.quant_bits, dataset=','.join(_args.dataset), batch_size=batch_size)
88-
_origin_get_dataset = quantizer.get_dataset
89-
quantizer.get_dataset = _get_dataset
90-
logger.info('Start quantizing the model...')
91-
logger.warning('The process of packing the model takes a long time and there is no progress bar. '
92-
'Please be patient and wait...')
93-
gptq_quantizer.quantize_model(model, tokenizer)
94-
quantizer.get_dataset = _origin_get_dataset # recover
151+
with _patch_gptq():
152+
gptq_quantizer = GPTQQuantizer(
153+
bits=_args.quant_bits,
154+
dataset=','.join(_args.dataset),
155+
batch_size=batch_size,
156+
block_name_to_quantize=get_block_name_to_quantize(model, _args.model_type))
157+
logger.info('Start quantizing the model...')
158+
logger.warning('The process of packing the model takes a long time and there is no progress bar. '
159+
'Please be patient and wait...')
160+
if not hasattr(model.config, 'use_cache'):
161+
model.config.use_cache = None
162+
gptq_quantizer.quantize_model(model, tokenizer)
95163
return gptq_quantizer
96164

97165

98-
def replace_and_concat(template: 'Template', template_list: List, placeholder: str, keyword: str):
166+
def replace_and_concat(template: Template, template_list: List, placeholder: str, keyword: str):
99167
final_str = ''
100168
for t in template_list:
101169
if isinstance(t, str):

swift/llm/utils/client_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
def _get_request_kwargs(api_key: Optional[str] = None) -> Dict[str, Any]:
22-
timeout = float(os.getenv('TIMEOUT', '300'))
22+
timeout = float(os.getenv('TIMEOUT', '1800'))
2323
request_kwargs = {}
2424
if timeout > 0:
2525
request_kwargs['timeout'] = timeout

swift/llm/utils/media.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ def __call__(self, d: Dict[str, Any], medias: Union[tuple, list]) -> None:
9494
raise NotImplementedError
9595
else:
9696
pass
97-
standard_tag = self.standard_tags[self.media_type]
9897

9998
all_queries = ''.join([h[0] for h in history]) + query
10099
if self.media_tag in all_queries:
100+
standard_tag = self.standard_tags[self.media_type]
101101
assert all_queries.count(self.media_tag) == media_cnt
102102
for h in history:
103103
h[0] = h[0].replace(self.media_tag, standard_tag)

swift/llm/utils/model.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,8 @@ def _output_device_map_hook(module, input, output):
10451045
def get_model_tokenizer_pixtral(model_dir: str, *args, **kwargs):
10461046
from transformers import AutoProcessor, LlavaForConditionalGeneration
10471047
processor = AutoProcessor.from_pretrained(model_dir)
1048-
kwargs['automodel_class'] = LlavaForConditionalGeneration
1048+
if 'automodel_class' not in kwargs:
1049+
kwargs['automodel_class'] = LlavaForConditionalGeneration
10491050
kwargs['tokenizer'] = processor.tokenizer
10501051
model, tokenizer = get_model_tokenizer_from_repo(model_dir, *args, **kwargs)
10511052
tokenizer.processor = processor
@@ -1117,14 +1118,10 @@ def get_model_tokenizer_llava_llama(model_dir: str,
11171118

11181119
model_config = LlavaConfig.from_pretrained(model_dir)
11191120
processor = AutoProcessor.from_pretrained(model_dir)
1121+
if 'automodel_class' not in kwargs:
1122+
kwargs['automodel_class'] = LlavaForConditionalGeneration
11201123
model, tokenizer = get_model_tokenizer_with_flash_attn(
1121-
model_dir,
1122-
torch_dtype,
1123-
model_kwargs,
1124-
load_model,
1125-
model_config=model_config,
1126-
automodel_class=LlavaForConditionalGeneration,
1127-
**kwargs)
1124+
model_dir, torch_dtype, model_kwargs, load_model, model_config=model_config, **kwargs)
11281125
tokenizer.processor = processor
11291126
return model, tokenizer
11301127

@@ -6275,7 +6272,8 @@ def get_model_tokenizer_llama3_2_vision(*args, **kwargs):
62756272
hf_model_id='llava-hf/llava-1.5-7b-hf')
62766273
def get_model_tokenizer_llava_1_5(*args, **kwargs):
62776274
from transformers import LlavaForConditionalGeneration
6278-
kwargs['automodel_class'] = LlavaForConditionalGeneration
6275+
if 'automodel_class' not in kwargs:
6276+
kwargs['automodel_class'] = LlavaForConditionalGeneration
62796277
return get_model_tokenizer_llava_hf(*args, **kwargs)
62806278

62816279

@@ -6387,7 +6385,8 @@ def get_model_tokenizer_llava_onevision(*args, **kwargs):
63876385
tags=['multi-modal', 'vision'])
63886386
def get_model_tokenizer_llava_next(*args, **kwargs):
63896387
from transformers import LlavaNextForConditionalGeneration
6390-
kwargs['automodel_class'] = LlavaNextForConditionalGeneration
6388+
if 'automodel_class' not in kwargs:
6389+
kwargs['automodel_class'] = LlavaNextForConditionalGeneration
63916390
return get_model_tokenizer_llava_hf(*args, **kwargs)
63926391

63936392

swift/llm/utils/preprocess.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def new_preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
6666
column_state = self.column_state
6767
row = preprocess(self, row)
6868
for k, v in row.items():
69+
if k in ['images', 'videos', 'audios']:
70+
continue
6971
k_i = self.key_mapping[k]
7072
if column_state[k_i]:
7173
continue
@@ -196,11 +198,6 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
196198
}
197199
medias = self.parse_medias(d)
198200
self.media_replacer(row, medias)
199-
if self.media_type:
200-
if not isinstance(self.media_key, str):
201-
row[self.media_name] = medias
202-
else:
203-
row[self.media_key] = medias
204201
return row
205202

206203
def __call__(self, dataset: DATASET_TYPE) -> DATASET_TYPE:
@@ -295,11 +292,6 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
295292
})
296293
medias = self.parse_medias(d)
297294
self.media_replacer(row, medias)
298-
if self.media_type:
299-
if not isinstance(self.media_key, str):
300-
row[self.media_name] = medias
301-
else:
302-
row[self.media_key] = medias
303295
return row
304296
except (AssertionError, SyntaxError) as e:
305297
logger.error(e)
@@ -355,11 +347,6 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]:
355347
}
356348
medias = self.parse_medias(d)
357349
self.media_replacer(row, medias)
358-
if self.media_type:
359-
if not isinstance(self.media_key, str):
360-
row[self.media_name] = medias
361-
else:
362-
row[self.media_key] = medias
363350
except Exception:
364351
if self.error_strategy == 'raise':
365352
raise ValueError(f'conversations: {conversations}')

0 commit comments

Comments
 (0)