Skip to content

Commit d331156

Browse files
authored
[Bugfix] remove post_layernorm in siglip (#8106)
1 parent ccd7207 commit d331156

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

vllm/model_executor/models/siglip.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,14 +443,27 @@ def __init__(
443443
self.config = config
444444
embed_dim = config.hidden_size
445445

446+
if (num_hidden_layers_override is None
447+
or num_hidden_layers_override == config.num_hidden_layers):
448+
self.need_post_layernorm = True
449+
elif num_hidden_layers_override > config.num_hidden_layers:
450+
raise ValueError(
451+
"num_hidden_layers_override cannot be greater than "
452+
"num_hidden_layers")
453+
else:
454+
self.need_post_layernorm = False
455+
446456
self.embeddings = SiglipVisionEmbeddings(config)
447457
self.encoder = SiglipEncoder(
448458
config,
449459
quant_config=quant_config,
450460
num_hidden_layers_override=num_hidden_layers_override,
451461
)
452-
self.post_layernorm = nn.LayerNorm(embed_dim,
453-
eps=config.layer_norm_eps)
462+
if self.need_post_layernorm:
463+
self.post_layernorm = nn.LayerNorm(embed_dim,
464+
eps=config.layer_norm_eps)
465+
else:
466+
self.post_layernorm = nn.Identity()
454467
self.use_head = (True if not hasattr(config, "vision_use_head") else
455468
config.vision_use_head)
456469
if self.use_head:
@@ -470,7 +483,6 @@ def forward(
470483
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
471484

472485
last_hidden_state = self.post_layernorm(encoder_outputs)
473-
474486
# TODO: add this back when pooled_output is used in inference
475487
# if self.use_head:
476488
# pooled_output = self.head(last_hidden_state)
@@ -499,6 +511,10 @@ def __init__(
499511
num_hidden_layers_override=num_hidden_layers_override,
500512
)
501513

514+
@property
515+
def need_post_layernorm(self):
516+
return self.vision_model.need_post_layernorm
517+
502518
def get_input_embeddings(self) -> nn.Module:
503519
return self.vision_model.embeddings.patch_embedding
504520

@@ -517,6 +533,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
517533
layer_count = len(self.vision_model.encoder.layers)
518534

519535
for name, loaded_weight in weights:
536+
# post_layernorm is optional in SiglipVisionModel
537+
if ("vision_model.post_layernorm" in name
538+
and not self.need_post_layernorm):
539+
continue
540+
520541
# omit layers when num_hidden_layers_override is set
521542
if "vision_model.encoder.layers." in name:
522543
layer_idx = int(name.split(".")[3])

0 commit comments

Comments
 (0)