diff --git a/src/compressed_tensors/transforms/__init__.py b/src/compressed_tensors/transforms/__init__.py new file mode 100644 index 00000000..3d8a0f2e --- /dev/null +++ b/src/compressed_tensors/transforms/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import Transforms +from .hadamard import Hadamard +from .matrix_multiply import MatrixMultiply +from .random_hadamard import RandomHadamard diff --git a/src/compressed_tensors/transforms/base.py b/src/compressed_tensors/transforms/base.py new file mode 100644 index 00000000..56adcb62 --- /dev/null +++ b/src/compressed_tensors/transforms/base.py @@ -0,0 +1,106 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Union + +import torch +from compressed_tensors.registry.registry import RegistryMixin +from compressed_tensors.transforms.utils import apply_matrix_transform +from compressed_tensors.utils import register_offload_parameter, update_parameter_data + + +__all__ = ["Transforms"] + + +# TODO: We don't need to save all the __call__ args for serialization or even have +# them defined by a recipe. Some of them, such as if the transformation should be the +# first or second matirx in torch.matmul depending on dimensions, can be inferred +# by the layer time likely. + + +class Transforms(RegistryMixin): + def __init__( + self, + transform: torch.Tensor, + learnable: Optional[bool] = True, + device: Optional[Union[str, torch.device]] = "cuda", + dtype: Optional[torch.dtype] = torch.bfloat16, + ): + self.learnable = learnable + """ + Base class for setting up transforms. The registry creates transforms + as parameters which can be attached to modules. + + import torch + + size = 1024 + dtype = torch.bfloat16 + module = torch.nn.Linear(size, size) + name = "weight_transform" + + hadamard_transform = Transforms.load_from_registry( + "random_hadamard", size=size, dtype=dtype + ) + + hadamard_transform.register_to_module(name, module) + module.transform_data = {name: {"call_args": dict, "class": hadamard_transform}} + + transformed_output = hadamard_transform.apply(input_tensor=module.weight) + original_weight = hadamard_transform.inverse_apply( + input_tensor=transformed_output) + + :param transform: transform (e.g. torch.Tensor, scalar) to be applied + """ + if self.learnable: + self.transform = torch.nn.Parameter(transform.to(dtype).to(device)) + else: + self.transform = torch.nn.Buffer(transform.to(dtype).to(device)) + + # register to class for easy offloading, serialization, deserialization + def register_to_module(self, name: str, module: torch.nn.Module): + if self.learnable: + register_offload_parameter(module, name, self.transform) + else: + # TODO: have to verify serialization/offloading + module.register_buffer(name, self.transform) + + def update_transform( + self, + data: torch.Tensor, + module: Optional[torch.nn.Module] = None, + name: Optional[str] = None, + ): + if module is None: + self.transform.data.copy_(data) + else: + # If updating the module parameter data, assumes this is also the transform + # data + if name is None: + raise ValueError("Name and module are required to update parma data") + update_parameter_data(module, data, name) + + def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Apply the transform to the module + """ + raise NotImplementedError() + + # TODO: potentially split into its own transform using the same shared set-up + def inverse_apply( + self, input_tensor: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + """ + Apply the inverse operation applied by the apply method + """ + raise NotImplementedError() diff --git a/src/compressed_tensors/transforms/hadamard.py b/src/compressed_tensors/transforms/hadamard.py new file mode 100644 index 00000000..ef0e27a4 --- /dev/null +++ b/src/compressed_tensors/transforms/hadamard.py @@ -0,0 +1,96 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from compressed_tensors.transforms import Transforms +from compressed_tensors.transforms.hadamard_utils import deterministic_hadamard_matrix +from compressed_tensors.transforms.utils import apply_matrix_transform + + +@Transforms.register("hadamard") +class Hadamard(Transforms): + def __init__( + self, + size: int, + empty: Optional[bool] = False, + device: Optional[Union[str, torch.device]] = "cuda", + dtype: Optional[torch.dtype] = torch.bfloat16, + *args, + **kwargs, + ): + """ + Produces a hadamard matrix with dims (size, size), with values + -1 and 1, and the property HH.T == nI i.e the transformation + matrix when multiplied by its transpose is a multiple of the identity. + All rows and columns are orthonormal. The matrix returned + is not normalized and will be deterministic. + + :param size: size of the matrix, if generating a new Hadamard matrix. + The generated matrix will have dimensions (size, size) + :param transform: if loading in a previously generated matrix, will + use that through this transformation, as opposed to creating a new + one + :param dtype: type to cast the rotation matrix to + + """ + if not empty: + # TODO: this is deterministic; we should just serialize the size + transform = torch.Tensor(deterministic_hadamard_matrix(size=size)) + else: + transform = torch.empty((size, size)) + + super().__init__(transform=transform, dtype=dtype, device=device) + + def apply( + self, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + return apply_matrix_transform( + transform=self.transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) + + def inverse_apply( + self, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + """ + Apply the inverse operation of `apply` + + :param transform: hadamard tensor + :param input_tensor: tensor to which the transform matrix is applied + :param transpose: whether or not the transform matrix is transposed before + being applied. + :param first: if the transform matrix will be the first or second matrix to be + multiplied + """ + transpose = not transpose + # need to normalize before sending back + return ( + apply_matrix_transform( + transform=self.transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) + / self.transform.shape[0] + ) diff --git a/src/compressed_tensors/transforms/hadamard_utils.py b/src/compressed_tensors/transforms/hadamard_utils.py new file mode 100644 index 00000000..1f042941 --- /dev/null +++ b/src/compressed_tensors/transforms/hadamard_utils.py @@ -0,0 +1,165 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +import numpy +import torch + + +__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"] + +# adapted from: +# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py +def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: + """ + Construct an Hadamard matrix. + + Constructs an n-by-n Hadamard matrix, using Sylvester's + construction. `n` must be a power of 2. + + :param size: order of the matrix; must be a power of 2 + + returns a (size, size) hadamard matrix + """ + if size <= 0: + raise ValueError("Cannot construct deterministic hadamard of size <= 0") + + log2 = int(math.log(size, 2)) + if size != 2**log2: + raise ValueError("Cannot construct deterministic hadamard of size != 2^n") + + H = numpy.array([[1]], dtype=int) + + # Sylvester's construction + for i in range(0, log2): + H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H)))) + + return H + + +# adapted from: +# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py + +# TODO: the following library exists for online rotations and should be considered +# in the future: +# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master + + +def random_hadamard_matrix(size: int) -> torch.Tensor: + """ + Produces a randomly generated Hadamard matrix. + See https://cornell-relaxml.github.io/quip-sharp/ , + Section "Randomized Hadamard Transformation" + + :param size: The dimension of the matrix. Matrix generated will have dimensions + (size, size) + + """ + # TODO: potentially update to add "seed" as an arugment, to allow + # the matrix generated to be reproducible + + # Benefits: support other shapes / non powers of 2, support randomization + Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + Q = Q * 2 - 1 + Q = torch.diag(Q) + return _matmul_hadU(Q) + + +def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]: + # NOTE: we can easily extend the list of supported shapes/sizes + # by adding to these methods + hadK, K = None, None + if n % 20 == 0: + assert _is_pow2(n // 20) + K = 20 + hadK = _get_had20().T if transpose else _get_had20() + elif n % 12 == 0: + assert _is_pow2(n // 12) + K = 12 + hadK = _get_had12().T if transpose else _get_had12() + else: + assert _is_pow2(n) + K = 1 + + return hadK, K + + +def _matmul_hadU(X, transpose=False) -> torch.Tensor: + n = X.shape[-1] + # Check if we have the determined hadamard matrix + hadK, K = _get_hadK(n, transpose) + # Reshape diag matrix with randomized -1/+1 + input = X.clone().view(-1, n, 1) + output = input.clone() + + # for cases when hadK is not predetermined, determine hadamard matrix + while input.shape[1] > K: + input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) + output = output.view(input.shape) + output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] + output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] + output = output.view(input.shape[0], input.shape[1], -1) + (input, output) = (output, input) + del output + + # K == 1 when hadK is None; this happens when the size dim (n) + # is not comaptible with any of the maintained hadamard matrices + + if K > 1: + # Do not explicitly repeat - OOM + # input = torch.bmm( + # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) + # Use bcast instead + + # for cases when hadK is pre-determined + input = hadK.view(1, K, K).to(input) @ input + + # normalize + return input.view(X.shape) / torch.tensor(n).sqrt() + + +def _is_pow2(n: int) -> bool: + return (n & (n - 1) == 0) and (n > 0) + + +def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray: + had_unpacked = numpy.unpackbits(packed_bits) + had_unpacked = [1 if x == 1 else -1 for x in had_unpacked] + had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size)) + return had_unpacked + + +# http://www.neilsloane.com/hadamard/index.html +def _get_had12() -> torch.Tensor: + # fmt: off + had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218, + 62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8) + # fmt: on + # TODO: just unpack during apply + had_12_unpacked = _reshape_bits(had_12, original_size=12) + return torch.tensor(had_12_unpacked) + + +def _get_had20() -> torch.Tensor: + # fmt: off + had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252, + 216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243, + 97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43, + 205, 225, 94, 107, 10, 243], dtype=numpy.uint8) + # fmt: on + # TODO: just unpack during apply + had_20_unpacked = _reshape_bits(had_20, original_size=20) + return torch.tensor(had_20_unpacked) diff --git a/src/compressed_tensors/transforms/matrix_multiply.py b/src/compressed_tensors/transforms/matrix_multiply.py new file mode 100644 index 00000000..a06d61f2 --- /dev/null +++ b/src/compressed_tensors/transforms/matrix_multiply.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transforms import Transforms +from compressed_tensors.transforms.utils import apply_matrix_transform + + +# TODO: fix loading +@Transforms.register("matrix-mul") +class MatrixMultiply(Transforms): + def apply( + self, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + return apply_matrix_transform( + transform=self.transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) + + def inverse_apply( + self, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + """ + Apply the inverse operation of `apply` + + :param transform: matrix tensor + :param input_tensor: tensor to which the transform matrix is applied + :param transpose: whether or not the transform matrix is transposed before + being applied. + :param first: if the transform matrix will be the first or second matrix to be + multiplied + """ + + # Note: not implemented for lower precision than float32 + return apply_matrix_transform( + transform=torch.linalg.inv(self.transform), + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) diff --git a/src/compressed_tensors/transforms/random_hadamard.py b/src/compressed_tensors/transforms/random_hadamard.py new file mode 100644 index 00000000..162269c5 --- /dev/null +++ b/src/compressed_tensors/transforms/random_hadamard.py @@ -0,0 +1,99 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from compressed_tensors.transforms import Transforms +from compressed_tensors.transforms.hadamard_utils import random_hadamard_matrix +from compressed_tensors.transforms.utils import apply_matrix_transform + + +@Transforms.register("random-hadamard") +class RandomHadamard(Transforms): + def __init__( + self, + size: int, + empty: Optional[bool] = False, + device: Optional[Union[str, torch.device]] = "cuda", + dtype: Optional[torch.dtype] = torch.bfloat16, + ): + """ + Produces a randomly generated matrix with dims (size, size), with values + between -1 and 1, and the property HH.T == I i.e the transformation + matrix when multiplied by its transpose is the identity. + All rows and columns are orthonormal. The matrix returned + is normalized and has the form (1/sqrt(size)) * M where all elements + of M are -1 or +1. + + :param size: size of the matrix, if generating a new Hadamard matrix. + The generated matrix will have dimensions (size, size) + :param transform: if loading in a previously generated matrix, will + use that through this transformation, as opposed to creating a new + one + :param dtype: type to cast the rotation matrix to + + TODO: We can likely make the serialization of this more efficient: + The generation of this matrix starts with generating a random + matrix with dims (size, size). We could potentially just store + a randomly generated seed and the size, as opposed to storing the entire + matrix, to reproduce an identical matrix during runtime. That way, + we will not have to store the entire matrix. Will need to consider + accuracy implications. + """ + + if not empty: + transform = random_hadamard_matrix(size=size) + else: + transform = torch.empty((size, size)) + + super().__init__(transform=transform, device=device, dtype=dtype) + + def apply( + self, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + return apply_matrix_transform( + transform=self.transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) + + def inverse_apply( + self, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, + ) -> torch.Tensor: + """ + Apply the inverse operation of `apply` + + :param transform: hadamard tensor + :param input_tensor: tensor to which the transform matrix is applied + :param transpose: whether or not the transform matrix is transposed before + being applied. + :param first: if the transform matrix will be the first or second matrix to be + multiplied + """ + + transpose = not transpose + return apply_matrix_transform( + transform=self.transform, + input_tensor=input_tensor, + transpose=transpose, + first=first, + ) diff --git a/src/compressed_tensors/transforms/utils.py b/src/compressed_tensors/transforms/utils.py new file mode 100644 index 00000000..997c91f1 --- /dev/null +++ b/src/compressed_tensors/transforms/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +__all__ = ["apply_matrix_transform"] + + +def apply_matrix_transform( + transform: torch.Tensor, + input_tensor: torch.Tensor, + transpose: bool = False, + first: bool = True, +) -> torch.Tensor: + """ + Apply a matrix-type transform + + :param transform: transform tensor + :param input_tensor: tensor to which the transform matrix is applied + :param transpose: whether or not the transform matrix is transposed before + being applied. + :param first: if the transform matrix will be the first or second matrix to be + multiplied + + returns a transformed input_tensor + """ + + if transpose: + return ( + torch.matmul(transform.T, input_tensor) + if first + else torch.matmul(input_tensor, transform.T) + ) + + return ( + torch.matmul(transform, input_tensor) + if first + else torch.matmul(input_tensor, transform) + ) diff --git a/tests/test_transforms/test_hadamards.py b/tests/test_transforms/test_hadamards.py new file mode 100644 index 00000000..25c8a0fe --- /dev/null +++ b/tests/test_transforms/test_hadamards.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy +import pytest +import torch +from compressed_tensors.transforms.hadamard_utils import ( + _get_had12, + _get_had20, + deterministic_hadamard_matrix, + random_hadamard_matrix, +) + + +@pytest.mark.parametrize( + "had_func", + [ + _get_had12, + _get_had20, + ], +) +def test_packed_hadamard_compliant(had_func): + had_matrix = had_func() + size = had_matrix.shape[0] + # HH.T == nI + val_1 = had_matrix @ had_matrix.T + assert torch.equal(val_1 / size, torch.eye(size)) + + +@pytest.mark.parametrize( + "size", + [4096, 2048], +) +def test_random_hadamard_matrix_compliant(size): + had_matrix = random_hadamard_matrix(size) + val_1 = torch.round(had_matrix @ had_matrix.T) + assert torch.equal(val_1, torch.eye(size)) + + +@pytest.mark.parametrize( + "size", + [1024], +) +def test_deterministic_hadamard_compliant(size): + had_matrix = deterministic_hadamard_matrix(size) + # HH.T == nI + val_1 = had_matrix @ had_matrix.T + assert numpy.array_equal(val_1 / size, numpy.eye(size)) diff --git a/tests/test_transforms/test_transforms.py b/tests/test_transforms/test_transforms.py new file mode 100644 index 00000000..adf1941e --- /dev/null +++ b/tests/test_transforms/test_transforms.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Union + +import pytest +import torch +from compressed_tensors.transforms import ( + Hadamard, + MatrixMultiply, + RandomHadamard, + Transforms, +) +from compressed_tensors.transforms.hadamard_utils import random_hadamard_matrix + + +@pytest.mark.parametrize( + "size,dtype", + [ + [1024, torch.float32], + [2048, torch.float16], + [2048, torch.bfloat16], + [4096, torch.float32], + [5120, torch.float16], + [8192, torch.bfloat16], + ], +) +def test_random_hadamard_transform(size: int, dtype: torch.dtype): + hadamard_transform = Transforms.load_from_registry( + "random-hadamard", size=size, dtype=dtype, device="cpu" + ) + # check initialize + assert hadamard_transform is not None + + val_1 = torch.round(hadamard_transform.transform @ hadamard_transform.transform.T) + + # output will be normalized, multiply by sqrt(size) to ensure form + normalized = math.sqrt(size) * hadamard_transform.transform + # all values should be -1 or +1 + assert torch.all(torch.isin(normalized, torch.Tensor([-1, +1]))) + # check creation; HH.T == I + assert torch.equal(val_1, torch.eye(size)) + + # check apply + x = torch.rand((size, size), dtype=dtype) + transformed_value = hadamard_transform.apply(input_tensor=x) + # TODO: check to make sure the matrix was applied correctly? + assert transformed_value.shape == (size, size) + + +@pytest.mark.parametrize( + "size,dtype", + [ + [1024, torch.bfloat16], + [2048, torch.float16], + ], +) +def test_deterministic_hadamard_transform(size: int, dtype: torch.dtype): + hadamard_transform = Transforms.load_from_registry( + "hadamard", size=size, dtype=dtype, device="cpu" + ) + + # check initialize + assert hadamard_transform is not None + assert torch.all(torch.isin(hadamard_transform.transform, torch.Tensor([-1, +1]))) + + val_1 = hadamard_transform.transform @ hadamard_transform.transform.T + # check creation; HH.T == nI + assert torch.equal(val_1 / size, torch.eye(size)) + + # check apply + x = torch.rand((size, size), dtype=dtype) + transformed_value = hadamard_transform.apply(input_tensor=x) + # TODO: check to make sure the matrix was applied correctly? + assert transformed_value.shape == (size, size) + + +@pytest.mark.parametrize( + "size,dtype", + [ + [1024, torch.float32], + [2048, torch.float16], + [4096, torch.bfloat16], + ], +) +def test_multiplier_transform(size: int, dtype: torch.dtype): + multiplier = torch.eye((size)) + multiplier_transform = Transforms.load_from_registry( + "matrix-mul", transform=multiplier, device="cpu", dtype=dtype + ) + assert multiplier_transform is not None + assert torch.equal(multiplier_transform.transform, multiplier) + + x = torch.rand((size, size), dtype=dtype) + transformed_output = multiplier_transform.apply(x) + assert torch.equal(transformed_output, x)