Skip to content

Commit 759e749

Browse files
ywang96Isotr0py
andcommitted
clean up
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent 49e3dad commit 759e749

File tree

2 files changed

+27
-31
lines changed

2 files changed

+27
-31
lines changed

vllm/model_executor/models/intern_vit.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def __init__(
110110
self.head_dim = self.embed_dim // self.num_heads
111111
self.tp_size = get_tensor_model_parallel_world_size()
112112
self.tp_rank = get_tensor_model_parallel_rank()
113+
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
114+
113115
if self.head_dim * self.num_heads != self.embed_dim:
114116
raise ValueError(
115117
f'embed_dim must be divisible by num_heads '
@@ -137,9 +139,6 @@ def __init__(
137139
quant_config=quant_config,
138140
)
139141

140-
self.tp_size = get_tensor_model_parallel_world_size()
141-
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
142-
143142
def _apply_qk_norm(self, q, k):
144143
if self.tp_size > 1:
145144
q = tensor_model_parallel_all_gather(q.contiguous())
@@ -154,7 +153,7 @@ def _apply_qk_norm(self, q, k):
154153
return q, k
155154

156155
def forward(self, x):
157-
B, N, C = x.shape
156+
B, N, _ = x.shape
158157
qkv, _ = self.qkv(x)
159158
q, k, v = qkv.chunk(3, dim=-1)
160159

vllm/model_executor/models/nvlm_d.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def __init__(
8888
self.num_dummy_heads = num_dummy_heads
8989
self.dummy_dim = (self.num_dummy_heads +
9090
self.num_heads) * self.head_dim
91+
self.num_heads_per_partition = divide(
92+
self.num_dummy_heads + self.num_heads, self.tp_size)
9193

9294
self.scale = self.head_dim**-0.5
9395
self.qkv = QKVParallelLinear(
@@ -114,26 +116,31 @@ def __init__(
114116
quant_config=quant_config,
115117
)
116118

117-
self.tp_size = get_tensor_model_parallel_world_size()
118-
self.num_heads_per_partition = divide(
119-
self.num_dummy_heads + self.num_heads, self.tp_size)
119+
def _apply_qk_norm(self, q, k):
120+
if self.tp_size > 1:
121+
q = tensor_model_parallel_all_gather(q.contiguous())
122+
k = tensor_model_parallel_all_gather(k.contiguous())
123+
q = self.q_norm.forward_native(q)
124+
k = self.k_norm.forward_native(k)
125+
if self.tp_size > 1:
126+
splitter = partial(split_tensor_along_last_dim,
127+
num_partitions=self.tp_size)
128+
q = splitter(q)[self.tp_rank]
129+
k = splitter(k)[self.tp_rank]
130+
return q, k
120131

121132
def forward(self, x):
122-
B, N, C = x.shape
133+
B, N, _ = x.shape
123134
qkv, _ = self.qkv(x)
124135
q, k, v = qkv.chunk(3, dim=-1)
125136

137+
if self.qk_normalization:
138+
q, k = self._apply_qk_norm(q, k)
139+
126140
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
127141
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
128142
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
129143

130-
if self.qk_normalization:
131-
B_, N_, H_, D_ = q.shape
132-
q = self.q_norm.forward_native(q.flatten(-2,
133-
-1)).view(B_, N_, H_, D_)
134-
k = self.k_norm.forward_native(k.flatten(-2,
135-
-1)).view(B_, N_, H_, D_)
136-
137144
x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
138145
x = x.view(B, N, -1)
139146

@@ -179,31 +186,21 @@ def __init__(self, config: PretrainedConfig, num_dummy_heads: int = 7):
179186

180187
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
181188

182-
def _apply_qk_norm(self, q, k):
183-
if self.tp_size > 1:
184-
q = tensor_model_parallel_all_gather(q.contiguous())
185-
k = tensor_model_parallel_all_gather(k.contiguous())
186-
q = self.q_norm.forward_native(q)
187-
k = self.k_norm.forward_native(k)
188-
if self.tp_size > 1:
189-
splitter = partial(split_tensor_along_last_dim,
190-
num_partitions=self.tp_size)
191-
q = splitter(q)[self.tp_rank]
192-
k = splitter(k)[self.tp_rank]
193-
return q, k
194-
195189
def forward(self, x):
196190
B, N, C = x.shape
197191
qkv = self.qkv(x)
198192
q, k, v = qkv.chunk(3, dim=-1)
199193

200-
if self.qk_normalization:
201-
q, k = self._apply_qk_norm(q, k)
202-
203194
q = q.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim)
204195
k = k.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim)
205196
v = v.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim)
206197

198+
if self.qk_normalization:
199+
B_, N_, H_, D_ = q.shape
200+
q = self.q_norm.forward_native(q.flatten(-2,
201+
-1)).view(B_, N_, H_, D_)
202+
k = self.k_norm.forward_native(k.flatten(-2,
203+
-1)).view(B_, N_, H_, D_)
207204
q = q.transpose(1, 2)
208205
k = k.transpose(1, 2)
209206
v = v.transpose(1, 2)

0 commit comments

Comments
 (0)