-
Notifications
You must be signed in to change notification settings - Fork 13
[Transforms] Transform Registry Support #274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1adfa30
ab6101e
1e1760b
749420b
7ecb1b0
2988aba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .base import Transforms | ||
from .hadamard import Hadamard | ||
from .matrix_multiply import MatrixMultiply | ||
from .random_hadamard import RandomHadamard |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Optional, Union | ||
|
||
import torch | ||
from compressed_tensors.registry.registry import RegistryMixin | ||
from compressed_tensors.transforms.utils import apply_matrix_transform | ||
from compressed_tensors.utils import register_offload_parameter, update_parameter_data | ||
|
||
|
||
__all__ = ["Transforms"] | ||
|
||
|
||
# TODO: We don't need to save all the __call__ args for serialization or even have | ||
# them defined by a recipe. Some of them, such as if the transformation should be the | ||
# first or second matirx in torch.matmul depending on dimensions, can be inferred | ||
# by the layer time likely. | ||
|
||
|
||
class Transforms(RegistryMixin): | ||
def __init__( | ||
self, | ||
transform: torch.Tensor, | ||
learnable: Optional[bool] = True, | ||
device: Optional[Union[str, torch.device]] = "cuda", | ||
dtype: Optional[torch.dtype] = torch.bfloat16, | ||
): | ||
self.learnable = learnable | ||
""" | ||
Base class for setting up transforms. The registry creates transforms | ||
as parameters which can be attached to modules. | ||
|
||
import torch | ||
|
||
size = 1024 | ||
dtype = torch.bfloat16 | ||
module = torch.nn.Linear(size, size) | ||
name = "weight_transform" | ||
|
||
hadamard_transform = Transforms.load_from_registry( | ||
"random_hadamard", size=size, dtype=dtype | ||
) | ||
|
||
hadamard_transform.register_to_module(name, module) | ||
module.transform_data = {name: {"call_args": dict, "class": hadamard_transform}} | ||
|
||
transformed_output = hadamard_transform.apply(input_tensor=module.weight) | ||
original_weight = hadamard_transform.inverse_apply( | ||
input_tensor=transformed_output) | ||
|
||
:param transform: transform (e.g. torch.Tensor, scalar) to be applied | ||
""" | ||
if self.learnable: | ||
self.transform = torch.nn.Parameter(transform.to(dtype).to(device)) | ||
else: | ||
self.transform = torch.nn.Buffer(transform.to(dtype).to(device)) | ||
|
||
# register to class for easy offloading, serialization, deserialization | ||
def register_to_module(self, name: str, module: torch.nn.Module): | ||
if self.learnable: | ||
register_offload_parameter(module, name, self.transform) | ||
else: | ||
# TODO: have to verify serialization/offloading | ||
module.register_buffer(name, self.transform) | ||
|
||
def update_transform( | ||
self, | ||
data: torch.Tensor, | ||
module: Optional[torch.nn.Module] = None, | ||
name: Optional[str] = None, | ||
): | ||
if module is None: | ||
self.transform.data.copy_(data) | ||
else: | ||
# If updating the module parameter data, assumes this is also the transform | ||
# data | ||
if name is None: | ||
raise ValueError("Name and module are required to update parma data") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. param data? |
||
update_parameter_data(module, data, name) | ||
|
||
def apply(self, input_tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor: | ||
""" | ||
Apply the transform to the module | ||
""" | ||
raise NotImplementedError() | ||
|
||
# TODO: potentially split into its own transform using the same shared set-up | ||
def inverse_apply( | ||
self, input_tensor: torch.Tensor, *args, **kwargs | ||
) -> torch.Tensor: | ||
""" | ||
Apply the inverse operation applied by the apply method | ||
""" | ||
raise NotImplementedError() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional, Union | ||
|
||
import torch | ||
from compressed_tensors.transforms import Transforms | ||
from compressed_tensors.transforms.hadamard_utils import deterministic_hadamard_matrix | ||
from compressed_tensors.transforms.utils import apply_matrix_transform | ||
|
||
|
||
@Transforms.register("hadamard") | ||
class Hadamard(Transforms): | ||
def __init__( | ||
self, | ||
size: int, | ||
empty: Optional[bool] = False, | ||
device: Optional[Union[str, torch.device]] = "cuda", | ||
dtype: Optional[torch.dtype] = torch.bfloat16, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Produces a hadamard matrix with dims (size, size), with values | ||
-1 and 1, and the property HH.T == nI i.e the transformation | ||
matrix when multiplied by its transpose is a multiple of the identity. | ||
All rows and columns are orthonormal. The matrix returned | ||
is not normalized and will be deterministic. | ||
|
||
:param size: size of the matrix, if generating a new Hadamard matrix. | ||
The generated matrix will have dimensions (size, size) | ||
:param transform: if loading in a previously generated matrix, will | ||
use that through this transformation, as opposed to creating a new | ||
one | ||
:param dtype: type to cast the rotation matrix to | ||
|
||
""" | ||
if not empty: | ||
# TODO: this is deterministic; we should just serialize the size | ||
transform = torch.Tensor(deterministic_hadamard_matrix(size=size)) | ||
else: | ||
transform = torch.empty((size, size)) | ||
|
||
super().__init__(transform=transform, dtype=dtype, device=device) | ||
|
||
def apply( | ||
self, | ||
input_tensor: torch.Tensor, | ||
transpose: bool = False, | ||
first: bool = True, | ||
) -> torch.Tensor: | ||
return apply_matrix_transform( | ||
transform=self.transform, | ||
input_tensor=input_tensor, | ||
transpose=transpose, | ||
first=first, | ||
) | ||
|
||
def inverse_apply( | ||
self, | ||
input_tensor: torch.Tensor, | ||
transpose: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Afaict transpose is only used when apply an inverse hadamard, and therefore should not be controllable by the user to avoid footgunning |
||
first: bool = True, | ||
) -> torch.Tensor: | ||
""" | ||
Apply the inverse operation of `apply` | ||
|
||
:param transform: hadamard tensor | ||
:param input_tensor: tensor to which the transform matrix is applied | ||
:param transpose: whether or not the transform matrix is transposed before | ||
being applied. | ||
:param first: if the transform matrix will be the first or second matrix to be | ||
multiplied | ||
""" | ||
transpose = not transpose | ||
# need to normalize before sending back | ||
return ( | ||
apply_matrix_transform( | ||
transform=self.transform, | ||
input_tensor=input_tensor, | ||
transpose=transpose, | ||
first=first, | ||
) | ||
/ self.transform.shape[0] | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import math | ||
from typing import Tuple | ||
|
||
import numpy | ||
import torch | ||
|
||
|
||
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"] | ||
|
||
# adapted from: | ||
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py | ||
def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: | ||
""" | ||
Construct an Hadamard matrix. | ||
|
||
Constructs an n-by-n Hadamard matrix, using Sylvester's | ||
construction. `n` must be a power of 2. | ||
|
||
:param size: order of the matrix; must be a power of 2 | ||
|
||
returns a (size, size) hadamard matrix | ||
""" | ||
if size <= 0: | ||
raise ValueError("Cannot construct deterministic hadamard of size <= 0") | ||
|
||
log2 = int(math.log(size, 2)) | ||
if size != 2**log2: | ||
raise ValueError("Cannot construct deterministic hadamard of size != 2^n") | ||
|
||
H = numpy.array([[1]], dtype=int) | ||
|
||
# Sylvester's construction | ||
for i in range(0, log2): | ||
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H)))) | ||
|
||
return H | ||
|
||
|
||
# adapted from: | ||
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py | ||
|
||
# TODO: the following library exists for online rotations and should be considered | ||
# in the future: | ||
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master | ||
|
||
|
||
def random_hadamard_matrix(size: int) -> torch.Tensor: | ||
""" | ||
Produces a randomly generated Hadamard matrix. | ||
See https://cornell-relaxml.github.io/quip-sharp/ , | ||
Section "Randomized Hadamard Transformation" | ||
|
||
:param size: The dimension of the matrix. Matrix generated will have dimensions | ||
(size, size) | ||
|
||
""" | ||
# TODO: potentially update to add "seed" as an arugment, to allow | ||
# the matrix generated to be reproducible | ||
|
||
# Benefits: support other shapes / non powers of 2, support randomization | ||
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) | ||
Q = Q * 2 - 1 | ||
Q = torch.diag(Q) | ||
return _matmul_hadU(Q) | ||
|
||
|
||
def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]: | ||
# NOTE: we can easily extend the list of supported shapes/sizes | ||
# by adding to these methods | ||
hadK, K = None, None | ||
if n % 20 == 0: | ||
assert _is_pow2(n // 20) | ||
K = 20 | ||
hadK = _get_had20().T if transpose else _get_had20() | ||
elif n % 12 == 0: | ||
assert _is_pow2(n // 12) | ||
K = 12 | ||
hadK = _get_had12().T if transpose else _get_had12() | ||
else: | ||
assert _is_pow2(n) | ||
K = 1 | ||
|
||
return hadK, K | ||
|
||
|
||
def _matmul_hadU(X, transpose=False) -> torch.Tensor: | ||
n = X.shape[-1] | ||
# Check if we have the determined hadamard matrix | ||
hadK, K = _get_hadK(n, transpose) | ||
# Reshape diag matrix with randomized -1/+1 | ||
input = X.clone().view(-1, n, 1) | ||
output = input.clone() | ||
|
||
# for cases when hadK is not predetermined, determine hadamard matrix | ||
while input.shape[1] > K: | ||
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) | ||
output = output.view(input.shape) | ||
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] | ||
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] | ||
output = output.view(input.shape[0], input.shape[1], -1) | ||
(input, output) = (output, input) | ||
del output | ||
|
||
# K == 1 when hadK is None; this happens when the size dim (n) | ||
# is not comaptible with any of the maintained hadamard matrices | ||
|
||
if K > 1: | ||
# Do not explicitly repeat - OOM | ||
# input = torch.bmm( | ||
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) | ||
# Use bcast instead | ||
|
||
# for cases when hadK is pre-determined | ||
input = hadK.view(1, K, K).to(input) @ input | ||
|
||
# normalize | ||
return input.view(X.shape) / torch.tensor(n).sqrt() | ||
|
||
|
||
def _is_pow2(n: int) -> bool: | ||
return (n & (n - 1) == 0) and (n > 0) | ||
|
||
|
||
def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray: | ||
had_unpacked = numpy.unpackbits(packed_bits) | ||
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked] | ||
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size)) | ||
return had_unpacked | ||
|
||
|
||
# http://www.neilsloane.com/hadamard/index.html | ||
def _get_had12() -> torch.Tensor: | ||
# fmt: off | ||
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218, | ||
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8) | ||
# fmt: on | ||
# TODO: just unpack during apply | ||
had_12_unpacked = _reshape_bits(had_12, original_size=12) | ||
return torch.tensor(had_12_unpacked) | ||
|
||
|
||
def _get_had20() -> torch.Tensor: | ||
# fmt: off | ||
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252, | ||
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243, | ||
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43, | ||
205, 225, 94, 107, 10, 243], dtype=numpy.uint8) | ||
# fmt: on | ||
# TODO: just unpack during apply | ||
had_20_unpacked = _reshape_bits(had_20, original_size=20) | ||
return torch.tensor(had_20_unpacked) |
Uh oh!
There was an error while loading. Please reload this page.