@@ -149,13 +149,6 @@ def validate_and_update_archs(archs):
149
149
TORCH_MAJOR = int (torch .__version__ .split ("." )[0 ])
150
150
TORCH_MINOR = int (torch .__version__ .split ("." )[1 ])
151
151
152
- # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
153
- # See https://github.com/pytorch/pytorch/pull/70650
154
- generator_flag = []
155
- torch_dir = torch .__path__ [0 ]
156
- if os .path .exists (os .path .join (torch_dir , "include" , "ATen" , "CUDAGeneratorImpl.h" )):
157
- generator_flag = ["-DOLD_GENERATOR_PATH" ]
158
-
159
152
check_if_cuda_home_none ("flash_attn" )
160
153
# Check, if CUDA11 is installed for compute capability 8.0
161
154
cc_flag = []
@@ -271,7 +264,7 @@ def validate_and_update_archs(archs):
271
264
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu" ,
272
265
],
273
266
extra_compile_args = {
274
- "cxx" : ["-O3" , "-std=c++17" ] + generator_flag ,
267
+ "cxx" : ["-O3" , "-std=c++17" ],
275
268
"nvcc" : append_nvcc_threads (
276
269
[
277
270
"-O3" ,
@@ -293,7 +286,6 @@ def validate_and_update_archs(archs):
293
286
# "-DFLASHATTENTION_DISABLE_UNEVEN_K",
294
287
# "-DFLASHATTENTION_DISABLE_LOCAL",
295
288
]
296
- + generator_flag
297
289
+ cc_flag
298
290
),
299
291
},
0 commit comments