Skip to content

Commit f1abb8e

Browse files
improve type hinting and fix use_cache (#680)
1 parent cbd6a75 commit f1abb8e

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

awq/models/auto.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,15 @@ def from_pretrained(
7474
model_path, trust_remote_code, **model_init_kwargs
7575
)
7676

77-
if model_init_kwargs.get("low_cpu_mem_usage") is None:
78-
model_init_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
79-
if model_init_kwargs.get("use_cache") is None:
80-
model_init_kwargs["use_cache"] = use_cache
81-
8277
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
8378
model_path,
8479
model_type,
8580
trust_remote_code=trust_remote_code,
8681
safetensors=safetensors,
8782
device_map=device_map,
8883
download_kwargs=download_kwargs,
84+
low_cpu_mem_usage=low_cpu_mem_usage,
85+
use_cache=use_cache,
8986
**model_init_kwargs,
9087
)
9188

awq/models/base.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
PretrainedConfig,
3737
AutoProcessor,
3838
BaseImageProcessor,
39+
ProcessorMixin,
3940
PreTrainedTokenizer,
4041
)
4142
from accelerate.big_modeling import (
@@ -112,7 +113,7 @@ def __init__(
112113
self.search_result = None
113114
self.config: PretrainedConfig = config
114115
self.quant_config: AwqConfig = quant_config
115-
self.processor: BaseImageProcessor = processor
116+
self.processor: ProcessorMixin = processor
116117

117118
def to(self, device: Annotated[str, Doc("The device to move your model to.")]):
118119
"""A utility function for moving the model to a device."""
@@ -342,6 +343,14 @@ def from_pretrained(
342343
Dict,
343344
Doc("Used for configure download model"),
344345
] = None,
346+
low_cpu_mem_usage: Annotated[
347+
bool,
348+
Doc("Use low_cpu_mem_usage when loading from transformers.")
349+
] = True,
350+
use_cache: Annotated[
351+
bool,
352+
Doc("Use use_cache argument in transformers")
353+
] = False,
345354
**model_init_kwargs: Annotated[
346355
Dict,
347356
Doc(
@@ -367,6 +376,11 @@ def from_pretrained(
367376
if target_cls_name == "AutoModelForVision2Seq":
368377
processor = AutoProcessor.from_pretrained(model_weights_path)
369378

379+
if model_init_kwargs.get("low_cpu_mem_usage") is None:
380+
model_init_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
381+
if model_init_kwargs.get("use_cache") is None and target_cls_name != "AutoModelForVision2Seq":
382+
model_init_kwargs["use_cache"] = use_cache
383+
370384
# If not quantized, must load with AutoModelForCausalLM
371385
model = target_cls.from_pretrained(
372386
model_weights_path,

awq/models/llava.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def move_embed(model: OldLlavaForConditionalGeneration, device: str):
3535
model.language_model.model.embed_tokens = model.get_input_embeddings().to(
3636
device
3737
)
38-
model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device)
38+
if hasattr(model.language_model.model, "rotary_emb"):
39+
model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device)
3940

4041
@staticmethod
4142
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):

0 commit comments

Comments
 (0)