diff --git a/README.md b/README.md index 5ae2386..d5990c0 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,9 @@ The detailed support list: | [Vicuna-v1.1](/scripts/vicuna_example.sh) | 7B/13B | ✅ | | | [LLaVA-v0](/scripts/llava_example.sh) | 13B | ✅ | | | [VILA](/scripts/vila_example.sh) | 7B/13B | ✅ | | +| [Mistral](/scripts/mistral_example.sh) | 7B | ✅ | | +| [Mixtral](/scripts/mixtral_example.sh) | 8x7B | ✅ | | + Note: We only list models that we have prepare the [AWQ searching results](https://huggingface.co/datasets/mit-han-lab/awq-model-zoo/tree/main) in the table above. AWQ also supports models such as LLaVA-v1.5 7B, and you may need to run the [AWQ search](#usage) on your own to quantize these models. diff --git a/awq/entry.py b/awq/entry.py index 6e51e0c..a2d34bf 100644 --- a/awq/entry.py +++ b/awq/entry.py @@ -130,6 +130,8 @@ def build_model_and_enc(model_path): "BloomBlock", "MPTBlock", "DecoderLayer", + "MistralDecoderLayer", + "MixtralDecoderLayer", ], **kwargs, ) @@ -208,6 +210,7 @@ def build_model_and_enc(model_path): model, max_memory if len(max_memory) > 0 else None ) } + device_map = infer_auto_device_map( model, # TODO: can we remove this? @@ -217,6 +220,8 @@ def build_model_and_enc(model_path): "BloomBlock", "MPTBlock", "DecoderLayer", + "MistralDecoderLayer", + "MixtralDecoderLayer", ], **kwargs, ) diff --git a/awq/quantize/auto_scale.py b/awq/quantize/auto_scale.py index bb66e10..b4de8fd 100644 --- a/awq/quantize/auto_scale.py +++ b/awq/quantize/auto_scale.py @@ -5,6 +5,8 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm +from transformers.models.mistral.modeling_mistral import MistralRMSNorm +from transformers.models.mixtral.modeling_mixtral import MixtralRMSNorm from transformers.activations import GELUActivation from .qmodule import ScaledActivation @@ -439,6 +441,95 @@ def _auto_get_scale(prev_op, layers, inp, module2inspect=None, kwargs={}): inp=input_feat["mlp.dense_4h_to_h"], ) ) + elif "mistral" in str(module.__class__).lower(): + # attention input + scales_list.append( + _auto_get_scale( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + # attn out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + scales_list.append( + _auto_get_scale( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # fc1 + scales_list.append( + _auto_get_scale( + prev_op=module.post_attention_layernorm, + layers=[module.mlp.gate_proj, module.mlp.up_proj], + inp=input_feat["mlp.gate_proj"], + module2inspect=module.mlp, + ) + ) + # fc2 + scales_list.append( + _auto_get_scale( + prev_op=module.mlp.up_proj, + layers=[module.mlp.down_proj], + inp=input_feat["mlp.down_proj"], + ) + ) + elif "mixtral" in str(module.__class__).lower(): + # attention input + scales_list.append( + _auto_get_scale( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + # attn out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + scales_list.append( + _auto_get_scale( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + # fc1 + scales_list.append( + _auto_get_scale( + prev_op=module.post_attention_layernorm, + layers=[ + w + for expert in module.block_sparse_moe.experts + for w in [expert.w1, expert.w3] + ], + inp=input_feat["block_sparse_moe"], + module2inspect=module.block_sparse_moe, + ) + ) + # fc2 + for i, expert in enumerate(module.block_sparse_moe.experts): + scales_list.append( + _auto_get_scale( + prev_op=expert.w3, + layers=[expert.w2], + inp=input_feat[f"block_sparse_moe.experts.{i}.w2"], + ) + ) else: raise NotImplementedError(f"{type(module)} not supported yet!") @@ -458,7 +549,7 @@ def apply_scale(module, scales_list, input_feat_dict=None): if isinstance(prev_op, nn.Linear): assert len(layers) == 1 scale_fc_fc(prev_op, layers[0], scales) - elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)): + elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm, MistralRMSNorm, MixtralRMSNorm)): scale_ln_fcs(prev_op, layers, scales) elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)): new_module = ScaledActivation(prev_op, scales) diff --git a/awq/quantize/pre_quant.py b/awq/quantize/pre_quant.py index f35531d..bfa8640 100644 --- a/awq/quantize/pre_quant.py +++ b/awq/quantize/pre_quant.py @@ -18,8 +18,17 @@ def get_named_linears(module): - return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} - + named_linears = {} + for name, m in module.named_modules(): + if isinstance(m, nn.Linear): + # exclude the gate layer + if "mixtral" in str(module.__class__).lower(): + if "gate" in name: + continue + + named_linears[name] = m + + return named_linears def get_blocks(model): if model.__class__.__name__ == "LlamaForCausalLM": @@ -39,6 +48,10 @@ def get_blocks(model): layers = model.transformer.h elif "neox" in str(model.__class__).lower(): layers = model.gpt_neox.layers + elif "mistral" in str(model.__class__).lower(): + layers = model.model.layers + elif "mixtral" in str(model.__class__).lower(): + layers = model.model.layers else: raise NotImplementedError(type(model)) return layers @@ -73,6 +86,10 @@ def move_embed(model, device): model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device) model.gpt_neox.emb_dropout = model.gpt_neox.emb_dropout.to(device) model.embed_out = model.embed_out.to(device) + elif "mistral" in str(model.__class__).lower(): + model.model.embed_tokens = model.model.embed_tokens.to(device) + elif "mixtral" in str(model.__class__).lower(): + model.model.embed_tokens = model.model.embed_tokens.to(device) else: raise NotImplementedError(type(model)) @@ -129,10 +146,18 @@ def forward(self, inp, **kwargs): model(samples.to(next(model.parameters()).device)) except ValueError: # work with early exit pass + + # From AutoAWQ + # Update the layer kwargs with `prepare_inputs_for_generation` method + # that takes care of everything to avoid unexpected errors. + layer_kwargs = model.prepare_inputs_for_generation(samples, **layer_kwargs) + # Pop the input_ids as they are not needed at all. + layer_kwargs.pop("input_ids") + del samples layers[0] = layers[0].module # restore inps = inps[0] - + layers[0] = layers[0].cpu() move_embed(model, "cpu") @@ -158,6 +183,13 @@ def cache_input_hook(m, x, y, name, feat_dict): input_feat = defaultdict(list) handles = [] + + if "mixtral" in str(model.__class__).lower(): + named_linears = { + **named_linears, + "block_sparse_moe": layer.block_sparse_moe, + } + for name in named_linears: handles.append( named_linears[name].register_forward_hook( diff --git a/scripts/mistral_example.sh b/scripts/mistral_example.sh new file mode 100644 index 0000000..87008b0 --- /dev/null +++ b/scripts/mistral_example.sh @@ -0,0 +1,25 @@ +MODEL=Mistral-7B-Instruct-v0.2 + +# run AWQ search (optional; we provided the pre-computed results) +python -m awq.entry --model_path /dataset/mistral/$MODEL \ + --w_bit 4 --q_group_size 128 \ + --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt + +# evaluate the AWQ quantize model (simulated pseudo quantization) +python -m awq.entry --model_path /dataset/mistral/$MODEL \ + --tasks wikitext \ + --w_bit 4 --q_group_size 128 \ + --load_awq awq_cache/$MODEL-w4-g128.pt \ + --q_backend fake + +# generate real quantized weights (w4) +python -m awq.entry --model_path /dataset/mistral/$MODEL \ + --w_bit 4 --q_group_size 128 \ + --load_awq awq_cache/$MODEL-w4-g128.pt \ + --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt + +# load and evaluate the real quantized model (smaller gpu memory usage) +python -m awq.entry --model_path /dataset/mistral/$MODEL \ + --tasks wikitext \ + --w_bit 4 --q_group_size 128 \ + --load_quant quant_cache/$MODEL-w4-g128-awq.pt diff --git a/scripts/mixtral_example.sh b/scripts/mixtral_example.sh new file mode 100644 index 0000000..14278c2 --- /dev/null +++ b/scripts/mixtral_example.sh @@ -0,0 +1,25 @@ +MODEL=Mixtral-8x7B-Instruct-v0.1 + +# run AWQ search (optional; we provided the pre-computed results) +python -m awq.entry --model_path /dataset/mixtral/$MODEL \ + --w_bit 4 --q_group_size 128 \ + --run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt + +# evaluate the AWQ quantize model (simulated pseudo quantization) +python -m awq.entry --model_path /dataset/mixtral/$MODEL \ + --tasks wikitext \ + --w_bit 4 --q_group_size 128 \ + --load_awq awq_cache/$MODEL-w4-g128.pt \ + --q_backend fake + +# generate real quantized weights (w4) +python -m awq.entry --model_path /dataset/mixtral/$MODEL \ + --w_bit 4 --q_group_size 128 \ + --load_awq awq_cache/$MODEL-w4-g128.pt \ + --q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt + +# load and evaluate the real quantized model (smaller gpu memory usage) +python -m awq.entry --model_path /dataset/mixtral/$MODEL \ + --tasks wikitext \ + --w_bit 4 --q_group_size 128 \ + --load_quant quant_cache/$MODEL-w4-g128-awq.pt