Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fallback to triton if we fail to compile for CUDA #223

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions python/xgrammar/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@

__all__ = ["apply_token_bitmask_inplace_kernels"]

if torch.cuda.is_available():
from .apply_token_bitmask_inplace_cuda import apply_token_bitmask_inplace_cuda
try:
if torch.cuda.is_available():
from .apply_token_bitmask_inplace_cuda import apply_token_bitmask_inplace_cuda

apply_token_bitmask_inplace_kernels["cuda"] = apply_token_bitmask_inplace_cuda
apply_token_bitmask_inplace_kernels["cuda"] = apply_token_bitmask_inplace_cuda
except ImportError:
# If we can't find nvcc, then don't register the CUDA kernel.
pass
except RuntimeError:
# If we are unable to compile the CUDA kernel, then don't register the CUDA kernel.
pass

try:
from .apply_token_bitmask_inplace_triton import ( # isort: skip
Expand Down
37 changes: 37 additions & 0 deletions python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,42 @@
import torch.utils.cpp_extension


def _check_cuda_toolchain() -> None:
"""check if nvcc is available and if pytorch will likely find it"""
import glob
import os
import shutil
from pathlib import Path

# First check if CUDA is available in PyTorch
if not torch.cuda.is_available():
raise ImportError("CUDA is not available in PyTorch")

# This is similar logic to what pytorch does to find the nvcc compiler
nvcc_path = shutil.which("nvcc")
if nvcc_path is None:
cuda_home = os.environ.get("CUDA_HOME", os.environ.get("CUDA_PATH", None))
if cuda_home is None:
if os.name == "nt":
# This is a very hardcoded asumption about install directories but pytorch does this.
cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*")

if len(cuda_homes) == 0:
cuda_home = ""
else:
cuda_home = cuda_homes[0]
else:
cuda_home = "/usr/local/cuda"

if cuda_home is None:
raise ImportError("No CUDA toolchain found")

nvcc_path = str(Path(cuda_home) / "bin" / "nvcc")

if not os.path.exists(nvcc_path):
raise ImportError(f"nvcc compiler not found at {nvcc_path}")


def _remove_torch_nvcc_flags() -> None:
REMOVE_NVCC_FLAGS = [
"-D__CUDA_NO_HALF_OPERATORS__",
Expand Down Expand Up @@ -50,6 +86,7 @@ def _load_torch_ops() -> None:
)


_check_cuda_toolchain()
_remove_torch_nvcc_flags()
_load_torch_ops()

Expand Down
25 changes: 18 additions & 7 deletions python/xgrammar/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,26 @@ def apply_token_bitmask_inplace(
)

if logits.device.type == "cuda":
if os.environ.get("XGRAMMAR_TOKEN_BITMASK_TRITON") == "1":
if "triton" not in apply_token_bitmask_inplace_kernels:
raise RuntimeError(
"The triton kernel is not available even though XGRAMMAR_TOKEN_BITMASK_TRITON "
"is set to 1"
)
if (
"triton" not in apply_token_bitmask_inplace_kernels
and os.environ.get("XGRAMMAR_TOKEN_BITMASK_TRITON") == "1"
):
raise RuntimeError(
"The triton kernel is not available even though XGRAMMAR_TOKEN_BITMASK_TRITON "
"is set to 1"
)

if (
"cuda" in apply_token_bitmask_inplace_kernels
and os.environ.get("XGRAMMAR_TOKEN_BITMASK_TRITON") != "1"
):
apply_token_bitmask_inplace_kernels["cuda"](logits, bitmask, indices)
elif "triton" in apply_token_bitmask_inplace_kernels:
apply_token_bitmask_inplace_kernels["triton"](logits, bitmask, indices)
else:
apply_token_bitmask_inplace_kernels["cuda"](logits, bitmask, indices)
raise RuntimeError(
"No CUDA kernel is available. Check if you have a CUDA compatible toolchain installed."
)
elif logits.device.type == "cpu":
apply_token_bitmask_inplace_kernels["cpu"](logits, bitmask, indices)
else:
Expand Down
Loading