@@ -58,12 +58,31 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
58
58
nn .Linear (llm_intermediate_size , llm_hidden_size , bias = False ),
59
59
)
60
60
61
- def _init_vision_model (self , config : PretrainedConfig ,
62
- quant_config : Optional [QuantizationConfig ],
63
- num_hidden_layers : int ):
64
- # We added additional dummy heads to the original num of heads to make
65
- # the number of heads divisible by 8.
66
- return InternVisionModel (config .vision_config ,
67
- quant_config = quant_config ,
68
- num_hidden_layers_override = num_hidden_layers ,
69
- num_dummy_heads = 7 )
61
+ def _init_vision_model (
62
+ self ,
63
+ config : PretrainedConfig ,
64
+ quant_config : Optional [QuantizationConfig ],
65
+ * ,
66
+ is_mono : bool ,
67
+ prefix : str ,
68
+ ):
69
+ if not is_mono :
70
+ vision_feature_layer = config .select_layer
71
+ if vision_feature_layer < 0 :
72
+ num_hidden_layers = config .vision_config .num_hidden_layers \
73
+ + vision_feature_layer + 1
74
+ else :
75
+ num_hidden_layers = vision_feature_layer + 1
76
+
77
+ # We added additional dummy heads to the original num of heads to
78
+ # make the number of heads divisible by 8.
79
+ return InternVisionModel (
80
+ config .vision_config ,
81
+ quant_config = quant_config ,
82
+ num_hidden_layers_override = num_hidden_layers ,
83
+ num_dummy_heads = 7 ,
84
+ prefix = prefix ,
85
+ )
86
+ else :
87
+ msg = "Monolith mode is not applicable to NVLM_D"
88
+ raise NotImplementedError (msg )
0 commit comments