|
1 |
| -# Copied code from |
| 1 | +# Adapted from |
2 | 2 | # https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
|
3 | 3 |
|
4 | 4 | from dataclasses import dataclass, field
|
@@ -44,16 +44,26 @@ class ModelKeys:
|
44 | 44 |
|
45 | 45 | @dataclass
|
46 | 46 | class MultiModelKeys(ModelKeys):
|
47 |
| - language_model: Union[List[str], str] = field(default_factory=list) |
48 |
| - connector: Union[List[str], str] = field(default_factory=list) |
| 47 | + language_model: List[str] = field(default_factory=list) |
| 48 | + connector: List[str] = field(default_factory=list) |
49 | 49 | # vision tower and audio tower
|
50 |
| - tower_model: Union[List[str], str] = field(default_factory=list) |
51 |
| - generator: Union[List[str], str] = field(default_factory=list) |
52 |
| - |
53 |
| - def __post_init__(self): |
54 |
| - for key in ["language_model", "connector", "tower_model", "generator"]: |
55 |
| - v = getattr(self, key) |
56 |
| - if isinstance(v, str): |
57 |
| - setattr(self, key, [v]) |
58 |
| - if v is None: |
59 |
| - setattr(self, key, []) |
| 50 | + tower_model: List[str] = field(default_factory=list) |
| 51 | + generator: List[str] = field(default_factory=list) |
| 52 | + |
| 53 | + @staticmethod |
| 54 | + def from_string_field(language_model: Union[str, List[str]] = None, |
| 55 | + connector: Union[str, List[str]] = None, |
| 56 | + tower_model: Union[str, List[str]] = None, |
| 57 | + generator: Union[str, List[str]] = None, |
| 58 | + **kwargs) -> 'MultiModelKeys': |
| 59 | + |
| 60 | + def to_list(value): |
| 61 | + if value is None: |
| 62 | + return [] |
| 63 | + return [value] if isinstance(value, str) else list(value) |
| 64 | + |
| 65 | + return MultiModelKeys(language_model=to_list(language_model), |
| 66 | + connector=to_list(connector), |
| 67 | + tower_model=to_list(tower_model), |
| 68 | + generator=to_list(generator), |
| 69 | + **kwargs) |
0 commit comments