Skip to content

[Transform] Construct on GPU, cache on CPU #352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a27db62
use hadamards database file
kylesayrs Jun 11, 2025
ce63955
try manifest
kylesayrs Jun 11, 2025
7ae5863
try setup, update hadamards list
kylesayrs Jun 11, 2025
67675c3
fix setup
kylesayrs Jun 11, 2025
f061db9
add docstrings, cleanup
kylesayrs Jun 11, 2025
4a84ce1
fix setup, thank you @dbarbuzzi
kylesayrs Jun 11, 2025
cde1066
remove numpy, add tests
kylesayrs Jun 11, 2025
1ba6195
solidify dtype, add gpu tests
kylesayrs Jun 11, 2025
c373345
fix docstring
kylesayrs Jun 11, 2025
fbaf47a
add device option
kylesayrs Jun 11, 2025
5a887f4
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
310fe6d
save construction device changes for later
kylesayrs Jun 11, 2025
b715329
construct on execution device, cache on offload device
kylesayrs Jun 11, 2025
249323c
cite nja sloane
kylesayrs Jun 11, 2025
1823af4
Merge branch 'kylesayrs/extend-hadamard', remote-tracking branch 'ori…
kylesayrs Jun 11, 2025
94a0bf5
Merge remote-tracking branch 'origin' into kylesayrs/extend-hadamard
kylesayrs Jun 11, 2025
cf066e0
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 11, 2025
c1a4a34
remove dreg
kylesayrs Jun 11, 2025
5807ee1
put on device via safe_open
kylesayrs Jun 11, 2025
ccb88ed
nits and docstrings
kylesayrs Jun 12, 2025
feba695
update docstring
kylesayrs Jun 12, 2025
c8f6b53
Merge branch 'kylesayrs/extend-hadamard' into kylesayrs/transform_con…
kylesayrs Jun 12, 2025
b6a0dd4
Merge remote-tracking branch 'origin' into kylesayrs/transform_constr…
kylesayrs Jun 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,6 @@ def _setup_extras() -> Dict:
extras_require=_setup_extras(),
install_requires=_setup_install_requires(),
package_dir={"": "src"},
package_data={"": ["transform/utils/hadamards.safetensors"]},
packages=_setup_packages(),
)
18 changes: 13 additions & 5 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
apply_transform_weight,
get_matrix_size,
)
from compressed_tensors.utils import get_offloaded_device
from compressed_tensors.utils import get_execution_device, get_offloaded_device
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
from torch import Tensor, device, dtype
from torch.nn import Linear, Module, Parameter
Expand Down Expand Up @@ -54,13 +54,21 @@ def create_transform(self, module: Module, args: TransformArgs):
size = get_matrix_size(module, args.location)
dtype = module.weight.dtype
device = get_offloaded_device(module)
exec_device = get_execution_device(module)

weight = self.weights[size, dtype, device]
weight = self.weights.get(size, dtype, device, construct_device=exec_device)
return HadamardTransform(weight, args)

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = deterministic_hadamard_matrix(size)
data = data.to(dtype=dtype, device=device)
def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
) -> Parameter:
# construct on execution device, cache on offload device
data = deterministic_hadamard_matrix(size, dtype, construct_device)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)


Expand Down
13 changes: 10 additions & 3 deletions src/compressed_tensors/transform/factory/random_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@ class RandomHadamardFactory(HadamardFactory):
:param seed: random seed used to transform weight randomization
"""

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = random_hadamard_matrix(size, self.generator)
data = data.to(dtype=dtype, device=device)
def _create_weight(
self,
size: int,
dtype: dtype,
device: device,
construct_device: device,
) -> Parameter:
# construct on execution device, cache on offload device
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)
183 changes: 91 additions & 92 deletions src/compressed_tensors/transform/utils/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,149 +13,148 @@
# limitations under the License.

import math
from typing import Optional, Tuple
from pathlib import Path
from typing import Optional

import numpy
import torch
from safetensors import safe_open


__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"

# adapted from:
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
def deterministic_hadamard_matrix(size: int) -> torch.Tensor:

__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]


# note that hadamard matrix multiplication can be accelerated using a library such as
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master


def deterministic_hadamard_matrix(
size: int,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
`n` must be a power of 2.

Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501

:param size: order of the matrix, must be a power of 2
:param dtype: data type of matrix
:param device: device to construct matrix on
:return: hadamard matrix of size `size`
"""
if size <= 0:
raise ValueError("Cannot construct deterministic hadamard of size <= 0")

log2 = int(math.log(size, 2))
log2 = int(math.log2(size))
if size != 2**log2:
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")

H = numpy.array([[1]], dtype=int)
H = torch.tensor([[1]], dtype=dtype, device=device)

# Sylvester's construction
for i in range(0, log2):
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))

return torch.from_numpy(H / math.sqrt(size))
for _ in range(log2):
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))


# adapted from:
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py

# TODO: the following library exists for online rotations and should be considered
# in the future:
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
return H / math.sqrt(size)


def random_hadamard_matrix(
size: int, gen: Optional[torch.Generator] = None
size: int,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = torch.device("cpu"),
gen: Optional[torch.Generator] = None,
) -> torch.Tensor:
"""
Produces a randomly generated Hadamard matrix.
See https://cornell-relaxml.github.io/quip-sharp/ ,
Section "Randomized Hadamard Transformation"
Produces a randomly generated Hadamard matrix. Differs from
`deterministic_hadamard_matrix` in that this function supports non powers of 2
and randomization using a seeded generator

Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501

:param size: The dimension of the hamadard matrix
:param dtype: data type of matrix
:param device: device to construct matrix on
:param gen: Optional generator random values
:return: randomly generated hadamard matrix
"""
# Benefits: support other shapes / non powers of 2, support randomization
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64)
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
Q = Q.to(device=device)
Q = Q * 2 - 1
Q = torch.diag(Q)
return _matmul_hadU(Q) / math.sqrt(size)


def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]:
# NOTE: we can easily extend the list of supported shapes/sizes
# by adding to these methods
hadK, K = None, None
if n % 20 == 0:
assert _is_pow2(n // 20)
K = 20
hadK = _get_had20().T if transpose else _get_had20()
elif n % 12 == 0:
assert _is_pow2(n // 12)
K = 12
hadK = _get_had12().T if transpose else _get_had12()
else:
assert _is_pow2(n)
K = 1
def is_pow2(n: int) -> bool:
"""
Check if a number is a power of 2

return hadK, K
:param n: number to check
:return: True iff `n` is a power of 2
"""
return n > 0 and (n & (n - 1) == 0)


def _fetch_hadamard_divisor(
n: int,
dtype: torch.dtype,
device: torch.device = torch.device("cpu"),
file_path: str = REPO_PATH,
) -> Optional[torch.Tensor]:
"""
Fetch a known hadamard matrix from the given file path. The returned matrix will
be of of size `k` such that `n / k` is a power of two. Return None if no such
matrix exists.

Note: This function reopens the safetensors file every time it is called.
This is technically inefficient, but a very small runtime cost and simpler
than forcing callers to manage the file open context

:param n: size of known hadamard matrix
:return: a known hadamard matrix of size `n` if one exists, else None
"""
with safe_open(file_path, framework="pt", device=str(device)) as file:
divisors = sorted((int(key) for key in file.keys()), reverse=True)
for divisor in divisors:
if n % divisor == 0 and is_pow2(n // divisor):
return file.get_tensor(str(divisor)).to(dtype=dtype)

return None


def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
size = X.size(0)
dtype = X.dtype
device = X.device

def _matmul_hadU(X, transpose=False) -> torch.Tensor:
n = X.shape[-1]
# Check if we have the determined hadamard matrix
hadK, K = _get_hadK(n, transpose)
hadK = _fetch_hadamard_divisor(size, dtype, device=device)
if hadK is None:
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
K = hadK.size(0)

# Reshape diag matrix with randomized -1/+1
input = X.clone().view(-1, n, 1)
input = X.clone().view(-1, size, 1)
output = input.clone()

# for cases when hadK is not predetermined, determine hadamard matrix
while input.shape[1] > K:
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
output = output.view(input.shape)
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
output = output.view(input.shape[0], input.shape[1], -1)
(input, output) = (output, input)
assert input.shape[1] == K
del output

# K == 1 when hadK is None; this happens when the size dim (n)
# is not comaptible with any of the maintained hadamard matrices

if K > 1:
# Do not explicitly repeat - OOM
# input = torch.bmm(
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
# Use bcast instead

# for cases when hadK is pre-determined
input = hadK.view(1, K, K).to(input) @ input
# Do not explicitly repeat - OOM
# input = torch.bmm(
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
# Use bcast instead
input = hadK.view(1, K, K).to(input) @ input

# normalize
return input.view(X.shape)


def _is_pow2(n: int) -> bool:
return (n & (n - 1) == 0) and (n > 0)


def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray:
had_unpacked = numpy.unpackbits(packed_bits)
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
return had_unpacked


# http://www.neilsloane.com/hadamard/index.html
def _get_had12() -> torch.Tensor:
# fmt: off
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
# fmt: on
# TODO: just unpack during apply
had_12_unpacked = _reshape_bits(had_12, original_size=12)
return torch.tensor(had_12_unpacked)


def _get_had20() -> torch.Tensor:
# fmt: off
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
# fmt: on
# TODO: just unpack during apply
had_20_unpacked = _reshape_bits(had_20, original_size=20)
return torch.tensor(had_20_unpacked)
Binary file not shown.
11 changes: 8 additions & 3 deletions src/compressed_tensors/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,16 @@ class ParameterizedDefaultDict(dict):

def __init__(self, default_factory: Callable[[Any], Any]):
self.default_factory = default_factory
self._kwargs = {}

def __missing__(self, key):
def __missing__(self, key: Any) -> Any:
if isinstance(key, tuple):
value = self.default_factory(*key)
value = self.default_factory(*key, **self._kwargs)
else:
value = self.default_factory(key)
value = self.default_factory(key, **self._kwargs)
self[key] = value
return value

def get(self, *args, **kwargs) -> Any:
with patch_attr(self, "_kwargs", kwargs):
return self[args]
Comment on lines +386 to +388
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a docstring might be useful here. basically kwargs are only used if the key to get is missing, and user has to know args are keys and kwargs are constructor inputs? Maybe something like

def get(self, *args, constructor_kwargs={}):

would be a clearer signature?

Copy link
Contributor Author

@kylesayrs kylesayrs Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll admit that the current implementation is less verbose than it could be for a somewhat unfamiliar class

Loading