Skip to content

Commit 49e3dad

Browse files
ywang96Isotr0py
andcommitted
fix qk norm for paralleled VIT attention
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent b5ea51b commit 49e3dad

File tree

2 files changed

+46
-16
lines changed

2 files changed

+46
-16
lines changed

vllm/model_executor/models/intern_vit.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
# Copyright (c) 2023 OpenGVLab
55
# Licensed under The MIT License [see LICENSE for details]
66
# --------------------------------------------------------
7+
from functools import partial
78
from typing import Iterable, Optional, Tuple
89

910
import torch
1011
import torch.nn as nn
1112
import torch.nn.functional as F
1213
from transformers import PretrainedConfig
1314

14-
from vllm.distributed import divide, get_tensor_model_parallel_world_size
15+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
16+
get_tensor_model_parallel_world_size,
17+
split_tensor_along_last_dim,
18+
tensor_model_parallel_all_gather)
1519
from vllm.model_executor.layers.activation import get_act_fn
1620
from vllm.model_executor.layers.layernorm import RMSNorm
1721
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -104,6 +108,8 @@ def __init__(
104108
self.embed_dim = config.hidden_size
105109
self.num_heads = config.num_attention_heads
106110
self.head_dim = self.embed_dim // self.num_heads
111+
self.tp_size = get_tensor_model_parallel_world_size()
112+
self.tp_rank = get_tensor_model_parallel_rank()
107113
if self.head_dim * self.num_heads != self.embed_dim:
108114
raise ValueError(
109115
f'embed_dim must be divisible by num_heads '
@@ -134,22 +140,31 @@ def __init__(
134140
self.tp_size = get_tensor_model_parallel_world_size()
135141
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
136142

143+
def _apply_qk_norm(self, q, k):
144+
if self.tp_size > 1:
145+
q = tensor_model_parallel_all_gather(q.contiguous())
146+
k = tensor_model_parallel_all_gather(k.contiguous())
147+
q = self.q_norm.forward_native(q)
148+
k = self.k_norm.forward_native(k)
149+
if self.tp_size > 1:
150+
splitter = partial(split_tensor_along_last_dim,
151+
num_partitions=self.tp_size)
152+
q = splitter(q)[self.tp_rank]
153+
k = splitter(k)[self.tp_rank]
154+
return q, k
155+
137156
def forward(self, x):
138157
B, N, C = x.shape
139158
qkv, _ = self.qkv(x)
140159
q, k, v = qkv.chunk(3, dim=-1)
141160

161+
if self.qk_normalization:
162+
q, k = self._apply_qk_norm(q, k)
163+
142164
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
143165
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
144166
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
145167

146-
if self.qk_normalization:
147-
B_, N_, H_, D_ = q.shape
148-
q = self.q_norm.forward_native(q.flatten(-2,
149-
-1)).view(B_, N_, H_, D_)
150-
k = self.k_norm.forward_native(k.flatten(-2,
151-
-1)).view(B_, N_, H_, D_)
152-
153168
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
154169
x = x.view(B, N, -1)
155170

vllm/model_executor/models/nvlm_d.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
# Copyright (c) 2024 NVIDIA
55
# Licensed under Apache 2.0 License [see LICENSE for details]
66
# --------------------------------------------------------
7+
from functools import partial
78
from typing import Optional
89

910
import torch
1011
import torch.nn as nn
1112
import torch.nn.functional as F
1213
from transformers import PretrainedConfig
1314

14-
from vllm.distributed import divide, get_tensor_model_parallel_world_size
15+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
16+
get_tensor_model_parallel_world_size,
17+
split_tensor_along_last_dim,
18+
tensor_model_parallel_all_gather)
1519
from vllm.inputs import INPUT_REGISTRY
1620
from vllm.model_executor.layers.layernorm import RMSNorm
1721
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@@ -71,6 +75,8 @@ def __init__(
7175
self.embed_dim = config.hidden_size
7276
self.num_heads = config.num_attention_heads
7377
self.head_dim = self.embed_dim // self.num_heads
78+
self.tp_size = get_tensor_model_parallel_world_size()
79+
self.tp_rank = get_tensor_model_parallel_rank()
7480
if self.head_dim * self.num_heads != self.embed_dim:
7581
raise ValueError(
7682
f'embed_dim must be divisible by num_heads '
@@ -173,22 +179,31 @@ def __init__(self, config: PretrainedConfig, num_dummy_heads: int = 7):
173179

174180
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
175181

182+
def _apply_qk_norm(self, q, k):
183+
if self.tp_size > 1:
184+
q = tensor_model_parallel_all_gather(q.contiguous())
185+
k = tensor_model_parallel_all_gather(k.contiguous())
186+
q = self.q_norm.forward_native(q)
187+
k = self.k_norm.forward_native(k)
188+
if self.tp_size > 1:
189+
splitter = partial(split_tensor_along_last_dim,
190+
num_partitions=self.tp_size)
191+
q = splitter(q)[self.tp_rank]
192+
k = splitter(k)[self.tp_rank]
193+
return q, k
194+
176195
def forward(self, x):
177196
B, N, C = x.shape
178197
qkv = self.qkv(x)
179198
q, k, v = qkv.chunk(3, dim=-1)
180199

200+
if self.qk_normalization:
201+
q, k = self._apply_qk_norm(q, k)
202+
181203
q = q.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim)
182204
k = k.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim)
183205
v = v.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim)
184206

185-
if self.qk_normalization:
186-
B_, N_, H_, D_ = q.shape
187-
q = self.q_norm.forward_native(q.flatten(-2,
188-
-1)).view(B_, N_, H_, D_)
189-
k = self.k_norm.forward_native(k.flatten(-2,
190-
-1)).view(B_, N_, H_, D_)
191-
192207
q = q.transpose(1, 2)
193208
k = k.transpose(1, 2)
194209
v = v.transpose(1, 2)

0 commit comments

Comments
 (0)