@@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
228
228
"${multiValueArgs} " ${ARGN} )
229
229
230
230
foreach (_ARCH ${arg_CUDA_ARCHS} )
231
- string (REPLACE "." "" _ARCH "${_ARCH} " )
232
- set_gencode_flag_for_srcs(
233
- SRCS ${arg_SRCS}
234
- ARCH "compute_${_ARCH} "
235
- CODE "sm_${_ARCH} " )
231
+ # handle +PTX suffix: generate both sm and ptx codes if requested
232
+ string (FIND "${_ARCH} " "+PTX" _HAS_PTX)
233
+ if (NOT _HAS_PTX EQUAL -1)
234
+ string (REPLACE "+PTX" "" _BASE_ARCH "${_ARCH} " )
235
+ string (REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH} " )
236
+ set_gencode_flag_for_srcs(
237
+ SRCS ${arg_SRCS}
238
+ ARCH "compute_${_STRIPPED_ARCH} "
239
+ CODE "sm_${_STRIPPED_ARCH} " )
240
+ set_gencode_flag_for_srcs(
241
+ SRCS ${arg_SRCS}
242
+ ARCH "compute_${_STRIPPED_ARCH} "
243
+ CODE "compute_${_STRIPPED_ARCH} " )
244
+ else ()
245
+ string (REPLACE "." "" _STRIPPED_ARCH "${_ARCH} " )
246
+ set_gencode_flag_for_srcs(
247
+ SRCS ${arg_SRCS}
248
+ ARCH "compute_${_STRIPPED_ARCH} "
249
+ CODE "sm_${_STRIPPED_ARCH} " )
250
+ endif ()
236
251
endforeach ()
237
252
238
253
if (${arg_BUILD_PTX_FOR_ARCH} )
@@ -251,7 +266,10 @@ endmacro()
251
266
#
252
267
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
253
268
# `<major>.<minor>[letter]` compute the "loose intersection" with the
254
- # `TGT_CUDA_ARCHS` list of gencodes.
269
+ # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
270
+ # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
271
+ # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
272
+ # architecture in `SRC_CUDA_ARCHS`.
255
273
# The loose intersection is defined as:
256
274
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
257
275
# where `<=` is the version comparison operator.
@@ -268,44 +286,63 @@ endmacro()
268
286
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
269
287
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
270
288
#
289
+ # Example With PTX:
290
+ # SRC_CUDA_ARCHS="8.0+PTX"
291
+ # TGT_CUDA_ARCHS="9.0"
292
+ # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
293
+ # OUT_CUDA_ARCHS="8.0+PTX"
294
+ #
271
295
function (cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
272
- list (REMOVE_DUPLICATES SRC_CUDA_ARCHS)
273
- set (TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS} )
296
+ set (_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS} " )
297
+ set (_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS} )
298
+
299
+ # handle +PTX suffix: separate base arch for matching, record PTX requests
300
+ set (_PTX_ARCHS)
301
+ foreach (_arch ${_SRC_CUDA_ARCHS} )
302
+ if (_arch MATCHES "\\ +PTX$" )
303
+ string (REPLACE "+PTX" "" _base "${_arch} " )
304
+ list (APPEND _PTX_ARCHS "${_base} " )
305
+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch} " )
306
+ list (APPEND _SRC_CUDA_ARCHS "${_base} " )
307
+ endif ()
308
+ endforeach ()
309
+ list (REMOVE_DUPLICATES _PTX_ARCHS)
310
+ list (REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
274
311
275
312
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
276
313
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
277
314
set (_CUDA_ARCHS)
278
- if ("9.0a" IN_LIST SRC_CUDA_ARCHS )
279
- list (REMOVE_ITEM SRC_CUDA_ARCHS "9.0a" )
280
- if ("9.0" IN_LIST TGT_CUDA_ARCHS_ )
281
- list (REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0" )
315
+ if ("9.0a" IN_LIST _SRC_CUDA_ARCHS )
316
+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a" )
317
+ if ("9.0" IN_LIST TGT_CUDA_ARCHS )
318
+ list (REMOVE_ITEM _TGT_CUDA_ARCHS "9.0" )
282
319
set (_CUDA_ARCHS "9.0a" )
283
320
endif ()
284
321
endif ()
285
322
286
- if ("10.0a" IN_LIST SRC_CUDA_ARCHS )
287
- list (REMOVE_ITEM SRC_CUDA_ARCHS "10.0a" )
323
+ if ("10.0a" IN_LIST _SRC_CUDA_ARCHS )
324
+ list (REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a" )
288
325
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
289
- list (REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0" )
326
+ list (REMOVE_ITEM _TGT_CUDA_ARCHS "10.0" )
290
327
set (_CUDA_ARCHS "10.0a" )
291
328
endif ()
292
329
endif ()
293
330
294
- list (SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
331
+ list (SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
295
332
296
333
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
297
334
# is less or equal to ARCH (but has the same major version since SASS binary
298
335
# compatibility is only forward compatible within the same major version).
299
- foreach (_ARCH ${TGT_CUDA_ARCHS_ } )
336
+ foreach (_ARCH ${_TGT_CUDA_ARCHS } )
300
337
set (_TMP_ARCH)
301
338
# Extract the major version of the target arch
302
339
string (REGEX REPLACE "^([0-9]+)\\ ..*$" "\\ 1" TGT_ARCH_MAJOR "${_ARCH} " )
303
- foreach (_SRC_ARCH ${SRC_CUDA_ARCHS } )
340
+ foreach (_SRC_ARCH ${_SRC_CUDA_ARCHS } )
304
341
# Extract the major version of the source arch
305
342
string (REGEX REPLACE "^([0-9]+)\\ ..*$" "\\ 1" SRC_ARCH_MAJOR "${_SRC_ARCH} " )
306
- # Check major- version match AND version -less-or-equal
343
+ # Check version-less-or-equal, and allow PTX arches to match across majors
307
344
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
308
- if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
345
+ if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
309
346
set (_TMP_ARCH "${_SRC_ARCH} " )
310
347
endif ()
311
348
else ()
@@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
321
358
endforeach ()
322
359
323
360
list (REMOVE_DUPLICATES _CUDA_ARCHS)
361
+
362
+ # reapply +PTX suffix to architectures that requested PTX
363
+ set (_FINAL_ARCHS)
364
+ foreach (_arch ${_CUDA_ARCHS} )
365
+ if (_arch IN_LIST _PTX_ARCHS)
366
+ list (APPEND _FINAL_ARCHS "${_arch} +PTX" )
367
+ else ()
368
+ list (APPEND _FINAL_ARCHS "${_arch} " )
369
+ endif ()
370
+ endforeach ()
371
+ set (_CUDA_ARCHS ${_FINAL_ARCHS} )
372
+
324
373
set (${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
325
374
endfunction ()
326
375
0 commit comments