1
1
# Copyright (c) Alibaba, Inc. and its affiliates.
2
2
import os
3
- from typing import List , Optional
3
+ from contextlib import contextmanager
4
+ from types import MethodType
5
+ from typing import Dict , List , Optional
4
6
5
7
import json
6
8
import torch
9
+ import torch .nn as nn
7
10
8
11
from swift .llm import get_model_tokenizer , get_template
9
12
from swift .utils import (check_json_format , get_logger , get_main , get_model_info , push_to_ms_hub , seed_everything ,
10
13
show_layers )
11
14
from .infer import merge_lora , prepare_model_template , save_checkpoint
12
- from .utils import ExportArguments , Template , get_dataset , swift_to_peft_format
15
+ from .utils import ExportArguments , Template , deep_getattr , get_dataset , get_mllm_arch , swift_to_peft_format
13
16
14
17
logger = get_logger ()
15
18
16
19
_args : Optional [ExportArguments ] = None
17
20
template : Optional [Template ] = None
18
21
19
22
23
+ def _prepare_dataset (examples : List [Dict [str , torch .LongTensor ]], batch_size : int = 1 , * args , ** kwargs ):
24
+ global _args , template
25
+ assert template is not None
26
+ examples = [
27
+ template .data_collator (examples [start :start + batch_size ]) for start in range (0 , len (examples ), batch_size )
28
+ ]
29
+ return examples
30
+
31
+
20
32
def _get_dataset (* args , ** kwargs ):
21
33
global _args , template
22
34
assert _args is not None
@@ -39,27 +51,31 @@ def _get_dataset(*args, **kwargs):
39
51
samples = []
40
52
n_run = 0
41
53
for data in dataset :
42
- input_ids = template .encode (data )[0 ].get ('input_ids' )
54
+ inputs = template .encode (data )[0 ]
55
+ input_ids = inputs ['input_ids' ]
43
56
if input_ids is None or len (input_ids ) == 0 :
44
57
continue
45
- sample = torch .tensor (input_ids )
46
- samples .append (sample )
58
+ if _args .is_multimodal and _args .quant_method == 'gptq' :
59
+ inputs .pop ('labels' , None )
60
+ samples .append (inputs )
61
+ else :
62
+ samples += input_ids
47
63
n_run += 1
48
64
if n_run == n_samples :
49
65
break
66
+ if _args .is_multimodal and _args .quant_method == 'gptq' :
67
+ return samples
50
68
# now concatenate all samples and split according to block size
51
- cat_samples = torch .cat (samples , dim = 0 ) # shape: [X]
52
- n_split = cat_samples .shape [0 ] // block_size
69
+ n_split = len (samples ) // block_size
53
70
logger .info (f'Split into { n_split } blocks' )
54
- if _args .quant_method == 'awq' :
55
- return [cat_samples [None , i * block_size :(i + 1 ) * block_size ] for i in range (n_split )]
56
- else : # gptq
57
- res = []
58
- for i in range (n_split ):
59
- input_ids = cat_samples [None , i * block_size :(i + 1 ) * block_size ]
60
- attention_mask = torch .ones_like (input_ids )
61
- res .append ({'input_ids' : input_ids , 'attention_mask' : attention_mask })
62
- return res
71
+ res = []
72
+ for i in range (n_split ):
73
+ input_ids = samples [i * block_size :(i + 1 ) * block_size ]
74
+ if _args .quant_method == 'awq' :
75
+ res .append (torch .tensor (input_ids )[None ])
76
+ else :
77
+ res .append ({'input_ids' : input_ids })
78
+ return res
63
79
64
80
65
81
def awq_model_quantize (awq_model , tokenizer , batch_size ) -> None :
@@ -80,22 +96,74 @@ def awq_model_quantize(awq_model, tokenizer, batch_size) -> None:
80
96
bits = _args .quant_bits , group_size = group_size , zero_point = True , version = 'GEMM' )
81
97
82
98
99
+ @contextmanager
100
+ def _patch_gptq ():
101
+ from optimum .gptq import quantizer
102
+ _get_dataset_origin = quantizer .get_dataset
103
+ _prepare_dataset_origin = quantizer .prepare_dataset
104
+ quantizer .get_dataset = _get_dataset
105
+ quantizer .prepare_dataset = _prepare_dataset
106
+ yield
107
+ quantizer .get_dataset = _get_dataset_origin
108
+ quantizer .prepare_dataset = _prepare_dataset_origin
109
+
110
+
111
+ def _patch_model_forward (module_list ):
112
+
113
+ def _new_forward (self , * args , ** kwargs ):
114
+ if 'use_cache' in kwargs :
115
+ kwargs ['use_cache' ] = False
116
+ layer_ret = self .__old_forward (* args , ** kwargs )
117
+ return layer_ret + args [len (layer_ret ):]
118
+
119
+ for module in module_list :
120
+ if hasattr (module , '_old_forward' ): # device_map
121
+ __old_forward = module ._old_forward
122
+ module ._old_forward = MethodType (_new_forward , module )
123
+ else :
124
+ __old_forward = module .forward
125
+ module .forward = MethodType (_new_forward , module )
126
+ module .__old_forward = __old_forward
127
+
128
+
129
+ def get_block_name_to_quantize (model : nn .Module , model_type : str ) -> Optional [str ]:
130
+ mllm_arch = get_mllm_arch (model_type )
131
+ prefix = ''
132
+ if mllm_arch is not None :
133
+ assert len (mllm_arch .language_model ) == 1 , f'mllm_arch.language_model: { mllm_arch .language_model } '
134
+ prefix = mllm_arch .language_model [0 ]
135
+ model = deep_getattr (model , prefix )
136
+
137
+ module_lists = []
138
+ for n , m in model .named_modules ():
139
+ if isinstance (m , nn .ModuleList ) and len (m ) >= 10 :
140
+ module_lists .append ((n , m ))
141
+ if module_lists :
142
+ module_list = max (module_lists , key = lambda x : len (x [1 ]))
143
+ _patch_model_forward (module_list [1 ])
144
+ return f'{ prefix } .{ module_list [0 ]} '
145
+
146
+
83
147
def gptq_model_quantize (model , tokenizer , batch_size ):
84
- from optimum .gptq import GPTQQuantizer , quantizer
148
+ from optimum .gptq import GPTQQuantizer
85
149
global _args
86
150
logger .info (f'Quantization dataset: { _args .dataset } ' )
87
- gptq_quantizer = GPTQQuantizer (bits = _args .quant_bits , dataset = ',' .join (_args .dataset ), batch_size = batch_size )
88
- _origin_get_dataset = quantizer .get_dataset
89
- quantizer .get_dataset = _get_dataset
90
- logger .info ('Start quantizing the model...' )
91
- logger .warning ('The process of packing the model takes a long time and there is no progress bar. '
92
- 'Please be patient and wait...' )
93
- gptq_quantizer .quantize_model (model , tokenizer )
94
- quantizer .get_dataset = _origin_get_dataset # recover
151
+ with _patch_gptq ():
152
+ gptq_quantizer = GPTQQuantizer (
153
+ bits = _args .quant_bits ,
154
+ dataset = ',' .join (_args .dataset ),
155
+ batch_size = batch_size ,
156
+ block_name_to_quantize = get_block_name_to_quantize (model , _args .model_type ))
157
+ logger .info ('Start quantizing the model...' )
158
+ logger .warning ('The process of packing the model takes a long time and there is no progress bar. '
159
+ 'Please be patient and wait...' )
160
+ if not hasattr (model .config , 'use_cache' ):
161
+ model .config .use_cache = None
162
+ gptq_quantizer .quantize_model (model , tokenizer )
95
163
return gptq_quantizer
96
164
97
165
98
- def replace_and_concat (template : ' Template' , template_list : List , placeholder : str , keyword : str ):
166
+ def replace_and_concat (template : Template , template_list : List , placeholder : str , keyword : str ):
99
167
final_str = ''
100
168
for t in template_list :
101
169
if isinstance (t , str ):
0 commit comments