@@ -8,7 +8,7 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
8
8
ENV PATH="/root/miniconda3/bin:${PATH}"
9
9
10
10
ARG PYTHON_VERSION="3.9"
11
- ARG PYTORCH ="2.0.0 "
11
+ ARG PYTORCH_VERSION ="2.0.1 "
12
12
ARG CUDA="118"
13
13
14
14
ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -29,18 +29,18 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
29
29
WORKDIR /workspace
30
30
31
31
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
32
- python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu$CUDA
32
+ python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
33
33
34
34
35
35
FROM base-builder AS flash-attn-builder
36
36
37
37
WORKDIR /workspace
38
38
39
- ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
39
+ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0 +PTX"
40
40
41
41
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
42
42
cd flash-attention && \
43
- git checkout 9ee0ff1 && \
43
+ git checkout v2.0.1 && \
44
44
python3 setup.py bdist_wheel && \
45
45
cd csrc/fused_dense_lib && \
46
46
python3 setup.py bdist_wheel && \
@@ -53,7 +53,7 @@ RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
53
53
54
54
FROM base-builder AS deepspeed-builder
55
55
56
- ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
56
+ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0 +PTX"
57
57
58
58
WORKDIR /workspace
59
59
@@ -74,6 +74,9 @@ RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
74
74
75
75
FROM base-builder
76
76
77
+ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
78
+ ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
79
+
77
80
# recompile apex
78
81
RUN python3 -m pip uninstall -y apex
79
82
RUN git clone https://github.com/NVIDIA/apex
0 commit comments