@@ -443,14 +443,27 @@ def __init__(
443
443
self .config = config
444
444
embed_dim = config .hidden_size
445
445
446
+ if (num_hidden_layers_override is None
447
+ or num_hidden_layers_override == config .num_hidden_layers ):
448
+ self .need_post_layernorm = True
449
+ elif num_hidden_layers_override > config .num_hidden_layers :
450
+ raise ValueError (
451
+ "num_hidden_layers_override cannot be greater than "
452
+ "num_hidden_layers" )
453
+ else :
454
+ self .need_post_layernorm = False
455
+
446
456
self .embeddings = SiglipVisionEmbeddings (config )
447
457
self .encoder = SiglipEncoder (
448
458
config ,
449
459
quant_config = quant_config ,
450
460
num_hidden_layers_override = num_hidden_layers_override ,
451
461
)
452
- self .post_layernorm = nn .LayerNorm (embed_dim ,
453
- eps = config .layer_norm_eps )
462
+ if self .need_post_layernorm :
463
+ self .post_layernorm = nn .LayerNorm (embed_dim ,
464
+ eps = config .layer_norm_eps )
465
+ else :
466
+ self .post_layernorm = nn .Identity ()
454
467
self .use_head = (True if not hasattr (config , "vision_use_head" ) else
455
468
config .vision_use_head )
456
469
if self .use_head :
@@ -470,7 +483,6 @@ def forward(
470
483
encoder_outputs = self .encoder (inputs_embeds = hidden_states )
471
484
472
485
last_hidden_state = self .post_layernorm (encoder_outputs )
473
-
474
486
# TODO: add this back when pooled_output is used in inference
475
487
# if self.use_head:
476
488
# pooled_output = self.head(last_hidden_state)
@@ -499,6 +511,10 @@ def __init__(
499
511
num_hidden_layers_override = num_hidden_layers_override ,
500
512
)
501
513
514
+ @property
515
+ def need_post_layernorm (self ):
516
+ return self .vision_model .need_post_layernorm
517
+
502
518
def get_input_embeddings (self ) -> nn .Module :
503
519
return self .vision_model .embeddings .patch_embedding
504
520
@@ -517,6 +533,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
517
533
layer_count = len (self .vision_model .encoder .layers )
518
534
519
535
for name , loaded_weight in weights :
536
+ # post_layernorm is optional in SiglipVisionModel
537
+ if ("vision_model.post_layernorm" in name
538
+ and not self .need_post_layernorm ):
539
+ continue
540
+
520
541
# omit layers when num_hidden_layers_override is set
521
542
if "vision_model.encoder.layers." in name :
522
543
layer_idx = int (name .split ("." )[3 ])
0 commit comments