Skip to content

Commit f86e3dd

Browse files
committed
[CI] Use MAX_JOBS=1 with nvcc 12.3, don't need OLD_GENERATOR_PATH
1 parent b7d29fb commit f86e3dd

File tree

3 files changed

+4
-11
lines changed

3 files changed

+4
-11
lines changed

.github/workflows/publish.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ jobs:
154154
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
155155
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
156156
# Limit MAX_JOBS otherwise the github runner goes OOM
157-
MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
157+
# nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM
158+
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
158159
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
159160
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
160161
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}

flash_attn/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.7.2"
1+
__version__ = "2.7.2.post1"
22

33
from flash_attn.flash_attn_interface import (
44
flash_attn_func,

setup.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,6 @@ def validate_and_update_archs(archs):
149149
TORCH_MAJOR = int(torch.__version__.split(".")[0])
150150
TORCH_MINOR = int(torch.__version__.split(".")[1])
151151

152-
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
153-
# See https://github.com/pytorch/pytorch/pull/70650
154-
generator_flag = []
155-
torch_dir = torch.__path__[0]
156-
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
157-
generator_flag = ["-DOLD_GENERATOR_PATH"]
158-
159152
check_if_cuda_home_none("flash_attn")
160153
# Check, if CUDA11 is installed for compute capability 8.0
161154
cc_flag = []
@@ -271,7 +264,7 @@ def validate_and_update_archs(archs):
271264
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
272265
],
273266
extra_compile_args={
274-
"cxx": ["-O3", "-std=c++17"] + generator_flag,
267+
"cxx": ["-O3", "-std=c++17"],
275268
"nvcc": append_nvcc_threads(
276269
[
277270
"-O3",
@@ -293,7 +286,6 @@ def validate_and_update_archs(archs):
293286
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
294287
# "-DFLASHATTENTION_DISABLE_LOCAL",
295288
]
296-
+ generator_flag
297289
+ cc_flag
298290
),
299291
},

0 commit comments

Comments
 (0)