Skip to content

vLLM Windows CUDA support [tested] #2158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 12, 2025
Merged
2 changes: 1 addition & 1 deletion unsloth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
from .qwen2 import FastQwen2Model
from .granite import FastGraniteModel
from .dpo import PatchDPOTrainer, PatchKTOTrainer
from ._utils import is_bfloat16_supported, __version__
from ._utils import is_bfloat16_supported, is_vLLM_available, __version__
from .rl import PatchFastRL, vLLMSamplingParams
4 changes: 4 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
__all__ = [
"SUPPORTS_BFLOAT16",
"is_bfloat16_supported",
"is_vLLM_available",

"prepare_model_for_kbit_training",
"xformers",
Expand Down Expand Up @@ -790,6 +791,9 @@ def is_bfloat16_supported():
return SUPPORTS_BFLOAT16
pass

def is_vLLM_available():
return _is_package_available("vllm")
pass

# Patches models to add RoPE Scaling
def patch_linear_scaling(
Expand Down
4 changes: 1 addition & 3 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,9 +1654,7 @@ def from_pretrained(
)
pass
if fast_inference:
from transformers.utils.import_utils import _is_package_available
_vllm_available = _is_package_available("vllm")
if _vllm_available == False:
if is_vLLM_available() == False:
print("Unsloth: vLLM is not installed! Will use Unsloth inference!")
fast_inference = False
major_version, minor_version = torch.cuda.get_device_capability()
Expand Down
6 changes: 2 additions & 4 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ._utils import (
is_bfloat16_supported,
is_vLLM_available,
HAS_FLASH_ATTENTION,
HAS_FLASH_ATTENTION_SOFTCAPPING,
USE_MODELSCOPE,
Expand Down Expand Up @@ -338,10 +339,7 @@ def from_pretrained(
pass

if fast_inference:
import platform
from transformers.utils.import_utils import _is_package_available
_vllm_available = _is_package_available("vllm")
if _vllm_available == False:
if is_vLLM_available() == False:
print("Unsloth: vLLM is not installed! Will use Unsloth inference!")
fast_inference = False
pass
Expand Down