From 1adfa30b1018e8e81cb5ab43637512d278aa1c4e Mon Sep 17 00:00:00 2001 From: Dipika Date: Tue, 11 Mar 2025 19:33:08 +0000 Subject: [PATCH 1/6] transform registry support --- src/compressed_tensors/transforms/__init__.py | 18 ++ src/compressed_tensors/transforms/base.py | 87 +++++++++ src/compressed_tensors/transforms/hadamard.py | 82 +++++++++ .../transforms/hadamard_utils.py | 166 ++++++++++++++++++ .../transforms/matrix_multiply.py | 47 +++++ .../transforms/random_hadamard.py | 87 +++++++++ src/compressed_tensors/transforms/utils.py | 51 ++++++ tests/test_transforms/test_hadamards.py | 60 +++++++ tests/test_transforms/test_transforms.py | 111 ++++++++++++ 9 files changed, 709 insertions(+) create mode 100644 src/compressed_tensors/transforms/__init__.py create mode 100644 src/compressed_tensors/transforms/base.py create mode 100644 src/compressed_tensors/transforms/hadamard.py create mode 100644 src/compressed_tensors/transforms/hadamard_utils.py create mode 100644 src/compressed_tensors/transforms/matrix_multiply.py create mode 100644 src/compressed_tensors/transforms/random_hadamard.py create mode 100644 src/compressed_tensors/transforms/utils.py create mode 100644 tests/test_transforms/test_hadamards.py create mode 100644 tests/test_transforms/test_transforms.py diff --git a/src/compressed_tensors/transforms/__init__.py b/src/compressed_tensors/transforms/__init__.py new file mode 100644 index 00000000..3d8a0f2e --- /dev/null +++ b/src/compressed_tensors/transforms/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import Transforms +from .hadamard import Hadamard +from .matrix_multiply import MatrixMultiply +from .random_hadamard import RandomHadamard diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py new file mode 100644 index 00000000..16bcf927 --- /dev/null +++ b/src/compressed_tensors/transforms/base.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union + +import torch +from compressed_tensors.registry.registry import RegistryMixin +from compressed_tensors.transforms.utils import apply_matrix_transform + + +__all__ = ["Transforms"] + + +# TODO: We don't need to save all the __call__ args for serialization or even have +# them defined by a recipe. Some of them, such as if the transformation should be the +# first or second matirx in torch.matmul depending on dimensions, can be inferred +# by the layer time likely. + +MATIRX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"] + + +class Transforms(RegistryMixin): + def __new__( + cls, + transform: torch.Tensor, + device: Optional[Union[str, torch.device]] = "cuda", + dtype: Optional[torch.dtype] = torch.bfloat16, + *args, + **kwargs, + ): + """ + Base class for setting up transforms. The registry creates transforms + as parameters which can be attached to modules. + + import torch + + size = 1024 + dtype = torch.bfloat16 + module = torch.nn.Linear(size, size) + + hadamard_transform = Transforms.load_from_registry( + "random_hadamard", size=size, dtype=dtype + ) + hadamard_apply = Transforms.fetch_apply("random_hadamard") + module.weight_transform = hadamard_transform + + transformed_output = hadamard_apply(input_tensor=module.weight, + transform=moduel.weight_transform) + + hadamard_inverse = Transforms.fetch_inverse_apply("random_hadamard") + original_weight = hadamard_inverse(input_tensor=transformed_output, + transform=model.weight_trainsform, + transpose=True) + + :param transform: transform (e.g. torch.Tensor, scalar) to be applied + """ + return torch.nn.Parameter(transform.to(device).to(dtype), requires_grad=False) + + @classmethod + def fetch_apply(cls, name: str): + if name in MATIRX_TRANSFORMS: + return apply_matrix_transform + raise NotImplementedError("Only matrix transforms are supported") + + @classmethod + def fetch_inverse_apply(cls, name: str): + return cls.get_value_from_registry(name=name).inverse_apply + + @staticmethod + def inverse_apply( + transform: torch.Tensor, input_tensor: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + """ + Apply the inverse operation applied by the apply method + """ + raise NotImplementedError() diff --git a/src/compressed_tensors/transforms/hadamard.py b/src/compressed_tensors/transforms/hadamard.py new file mode 100644 index 00000000..b1071ab0 --- /dev/null +++ b/src/compressed_tensors/transforms/hadamard.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from compressed_tensors.transforms import Transforms +from compressed_tensors.transforms.hadamard_utils import deterministic_hadamard_matrix +from compressed_tensors.transforms.utils import apply_matrix_transform + + +@Transforms.register("hadamard") +class Hadamard(Transforms): + def __new__( + cls, + size: int, + empty: Optional[bool] = False, + device: Optional[Union[str, torch.device]] = "cuda", + dtype: Optional[torch.dtype] = torch.bfloat16, + ): + """ + Produces a hadamard matrix with dims (size, size), with values + -1 and 1, and the property HH.T == nI i.e the transformation + matrix when multiplied by its transpose is a multiple of the identity. + All rows and columns are orthonormal. The matrix returned + is not normalized and will be deterministic. + + :param size: size of the matrix, if generating a new Hadamard matrix. + The generated matrix will have dimensions (size, size) + :param transform: if loading in a previously generated matrix, will + use that through this transformation, as opposed to creating a new + one + :param dtype: type to cast the rotation matrix to + + """ + if not empty: + # TODO: this is deterministic; we should just serialize the size + transform = torch.Tensor(deterministic_hadamard_matrix(size=size)) + else: + transform = torch.empty((size, size)) + + return super().__new__(cls, transform=transform, device=device, dtype=dtype) + + @staticmethod + def inverse_apply( + transform: torch.Tensor, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + """ + Apply the inverse operation of `apply` + + :param transform: hadamard tensor + :param input_tensor: tensor to which the transform matrix is applied + :param transpose: whether or not the transform matrix is transposed before + being applied. + :param first: if the transform matrix will be the first or second matrix to be + multiplied + """ + transpose = not transpose + # need to normalize before sending back + return ( + apply_matrix_transform( + transform=transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) + / transform.shape[0] + ) diff --git a/src/compressed_tensors/transforms/hadamard_utils.py b/src/compressed_tensors/transforms/hadamard_utils.py new file mode 100644 index 00000000..2cbd74d8 --- /dev/null +++ b/src/compressed_tensors/transforms/hadamard_utils.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy +import torch + + +__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"] + +# adapted from: +# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py +def deterministic_hadamard_matrix(size: int): + """ + Construct an Hadamard matrix. + + Constructs an n-by-n Hadamard matrix, using Sylvester's + construction. `n` must be a power of 2. + + :param size: order of the matrix; must be a power of 2 + + returns a (size, size) hadamard matrix + """ + + dtype = int + if size < 1: + lg2 = 0 + else: + lg2 = int(math.log(size, 2)) + if 2**lg2 != size: + raise ValueError("size must be an positive integer and a power of 2") + + H = numpy.array([[1]], dtype=dtype) + + # Sylvester's construction + for i in range(0, lg2): + H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H)))) + + return H + + +# adapted from: +# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py + +# TODO: the following library exists for online rotations and should be considered +# in the future: +# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master + + +def random_hadamard_matrix(size: int) -> torch.Tensor: + """ + Produces a randomly generated Hadamard matrix. + See https://cornell-relaxml.github.io/quip-sharp/ , + Section "Randomized Hadamard Transformation" + + :param size: The dimension of the matrix. Matrix generated will have dimensions + (size, size) + + """ + # TODO: potentially update to add "seed" as an arugment, to allow + # the matrix generated to be reproducible + + # Benefits: support other shapes / non powers of 2, support randomization + Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + Q = Q * 2 - 1 + Q = torch.diag(Q) + return _matmul_hadU(Q) + + +def _get_hadK(n, transpose=False): + # NOTE: we can easily extend the list of supported shapes/sizes + # by adding to these methods + hadK, K = None, None + if n % 20 == 0: + assert _is_pow2(n // 20) + K = 20 + hadK = _get_had20().T if transpose else _get_had20() + elif n % 12 == 0: + assert _is_pow2(n // 12) + K = 12 + hadK = _get_had12().T if transpose else _get_had12() + else: + assert _is_pow2(n) + K = 1 + + return hadK, K + + +def _matmul_hadU(X, transpose=False): + n = X.shape[-1] + # Check if we have the determined hadamard matrix + hadK, K = _get_hadK(n, transpose) + # Reshape diag matrix with randomized -1/+1 + input = X.clone().view(-1, n, 1) + output = input.clone() + + # for cases when hadK is not predetermined, determine hadamard matrix + while input.shape[1] > K: + input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) + output = output.view(input.shape) + output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] + output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] + output = output.view(input.shape[0], input.shape[1], -1) + (input, output) = (output, input) + del output + + # K == 1 when hadK is None; this happens when the size dim (n) + # is not comaptible with any of the maintained hadamard matrices + + if K > 1: + # Do not explicitly repeat - OOM + # input = torch.bmm( + # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) + # Use bcast instead + + # for cases when hadK is pre-determined + input = hadK.view(1, K, K).to(input) @ input + + # normalize + return input.view(X.shape) / torch.tensor(n).sqrt() + + +def _is_pow2(n): + return (n & (n - 1) == 0) and (n > 0) + + +def _reshape_bits(packed_bits, original_size): + had_unpacked = numpy.unpackbits(packed_bits) + had_unpacked = [1 if x == 1 else -1 for x in had_unpacked] + had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size)) + return had_unpacked + + +# http://www.neilsloane.com/hadamard/index.html +def _get_had12(): + # fmt: off + had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218, + 62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8) + # fmt: on + # TODO: just unpack during apply + had_12_unpacked = _reshape_bits(had_12, original_size=12) + return torch.FloatTensor(had_12_unpacked) + + +def _get_had20(): + # fmt: off + had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252, + 216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243, + 97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43, + 205, 225, 94, 107, 10, 243], dtype=numpy.uint8) + # fmt: on + # TODO: just unpack during apply + had_20_unpacked = _reshape_bits(had_20, original_size=20) + return torch.FloatTensor(had_20_unpacked) diff --git a/src/compressed_tensors/transforms/matrix_multiply.py b/src/compressed_tensors/transforms/matrix_multiply.py new file mode 100644 index 00000000..dd5611ec --- /dev/null +++ b/src/compressed_tensors/transforms/matrix_multiply.py @@ -0,0 +1,47 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transforms import Transforms + + +# TODO: fix loading +@Transforms.register("matrix-mul") +class MatrixMultiply(Transforms): + @staticmethod + def inverse_apply( + transform: torch.Tensor, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + """ + Apply the inverse operation of `apply` + + :param transform: matrix tensor + :param input_tensor: tensor to which the transform matrix is applied + :param transpose: whether or not the transform matrix is transposed before + being applied. + :param first: if the transform matrix will be the first or second matrix to be + multiplied + """ + + # Note: not implemented for lower precision than float32 + transform = torch.linalg.inv(transform) + return apply_matrix_transform( + transform=transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) diff --git a/src/compressed_tensors/transforms/random_hadamard.py b/src/compressed_tensors/transforms/random_hadamard.py new file mode 100644 index 00000000..f61e878d --- /dev/null +++ b/src/compressed_tensors/transforms/random_hadamard.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from compressed_tensors.transforms import Transforms +from compressed_tensors.transforms.hadamard_utils import random_hadamard_matrix +from compressed_tensors.transforms.utils import apply_matrix_transform + + +@Transforms.register("random-hadamard") +class RandomHadamard(Transforms): + def __new__( + cls, + size: int, + empty: Optional[bool] = False, + device: Optional[Union[str, torch.device]] = "cuda", + dtype: Optional[torch.dtype] = torch.bfloat16, + ): + """ + Produces a randomly generated matrix with dims (size, size), with values + between -1 and 1, and the property HH.T == I i.e the transformation + matrix when multiplied by its transpose is the identity. + All rows and columns are orthonormal. The matrix returned + is normalized and has the form (1/sqrt(size)) * M where all elements + of M are -1 or +1. + + :param size: size of the matrix, if generating a new Hadamard matrix. + The generated matrix will have dimensions (size, size) + :param transform: if loading in a previously generated matrix, will + use that through this transformation, as opposed to creating a new + one + :param dtype: type to cast the rotation matrix to + + TODO: We can likely make the serialization of this more efficient: + The generation of this matrix starts with generating a random + matrix with dims (size, size). We could potentially just store + a randomly generated seed and the size, as opposed to storing the entire + matrix, to reproduce an identical matrix during runtime. That way, + we will not have to store the entire matrix. Will need to consider + accuracy implications. + """ + + if not empty: + transform = random_hadamard_matrix(size=size) + else: + transform = torch.empty((size, size)) + + return super().__new__(cls, transform=transform, device=device, dtype=dtype) + + @staticmethod + def inverse_apply( + transform: torch.Tensor, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + """ + Apply the inverse operation of `apply` + + :param transform: hadamard tensor + :param input_tensor: tensor to which the transform matrix is applied + :param transpose: whether or not the transform matrix is transposed before + being applied. + :param first: if the transform matrix will be the first or second matrix to be + multiplied + """ + + transpose = not transpose + return apply_matrix_transform( + transform=transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) diff --git a/src/compressed_tensors/transforms/utils.py b/src/compressed_tensors/transforms/utils.py new file mode 100644 index 00000000..997c91f1 --- /dev/null +++ b/src/compressed_tensors/transforms/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +__all__ = ["apply_matrix_transform"] + + +def apply_matrix_transform( + transform: torch.Tensor, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, +) -> torch.Tensor: + """ + Apply a matrix-type transform + + :param transform: transform tensor + :param input_tensor: tensor to which the transform matrix is applied + :param transpose: whether or not the transform matrix is transposed before + being applied. + :param first: if the transform matrix will be the first or second matrix to be + multiplied + + returns a transformed input_tensor + """ + + if transpose: + return ( + torch.matmul(transform.T, input_tensor) + if first + else torch.matmul(input_tensor, transform.T) + ) + + return ( + torch.matmul(transform, input_tensor) + if first + else torch.matmul(input_tensor, transform) + ) diff --git a/tests/test_transforms/test_hadamards.py b/tests/test_transforms/test_hadamards.py new file mode 100644 index 00000000..77f3394a --- /dev/null +++ b/tests/test_transforms/test_hadamards.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy +import pytest +import torch +from compressed_tensors.transforms.hadamard_utils import ( + _get_had12, + _get_had20, + deterministic_hadamard_matrix, + random_hadamard_matrix, +) + + +@pytest.mark.parametrize( + "had_func", + [ + _get_had12, + _get_had20, + ], +) +def test_packed_hadamard_compliant(had_func): + had_matrix = had_func() + size = had_matrix.shape[0] + # HH.T == nI + val_1 = had_matrix @ had_matrix.T + assert torch.equal(val_1 / size, torch.eye(size)) + + +@pytest.mark.parametrize( + "size", + [4096, 2048], +) +def test_random_hadamard_matrix_compliant(size): + had_matrix = random_hadamard_matrix(size) + val_1 = torch.round(had_matrix @ had_matrix.T) + assert torch.equal(val_1, torch.eye(size)) + + +@pytest.mark.parametrize( + "size", + [1024, 2048], +) +def test_deterministic_hadamard_compliant(size): + had_matrix = deterministic_hadamard_matrix(size) + # HH.T == nI + val_1 = had_matrix @ had_matrix.T + assert numpy.array_equal(val_1 / size, numpy.eye(size)) diff --git a/tests/test_transforms/test_transforms.py b/tests/test_transforms/test_transforms.py new file mode 100644 index 00000000..c310f023 --- /dev/null +++ b/tests/test_transforms/test_transforms.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Union + +import pytest +import torch +from compressed_tensors.transforms import ( + Hadamard, + MatrixMultiply, + RandomHadamard, + Transforms, +) +from compressed_tensors.transforms.hadamard_utils import random_hadamard_matrix + + +@pytest.mark.parametrize( + "size,dtype", + [ + [1024, torch.float32], + [2048, torch.float16], + [2048, torch.bfloat16], + [4096, torch.float32], + [5120, torch.float16], + [8192, torch.bfloat16], + ], +) +def test_random_hadamard_transform(size: int, dtype: torch.dtype): + hadamard_transform = Transforms.load_from_registry( + "random-hadamard", size=size, dtype=dtype, device="cpu" + ) + # check initialize + assert hadamard_transform is not None + + val_1 = torch.round(hadamard_transform @ hadamard_transform.T) + + # output will be normalized, multiply by sqrt(size) to ensure form + normalized = math.sqrt(size) * hadamard_transform + # all values should be -1 or +1 + assert torch.all(torch.isin(normalized, torch.Tensor([-1, +1]))) + # check creation; HH.T == I + assert torch.equal(val_1, torch.eye(size)) + + # check apply + x = torch.rand((size, size), dtype=dtype) + apply = Transforms.fetch_apply("random-hadamard") + transformed_value = apply(input_tensor=x, transform=hadamard_transform) + # TODO: check to make sure the matrix was applied correctly? + assert transformed_value.shape == (size, size) + + +@pytest.mark.parametrize( + "size,dtype", + [ + [1024, torch.bfloat16], + [2048, torch.float16], + ], +) +def test_deterministic_hadamard_transform(size: int, dtype: torch.dtype): + hadamard_transform = Transforms.load_from_registry( + "hadamard", size=size, dtype=dtype, device="cpu" + ) + + # check initialize + assert hadamard_transform is not None + assert torch.all(torch.isin(hadamard_transform, torch.Tensor([-1, +1]))) + + val_1 = hadamard_transform @ hadamard_transform.T + # check creation; HH.T == nI + assert torch.equal(val_1 / size, torch.eye(size)) + + # check apply + x = torch.rand((size, size), dtype=dtype) + apply = Transforms.fetch_apply("hadamard") + transformed_value = apply(input_tensor=x, transform=hadamard_transform) + # TODO: check to make sure the matrix was applied correctly? + assert transformed_value.shape == (size, size) + + +@pytest.mark.parametrize( + "size,dtype", + [ + [1024, torch.float32], + [2048, torch.float16], + [4096, torch.bfloat16], + ], +) +def test_multiplier_transform(size: int, dtype: torch.dtype): + multiplier = torch.eye((size)) + multiplier_transform = Transforms.load_from_registry( + "matrix-mul", transform=multiplier, device="cpu", dtype=dtype + ) + assert multiplier_transform is not None + assert torch.equal(multiplier_transform, multiplier) + + x = torch.rand((size, size), dtype=dtype) + apply = Transforms.fetch_apply("matrix-mul") + transformed_value = apply(input_tensor=x, transform=multiplier_transform) + assert torch.equal(transformed_value, x) From ab6101ee865310f5cd0c463670333356b919b537 Mon Sep 17 00:00:00 2001 From: Dipika Date: Thu, 20 Mar 2025 18:39:51 +0000 Subject: [PATCH 2/6] fix typo, change class, remove long test case --- src/compressed_tensors/transforms/base.py | 6 +++--- tests/test_transforms/test_hadamards.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py index 16bcf927..be2310e4 100644 --- a/src/compressed_tensors/transforms/base.py +++ b/src/compressed_tensors/transforms/base.py @@ -27,10 +27,10 @@ # first or second matirx in torch.matmul depending on dimensions, can be inferred # by the layer time likely. -MATIRX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"] +MATRIX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"] -class Transforms(RegistryMixin): +class Transforms(torch.nn.Parameter, RegistryMixin): def __new__( cls, transform: torch.Tensor, @@ -69,7 +69,7 @@ def __new__( @classmethod def fetch_apply(cls, name: str): - if name in MATIRX_TRANSFORMS: + if name in MATRIX_TRANSFORMS: return apply_matrix_transform raise NotImplementedError("Only matrix transforms are supported") diff --git a/tests/test_transforms/test_hadamards.py b/tests/test_transforms/test_hadamards.py index 77f3394a..25c8a0fe 100644 --- a/tests/test_transforms/test_hadamards.py +++ b/tests/test_transforms/test_hadamards.py @@ -51,7 +51,7 @@ def test_random_hadamard_matrix_compliant(size): @pytest.mark.parametrize( "size", - [1024, 2048], + [1024], ) def test_deterministic_hadamard_compliant(size): had_matrix = deterministic_hadamard_matrix(size) From 1e1760bac6cad594ba9761b6e0bcc04e95de9d34 Mon Sep 17 00:00:00 2001 From: Dipika Date: Fri, 21 Mar 2025 22:09:23 +0000 Subject: [PATCH 3/6] clean-up --- src/compressed_tensors/transforms/base.py | 62 ++++++++++--------- src/compressed_tensors/transforms/hadamard.py | 26 ++++++-- .../transforms/matrix_multiply.py | 19 +++++- .../transforms/random_hadamard.py | 24 +++++-- tests/test_transforms/test_transforms.py | 21 +++---- 5 files changed, 97 insertions(+), 55 deletions(-) diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py index be2310e4..ce25642a 100644 --- a/src/compressed_tensors/transforms/base.py +++ b/src/compressed_tensors/transforms/base.py @@ -17,6 +17,7 @@ import torch from compressed_tensors.registry.registry import RegistryMixin from compressed_tensors.transforms.utils import apply_matrix_transform +from compressed_tensors.utils import register_offload_parameter, update_parameter_data __all__ = ["Transforms"] @@ -27,18 +28,16 @@ # first or second matirx in torch.matmul depending on dimensions, can be inferred # by the layer time likely. -MATRIX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"] - -class Transforms(torch.nn.Parameter, RegistryMixin): - def __new__( - cls, +class Transforms(RegistryMixin): + def __init__( + self, transform: torch.Tensor, + learnable: Optional[bool] = True, device: Optional[Union[str, torch.device]] = "cuda", dtype: Optional[torch.dtype] = torch.bfloat16, - *args, - **kwargs, ): + self.learnable = learnable """ Base class for setting up transforms. The registry creates transforms as parameters which can be attached to modules. @@ -48,38 +47,45 @@ def __new__( size = 1024 dtype = torch.bfloat16 module = torch.nn.Linear(size, size) + name = "weight_transform" hadamard_transform = Transforms.load_from_registry( "random_hadamard", size=size, dtype=dtype ) - hadamard_apply = Transforms.fetch_apply("random_hadamard") - module.weight_transform = hadamard_transform - transformed_output = hadamard_apply(input_tensor=module.weight, - transform=moduel.weight_transform) + hadamard_transform.register_to_module(name, module) + module.transform_data = {name: {"call_args": dict, "class": hadamard_transform}} - hadamard_inverse = Transforms.fetch_inverse_apply("random_hadamard") - original_weight = hadamard_inverse(input_tensor=transformed_output, - transform=model.weight_trainsform, - transpose=True) + transformed_output = hadamard_transform.apply(input_tensor=module.weight) + original_weight = hadamard_transform.inverse_apply( + input_tensor=transformed_output) :param transform: transform (e.g. torch.Tensor, scalar) to be applied """ - return torch.nn.Parameter(transform.to(device).to(dtype), requires_grad=False) - - @classmethod - def fetch_apply(cls, name: str): - if name in MATRIX_TRANSFORMS: - return apply_matrix_transform - raise NotImplementedError("Only matrix transforms are supported") - - @classmethod - def fetch_inverse_apply(cls, name: str): - return cls.get_value_from_registry(name=name).inverse_apply + if self.learnable: + self.transform = torch.nn.Parameter( + transform.to(dtype).to(device), requires_grad=False + ) + else: + self.transform = torch.nn.Buffer(transform.to(dtype).to(device)) + + # register to class for easy offloading, serialization, deserialization + def register_to_module(self, name: str, module: torch.nn.Module): + if self.learnable: + register_offload_parameter(module, name, self.transform) + else: + # TODO: have to verify serialization/offloading + module.register_buffer(name, self.transform) + + def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Apply the transform to the module + """ + raise NotImplementedError() - @staticmethod + # TODO: potentially split into its own transform using the same shared set-up def inverse_apply( - transform: torch.Tensor, input_tensor: torch.Tensor, *args, **kwargs + self, input_tensor: torch.Tensor, *args, **kwargs ) -> torch.Tensor: """ Apply the inverse operation applied by the apply method diff --git a/src/compressed_tensors/transforms/hadamard.py b/src/compressed_tensors/transforms/hadamard.py index b1071ab0..19e3c98d 100644 --- a/src/compressed_tensors/transforms/hadamard.py +++ b/src/compressed_tensors/transforms/hadamard.py @@ -22,12 +22,14 @@ @Transforms.register("hadamard") class Hadamard(Transforms): - def __new__( - cls, + def __init__( + self, size: int, empty: Optional[bool] = False, device: Optional[Union[str, torch.device]] = "cuda", dtype: Optional[torch.dtype] = torch.bfloat16, + *args, + **kwargs, ): """ Produces a hadamard matrix with dims (size, size), with values @@ -50,11 +52,23 @@ def __new__( else: transform = torch.empty((size, size)) - return super().__new__(cls, transform=transform, device=device, dtype=dtype) + super().__init__(transform=transform, dtype=dtype, device=device) + + def apply( + self, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + return apply_matrix_transform( + transform=self.transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) - @staticmethod def inverse_apply( - transform: torch.Tensor, + self, input_tensor: torch.Tensor, transpose: bool = False, first: bool = True, @@ -73,7 +87,7 @@ def inverse_apply( # need to normalize before sending back return ( apply_matrix_transform( - transform=transform, + transform=self.transform, input_tensor=input_tensor, transpose=transpose, first=first, diff --git a/src/compressed_tensors/transforms/matrix_multiply.py b/src/compressed_tensors/transforms/matrix_multiply.py index dd5611ec..1bfeffa4 100644 --- a/src/compressed_tensors/transforms/matrix_multiply.py +++ b/src/compressed_tensors/transforms/matrix_multiply.py @@ -14,14 +14,27 @@ import torch from compressed_tensors.transforms import Transforms +from compressed_tensors.transforms.utils import apply_matrix_transform # TODO: fix loading @Transforms.register("matrix-mul") class MatrixMultiply(Transforms): - @staticmethod + def apply( + self, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + return apply_matrix_transform( + transform=self.transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) + def inverse_apply( - transform: torch.Tensor, + self, input_tensor: torch.Tensor, transpose: bool = False, first: bool = True, @@ -40,7 +53,7 @@ def inverse_apply( # Note: not implemented for lower precision than float32 transform = torch.linalg.inv(transform) return apply_matrix_transform( - transform=transform, + transform=self.transform, input_tensor=input_tensor, transpose=transpose, first=first, diff --git a/src/compressed_tensors/transforms/random_hadamard.py b/src/compressed_tensors/transforms/random_hadamard.py index f61e878d..162269c5 100644 --- a/src/compressed_tensors/transforms/random_hadamard.py +++ b/src/compressed_tensors/transforms/random_hadamard.py @@ -22,8 +22,8 @@ @Transforms.register("random-hadamard") class RandomHadamard(Transforms): - def __new__( - cls, + def __init__( + self, size: int, empty: Optional[bool] = False, device: Optional[Union[str, torch.device]] = "cuda", @@ -58,11 +58,23 @@ def __new__( else: transform = torch.empty((size, size)) - return super().__new__(cls, transform=transform, device=device, dtype=dtype) + super().__init__(transform=transform, device=device, dtype=dtype) + + def apply( + self, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + return apply_matrix_transform( + transform=self.transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) - @staticmethod def inverse_apply( - transform: torch.Tensor, + self, input_tensor: torch.Tensor, transpose: bool = False, first: bool = True, @@ -80,7 +92,7 @@ def inverse_apply( transpose = not transpose return apply_matrix_transform( - transform=transform, + transform=self.transform, input_tensor=input_tensor, transpose=transpose, first=first, diff --git a/tests/test_transforms/test_transforms.py b/tests/test_transforms/test_transforms.py index c310f023..adf1941e 100644 --- a/tests/test_transforms/test_transforms.py +++ b/tests/test_transforms/test_transforms.py @@ -44,10 +44,10 @@ def test_random_hadamard_transform(size: int, dtype: torch.dtype): # check initialize assert hadamard_transform is not None - val_1 = torch.round(hadamard_transform @ hadamard_transform.T) + val_1 = torch.round(hadamard_transform.transform @ hadamard_transform.transform.T) # output will be normalized, multiply by sqrt(size) to ensure form - normalized = math.sqrt(size) * hadamard_transform + normalized = math.sqrt(size) * hadamard_transform.transform # all values should be -1 or +1 assert torch.all(torch.isin(normalized, torch.Tensor([-1, +1]))) # check creation; HH.T == I @@ -55,8 +55,7 @@ def test_random_hadamard_transform(size: int, dtype: torch.dtype): # check apply x = torch.rand((size, size), dtype=dtype) - apply = Transforms.fetch_apply("random-hadamard") - transformed_value = apply(input_tensor=x, transform=hadamard_transform) + transformed_value = hadamard_transform.apply(input_tensor=x) # TODO: check to make sure the matrix was applied correctly? assert transformed_value.shape == (size, size) @@ -75,16 +74,15 @@ def test_deterministic_hadamard_transform(size: int, dtype: torch.dtype): # check initialize assert hadamard_transform is not None - assert torch.all(torch.isin(hadamard_transform, torch.Tensor([-1, +1]))) + assert torch.all(torch.isin(hadamard_transform.transform, torch.Tensor([-1, +1]))) - val_1 = hadamard_transform @ hadamard_transform.T + val_1 = hadamard_transform.transform @ hadamard_transform.transform.T # check creation; HH.T == nI assert torch.equal(val_1 / size, torch.eye(size)) # check apply x = torch.rand((size, size), dtype=dtype) - apply = Transforms.fetch_apply("hadamard") - transformed_value = apply(input_tensor=x, transform=hadamard_transform) + transformed_value = hadamard_transform.apply(input_tensor=x) # TODO: check to make sure the matrix was applied correctly? assert transformed_value.shape == (size, size) @@ -103,9 +101,8 @@ def test_multiplier_transform(size: int, dtype: torch.dtype): "matrix-mul", transform=multiplier, device="cpu", dtype=dtype ) assert multiplier_transform is not None - assert torch.equal(multiplier_transform, multiplier) + assert torch.equal(multiplier_transform.transform, multiplier) x = torch.rand((size, size), dtype=dtype) - apply = Transforms.fetch_apply("matrix-mul") - transformed_value = apply(input_tensor=x, transform=multiplier_transform) - assert torch.equal(transformed_value, x) + transformed_output = multiplier_transform.apply(x) + assert torch.equal(transformed_output, x) From 749420b9d126b448d89c537e41fe19dadcd40edd Mon Sep 17 00:00:00 2001 From: Dipika Date: Sat, 22 Mar 2025 20:56:38 +0000 Subject: [PATCH 4/6] update --- src/compressed_tensors/transforms/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py index ce25642a..5cdda7f0 100644 --- a/src/compressed_tensors/transforms/base.py +++ b/src/compressed_tensors/transforms/base.py @@ -63,9 +63,7 @@ def __init__( :param transform: transform (e.g. torch.Tensor, scalar) to be applied """ if self.learnable: - self.transform = torch.nn.Parameter( - transform.to(dtype).to(device), requires_grad=False - ) + self.transform = torch.nn.Parameter(transform.to(dtype).to(device)) else: self.transform = torch.nn.Buffer(transform.to(dtype).to(device)) From 7ecb1b07c9d7b4c1fcf1c034ee385ec99c57b5bf Mon Sep 17 00:00:00 2001 From: Dipika Date: Sun, 23 Mar 2025 02:28:08 +0000 Subject: [PATCH 5/6] fix; add update --- src/compressed_tensors/transforms/base.py | 15 +++++++++++++++ src/compressed_tensors/transforms/hadamard.py | 2 +- .../transforms/matrix_multiply.py | 3 +-- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py index 5cdda7f0..56adcb62 100644 --- a/src/compressed_tensors/transforms/base.py +++ b/src/compressed_tensors/transforms/base.py @@ -75,6 +75,21 @@ def register_to_module(self, name: str, module: torch.nn.Module): # TODO: have to verify serialization/offloading module.register_buffer(name, self.transform) + def update_transform( + self, + data: torch.Tensor, + module: Optional[torch.nn.Module] = None, + name: Optional[str] = None, + ): + if module is None: + self.transform.data.copy_(data) + else: + # If updating the module parameter data, assumes this is also the transform + # data + if name is None: + raise ValueError("Name and module are required to update parma data") + update_parameter_data(module, data, name) + def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Apply the transform to the module diff --git a/src/compressed_tensors/transforms/hadamard.py b/src/compressed_tensors/transforms/hadamard.py index 19e3c98d..ef0e27a4 100644 --- a/src/compressed_tensors/transforms/hadamard.py +++ b/src/compressed_tensors/transforms/hadamard.py @@ -92,5 +92,5 @@ def inverse_apply( transpose=transpose, first=first, ) - / transform.shape[0] + / self.transform.shape[0] ) diff --git a/src/compressed_tensors/transforms/matrix_multiply.py b/src/compressed_tensors/transforms/matrix_multiply.py index 1bfeffa4..a06d61f2 100644 --- a/src/compressed_tensors/transforms/matrix_multiply.py +++ b/src/compressed_tensors/transforms/matrix_multiply.py @@ -51,9 +51,8 @@ def inverse_apply( """ # Note: not implemented for lower precision than float32 - transform = torch.linalg.inv(transform) return apply_matrix_transform( - transform=self.transform, + transform=torch.linalg.inv(self.transform), input_tensor=input_tensor, transpose=transpose, first=first, From 2988abaa2789004d94c070c86cfe17e2a21bbfc6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 28 Apr 2025 10:01:59 -0400 Subject: [PATCH 6/6] update typehints, ignore floattensor Signed-off-by: Kyle Sayers --- .../transforms/hadamard_utils.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/compressed_tensors/transforms/hadamard_utils.py b/src/compressed_tensors/transforms/hadamard_utils.py index 2cbd74d8..1f042941 100644 --- a/src/compressed_tensors/transforms/hadamard_utils.py +++ b/src/compressed_tensors/transforms/hadamard_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from typing import Tuple import numpy import torch @@ -22,7 +23,7 @@ # adapted from: # https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py -def deterministic_hadamard_matrix(size: int): +def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: """ Construct an Hadamard matrix. @@ -33,19 +34,17 @@ def deterministic_hadamard_matrix(size: int): returns a (size, size) hadamard matrix """ + if size <= 0: + raise ValueError("Cannot construct deterministic hadamard of size <= 0") - dtype = int - if size < 1: - lg2 = 0 - else: - lg2 = int(math.log(size, 2)) - if 2**lg2 != size: - raise ValueError("size must be an positive integer and a power of 2") + log2 = int(math.log(size, 2)) + if size != 2**log2: + raise ValueError("Cannot construct deterministic hadamard of size != 2^n") - H = numpy.array([[1]], dtype=dtype) + H = numpy.array([[1]], dtype=int) # Sylvester's construction - for i in range(0, lg2): + for i in range(0, log2): H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H)))) return H @@ -79,7 +78,7 @@ def random_hadamard_matrix(size: int) -> torch.Tensor: return _matmul_hadU(Q) -def _get_hadK(n, transpose=False): +def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]: # NOTE: we can easily extend the list of supported shapes/sizes # by adding to these methods hadK, K = None, None @@ -98,7 +97,7 @@ def _get_hadK(n, transpose=False): return hadK, K -def _matmul_hadU(X, transpose=False): +def _matmul_hadU(X, transpose=False) -> torch.Tensor: n = X.shape[-1] # Check if we have the determined hadamard matrix hadK, K = _get_hadK(n, transpose) @@ -132,11 +131,11 @@ def _matmul_hadU(X, transpose=False): return input.view(X.shape) / torch.tensor(n).sqrt() -def _is_pow2(n): +def _is_pow2(n: int) -> bool: return (n & (n - 1) == 0) and (n > 0) -def _reshape_bits(packed_bits, original_size): +def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray: had_unpacked = numpy.unpackbits(packed_bits) had_unpacked = [1 if x == 1 else -1 for x in had_unpacked] had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size)) @@ -144,17 +143,17 @@ def _reshape_bits(packed_bits, original_size): # http://www.neilsloane.com/hadamard/index.html -def _get_had12(): +def _get_had12() -> torch.Tensor: # fmt: off had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218, 62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8) # fmt: on # TODO: just unpack during apply had_12_unpacked = _reshape_bits(had_12, original_size=12) - return torch.FloatTensor(had_12_unpacked) + return torch.tensor(had_12_unpacked) -def _get_had20(): +def _get_had20() -> torch.Tensor: # fmt: off had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252, 216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243, @@ -163,4 +162,4 @@ def _get_had20(): # fmt: on # TODO: just unpack during apply had_20_unpacked = _reshape_bits(had_20, original_size=20) - return torch.FloatTensor(had_20_unpacked) + return torch.tensor(had_20_unpacked)