Skip to content

Commit 2ef4bb8

Browse files
Isotr0pymawong-amd
authored andcommitted
[Misc] Auto fallback to float16 for pre-Ampere GPUs when detected bfloat16 config (vllm-project#17265)
Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent 9c183a7 commit 2ef4bb8

File tree

4 files changed

+57
-26
lines changed

4 files changed

+57
-26
lines changed

vllm/config.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import inspect
88
import json
99
import re
10-
import sys
1110
import textwrap
1211
import warnings
1312
from collections import Counter
@@ -34,7 +33,7 @@
3433
QuantizationMethods,
3534
get_quantization_config)
3635
from vllm.model_executor.models import ModelRegistry
37-
from vllm.platforms import CpuArchEnum, current_platform
36+
from vllm.platforms import current_platform
3837
from vllm.tracing import is_otel_available, otel_import_error_traceback
3938
from vllm.transformers_utils.config import (
4039
ConfigFormat, get_config, get_hf_image_processor_config,
@@ -2988,44 +2987,41 @@ def _get_and_verify_dtype(
29882987
if isinstance(dtype, str):
29892988
dtype = dtype.lower()
29902989
if dtype == "auto":
2990+
# Set default dtype from model config
29912991
if config_dtype == torch.float32:
29922992
# Following common practice, we use float16 for float32 models
29932993
torch_dtype = torch.float16
29942994
else:
29952995
torch_dtype = config_dtype
29962996

29972997
if config.model_type == "plamo2":
2998-
logger.info(
2998+
logger.warning(
29992999
"For PLaMo2, we cast models to bfloat16 instead of using "
30003000
"float16 by default. This is because float16 does not work."
30013001
)
30023002
torch_dtype = torch.bfloat16
30033003

3004+
# Deal with torch dtype fallback for device compatibility.
30043005
from vllm.platforms import current_platform
3005-
if (current_platform.is_cpu()
3006-
and current_platform.get_cpu_architecture()
3007-
== CpuArchEnum.POWERPC
3008-
and (config_dtype == torch.float16
3009-
or config_dtype == torch.float32)):
3010-
logger.info(
3011-
"For POWERPC, we cast models to bfloat16 instead of "
3012-
"using float16 by default. Float16 is not currently "
3013-
"supported for POWERPC.")
3014-
torch_dtype = torch.bfloat16
3006+
if torch_dtype not in current_platform.supported_dtypes:
3007+
device_name = current_platform.get_device_name()
30153008

3016-
# TODO: change this condition to check if the platform support bf16
3017-
# instead of checking the OS. For instance M2 shall supports bf16
3018-
# already. But we need to modify `cpu_extension.cmake` to activate
3019-
# the feature in the build.
3020-
if (current_platform.is_cpu() and sys.platform.startswith("darwin")
3021-
and current_platform.get_cpu_architecture()
3022-
== CpuArchEnum.ARM and config_dtype == torch.bfloat16):
3023-
logger.info("For macOS with Apple Silicon, currently bfloat16 "
3024-
"is not supported. Setting dtype to float16.")
3025-
torch_dtype = torch.float16
3009+
if ((capability := current_platform.get_device_capability())
3010+
is None):
3011+
compute_str = ""
3012+
else:
3013+
version_str = capability.as_version_str()
3014+
compute_str = f" (with compute capability {version_str})"
3015+
fallback_dtype = current_platform.supported_dtypes[0]
3016+
logger.warning(
3017+
"Your %s device%s doesn't support %s. " \
3018+
"Falling back to %s for compatibility.",
3019+
device_name, compute_str, torch_dtype, fallback_dtype
3020+
)
3021+
torch_dtype = fallback_dtype
30263022

3027-
if current_platform.is_hpu() and config_dtype == torch.float16:
3028-
logger.info(
3023+
if current_platform.is_hpu() and torch_dtype == torch.float16:
3024+
logger.warning(
30293025
"For HPU, we cast models to bfloat16 instead of "
30303026
"using float16 by default. Please specify `dtype` if you "
30313027
"want to use float16.")

vllm/platforms/cpu.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from vllm.logger import init_logger
1212

13-
from .interface import Platform, PlatformEnum, _Backend
13+
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
1414

1515
logger = init_logger(__name__)
1616

@@ -26,6 +26,20 @@ class CpuPlatform(Platform):
2626
device_type: str = "cpu"
2727
dispatch_key: str = "CPU"
2828

29+
@property
30+
def supported_dtypes(self) -> list:
31+
if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
32+
return [torch.bfloat16, torch.float32]
33+
elif sys.platform.startswith(
34+
"darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM:
35+
# TODO: change this condition to check if the platform support bf16
36+
# instead of checking the OS. For instance M2 shall supports bf16
37+
# already. But we need to modify `cpu_extension.cmake` to activate
38+
# the feature in the build.
39+
return [torch.bfloat16, torch.float32]
40+
# x86/aarch64 CPU has supported both bf16 and fp16 natively.
41+
return [torch.bfloat16, torch.float16, torch.float32]
42+
2943
@classmethod
3044
def get_device_name(cls, device_id: int = 0) -> str:
3145
return "cpu"

vllm/platforms/cuda.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ class CudaPlatformBase(Platform):
7373
ray_device_key: str = "GPU"
7474
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
7575

76+
@property
77+
def supported_dtypes(self) -> List[torch.dtype]:
78+
if self.has_device_capability(80):
79+
# Ampere and Hopper or later NVIDIA GPUs.
80+
return [torch.bfloat16, torch.float16, torch.float32]
81+
elif (not self.has_device_capability(80)
82+
) and self.has_device_capability(60):
83+
# Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
84+
return [torch.float16, torch.float32]
85+
# Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
86+
# though vLLM doesn't support these GPUs.
87+
return [torch.float32]
88+
7689
@classmethod
7790
def get_device_capability(cls,
7891
device_id: int = 0

vllm/platforms/interface.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@ class Platform:
122122

123123
additional_env_vars: list[str] = []
124124

125+
@property
126+
def supported_dtypes(self) -> list[torch.dtype]:
127+
"""Returns the supported dtypes for the current platform."""
128+
# Be careful with the order of the dtypes. The first dtype will
129+
# be used as the default dtype fallback for the current platform,
130+
# when encountering unsupported dtypes in "auto" dtype.
131+
return [torch.bfloat16, torch.float16, torch.float32]
132+
125133
def is_cuda(self) -> bool:
126134
return self._enum == PlatformEnum.CUDA
127135

0 commit comments

Comments
 (0)