Skip to content

Commit 65e760c

Browse files
committed
test gpt2 implementation
1 parent b989b8a commit 65e760c

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

src/cehrbert/models/hf_models/hf_cehrbert.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from transformers.activations import gelu_new
99
from transformers.models.bert import modeling_bert
1010
from transformers.models.bert.modeling_bert import BertEncoder, BertOnlyMLMHead, BertPooler
11+
from transformers.pytorch_utils import Conv1D
1112
from transformers.utils import is_flash_attn_2_available, logging
1213

1314
if is_flash_attn_2_available():
@@ -50,7 +51,7 @@ def flash_attention_forward(
5051
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
5152
is_causal (`bool`, *optional*):
5253
"""
53-
batch_size, query_length = query_states.shape[:2]
54+
batch_size, query_length, n_heads, head_dim = query_states.shape
5455
query_states = query_states.to(torch.bfloat16)
5556
key_states = key_states.to(torch.bfloat16)
5657
value_states = value_states.to(torch.bfloat16)
@@ -91,7 +92,7 @@ def flash_attention_forward(
9192
softmax_scale=softmax_scale,
9293
causal=is_causal,
9394
)
94-
return attn_output.reshape(batch_size, query_length, -1)
95+
return attn_output.reshape(batch_size, query_length, n_heads * head_dim)
9596

9697

9798
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
@@ -161,11 +162,10 @@ def __init__(self, config, position_embedding_type=None):
161162
self.num_attention_heads = config.num_attention_heads
162163
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
163164
self.all_head_size = self.num_attention_heads * self.attention_head_size
164-
165-
self.query = nn.Linear(config.hidden_size, self.all_head_size)
166-
self.key = nn.Linear(config.hidden_size, self.all_head_size)
167-
self.value = nn.Linear(config.hidden_size, self.all_head_size)
168-
165+
self.split_size = config.hidden_size
166+
self.embed_dim = config.hidden_size
167+
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
168+
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
169169
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
170170

171171
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
@@ -182,8 +182,8 @@ def forward(
182182
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
183183
output_attentions: Optional[bool] = False,
184184
) -> Tuple[torch.Tensor]:
185-
mixed_query_layer = self.query(hidden_states)
186185

186+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
187187
# If this is instantiated as a cross-attention module, the keys
188188
# and values come from an encoder; the attention mask needs to be
189189
# such that the encoder's padding tokens are not attended to.
@@ -195,29 +195,32 @@ def forward(
195195
value_layer = past_key_value[1]
196196
attention_mask = encoder_attention_mask
197197
elif is_cross_attention:
198-
key_layer = self.split_heads(self.key(encoder_hidden_states))
199-
value_layer = self.split_heads(self.value(encoder_hidden_states))
198+
key_layer = self.split_heads(key)
199+
value_layer = self.split_heads(value)
200200
attention_mask = encoder_attention_mask
201201
elif past_key_value is not None:
202-
key_layer = self.split_heads(self.key(hidden_states))
203-
value_layer = self.split_heads(self.value(hidden_states))
202+
key_layer = self.split_heads(key)
203+
value_layer = self.split_heads(value)
204204
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
205205
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
206206
else:
207-
key_layer = self.split_heads(self.key(hidden_states))
208-
value_layer = self.split_heads(self.value(hidden_states))
207+
key_layer = self.split_heads(key)
208+
value_layer = self.split_heads(value)
209209

210-
query_layer = self.split_heads(mixed_query_layer)
210+
query_layer = self.split_heads(query)
211+
attn_dropout = self.attn_dropout.p if self.training else 0.0
211212
# Flash Attention forward pass
212213
attn_output = flash_attention_forward(
213214
query_layer,
214215
key_layer,
215216
value_layer,
216217
attention_mask,
217-
self.dropout.p,
218+
attn_dropout,
218219
softmax_scale=None,
219220
is_causal=False,
220221
)
222+
attn_output = self.c_proj(attn_output)
223+
attn_output = self.dropout(attn_output)
221224
# The BertLayer expects a tuple
222225
return (attn_output,)
223226

0 commit comments

Comments
 (0)