36
36
PretrainedConfig ,
37
37
AutoProcessor ,
38
38
BaseImageProcessor ,
39
+ ProcessorMixin ,
39
40
PreTrainedTokenizer ,
40
41
)
41
42
from accelerate .big_modeling import (
@@ -112,7 +113,7 @@ def __init__(
112
113
self .search_result = None
113
114
self .config : PretrainedConfig = config
114
115
self .quant_config : AwqConfig = quant_config
115
- self .processor : BaseImageProcessor = processor
116
+ self .processor : ProcessorMixin = processor
116
117
117
118
def to (self , device : Annotated [str , Doc ("The device to move your model to." )]):
118
119
"""A utility function for moving the model to a device."""
@@ -342,6 +343,14 @@ def from_pretrained(
342
343
Dict ,
343
344
Doc ("Used for configure download model" ),
344
345
] = None ,
346
+ low_cpu_mem_usage : Annotated [
347
+ bool ,
348
+ Doc ("Use low_cpu_mem_usage when loading from transformers." )
349
+ ] = True ,
350
+ use_cache : Annotated [
351
+ bool ,
352
+ Doc ("Use use_cache argument in transformers" )
353
+ ] = False ,
345
354
** model_init_kwargs : Annotated [
346
355
Dict ,
347
356
Doc (
@@ -367,6 +376,11 @@ def from_pretrained(
367
376
if target_cls_name == "AutoModelForVision2Seq" :
368
377
processor = AutoProcessor .from_pretrained (model_weights_path )
369
378
379
+ if model_init_kwargs .get ("low_cpu_mem_usage" ) is None :
380
+ model_init_kwargs ["low_cpu_mem_usage" ] = low_cpu_mem_usage
381
+ if model_init_kwargs .get ("use_cache" ) is None and target_cls_name != "AutoModelForVision2Seq" :
382
+ model_init_kwargs ["use_cache" ] = use_cache
383
+
370
384
# If not quantized, must load with AutoModelForCausalLM
371
385
model = target_cls .from_pretrained (
372
386
model_weights_path ,
0 commit comments