8
8
from transformers .activations import gelu_new
9
9
from transformers .models .bert import modeling_bert
10
10
from transformers .models .bert .modeling_bert import BertEncoder , BertOnlyMLMHead , BertPooler
11
+ from transformers .pytorch_utils import Conv1D
11
12
from transformers .utils import is_flash_attn_2_available , logging
12
13
13
14
if is_flash_attn_2_available ():
@@ -50,7 +51,7 @@ def flash_attention_forward(
50
51
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
51
52
is_causal (`bool`, *optional*):
52
53
"""
53
- batch_size , query_length = query_states .shape [: 2 ]
54
+ batch_size , query_length , n_heads , head_dim = query_states .shape
54
55
query_states = query_states .to (torch .bfloat16 )
55
56
key_states = key_states .to (torch .bfloat16 )
56
57
value_states = value_states .to (torch .bfloat16 )
@@ -91,7 +92,7 @@ def flash_attention_forward(
91
92
softmax_scale = softmax_scale ,
92
93
causal = is_causal ,
93
94
)
94
- return attn_output .reshape (batch_size , query_length , - 1 )
95
+ return attn_output .reshape (batch_size , query_length , n_heads * head_dim )
95
96
96
97
97
98
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
@@ -161,11 +162,10 @@ def __init__(self, config, position_embedding_type=None):
161
162
self .num_attention_heads = config .num_attention_heads
162
163
self .attention_head_size = int (config .hidden_size / config .num_attention_heads )
163
164
self .all_head_size = self .num_attention_heads * self .attention_head_size
164
-
165
- self .query = nn .Linear (config .hidden_size , self .all_head_size )
166
- self .key = nn .Linear (config .hidden_size , self .all_head_size )
167
- self .value = nn .Linear (config .hidden_size , self .all_head_size )
168
-
165
+ self .split_size = config .hidden_size
166
+ self .embed_dim = config .hidden_size
167
+ self .c_attn = Conv1D (3 * self .embed_dim , self .embed_dim )
168
+ self .c_proj = Conv1D (self .embed_dim , self .embed_dim )
169
169
self .dropout = nn .Dropout (config .attention_probs_dropout_prob )
170
170
171
171
def split_heads (self , x : torch .Tensor ) -> torch .Tensor :
@@ -182,8 +182,8 @@ def forward(
182
182
past_key_value : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None ,
183
183
output_attentions : Optional [bool ] = False ,
184
184
) -> Tuple [torch .Tensor ]:
185
- mixed_query_layer = self .query (hidden_states )
186
185
186
+ query , key , value = self .c_attn (hidden_states ).split (self .split_size , dim = 2 )
187
187
# If this is instantiated as a cross-attention module, the keys
188
188
# and values come from an encoder; the attention mask needs to be
189
189
# such that the encoder's padding tokens are not attended to.
@@ -195,29 +195,32 @@ def forward(
195
195
value_layer = past_key_value [1 ]
196
196
attention_mask = encoder_attention_mask
197
197
elif is_cross_attention :
198
- key_layer = self .split_heads (self . key ( encoder_hidden_states ) )
199
- value_layer = self .split_heads (self . value ( encoder_hidden_states ) )
198
+ key_layer = self .split_heads (key )
199
+ value_layer = self .split_heads (value )
200
200
attention_mask = encoder_attention_mask
201
201
elif past_key_value is not None :
202
- key_layer = self .split_heads (self . key ( hidden_states ) )
203
- value_layer = self .split_heads (self . value ( hidden_states ) )
202
+ key_layer = self .split_heads (key )
203
+ value_layer = self .split_heads (value )
204
204
key_layer = torch .cat ([past_key_value [0 ], key_layer ], dim = 2 )
205
205
value_layer = torch .cat ([past_key_value [1 ], value_layer ], dim = 2 )
206
206
else :
207
- key_layer = self .split_heads (self . key ( hidden_states ) )
208
- value_layer = self .split_heads (self . value ( hidden_states ) )
207
+ key_layer = self .split_heads (key )
208
+ value_layer = self .split_heads (value )
209
209
210
- query_layer = self .split_heads (mixed_query_layer )
210
+ query_layer = self .split_heads (query )
211
+ attn_dropout = self .attn_dropout .p if self .training else 0.0
211
212
# Flash Attention forward pass
212
213
attn_output = flash_attention_forward (
213
214
query_layer ,
214
215
key_layer ,
215
216
value_layer ,
216
217
attention_mask ,
217
- self . dropout . p ,
218
+ attn_dropout ,
218
219
softmax_scale = None ,
219
220
is_causal = False ,
220
221
)
222
+ attn_output = self .c_proj (attn_output )
223
+ attn_output = self .dropout (attn_output )
221
224
# The BertLayer expects a tuple
222
225
return (attn_output ,)
223
226
0 commit comments