Skip to content

Commit 2c37bf6

Browse files
authored
Prune cuda117 (axolotl-ai-cloud#327)
* drop cuda117/torch 1.13.1 from support, pin flash attention to v2.0.1, rm torchvision/torchaudio install * gptq base build not needed. add sm 9.0 support
1 parent 9f69c4d commit 2c37bf6

File tree

3 files changed

+12
-29
lines changed

3 files changed

+12
-29
lines changed

.github/workflows/base.yml

+3-13
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,12 @@ jobs:
1919
cuda_version: 11.8.0
2020
python_version: "3.9"
2121
pytorch: 2.0.1
22-
axolotl_extras:
22+
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
2323
- cuda: "118"
2424
cuda_version: 11.8.0
2525
python_version: "3.10"
2626
pytorch: 2.0.1
27-
axolotl_extras:
28-
- cuda: "117"
29-
cuda_version: 11.7.1
30-
python_version: "3.9"
31-
pytorch: 1.13.1
32-
axolotl_extras:
33-
- cuda: "118"
34-
cuda_version: 11.8.0
35-
python_version: "3.9"
36-
pytorch: 2.0.1
37-
axolotl_extras: gptq
27+
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
3828
steps:
3929
- name: Checkout
4030
uses: actions/checkout@v3
@@ -63,4 +53,4 @@ jobs:
6353
CUDA=${{ matrix.cuda }}
6454
PYTHON_VERSION=${{ matrix.python_version }}
6555
PYTORCH_VERSION=${{ matrix.pytorch }}
66-
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras }}
56+
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}

.github/workflows/main.yml

+1-11
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ jobs:
2929
python_version: "3.9"
3030
pytorch: 2.0.1
3131
axolotl_extras: gptq
32-
- cuda: cu117
33-
cuda_version: 11.7.1
34-
python_version: "3.9"
35-
pytorch: 1.13.1
36-
axolotl_extras:
3732
runs-on: self-hosted
3833
steps:
3934
- name: Checkout
@@ -55,7 +50,7 @@ jobs:
5550
with:
5651
context: .
5752
build-args: |
58-
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
53+
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}
5954
file: ./docker/Dockerfile
6055
push: ${{ github.event_name != 'pull_request' }}
6156
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
@@ -82,11 +77,6 @@ jobs:
8277
python_version: "3.9"
8378
pytorch: 2.0.1
8479
axolotl_extras: gptq
85-
- cuda: 117
86-
cuda_version: 11.7.1
87-
python_version: "3.9"
88-
pytorch: 1.13.1
89-
axolotl_extras:
9080
runs-on: self-hosted
9181
steps:
9282
- name: Checkout

docker/Dockerfile-base

+8-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
88
ENV PATH="/root/miniconda3/bin:${PATH}"
99

1010
ARG PYTHON_VERSION="3.9"
11-
ARG PYTORCH="2.0.0"
11+
ARG PYTORCH_VERSION="2.0.1"
1212
ARG CUDA="118"
1313

1414
ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -29,18 +29,18 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
2929
WORKDIR /workspace
3030

3131
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
32-
python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu$CUDA
32+
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
3333

3434

3535
FROM base-builder AS flash-attn-builder
3636

3737
WORKDIR /workspace
3838

39-
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
39+
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
4040

4141
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
4242
cd flash-attention && \
43-
git checkout 9ee0ff1 && \
43+
git checkout v2.0.1 && \
4444
python3 setup.py bdist_wheel && \
4545
cd csrc/fused_dense_lib && \
4646
python3 setup.py bdist_wheel && \
@@ -53,7 +53,7 @@ RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
5353

5454
FROM base-builder AS deepspeed-builder
5555

56-
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
56+
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
5757

5858
WORKDIR /workspace
5959

@@ -74,6 +74,9 @@ RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
7474

7575
FROM base-builder
7676

77+
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
78+
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
79+
7780
# recompile apex
7881
RUN python3 -m pip uninstall -y apex
7982
RUN git clone https://github.com/NVIDIA/apex

0 commit comments

Comments
 (0)