@@ -72,6 +72,8 @@ def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
72
72
73
73
def _get_position_embedding (self , H : int , W : int ) -> torch .Tensor :
74
74
position_embedding = self .position_embedding
75
+ if self .num_patches == H * W :
76
+ return position_embedding
75
77
76
78
return torch .cat (
77
79
[
@@ -102,44 +104,55 @@ def __init__(
102
104
self ,
103
105
config : PretrainedConfig ,
104
106
quant_config : Optional [QuantizationConfig ] = None ,
105
- ):
107
+ * ,
108
+ num_dummy_heads : int = 0 ,
109
+ ) -> None :
106
110
super ().__init__ ()
111
+
107
112
self .config = config
108
113
self .embed_dim = config .hidden_size
109
114
self .num_heads = config .num_attention_heads
110
115
self .head_dim = self .embed_dim // self .num_heads
111
- self .tp_size = get_tensor_model_parallel_world_size ()
112
- self .tp_rank = get_tensor_model_parallel_rank ()
113
- self .num_heads_per_partition = divide (self .num_heads , self .tp_size )
114
-
115
116
if self .head_dim * self .num_heads != self .embed_dim :
116
117
raise ValueError (
117
118
f'embed_dim must be divisible by num_heads '
118
119
f'(got `embed_dim`: { self .embed_dim } and `num_heads`:'
119
120
f' { self .num_heads } ).' )
120
121
122
+ self .tp_size = get_tensor_model_parallel_world_size ()
123
+ self .tp_rank = get_tensor_model_parallel_rank ()
124
+
125
+ # Additional dummy heads are used to enable TP for common GPU counts.
126
+ self .dummy_dim = (num_dummy_heads + self .num_heads ) * self .head_dim
127
+ self .num_heads_per_partition = divide (num_dummy_heads + self .num_heads ,
128
+ self .tp_size )
129
+
121
130
self .scale = self .head_dim ** - 0.5
122
131
self .qkv = QKVParallelLinear (
123
132
self .embed_dim ,
124
133
self .head_dim ,
125
- self .num_heads ,
134
+ num_dummy_heads + self .num_heads ,
126
135
bias = config .qkv_bias ,
127
136
quant_config = quant_config ,
128
137
)
129
138
130
139
self .qk_normalization = config .qk_normalization
131
140
132
141
if self .qk_normalization :
133
- self .q_norm = RMSNorm (self .embed_dim , eps = config .layer_norm_eps )
134
- self .k_norm = RMSNorm (self .embed_dim , eps = config .layer_norm_eps )
142
+ self .q_norm = RMSNorm (self .dummy_dim ,
143
+ eps = config .layer_norm_eps ,
144
+ var_hidden_size = self .embed_dim )
145
+ self .k_norm = RMSNorm (self .dummy_dim ,
146
+ eps = config .layer_norm_eps ,
147
+ var_hidden_size = self .embed_dim )
135
148
136
149
self .proj = RowParallelLinear (
137
- self .embed_dim ,
150
+ self .dummy_dim ,
138
151
self .embed_dim ,
139
152
quant_config = quant_config ,
140
153
)
141
154
142
- def _apply_qk_norm (self , q , k ):
155
+ def _apply_qk_norm (self , q : torch . Tensor , k : torch . Tensor ):
143
156
if self .tp_size > 1 :
144
157
q = tensor_model_parallel_all_gather (q .contiguous ())
145
158
k = tensor_model_parallel_all_gather (k .contiguous ())
@@ -152,7 +165,7 @@ def _apply_qk_norm(self, q, k):
152
165
k = splitter (k )[self .tp_rank ]
153
166
return q , k
154
167
155
- def forward (self , x ) :
168
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
156
169
B , N , _ = x .shape
157
170
qkv , _ = self .qkv (x )
158
171
q , k , v = qkv .chunk (3 , dim = - 1 )
@@ -174,8 +187,14 @@ def forward(self, x):
174
187
class InternSdpaAttention (nn .Module ):
175
188
"""Multi-headed attention from 'Attention Is All You Need' paper"""
176
189
177
- def __init__ (self , config : PretrainedConfig ):
190
+ def __init__ (
191
+ self ,
192
+ config : PretrainedConfig ,
193
+ * ,
194
+ num_dummy_heads : int = 0 ,
195
+ ) -> None :
178
196
super ().__init__ ()
197
+
179
198
self .config = config
180
199
self .embed_dim = config .hidden_size
181
200
self .num_heads = config .num_attention_heads
@@ -186,20 +205,27 @@ def __init__(self, config: PretrainedConfig):
186
205
f'(got `embed_dim`: { self .embed_dim } and `num_heads`:'
187
206
f' { self .num_heads } ).' )
188
207
208
+ # Additional dummy heads are used to enable TP for common GPU counts.
209
+ self .dummy_dim = (num_dummy_heads + self .num_heads ) * self .head_dim
210
+
189
211
self .scale = self .head_dim ** - 0.5
190
212
self .qkv = nn .Linear (self .embed_dim ,
191
- 3 * self .embed_dim ,
213
+ 3 * self .dummy_dim ,
192
214
bias = config .qkv_bias )
193
215
194
216
self .qk_normalization = config .qk_normalization
195
217
196
218
if self .qk_normalization :
197
- self .q_norm = RMSNorm (self .embed_dim , eps = config .layer_norm_eps )
198
- self .k_norm = RMSNorm (self .embed_dim , eps = config .layer_norm_eps )
219
+ self .q_norm = RMSNorm (self .dummy_dim ,
220
+ eps = config .layer_norm_eps ,
221
+ var_hidden_size = self .embed_dim )
222
+ self .k_norm = RMSNorm (self .dummy_dim ,
223
+ eps = config .layer_norm_eps ,
224
+ var_hidden_size = self .embed_dim )
199
225
200
- self .proj = nn .Linear (self .embed_dim , self .embed_dim )
226
+ self .proj = nn .Linear (self .dummy_dim , self .embed_dim )
201
227
202
- def forward (self , x ) :
228
+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
203
229
B , N , C = x .shape
204
230
qkv = self .qkv (x )
205
231
q , k , v = qkv .chunk (3 , dim = - 1 )
@@ -252,15 +278,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
252
278
253
279
class InternVisionEncoderLayer (nn .Module ):
254
280
255
- def __init__ (self ,
256
- config : PretrainedConfig ,
257
- quant_config : Optional [QuantizationConfig ] = None ):
281
+ def __init__ (
282
+ self ,
283
+ config : PretrainedConfig ,
284
+ quant_config : Optional [QuantizationConfig ] = None ,
285
+ * ,
286
+ num_dummy_heads : int = 0 ,
287
+ ) -> None :
258
288
super ().__init__ ()
289
+
259
290
self .embed_dim = config .hidden_size
260
291
self .intermediate_size = config .intermediate_size
261
292
self .norm_type = config .norm_type
262
293
263
- self .attn = self ._init_attn (config , quant_config )
294
+ self .attn = self ._init_attn (config ,
295
+ quant_config ,
296
+ num_dummy_heads = num_dummy_heads )
264
297
265
298
self .mlp = InternMLP (config , quant_config = quant_config )
266
299
self .norm1 = NORM2FN [self .norm_type ](self .embed_dim ,
@@ -273,16 +306,23 @@ def __init__(self,
273
306
self .ls2 = nn .Parameter (config .initializer_factor *
274
307
torch .ones (self .embed_dim ))
275
308
276
- def _init_attn (self , config : PretrainedConfig ,
277
- quant_config : Optional [QuantizationConfig ]):
309
+ def _init_attn (
310
+ self ,
311
+ config : PretrainedConfig ,
312
+ quant_config : Optional [QuantizationConfig ],
313
+ * ,
314
+ num_dummy_heads : int ,
315
+ ):
278
316
# fallback to sdpa attention if tp unavailable
279
317
tp_size = get_tensor_model_parallel_world_size ()
280
318
num_heads = config .num_attention_heads
281
319
282
- if USE_XFORMERS_OPS and num_heads % tp_size == 0 :
283
- return InternParallelAttention (config , quant_config = quant_config )
320
+ if USE_XFORMERS_OPS and (num_heads + num_dummy_heads ) % tp_size == 0 :
321
+ return InternParallelAttention (config ,
322
+ quant_config = quant_config ,
323
+ num_dummy_heads = num_dummy_heads )
284
324
285
- return InternSdpaAttention (config )
325
+ return InternSdpaAttention (config , num_dummy_heads = num_dummy_heads )
286
326
287
327
def forward (
288
328
self ,
@@ -299,27 +339,30 @@ def forward(
299
339
300
340
class InternVisionEncoder (nn .Module ):
301
341
302
- def __init__ (self ,
303
- config : PretrainedConfig ,
304
- quant_config : Optional [QuantizationConfig ] = None ,
305
- num_hidden_layers_override : Optional [int ] = None ):
342
+ def __init__ (
343
+ self ,
344
+ config : PretrainedConfig ,
345
+ quant_config : Optional [QuantizationConfig ] = None ,
346
+ * ,
347
+ num_hidden_layers_override : Optional [int ] = None ,
348
+ num_dummy_heads : int = 0 ,
349
+ ):
306
350
super ().__init__ ()
351
+
307
352
self .config = config
308
353
309
354
if num_hidden_layers_override is None :
310
355
num_hidden_layers = config .num_hidden_layers
311
356
else :
312
357
num_hidden_layers = num_hidden_layers_override
358
+
313
359
self .layers = nn .ModuleList ([
314
- self ._init_encoder_layer (config , quant_config )
360
+ InternVisionEncoderLayer (config ,
361
+ quant_config ,
362
+ num_dummy_heads = num_dummy_heads )
315
363
for _ in range (num_hidden_layers )
316
364
])
317
365
318
- def _init_encoder_layer (self , config : PretrainedConfig ,
319
- quant_config : Optional [QuantizationConfig ]):
320
- return InternVisionEncoderLayer (config = config ,
321
- quant_config = quant_config )
322
-
323
366
def forward (self , inputs_embeds : torch .Tensor ):
324
367
325
368
hidden_states = inputs_embeds
@@ -331,30 +374,24 @@ def forward(self, inputs_embeds: torch.Tensor):
331
374
332
375
class InternVisionModel (nn .Module ):
333
376
334
- def __init__ (self ,
335
- config : PretrainedConfig ,
336
- quant_config : Optional [QuantizationConfig ] = None ,
337
- num_hidden_layers_override : Optional [int ] = None ):
377
+ def __init__ (
378
+ self ,
379
+ config : PretrainedConfig ,
380
+ quant_config : Optional [QuantizationConfig ] = None ,
381
+ * ,
382
+ num_hidden_layers_override : Optional [int ] = None ,
383
+ num_dummy_heads : int = 0 ,
384
+ ):
338
385
super ().__init__ ()
339
- self .config = config
340
-
341
- self .embeddings = self ._init_embeddings (config )
342
- self .encoder = self ._init_encoder (
343
- config ,
344
- quant_config ,
345
- num_hidden_layers_override = num_hidden_layers_override ,
346
- )
347
386
348
- def _init_embeddings (self , config : PretrainedConfig ):
349
- return InternVisionEmbeddings (config )
387
+ self .config = config
350
388
351
- def _init_encoder (self , config : PretrainedConfig ,
352
- quant_config : Optional [QuantizationConfig ],
353
- num_hidden_layers_override : Optional [int ]):
354
- return InternVisionEncoder (
389
+ self .embeddings = InternVisionEmbeddings (config )
390
+ self .encoder = InternVisionEncoder (
355
391
config = config ,
356
392
quant_config = quant_config ,
357
393
num_hidden_layers_override = num_hidden_layers_override ,
394
+ num_dummy_heads = num_dummy_heads ,
358
395
)
359
396
360
397
def get_input_embeddings (self ):
0 commit comments