Skip to content

[Transform] Extend set of known Hadamard matrices #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,6 @@ def _setup_extras() -> Dict:
extras_require=_setup_extras(),
install_requires=_setup_install_requires(),
package_dir={"": "src"},
package_data={"": ["transform/utils/hadamards.safetensors"]},
packages=_setup_packages(),
)
2 changes: 1 addition & 1 deletion src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_transform(self, module: Module, args: TransformArgs):
return HadamardTransform(weight, args)

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = deterministic_hadamard_matrix(size)
data = deterministic_hadamard_matrix(size, dtype, device)
data = data.to(dtype=dtype, device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ class RandomHadamardFactory(HadamardFactory):
"""

def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
data = random_hadamard_matrix(size, self.generator)
data = random_hadamard_matrix(size, dtype, device, self.generator)
data = data.to(dtype=dtype, device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)
183 changes: 91 additions & 92 deletions src/compressed_tensors/transform/utils/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,149 +13,148 @@
# limitations under the License.

import math
from typing import Optional, Tuple
from pathlib import Path
from typing import Optional

import numpy
import torch
from safetensors import safe_open


__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
REPO_PATH = Path(__file__).parent / "hadamards.safetensors"

# adapted from:
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
def deterministic_hadamard_matrix(size: int) -> torch.Tensor:

__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]


# note that hadamard matrix multiplication can be accelerated using a library such as
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master


def deterministic_hadamard_matrix(
size: int,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
`n` must be a power of 2.

Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501

:param size: order of the matrix, must be a power of 2
:param dtype: data type of matrix
:param device: device to construct matrix on
:return: hadamard matrix of size `size`
"""
if size <= 0:
raise ValueError("Cannot construct deterministic hadamard of size <= 0")

log2 = int(math.log(size, 2))
log2 = int(math.log2(size))
if size != 2**log2:
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")

H = numpy.array([[1]], dtype=int)
H = torch.tensor([[1]], dtype=dtype, device=device)

# Sylvester's construction
for i in range(0, log2):
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))

return torch.from_numpy(H / math.sqrt(size))
for _ in range(log2):
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))


# adapted from:
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py

# TODO: the following library exists for online rotations and should be considered
# in the future:
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
return H / math.sqrt(size)


def random_hadamard_matrix(
size: int, gen: Optional[torch.Generator] = None
size: int,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = torch.device("cpu"),
gen: Optional[torch.Generator] = None,
) -> torch.Tensor:
"""
Produces a randomly generated Hadamard matrix.
See https://cornell-relaxml.github.io/quip-sharp/ ,
Section "Randomized Hadamard Transformation"
Produces a randomly generated Hadamard matrix. Differs from
`deterministic_hadamard_matrix` in that this function supports non powers of 2
and randomization using a seeded generator

Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
Known matrices were retrieved from N. J. A. Sloane's Library of Hadamard Matrices http://www.neilsloane.com/hadamard/ # noqa: E501

:param size: The dimension of the hamadard matrix
:param dtype: data type of matrix
:param device: device to construct matrix on
:param gen: Optional generator random values
:return: randomly generated hadamard matrix
"""
# Benefits: support other shapes / non powers of 2, support randomization
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64)
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype) # cpu
Q = Q.to(device=device)
Q = Q * 2 - 1
Q = torch.diag(Q)
return _matmul_hadU(Q) / math.sqrt(size)


def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]:
# NOTE: we can easily extend the list of supported shapes/sizes
# by adding to these methods
hadK, K = None, None
if n % 20 == 0:
assert _is_pow2(n // 20)
K = 20
hadK = _get_had20().T if transpose else _get_had20()
elif n % 12 == 0:
assert _is_pow2(n // 12)
K = 12
hadK = _get_had12().T if transpose else _get_had12()
else:
assert _is_pow2(n)
K = 1
def is_pow2(n: int) -> bool:
"""
Check if a number is a power of 2

return hadK, K
:param n: number to check
:return: True iff `n` is a power of 2
"""
return n > 0 and (n & (n - 1) == 0)


def _fetch_hadamard_divisor(
n: int,
dtype: torch.dtype,
device: torch.device = torch.device("cpu"),
file_path: str = REPO_PATH,
) -> Optional[torch.Tensor]:
"""
Fetch a known hadamard matrix from the given file path. The returned matrix will
be of of size `k` such that `n / k` is a power of two. Return None if no such
matrix exists.

Note: This function reopens the safetensors file every time it is called.
This is technically inefficient, but a very small runtime cost and simpler
than forcing callers to manage the file open context

:param n: size of known hadamard matrix
:return: a known hadamard matrix of size `n` if one exists, else None
"""
with safe_open(file_path, framework="pt", device=str(device)) as file:
divisors = sorted((int(key) for key in file.keys()), reverse=True)
for divisor in divisors:
if n % divisor == 0 and is_pow2(n // divisor):
return file.get_tensor(str(divisor)).to(dtype=dtype)

return None


def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
size = X.size(0)
dtype = X.dtype
device = X.device

def _matmul_hadU(X, transpose=False) -> torch.Tensor:
n = X.shape[-1]
# Check if we have the determined hadamard matrix
hadK, K = _get_hadK(n, transpose)
hadK = _fetch_hadamard_divisor(size, dtype, device=device)
if hadK is None:
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
K = hadK.size(0)

# Reshape diag matrix with randomized -1/+1
input = X.clone().view(-1, n, 1)
input = X.clone().view(-1, size, 1)
output = input.clone()

# for cases when hadK is not predetermined, determine hadamard matrix
while input.shape[1] > K:
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
output = output.view(input.shape)
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
output = output.view(input.shape[0], input.shape[1], -1)
(input, output) = (output, input)
assert input.shape[1] == K
del output

# K == 1 when hadK is None; this happens when the size dim (n)
# is not comaptible with any of the maintained hadamard matrices

if K > 1:
# Do not explicitly repeat - OOM
# input = torch.bmm(
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
# Use bcast instead

# for cases when hadK is pre-determined
input = hadK.view(1, K, K).to(input) @ input
# Do not explicitly repeat - OOM
# input = torch.bmm(
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
# Use bcast instead
input = hadK.view(1, K, K).to(input) @ input

# normalize
return input.view(X.shape)


def _is_pow2(n: int) -> bool:
return (n & (n - 1) == 0) and (n > 0)


def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray:
had_unpacked = numpy.unpackbits(packed_bits)
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
return had_unpacked


# http://www.neilsloane.com/hadamard/index.html
def _get_had12() -> torch.Tensor:
# fmt: off
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
# fmt: on
# TODO: just unpack during apply
had_12_unpacked = _reshape_bits(had_12, original_size=12)
return torch.tensor(had_12_unpacked)


def _get_had20() -> torch.Tensor:
# fmt: off
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
# fmt: on
# TODO: just unpack during apply
had_20_unpacked = _reshape_bits(had_20, original_size=20)
return torch.tensor(had_20_unpacked)
Binary file not shown.
70 changes: 38 additions & 32 deletions tests/test_transform/utils/test_hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,48 @@
# limitations under the License.


import numpy
import pytest
import torch
from compressed_tensors.transform.utils.hadamard import (
_get_had12,
_get_had20,
deterministic_hadamard_matrix,
is_pow2,
random_hadamard_matrix,
)
from tests.testing_utils import requires_gpu


@pytest.mark.parametrize(
"had_func",
[
_get_had12,
_get_had20,
],
)
def test_packed_hadamard_compliant(had_func):
had_matrix = had_func()
size = had_matrix.size(0)
# HH.T == nI
product = had_matrix @ had_matrix.T
assert torch.equal(product, size * torch.eye(size))
_sizes_to_test = [
768, # gpt2 small
1024, # gpt2 medium
1280, # qwen_2_5_vl vision
1600, # gpt2 xl
2048, # gpt3 small
3584, # qwen_2_5_vl
3840, # qwen_2_5_vl vision qkv
4096, # llama3
7168, # deepseek_v3
14336, # llama3 intermediate
18432, # deepseek_v3 intermediate
18944, # qwen_2_5_vl intermediate
]
_atol = 1e-1 # bfloat16 is low precision for large matrices


@pytest.mark.parametrize(
"size",
[4096, 2048],
)
@requires_gpu
@pytest.mark.parametrize("size", _sizes_to_test)
def test_random_hadamard_matrix_compliant(size):
had_matrix = random_hadamard_matrix(size)
product = torch.round(had_matrix @ had_matrix.T)
assert torch.equal(product, torch.eye(size))
# (H / sqrt(n))(H.T / sqrt(n)) == I
matrix = random_hadamard_matrix(size, device="cuda")
product = matrix @ matrix.T
eye = torch.eye(size, dtype=product.dtype, device="cuda")
assert torch.allclose(product, eye, atol=_atol)


def test_random_hadamard_generator():
# check that generation is deterministic with a seed
generator = torch.Generator().manual_seed(42)
one = random_hadamard_matrix(2048, generator)
two = random_hadamard_matrix(2048, generator)
one = random_hadamard_matrix(2048, gen=generator)
two = random_hadamard_matrix(2048, gen=generator)

one_true = torch.tensor(
[
Expand All @@ -73,12 +75,16 @@ def test_random_hadamard_generator():
assert torch.all(two[:3, :3].sign() == two_true.sign())


@pytest.mark.parametrize(
"size",
[1024],
)
@requires_gpu
@pytest.mark.parametrize("size", _sizes_to_test)
def test_deterministic_hadamard_compliant(size):
had_matrix = deterministic_hadamard_matrix(size)
if not is_pow2(size):
with pytest.raises(ValueError):
matrix = deterministic_hadamard_matrix(size, device="cuda")
return

# (H / sqrt(n))(H.T / sqrt(n)) == I
product = had_matrix @ had_matrix.T
assert numpy.array_equal(product, numpy.eye(size))
matrix = deterministic_hadamard_matrix(size, device="cuda")
product = matrix @ matrix.T
eye = torch.eye(size, dtype=product.dtype, device="cuda")
assert torch.allclose(product, eye, atol=_atol)