24
24
from vllm .lora .punica import PunicaWrapper
25
25
from vllm .lora .utils import (from_layer , from_layer_logits_processor ,
26
26
parse_fine_tuned_lora_name , replace_submodule )
27
- from vllm .model_executor .models .interfaces import SupportsLoRA
27
+ from vllm .model_executor .models .interfaces import (SupportsLoRA ,
28
+ supports_multimodal )
29
+ from vllm .model_executor .models .module_mapping import MultiModelKeys
28
30
from vllm .model_executor .models .utils import PPMissingLayer
29
31
from vllm .utils import is_pin_memory_available
30
32
@@ -332,6 +334,8 @@ def __init__(
332
334
self .supported_lora_modules .append ("rotary_emb" )
333
335
self .packed_modules_mapping = copy .deepcopy (
334
336
self .model .packed_modules_mapping )
337
+ # Used to indicate whether the model is a multimodal model
338
+ self .supports_mm : bool = supports_multimodal (self .model )
335
339
self .packed_modules : Dict [str , List [str ]] = {}
336
340
self .modules : Dict [str , "BaseLayerWithLoRA" ] = {}
337
341
# Dict instead of a Set for compatibility with LRUCache.
@@ -437,12 +441,22 @@ def _create_lora_modules(self):
437
441
continue
438
442
if not self ._match_target_modules (module_name ):
439
443
continue
444
+ # A temporary approach for multimodal models to support LoRA
445
+ # TODO: Remove this restriction
446
+ if self ._filter_unsupported_mm_module (module_name ):
447
+ logger .warning (
448
+ "Regarding multimodal models, vLLM currently only supports "
449
+ "adding LoRA to language model, %s will be ignored." ,
450
+ module_name ,
451
+ )
452
+ continue
440
453
parts = module_name .split ("." )[- 1 ]
441
454
packed_moduled_lst = self .packed_modules_mapping .get (parts , [])
442
455
new_module = replace_submodule (
443
456
self .model , module_name ,
444
457
from_layer (module , self .lora_slots , self .lora_config ,
445
458
packed_moduled_lst , self .model .config ))
459
+
446
460
# LinearScalingRotaryEmbeddingWithLora is used to handle
447
461
# long context lora. Register relevant metadata.
448
462
if isinstance (new_module , LinearScalingRotaryEmbeddingWithLora ):
@@ -460,6 +474,15 @@ def _create_lora_modules(self):
460
474
module , self .lora_slots ,
461
475
self .lora_config ,
462
476
self .model .config ))
477
+
478
+ # In some models, especially multimodal ones, layers with the same
479
+ # name may have different types, such as nn.Linear and
480
+ # ReplicatedLinear. The nn.Linear layers cannot be replaced with
481
+ # LoRA layers, leading to assertion error. The following check
482
+ # aims to prevent this error
483
+ if self .supports_mm and not isinstance (new_module ,
484
+ BaseLayerWithLoRA ):
485
+ continue
463
486
self .register_module (module_name , new_module )
464
487
self ._register_packed_modules (module_name )
465
488
# All lora layers share the same punica_wrapper based on reference.
@@ -478,9 +501,10 @@ def create_dummy_lora(
478
501
"""Create zero-initialized LoRAModel for warmup."""
479
502
model = LoRAModel (lora_id , rank , {}, scaling_factor )
480
503
for module_name , module in self .model .named_modules ():
481
- if not self ._match_target_modules (module_name ) or not isinstance (
482
- module , BaseLayerWithLoRA ) or isinstance (
483
- module , LinearScalingRotaryEmbeddingWithLora ):
504
+ if (not self ._match_target_modules (module_name )
505
+ or not isinstance (module , BaseLayerWithLoRA )
506
+ or isinstance (module , LinearScalingRotaryEmbeddingWithLora )
507
+ or self ._filter_unsupported_mm_module (module_name )):
484
508
continue
485
509
parts = module_name .split ("." )
486
510
if module_name not in self .packed_modules :
@@ -541,6 +565,19 @@ def _match_target_modules(self, module_name: str):
541
565
module_name ) or target_module == module_name
542
566
for target_module in self .supported_lora_modules )
543
567
568
+ def _filter_unsupported_mm_module (self , module_name : str ) -> bool :
569
+ """
570
+ Regarding multimodal models, vLLM currently only supports adding LoRA to
571
+ language model. LoRA for other modules, such as the vision tower, will
572
+ be filtered out.
573
+ """
574
+ if self .supports_mm :
575
+ prefix = module_name .split ("." )[0 ]
576
+ module_mapping : MultiModelKeys = self .model .get_mm_mapping ()
577
+ return (prefix in module_mapping .connector
578
+ or prefix in module_mapping .tower_model )
579
+ return False
580
+
544
581
def _register_packed_modules (self , module_full_name : str ) -> None :
545
582
parts = module_full_name .split ("." )
546
583
module_name = parts [- 1 ]
0 commit comments