File tree 4 files changed +20
-18
lines changed
4 files changed +20
-18
lines changed Original file line number Diff line number Diff line change @@ -18,12 +18,12 @@ jobs:
18
18
- cuda : " 118"
19
19
cuda_version : 11.8.0
20
20
python_version : " 3.9"
21
- pytorch : 2.0.0
21
+ pytorch : 2.0.1
22
22
axolotl_extras :
23
23
- cuda : " 118"
24
24
cuda_version : 11.8.0
25
25
python_version : " 3.10"
26
- pytorch : 2.0.0
26
+ pytorch : 2.0.1
27
27
axolotl_extras :
28
28
- cuda : " 117"
29
29
cuda_version : 11.7.1
33
33
- cuda : " 118"
34
34
cuda_version : 11.8.0
35
35
python_version : " 3.9"
36
- pytorch : 2.0.0
36
+ pytorch : 2.0.1
37
37
axolotl_extras : gptq
38
38
steps :
39
39
- name : Checkout
Original file line number Diff line number Diff line change @@ -17,17 +17,17 @@ jobs:
17
17
- cuda : cu118
18
18
cuda_version : 11.8.0
19
19
python_version : " 3.9"
20
- pytorch : 2.0.0
20
+ pytorch : 2.0.1
21
21
axolotl_extras :
22
22
- cuda : cu118
23
23
cuda_version : 11.8.0
24
24
python_version : " 3.10"
25
- pytorch : 2.0.0
25
+ pytorch : 2.0.1
26
26
axolotl_extras :
27
27
- cuda : cu118
28
28
cuda_version : 11.8.0
29
29
python_version : " 3.9"
30
- pytorch : 2.0.0
30
+ pytorch : 2.0.1
31
31
axolotl_extras : gptq
32
32
- cuda : cu117
33
33
cuda_version : 11.7.1
@@ -72,17 +72,17 @@ jobs:
72
72
- cuda : cu118
73
73
cuda_version : 11.8.0
74
74
python_version : " 3.9"
75
- pytorch : 2.0.0
75
+ pytorch : 2.0.1
76
76
axolotl_extras :
77
77
- cuda : cu118
78
78
cuda_version : 11.8.0
79
79
python_version : " 3.10"
80
- pytorch : 2.0.0
80
+ pytorch : 2.0.1
81
81
axolotl_extras :
82
82
- cuda : cu118
83
83
cuda_version : 11.8.0
84
84
python_version : " 3.9"
85
- pytorch : 2.0.0
85
+ pytorch : 2.0.1
86
86
axolotl_extras : gptq
87
87
- cuda : cu117
88
88
cuda_version : 11.7.1
Original file line number Diff line number Diff line change @@ -38,8 +38,9 @@ WORKDIR /workspace
38
38
39
39
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
40
40
41
- RUN git clone https://github.com/HazyResearch /flash-attention.git && \
41
+ RUN git clone https://github.com/Dao-AILab /flash-attention.git && \
42
42
cd flash-attention && \
43
+ git checkout v1.0.9 && \
43
44
python3 setup.py bdist_wheel && \
44
45
cd csrc/fused_dense_lib && \
45
46
python3 setup.py bdist_wheel && \
Original file line number Diff line number Diff line change @@ -184,14 +184,15 @@ def sdp_attention_forward(
184
184
185
185
# We only apply sdp attention if we don't need to output the whole attention matrix
186
186
if not output_attentions :
187
- attn_output = torch .nn .functional .scaled_dot_product_attention (
188
- query_states ,
189
- key_states ,
190
- value_states ,
191
- attn_mask = attention_mask ,
192
- is_causal = False ,
193
- )
194
- attn_weights = None
187
+ with torch .backends .cuda .sdp_kernel ():
188
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
189
+ query_states ,
190
+ key_states ,
191
+ value_states ,
192
+ attn_mask = attention_mask ,
193
+ is_causal = False ,
194
+ )
195
+ attn_weights = None
195
196
else :
196
197
attn_weights = torch .matmul (
197
198
query_states , key_states .transpose (2 , 3 )
You can’t perform that action at this time.
0 commit comments