@@ -88,6 +88,8 @@ def __init__(
88
88
self .num_dummy_heads = num_dummy_heads
89
89
self .dummy_dim = (self .num_dummy_heads +
90
90
self .num_heads ) * self .head_dim
91
+ self .num_heads_per_partition = divide (
92
+ self .num_dummy_heads + self .num_heads , self .tp_size )
91
93
92
94
self .scale = self .head_dim ** - 0.5
93
95
self .qkv = QKVParallelLinear (
@@ -114,26 +116,31 @@ def __init__(
114
116
quant_config = quant_config ,
115
117
)
116
118
117
- self .tp_size = get_tensor_model_parallel_world_size ()
118
- self .num_heads_per_partition = divide (
119
- self .num_dummy_heads + self .num_heads , self .tp_size )
119
+ def _apply_qk_norm (self , q , k ):
120
+ if self .tp_size > 1 :
121
+ q = tensor_model_parallel_all_gather (q .contiguous ())
122
+ k = tensor_model_parallel_all_gather (k .contiguous ())
123
+ q = self .q_norm .forward_native (q )
124
+ k = self .k_norm .forward_native (k )
125
+ if self .tp_size > 1 :
126
+ splitter = partial (split_tensor_along_last_dim ,
127
+ num_partitions = self .tp_size )
128
+ q = splitter (q )[self .tp_rank ]
129
+ k = splitter (k )[self .tp_rank ]
130
+ return q , k
120
131
121
132
def forward (self , x ):
122
- B , N , C = x .shape
133
+ B , N , _ = x .shape
123
134
qkv , _ = self .qkv (x )
124
135
q , k , v = qkv .chunk (3 , dim = - 1 )
125
136
137
+ if self .qk_normalization :
138
+ q , k = self ._apply_qk_norm (q , k )
139
+
126
140
q = q .view (B , N , self .num_heads_per_partition , self .head_dim )
127
141
k = k .view (B , N , self .num_heads_per_partition , self .head_dim )
128
142
v = v .view (B , N , self .num_heads_per_partition , self .head_dim )
129
143
130
- if self .qk_normalization :
131
- B_ , N_ , H_ , D_ = q .shape
132
- q = self .q_norm .forward_native (q .flatten (- 2 ,
133
- - 1 )).view (B_ , N_ , H_ , D_ )
134
- k = self .k_norm .forward_native (k .flatten (- 2 ,
135
- - 1 )).view (B_ , N_ , H_ , D_ )
136
-
137
144
x = xops .memory_efficient_attention_forward (q , k , v , scale = self .scale )
138
145
x = x .view (B , N , - 1 )
139
146
@@ -179,31 +186,21 @@ def __init__(self, config: PretrainedConfig, num_dummy_heads: int = 7):
179
186
180
187
self .proj = nn .Linear (self .dummy_dim , self .embed_dim )
181
188
182
- def _apply_qk_norm (self , q , k ):
183
- if self .tp_size > 1 :
184
- q = tensor_model_parallel_all_gather (q .contiguous ())
185
- k = tensor_model_parallel_all_gather (k .contiguous ())
186
- q = self .q_norm .forward_native (q )
187
- k = self .k_norm .forward_native (k )
188
- if self .tp_size > 1 :
189
- splitter = partial (split_tensor_along_last_dim ,
190
- num_partitions = self .tp_size )
191
- q = splitter (q )[self .tp_rank ]
192
- k = splitter (k )[self .tp_rank ]
193
- return q , k
194
-
195
189
def forward (self , x ):
196
190
B , N , C = x .shape
197
191
qkv = self .qkv (x )
198
192
q , k , v = qkv .chunk (3 , dim = - 1 )
199
193
200
- if self .qk_normalization :
201
- q , k = self ._apply_qk_norm (q , k )
202
-
203
194
q = q .view (B , N , self .num_dummy_heads + self .num_heads , self .head_dim )
204
195
k = k .view (B , N , self .num_dummy_heads + self .num_heads , self .head_dim )
205
196
v = v .view (B , N , self .num_dummy_heads + self .num_heads , self .head_dim )
206
197
198
+ if self .qk_normalization :
199
+ B_ , N_ , H_ , D_ = q .shape
200
+ q = self .q_norm .forward_native (q .flatten (- 2 ,
201
+ - 1 )).view (B_ , N_ , H_ , D_ )
202
+ k = self .k_norm .forward_native (k .flatten (- 2 ,
203
+ - 1 )).view (B_ , N_ , H_ , D_ )
207
204
q = q .transpose (1 , 2 )
208
205
k = k .transpose (1 , 2 )
209
206
v = v .transpose (1 , 2 )
0 commit comments