Skip to content

Commit 262dc29

Browse files
authored
Merge pull request axolotl-ai-cloud#300 from OpenAccess-AI-Collective/pytorch-201
Pytorch 2.0.1
2 parents 28fd429 + a032c9f commit 262dc29

File tree

4 files changed

+20
-18
lines changed

4 files changed

+20
-18
lines changed

Diff for: .github/workflows/base.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ jobs:
1818
- cuda: "118"
1919
cuda_version: 11.8.0
2020
python_version: "3.9"
21-
pytorch: 2.0.0
21+
pytorch: 2.0.1
2222
axolotl_extras:
2323
- cuda: "118"
2424
cuda_version: 11.8.0
2525
python_version: "3.10"
26-
pytorch: 2.0.0
26+
pytorch: 2.0.1
2727
axolotl_extras:
2828
- cuda: "117"
2929
cuda_version: 11.7.1
@@ -33,7 +33,7 @@ jobs:
3333
- cuda: "118"
3434
cuda_version: 11.8.0
3535
python_version: "3.9"
36-
pytorch: 2.0.0
36+
pytorch: 2.0.1
3737
axolotl_extras: gptq
3838
steps:
3939
- name: Checkout

Diff for: .github/workflows/main.yml

+6-6
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@ jobs:
1717
- cuda: cu118
1818
cuda_version: 11.8.0
1919
python_version: "3.9"
20-
pytorch: 2.0.0
20+
pytorch: 2.0.1
2121
axolotl_extras:
2222
- cuda: cu118
2323
cuda_version: 11.8.0
2424
python_version: "3.10"
25-
pytorch: 2.0.0
25+
pytorch: 2.0.1
2626
axolotl_extras:
2727
- cuda: cu118
2828
cuda_version: 11.8.0
2929
python_version: "3.9"
30-
pytorch: 2.0.0
30+
pytorch: 2.0.1
3131
axolotl_extras: gptq
3232
- cuda: cu117
3333
cuda_version: 11.7.1
@@ -72,17 +72,17 @@ jobs:
7272
- cuda: cu118
7373
cuda_version: 11.8.0
7474
python_version: "3.9"
75-
pytorch: 2.0.0
75+
pytorch: 2.0.1
7676
axolotl_extras:
7777
- cuda: cu118
7878
cuda_version: 11.8.0
7979
python_version: "3.10"
80-
pytorch: 2.0.0
80+
pytorch: 2.0.1
8181
axolotl_extras:
8282
- cuda: cu118
8383
cuda_version: 11.8.0
8484
python_version: "3.9"
85-
pytorch: 2.0.0
85+
pytorch: 2.0.1
8686
axolotl_extras: gptq
8787
- cuda: cu117
8888
cuda_version: 11.7.1

Diff for: docker/Dockerfile-base

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ WORKDIR /workspace
3838

3939
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
4040

41-
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
41+
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
4242
cd flash-attention && \
43+
git checkout v1.0.9 && \
4344
python3 setup.py bdist_wheel && \
4445
cd csrc/fused_dense_lib && \
4546
python3 setup.py bdist_wheel && \

Diff for: src/axolotl/monkeypatch/llama_attn_hijack_xformers.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,15 @@ def sdp_attention_forward(
184184

185185
# We only apply sdp attention if we don't need to output the whole attention matrix
186186
if not output_attentions:
187-
attn_output = torch.nn.functional.scaled_dot_product_attention(
188-
query_states,
189-
key_states,
190-
value_states,
191-
attn_mask=attention_mask,
192-
is_causal=False,
193-
)
194-
attn_weights = None
187+
with torch.backends.cuda.sdp_kernel():
188+
attn_output = torch.nn.functional.scaled_dot_product_attention(
189+
query_states,
190+
key_states,
191+
value_states,
192+
attn_mask=attention_mask,
193+
is_causal=False,
194+
)
195+
attn_weights = None
195196
else:
196197
attn_weights = torch.matmul(
197198
query_states, key_states.transpose(2, 3)

0 commit comments

Comments
 (0)