Skip to content

Commit a94f2ee

Browse files
authored
Merge pull request axolotl-ai-cloud#299 from OpenAccess-AI-Collective/flash-attention-2
Flash attention 2
2 parents 1b63bf1 + cdf85fd commit a94f2ee

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

docker/Dockerfile-base

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
4040

4141
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
4242
cd flash-attention && \
43-
git checkout v1.0.9 && \
43+
git checkout 9ee0ff1 && \
4444
python3 setup.py bdist_wheel && \
4545
cd csrc/fused_dense_lib && \
4646
python3 setup.py bdist_wheel && \

src/axolotl/flash_attn.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import transformers
99
from einops import rearrange
1010
from flash_attn.bert_padding import pad_input, unpad_input
11-
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
1212
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
1313

1414

@@ -79,7 +79,7 @@ def forward(
7979
dtype=torch.int32,
8080
device=qkv.device,
8181
)
82-
output = flash_attn_unpadded_qkvpacked_func(
82+
output = flash_attn_varlen_qkvpacked_func(
8383
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
8484
)
8585
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
@@ -95,7 +95,7 @@ def forward(
9595
three=3,
9696
h=nheads,
9797
)
98-
output_unpad = flash_attn_unpadded_qkvpacked_func(
98+
output_unpad = flash_attn_varlen_qkvpacked_func(
9999
x_unpad,
100100
cu_q_lens,
101101
max_s,

0 commit comments

Comments
 (0)