Skip to content

Commit be0f9e0

Browse files
Fix workflow (#345)
1 parent 171c7af commit be0f9e0

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

.github/workflows/build.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ jobs:
113113
}
114114
115115
python setup.py sdist bdist_wheel
116-
116+
117117
- name: Upload Assets
118118
uses: shogo82148/actions-upload-release-asset@v1
119119
with:

scripts/download_wheels.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22

33
# Set variables
4-
AWQ_VERSION="0.1.6"
4+
AWQ_VERSION="0.2.0"
55
RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ/releases/tags/v${AWQ_VERSION}"
66

77
# Create a directory to download the wheels

setup.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import torch
33
import platform
44
import requests
5-
import importlib_metadata
65
from pathlib import Path
76
from setuptools import setup, find_packages
7+
from torch.utils.cpp_extension import CUDAExtension
88

99

1010
def get_latest_kernels_version(repo):
@@ -88,15 +88,20 @@ def get_kernels_whl_url(
8888
"torch>=2.0.1",
8989
"transformers>=4.35.0",
9090
"tokenizers>=0.12.1",
91+
"typing_extensions>=4.8.0"
9192
"accelerate",
9293
"datasets",
9394
"zstandard",
9495
]
9596

9697
try:
97-
importlib_metadata.version("autoawq-kernels")
98+
if ROCM_VERSION:
99+
import exlv2_ext
100+
else:
101+
import awq_ext
102+
98103
KERNELS_INSTALLED = True
99-
except importlib_metadata.PackageNotFoundError:
104+
except ImportError:
100105
KERNELS_INSTALLED = False
101106

102107
# kernels can be downloaded from pypi for cuda+121 only
@@ -133,5 +138,14 @@ def get_kernels_whl_url(
133138
"eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"],
134139
"dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"]
135140
},
141+
# NOTE: We create an empty CUDAExtension because torch helps us with
142+
# creating the right boilerplate to enable correct targeting of
143+
# the autoawq-kernels package
144+
ext_modules=[
145+
CUDAExtension(
146+
name="__build_artifact_for_awq_kernel_targeting",
147+
sources=[],
148+
)
149+
],
136150
**common_setup_kwargs,
137151
)

0 commit comments

Comments
 (0)