Skip to content

Commit c7852a6

Browse files
[Build] Allow shipping PTX on a per-file basis (#18155)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 8795eb9 commit c7852a6

File tree

2 files changed

+75
-23
lines changed

2 files changed

+75
-23
lines changed

CMakeLists.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
301301
# Only build Marlin kernels if we are building for at least some compatible archs.
302302
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
303303
# are not supported by Machete yet.
304-
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
304+
# 9.0 for latest bf16 atomicAdd PTX
305+
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
305306
if (MARLIN_ARCHS)
306307

307308
#
@@ -445,8 +446,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
445446
#
446447
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
447448
# kernels for the remaining archs that are not already built for 3x.
449+
# (Build 8.9 for FP8)
448450
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
449-
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
451+
"7.5;8.0;8.9+PTX" "${CUDA_ARCHS}")
450452
# subtract out the archs that are already built for 3x
451453
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
452454
if (SCALED_MM_2X_ARCHS)
@@ -675,7 +677,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
675677
CUDA_ARCHS "${CUDA_ARCHS}")
676678

677679
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
678-
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
680+
# 9.0 for latest bf16 atomicAdd PTX
681+
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}")
679682
if (MARLIN_MOE_ARCHS)
680683

681684
#

cmake/utils.cmake

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
228228
"${multiValueArgs}" ${ARGN} )
229229

230230
foreach(_ARCH ${arg_CUDA_ARCHS})
231-
string(REPLACE "." "" _ARCH "${_ARCH}")
232-
set_gencode_flag_for_srcs(
233-
SRCS ${arg_SRCS}
234-
ARCH "compute_${_ARCH}"
235-
CODE "sm_${_ARCH}")
231+
# handle +PTX suffix: generate both sm and ptx codes if requested
232+
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
233+
if(NOT _HAS_PTX EQUAL -1)
234+
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
235+
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
236+
set_gencode_flag_for_srcs(
237+
SRCS ${arg_SRCS}
238+
ARCH "compute_${_STRIPPED_ARCH}"
239+
CODE "sm_${_STRIPPED_ARCH}")
240+
set_gencode_flag_for_srcs(
241+
SRCS ${arg_SRCS}
242+
ARCH "compute_${_STRIPPED_ARCH}"
243+
CODE "compute_${_STRIPPED_ARCH}")
244+
else()
245+
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
246+
set_gencode_flag_for_srcs(
247+
SRCS ${arg_SRCS}
248+
ARCH "compute_${_STRIPPED_ARCH}"
249+
CODE "sm_${_STRIPPED_ARCH}")
250+
endif()
236251
endforeach()
237252

238253
if (${arg_BUILD_PTX_FOR_ARCH})
@@ -251,7 +266,10 @@ endmacro()
251266
#
252267
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
253268
# `<major>.<minor>[letter]` compute the "loose intersection" with the
254-
# `TGT_CUDA_ARCHS` list of gencodes.
269+
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
270+
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
271+
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
272+
# architecture in `SRC_CUDA_ARCHS`.
255273
# The loose intersection is defined as:
256274
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
257275
# where `<=` is the version comparison operator.
@@ -268,44 +286,63 @@ endmacro()
268286
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
269287
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
270288
#
289+
# Example With PTX:
290+
# SRC_CUDA_ARCHS="8.0+PTX"
291+
# TGT_CUDA_ARCHS="9.0"
292+
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
293+
# OUT_CUDA_ARCHS="8.0+PTX"
294+
#
271295
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
272-
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
273-
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
296+
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
297+
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
298+
299+
# handle +PTX suffix: separate base arch for matching, record PTX requests
300+
set(_PTX_ARCHS)
301+
foreach(_arch ${_SRC_CUDA_ARCHS})
302+
if(_arch MATCHES "\\+PTX$")
303+
string(REPLACE "+PTX" "" _base "${_arch}")
304+
list(APPEND _PTX_ARCHS "${_base}")
305+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
306+
list(APPEND _SRC_CUDA_ARCHS "${_base}")
307+
endif()
308+
endforeach()
309+
list(REMOVE_DUPLICATES _PTX_ARCHS)
310+
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
274311

275312
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
276313
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
277314
set(_CUDA_ARCHS)
278-
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
279-
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
280-
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
281-
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
315+
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
316+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
317+
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
318+
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
282319
set(_CUDA_ARCHS "9.0a")
283320
endif()
284321
endif()
285322

286-
if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
287-
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
323+
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
324+
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
288325
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
289-
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
326+
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
290327
set(_CUDA_ARCHS "10.0a")
291328
endif()
292329
endif()
293330

294-
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
331+
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
295332

296333
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
297334
# is less or equal to ARCH (but has the same major version since SASS binary
298335
# compatibility is only forward compatible within the same major version).
299-
foreach(_ARCH ${TGT_CUDA_ARCHS_})
336+
foreach(_ARCH ${_TGT_CUDA_ARCHS})
300337
set(_TMP_ARCH)
301338
# Extract the major version of the target arch
302339
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
303-
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
340+
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
304341
# Extract the major version of the source arch
305342
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
306-
# Check major-version match AND version-less-or-equal
343+
# Check version-less-or-equal, and allow PTX arches to match across majors
307344
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
308-
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
345+
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
309346
set(_TMP_ARCH "${_SRC_ARCH}")
310347
endif()
311348
else()
@@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
321358
endforeach()
322359

323360
list(REMOVE_DUPLICATES _CUDA_ARCHS)
361+
362+
# reapply +PTX suffix to architectures that requested PTX
363+
set(_FINAL_ARCHS)
364+
foreach(_arch ${_CUDA_ARCHS})
365+
if(_arch IN_LIST _PTX_ARCHS)
366+
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
367+
else()
368+
list(APPEND _FINAL_ARCHS "${_arch}")
369+
endif()
370+
endforeach()
371+
set(_CUDA_ARCHS ${_FINAL_ARCHS})
372+
324373
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
325374
endfunction()
326375

0 commit comments

Comments
 (0)