Skip to content

Commit a9e724c

Browse files
committed
Modify module_mapping logic
1 parent bf4ee9d commit a9e724c

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

vllm/model_executor/models/minicpmv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -627,9 +627,9 @@ def get_mm_mapping(self) -> MultiModelKeys:
627627
"""
628628
Get the module prefix in multimodal models
629629
"""
630-
return MultiModelKeys(language_model="llm",
631-
connector="resampler",
632-
tower_model="vpm")
630+
return MultiModelKeys.from_string_field(language_model="llm",
631+
connector="resampler",
632+
tower_model="vpm")
633633

634634
def init_llm(
635635
self,

vllm/model_executor/models/module_mapping.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copied code from
1+
# Adapted from
22
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
33

44
from dataclasses import dataclass, field
@@ -44,16 +44,26 @@ class ModelKeys:
4444

4545
@dataclass
4646
class MultiModelKeys(ModelKeys):
47-
language_model: Union[List[str], str] = field(default_factory=list)
48-
connector: Union[List[str], str] = field(default_factory=list)
47+
language_model: List[str] = field(default_factory=list)
48+
connector: List[str] = field(default_factory=list)
4949
# vision tower and audio tower
50-
tower_model: Union[List[str], str] = field(default_factory=list)
51-
generator: Union[List[str], str] = field(default_factory=list)
52-
53-
def __post_init__(self):
54-
for key in ["language_model", "connector", "tower_model", "generator"]:
55-
v = getattr(self, key)
56-
if isinstance(v, str):
57-
setattr(self, key, [v])
58-
if v is None:
59-
setattr(self, key, [])
50+
tower_model: List[str] = field(default_factory=list)
51+
generator: List[str] = field(default_factory=list)
52+
53+
@staticmethod
54+
def from_string_field(language_model: Union[str, List[str]] = None,
55+
connector: Union[str, List[str]] = None,
56+
tower_model: Union[str, List[str]] = None,
57+
generator: Union[str, List[str]] = None,
58+
**kwargs) -> 'MultiModelKeys':
59+
60+
def to_list(value):
61+
if value is None:
62+
return []
63+
return [value] if isinstance(value, str) else list(value)
64+
65+
return MultiModelKeys(language_model=to_list(language_model),
66+
connector=to_list(connector),
67+
tower_model=to_list(tower_model),
68+
generator=to_list(generator),
69+
**kwargs)

0 commit comments

Comments
 (0)