Skip to content

Commit 37e7d91

Browse files
committed
convert tensors back to the original dtype in the flash attention implementation
1 parent 659ce64 commit 37e7d91

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

Diff for: src/cehrbert/models/hf_models/hf_cehrbert.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def flash_attention_forward(
5151
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
5252
is_causal (`bool`, *optional*):
5353
"""
54+
dtype = query_states.dtype
5455
batch_size, query_length, n_heads, head_dim = query_states.shape
5556
query_states = query_states.to(torch.bfloat16)
5657
key_states = key_states.to(torch.bfloat16)
@@ -92,7 +93,7 @@ def flash_attention_forward(
9293
softmax_scale=softmax_scale,
9394
causal=is_causal,
9495
)
95-
return attn_output.reshape(batch_size, query_length, n_heads * head_dim)
96+
return attn_output.reshape(batch_size, query_length, n_heads * head_dim).to(dtype)
9697

9798

9899
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input

0 commit comments

Comments
 (0)