|
4 | 4 | # Copyright (c) 2023 OpenGVLab
|
5 | 5 | # Licensed under The MIT License [see LICENSE for details]
|
6 | 6 | # --------------------------------------------------------
|
| 7 | +from functools import partial |
7 | 8 | from typing import Iterable, Optional, Tuple
|
8 | 9 |
|
9 | 10 | import torch
|
10 | 11 | import torch.nn as nn
|
11 | 12 | import torch.nn.functional as F
|
12 | 13 | from transformers import PretrainedConfig
|
13 | 14 |
|
14 |
| -from vllm.distributed import divide, get_tensor_model_parallel_world_size |
| 15 | +from vllm.distributed import (divide, get_tensor_model_parallel_rank, |
| 16 | + get_tensor_model_parallel_world_size, |
| 17 | + split_tensor_along_last_dim, |
| 18 | + tensor_model_parallel_all_gather) |
15 | 19 | from vllm.model_executor.layers.activation import get_act_fn
|
16 | 20 | from vllm.model_executor.layers.layernorm import RMSNorm
|
17 | 21 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
@@ -104,6 +108,8 @@ def __init__(
|
104 | 108 | self.embed_dim = config.hidden_size
|
105 | 109 | self.num_heads = config.num_attention_heads
|
106 | 110 | self.head_dim = self.embed_dim // self.num_heads
|
| 111 | + self.tp_size = get_tensor_model_parallel_world_size() |
| 112 | + self.tp_rank = get_tensor_model_parallel_rank() |
107 | 113 | if self.head_dim * self.num_heads != self.embed_dim:
|
108 | 114 | raise ValueError(
|
109 | 115 | f'embed_dim must be divisible by num_heads '
|
@@ -134,22 +140,31 @@ def __init__(
|
134 | 140 | self.tp_size = get_tensor_model_parallel_world_size()
|
135 | 141 | self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
136 | 142 |
|
| 143 | + def _apply_qk_norm(self, q, k): |
| 144 | + if self.tp_size > 1: |
| 145 | + q = tensor_model_parallel_all_gather(q.contiguous()) |
| 146 | + k = tensor_model_parallel_all_gather(k.contiguous()) |
| 147 | + q = self.q_norm.forward_native(q) |
| 148 | + k = self.k_norm.forward_native(k) |
| 149 | + if self.tp_size > 1: |
| 150 | + splitter = partial(split_tensor_along_last_dim, |
| 151 | + num_partitions=self.tp_size) |
| 152 | + q = splitter(q)[self.tp_rank] |
| 153 | + k = splitter(k)[self.tp_rank] |
| 154 | + return q, k |
| 155 | + |
137 | 156 | def forward(self, x):
|
138 | 157 | B, N, C = x.shape
|
139 | 158 | qkv, _ = self.qkv(x)
|
140 | 159 | q, k, v = qkv.chunk(3, dim=-1)
|
141 | 160 |
|
| 161 | + if self.qk_normalization: |
| 162 | + q, k = self._apply_qk_norm(q, k) |
| 163 | + |
142 | 164 | q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
143 | 165 | k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
144 | 166 | v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
145 | 167 |
|
146 |
| - if self.qk_normalization: |
147 |
| - B_, N_, H_, D_ = q.shape |
148 |
| - q = self.q_norm.forward_native(q.flatten(-2, |
149 |
| - -1)).view(B_, N_, H_, D_) |
150 |
| - k = self.k_norm.forward_native(k.flatten(-2, |
151 |
| - -1)).view(B_, N_, H_, D_) |
152 |
| - |
153 | 168 | x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
|
154 | 169 | x = x.view(B, N, -1)
|
155 | 170 |
|
|
0 commit comments