Skip to content

Commit 7728064

Browse files
committed
Consolidate code
1 parent 1b57db8 commit 7728064

File tree

3 files changed

+105
-282
lines changed

3 files changed

+105
-282
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def __init__(
2424

2525
self.hidden_size = hidden_size
2626
self.variance_epsilon = eps
27-
self.var_hidden_size = var_hidden_size
27+
self.variance_size_override = (None if var_hidden_size == hidden_size
28+
else var_hidden_size)
2829

2930
self.weight = nn.Parameter(torch.ones(hidden_size))
3031

@@ -45,15 +46,15 @@ def forward_native(
4546
raise ValueError("Expected hidden_size to be "
4647
f"{self.hidden_size}, but found: {hidden_size}")
4748

48-
if self.var_hidden_size is None:
49+
if self.variance_size_override is None:
4950
x_var = x
5051
else:
51-
if hidden_size < self.var_hidden_size:
52+
if hidden_size < self.variance_size_override:
5253
raise ValueError(
5354
"Expected hidden_size to be at least "
54-
f"{self.var_hidden_size}, but found: {hidden_size}")
55+
f"{self.variance_size_override}, but found: {hidden_size}")
5556

56-
x_var = x[:, :, :self.var_hidden_size]
57+
x_var = x[:, :, :self.variance_size_override]
5758

5859
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
5960

@@ -69,7 +70,7 @@ def forward_cuda(
6970
x: torch.Tensor,
7071
residual: Optional[torch.Tensor] = None,
7172
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
72-
if self.var_hidden_size is not None:
73+
if self.variance_size_override is not None:
7374
return self.forward_native(x, residual)
7475

7576
from vllm import _custom_ops as ops
@@ -96,7 +97,7 @@ def forward_xpu(
9697
x: torch.Tensor,
9798
residual: Optional[torch.Tensor] = None,
9899
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
99-
if self.var_hidden_size is not None:
100+
if self.variance_size_override is not None:
100101
return self.forward_native(x, residual)
101102

102103
from vllm._ipex_ops import ipex_ops as ops

vllm/model_executor/models/intern_vit.py

Lines changed: 91 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
7272

7373
def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
7474
position_embedding = self.position_embedding
75+
if self.num_patches == H * W:
76+
return position_embedding
7577

7678
return torch.cat(
7779
[
@@ -102,44 +104,55 @@ def __init__(
102104
self,
103105
config: PretrainedConfig,
104106
quant_config: Optional[QuantizationConfig] = None,
105-
):
107+
*,
108+
num_dummy_heads: int = 0,
109+
) -> None:
106110
super().__init__()
111+
107112
self.config = config
108113
self.embed_dim = config.hidden_size
109114
self.num_heads = config.num_attention_heads
110115
self.head_dim = self.embed_dim // self.num_heads
111-
self.tp_size = get_tensor_model_parallel_world_size()
112-
self.tp_rank = get_tensor_model_parallel_rank()
113-
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
114-
115116
if self.head_dim * self.num_heads != self.embed_dim:
116117
raise ValueError(
117118
f'embed_dim must be divisible by num_heads '
118119
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
119120
f' {self.num_heads}).')
120121

122+
self.tp_size = get_tensor_model_parallel_world_size()
123+
self.tp_rank = get_tensor_model_parallel_rank()
124+
125+
# Additional dummy heads are used to enable TP for common GPU counts.
126+
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
127+
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
128+
self.tp_size)
129+
121130
self.scale = self.head_dim**-0.5
122131
self.qkv = QKVParallelLinear(
123132
self.embed_dim,
124133
self.head_dim,
125-
self.num_heads,
134+
num_dummy_heads + self.num_heads,
126135
bias=config.qkv_bias,
127136
quant_config=quant_config,
128137
)
129138

130139
self.qk_normalization = config.qk_normalization
131140

132141
if self.qk_normalization:
133-
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
134-
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
142+
self.q_norm = RMSNorm(self.dummy_dim,
143+
eps=config.layer_norm_eps,
144+
var_hidden_size=self.embed_dim)
145+
self.k_norm = RMSNorm(self.dummy_dim,
146+
eps=config.layer_norm_eps,
147+
var_hidden_size=self.embed_dim)
135148

136149
self.proj = RowParallelLinear(
137-
self.embed_dim,
150+
self.dummy_dim,
138151
self.embed_dim,
139152
quant_config=quant_config,
140153
)
141154

142-
def _apply_qk_norm(self, q, k):
155+
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
143156
if self.tp_size > 1:
144157
q = tensor_model_parallel_all_gather(q.contiguous())
145158
k = tensor_model_parallel_all_gather(k.contiguous())
@@ -152,7 +165,7 @@ def _apply_qk_norm(self, q, k):
152165
k = splitter(k)[self.tp_rank]
153166
return q, k
154167

155-
def forward(self, x):
168+
def forward(self, x: torch.Tensor) -> torch.Tensor:
156169
B, N, _ = x.shape
157170
qkv, _ = self.qkv(x)
158171
q, k, v = qkv.chunk(3, dim=-1)
@@ -174,8 +187,14 @@ def forward(self, x):
174187
class InternSdpaAttention(nn.Module):
175188
"""Multi-headed attention from 'Attention Is All You Need' paper"""
176189

177-
def __init__(self, config: PretrainedConfig):
190+
def __init__(
191+
self,
192+
config: PretrainedConfig,
193+
*,
194+
num_dummy_heads: int = 0,
195+
) -> None:
178196
super().__init__()
197+
179198
self.config = config
180199
self.embed_dim = config.hidden_size
181200
self.num_heads = config.num_attention_heads
@@ -186,20 +205,27 @@ def __init__(self, config: PretrainedConfig):
186205
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
187206
f' {self.num_heads}).')
188207

208+
# Additional dummy heads are used to enable TP for common GPU counts.
209+
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
210+
189211
self.scale = self.head_dim**-0.5
190212
self.qkv = nn.Linear(self.embed_dim,
191-
3 * self.embed_dim,
213+
3 * self.dummy_dim,
192214
bias=config.qkv_bias)
193215

194216
self.qk_normalization = config.qk_normalization
195217

196218
if self.qk_normalization:
197-
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
198-
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
219+
self.q_norm = RMSNorm(self.dummy_dim,
220+
eps=config.layer_norm_eps,
221+
var_hidden_size=self.embed_dim)
222+
self.k_norm = RMSNorm(self.dummy_dim,
223+
eps=config.layer_norm_eps,
224+
var_hidden_size=self.embed_dim)
199225

200-
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
226+
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
201227

202-
def forward(self, x):
228+
def forward(self, x: torch.Tensor) -> torch.Tensor:
203229
B, N, C = x.shape
204230
qkv = self.qkv(x)
205231
q, k, v = qkv.chunk(3, dim=-1)
@@ -252,15 +278,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
252278

253279
class InternVisionEncoderLayer(nn.Module):
254280

255-
def __init__(self,
256-
config: PretrainedConfig,
257-
quant_config: Optional[QuantizationConfig] = None):
281+
def __init__(
282+
self,
283+
config: PretrainedConfig,
284+
quant_config: Optional[QuantizationConfig] = None,
285+
*,
286+
num_dummy_heads: int = 0,
287+
) -> None:
258288
super().__init__()
289+
259290
self.embed_dim = config.hidden_size
260291
self.intermediate_size = config.intermediate_size
261292
self.norm_type = config.norm_type
262293

263-
self.attn = self._init_attn(config, quant_config)
294+
self.attn = self._init_attn(config,
295+
quant_config,
296+
num_dummy_heads=num_dummy_heads)
264297

265298
self.mlp = InternMLP(config, quant_config=quant_config)
266299
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
@@ -273,16 +306,23 @@ def __init__(self,
273306
self.ls2 = nn.Parameter(config.initializer_factor *
274307
torch.ones(self.embed_dim))
275308

276-
def _init_attn(self, config: PretrainedConfig,
277-
quant_config: Optional[QuantizationConfig]):
309+
def _init_attn(
310+
self,
311+
config: PretrainedConfig,
312+
quant_config: Optional[QuantizationConfig],
313+
*,
314+
num_dummy_heads: int,
315+
):
278316
# fallback to sdpa attention if tp unavailable
279317
tp_size = get_tensor_model_parallel_world_size()
280318
num_heads = config.num_attention_heads
281319

282-
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
283-
return InternParallelAttention(config, quant_config=quant_config)
320+
if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
321+
return InternParallelAttention(config,
322+
quant_config=quant_config,
323+
num_dummy_heads=num_dummy_heads)
284324

285-
return InternSdpaAttention(config)
325+
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
286326

287327
def forward(
288328
self,
@@ -299,27 +339,30 @@ def forward(
299339

300340
class InternVisionEncoder(nn.Module):
301341

302-
def __init__(self,
303-
config: PretrainedConfig,
304-
quant_config: Optional[QuantizationConfig] = None,
305-
num_hidden_layers_override: Optional[int] = None):
342+
def __init__(
343+
self,
344+
config: PretrainedConfig,
345+
quant_config: Optional[QuantizationConfig] = None,
346+
*,
347+
num_hidden_layers_override: Optional[int] = None,
348+
num_dummy_heads: int = 0,
349+
):
306350
super().__init__()
351+
307352
self.config = config
308353

309354
if num_hidden_layers_override is None:
310355
num_hidden_layers = config.num_hidden_layers
311356
else:
312357
num_hidden_layers = num_hidden_layers_override
358+
313359
self.layers = nn.ModuleList([
314-
self._init_encoder_layer(config, quant_config)
360+
InternVisionEncoderLayer(config,
361+
quant_config,
362+
num_dummy_heads=num_dummy_heads)
315363
for _ in range(num_hidden_layers)
316364
])
317365

318-
def _init_encoder_layer(self, config: PretrainedConfig,
319-
quant_config: Optional[QuantizationConfig]):
320-
return InternVisionEncoderLayer(config=config,
321-
quant_config=quant_config)
322-
323366
def forward(self, inputs_embeds: torch.Tensor):
324367

325368
hidden_states = inputs_embeds
@@ -331,30 +374,24 @@ def forward(self, inputs_embeds: torch.Tensor):
331374

332375
class InternVisionModel(nn.Module):
333376

334-
def __init__(self,
335-
config: PretrainedConfig,
336-
quant_config: Optional[QuantizationConfig] = None,
337-
num_hidden_layers_override: Optional[int] = None):
377+
def __init__(
378+
self,
379+
config: PretrainedConfig,
380+
quant_config: Optional[QuantizationConfig] = None,
381+
*,
382+
num_hidden_layers_override: Optional[int] = None,
383+
num_dummy_heads: int = 0,
384+
):
338385
super().__init__()
339-
self.config = config
340-
341-
self.embeddings = self._init_embeddings(config)
342-
self.encoder = self._init_encoder(
343-
config,
344-
quant_config,
345-
num_hidden_layers_override=num_hidden_layers_override,
346-
)
347386

348-
def _init_embeddings(self, config: PretrainedConfig):
349-
return InternVisionEmbeddings(config)
387+
self.config = config
350388

351-
def _init_encoder(self, config: PretrainedConfig,
352-
quant_config: Optional[QuantizationConfig],
353-
num_hidden_layers_override: Optional[int]):
354-
return InternVisionEncoder(
389+
self.embeddings = InternVisionEmbeddings(config)
390+
self.encoder = InternVisionEncoder(
355391
config=config,
356392
quant_config=quant_config,
357393
num_hidden_layers_override=num_hidden_layers_override,
394+
num_dummy_heads=num_dummy_heads,
358395
)
359396

360397
def get_input_embeddings(self):

0 commit comments

Comments
 (0)