|
7 | 7 | import inspect
|
8 | 8 | import json
|
9 | 9 | import re
|
10 |
| -import sys |
11 | 10 | import textwrap
|
12 | 11 | import warnings
|
13 | 12 | from collections import Counter
|
|
34 | 33 | QuantizationMethods,
|
35 | 34 | get_quantization_config)
|
36 | 35 | from vllm.model_executor.models import ModelRegistry
|
37 |
| -from vllm.platforms import CpuArchEnum, current_platform |
| 36 | +from vllm.platforms import current_platform |
38 | 37 | from vllm.tracing import is_otel_available, otel_import_error_traceback
|
39 | 38 | from vllm.transformers_utils.config import (
|
40 | 39 | ConfigFormat, get_config, get_hf_image_processor_config,
|
@@ -2988,44 +2987,41 @@ def _get_and_verify_dtype(
|
2988 | 2987 | if isinstance(dtype, str):
|
2989 | 2988 | dtype = dtype.lower()
|
2990 | 2989 | if dtype == "auto":
|
| 2990 | + # Set default dtype from model config |
2991 | 2991 | if config_dtype == torch.float32:
|
2992 | 2992 | # Following common practice, we use float16 for float32 models
|
2993 | 2993 | torch_dtype = torch.float16
|
2994 | 2994 | else:
|
2995 | 2995 | torch_dtype = config_dtype
|
2996 | 2996 |
|
2997 | 2997 | if config.model_type == "plamo2":
|
2998 |
| - logger.info( |
| 2998 | + logger.warning( |
2999 | 2999 | "For PLaMo2, we cast models to bfloat16 instead of using "
|
3000 | 3000 | "float16 by default. This is because float16 does not work."
|
3001 | 3001 | )
|
3002 | 3002 | torch_dtype = torch.bfloat16
|
3003 | 3003 |
|
| 3004 | + # Deal with torch dtype fallback for device compatibility. |
3004 | 3005 | from vllm.platforms import current_platform
|
3005 |
| - if (current_platform.is_cpu() |
3006 |
| - and current_platform.get_cpu_architecture() |
3007 |
| - == CpuArchEnum.POWERPC |
3008 |
| - and (config_dtype == torch.float16 |
3009 |
| - or config_dtype == torch.float32)): |
3010 |
| - logger.info( |
3011 |
| - "For POWERPC, we cast models to bfloat16 instead of " |
3012 |
| - "using float16 by default. Float16 is not currently " |
3013 |
| - "supported for POWERPC.") |
3014 |
| - torch_dtype = torch.bfloat16 |
| 3006 | + if torch_dtype not in current_platform.supported_dtypes: |
| 3007 | + device_name = current_platform.get_device_name() |
3015 | 3008 |
|
3016 |
| - # TODO: change this condition to check if the platform support bf16 |
3017 |
| - # instead of checking the OS. For instance M2 shall supports bf16 |
3018 |
| - # already. But we need to modify `cpu_extension.cmake` to activate |
3019 |
| - # the feature in the build. |
3020 |
| - if (current_platform.is_cpu() and sys.platform.startswith("darwin") |
3021 |
| - and current_platform.get_cpu_architecture() |
3022 |
| - == CpuArchEnum.ARM and config_dtype == torch.bfloat16): |
3023 |
| - logger.info("For macOS with Apple Silicon, currently bfloat16 " |
3024 |
| - "is not supported. Setting dtype to float16.") |
3025 |
| - torch_dtype = torch.float16 |
| 3009 | + if ((capability := current_platform.get_device_capability()) |
| 3010 | + is None): |
| 3011 | + compute_str = "" |
| 3012 | + else: |
| 3013 | + version_str = capability.as_version_str() |
| 3014 | + compute_str = f" (with compute capability {version_str})" |
| 3015 | + fallback_dtype = current_platform.supported_dtypes[0] |
| 3016 | + logger.warning( |
| 3017 | + "Your %s device%s doesn't support %s. " \ |
| 3018 | + "Falling back to %s for compatibility.", |
| 3019 | + device_name, compute_str, torch_dtype, fallback_dtype |
| 3020 | + ) |
| 3021 | + torch_dtype = fallback_dtype |
3026 | 3022 |
|
3027 |
| - if current_platform.is_hpu() and config_dtype == torch.float16: |
3028 |
| - logger.info( |
| 3023 | + if current_platform.is_hpu() and torch_dtype == torch.float16: |
| 3024 | + logger.warning( |
3029 | 3025 | "For HPU, we cast models to bfloat16 instead of "
|
3030 | 3026 | "using float16 by default. Please specify `dtype` if you "
|
3031 | 3027 | "want to use float16.")
|
|
0 commit comments