Skip to content

Commit 1adfa30

Browse files
committed
transform registry support
1 parent 95ba68d commit 1adfa30

File tree

9 files changed

+709
-0
lines changed

9 files changed

+709
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .base import Transforms
16+
from .hadamard import Hadamard
17+
from .matrix_multiply import MatrixMultiply
18+
from .random_hadamard import RandomHadamard
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Optional, Union
16+
17+
import torch
18+
from compressed_tensors.registry.registry import RegistryMixin
19+
from compressed_tensors.transforms.utils import apply_matrix_transform
20+
21+
22+
__all__ = ["Transforms"]
23+
24+
25+
# TODO: We don't need to save all the __call__ args for serialization or even have
26+
# them defined by a recipe. Some of them, such as if the transformation should be the
27+
# first or second matirx in torch.matmul depending on dimensions, can be inferred
28+
# by the layer time likely.
29+
30+
MATIRX_TRANSFORMS = ["matrix-mul", "hadamard", "random-hadamard"]
31+
32+
33+
class Transforms(RegistryMixin):
34+
def __new__(
35+
cls,
36+
transform: torch.Tensor,
37+
device: Optional[Union[str, torch.device]] = "cuda",
38+
dtype: Optional[torch.dtype] = torch.bfloat16,
39+
*args,
40+
**kwargs,
41+
):
42+
"""
43+
Base class for setting up transforms. The registry creates transforms
44+
as parameters which can be attached to modules.
45+
46+
import torch
47+
48+
size = 1024
49+
dtype = torch.bfloat16
50+
module = torch.nn.Linear(size, size)
51+
52+
hadamard_transform = Transforms.load_from_registry(
53+
"random_hadamard", size=size, dtype=dtype
54+
)
55+
hadamard_apply = Transforms.fetch_apply("random_hadamard")
56+
module.weight_transform = hadamard_transform
57+
58+
transformed_output = hadamard_apply(input_tensor=module.weight,
59+
transform=moduel.weight_transform)
60+
61+
hadamard_inverse = Transforms.fetch_inverse_apply("random_hadamard")
62+
original_weight = hadamard_inverse(input_tensor=transformed_output,
63+
transform=model.weight_trainsform,
64+
transpose=True)
65+
66+
:param transform: transform (e.g. torch.Tensor, scalar) to be applied
67+
"""
68+
return torch.nn.Parameter(transform.to(device).to(dtype), requires_grad=False)
69+
70+
@classmethod
71+
def fetch_apply(cls, name: str):
72+
if name in MATIRX_TRANSFORMS:
73+
return apply_matrix_transform
74+
raise NotImplementedError("Only matrix transforms are supported")
75+
76+
@classmethod
77+
def fetch_inverse_apply(cls, name: str):
78+
return cls.get_value_from_registry(name=name).inverse_apply
79+
80+
@staticmethod
81+
def inverse_apply(
82+
transform: torch.Tensor, input_tensor: torch.Tensor, *args, **kwargs
83+
) -> torch.Tensor:
84+
"""
85+
Apply the inverse operation applied by the apply method
86+
"""
87+
raise NotImplementedError()
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Union
16+
17+
import torch
18+
from compressed_tensors.transforms import Transforms
19+
from compressed_tensors.transforms.hadamard_utils import deterministic_hadamard_matrix
20+
from compressed_tensors.transforms.utils import apply_matrix_transform
21+
22+
23+
@Transforms.register("hadamard")
24+
class Hadamard(Transforms):
25+
def __new__(
26+
cls,
27+
size: int,
28+
empty: Optional[bool] = False,
29+
device: Optional[Union[str, torch.device]] = "cuda",
30+
dtype: Optional[torch.dtype] = torch.bfloat16,
31+
):
32+
"""
33+
Produces a hadamard matrix with dims (size, size), with values
34+
-1 and 1, and the property HH.T == nI i.e the transformation
35+
matrix when multiplied by its transpose is a multiple of the identity.
36+
All rows and columns are orthonormal. The matrix returned
37+
is not normalized and will be deterministic.
38+
39+
:param size: size of the matrix, if generating a new Hadamard matrix.
40+
The generated matrix will have dimensions (size, size)
41+
:param transform: if loading in a previously generated matrix, will
42+
use that through this transformation, as opposed to creating a new
43+
one
44+
:param dtype: type to cast the rotation matrix to
45+
46+
"""
47+
if not empty:
48+
# TODO: this is deterministic; we should just serialize the size
49+
transform = torch.Tensor(deterministic_hadamard_matrix(size=size))
50+
else:
51+
transform = torch.empty((size, size))
52+
53+
return super().__new__(cls, transform=transform, device=device, dtype=dtype)
54+
55+
@staticmethod
56+
def inverse_apply(
57+
transform: torch.Tensor,
58+
input_tensor: torch.Tensor,
59+
transpose: bool = False,
60+
first: bool = True,
61+
) -> torch.Tensor:
62+
"""
63+
Apply the inverse operation of `apply`
64+
65+
:param transform: hadamard tensor
66+
:param input_tensor: tensor to which the transform matrix is applied
67+
:param transpose: whether or not the transform matrix is transposed before
68+
being applied.
69+
:param first: if the transform matrix will be the first or second matrix to be
70+
multiplied
71+
"""
72+
transpose = not transpose
73+
# need to normalize before sending back
74+
return (
75+
apply_matrix_transform(
76+
transform=transform,
77+
input_tensor=input_tensor,
78+
transpose=transpose,
79+
first=first,
80+
)
81+
/ transform.shape[0]
82+
)
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
17+
import numpy
18+
import torch
19+
20+
21+
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
22+
23+
# adapted from:
24+
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
25+
def deterministic_hadamard_matrix(size: int):
26+
"""
27+
Construct an Hadamard matrix.
28+
29+
Constructs an n-by-n Hadamard matrix, using Sylvester's
30+
construction. `n` must be a power of 2.
31+
32+
:param size: order of the matrix; must be a power of 2
33+
34+
returns a (size, size) hadamard matrix
35+
"""
36+
37+
dtype = int
38+
if size < 1:
39+
lg2 = 0
40+
else:
41+
lg2 = int(math.log(size, 2))
42+
if 2**lg2 != size:
43+
raise ValueError("size must be an positive integer and a power of 2")
44+
45+
H = numpy.array([[1]], dtype=dtype)
46+
47+
# Sylvester's construction
48+
for i in range(0, lg2):
49+
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))
50+
51+
return H
52+
53+
54+
# adapted from:
55+
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
56+
57+
# TODO: the following library exists for online rotations and should be considered
58+
# in the future:
59+
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
60+
61+
62+
def random_hadamard_matrix(size: int) -> torch.Tensor:
63+
"""
64+
Produces a randomly generated Hadamard matrix.
65+
See https://cornell-relaxml.github.io/quip-sharp/ ,
66+
Section "Randomized Hadamard Transformation"
67+
68+
:param size: The dimension of the matrix. Matrix generated will have dimensions
69+
(size, size)
70+
71+
"""
72+
# TODO: potentially update to add "seed" as an arugment, to allow
73+
# the matrix generated to be reproducible
74+
75+
# Benefits: support other shapes / non powers of 2, support randomization
76+
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
77+
Q = Q * 2 - 1
78+
Q = torch.diag(Q)
79+
return _matmul_hadU(Q)
80+
81+
82+
def _get_hadK(n, transpose=False):
83+
# NOTE: we can easily extend the list of supported shapes/sizes
84+
# by adding to these methods
85+
hadK, K = None, None
86+
if n % 20 == 0:
87+
assert _is_pow2(n // 20)
88+
K = 20
89+
hadK = _get_had20().T if transpose else _get_had20()
90+
elif n % 12 == 0:
91+
assert _is_pow2(n // 12)
92+
K = 12
93+
hadK = _get_had12().T if transpose else _get_had12()
94+
else:
95+
assert _is_pow2(n)
96+
K = 1
97+
98+
return hadK, K
99+
100+
101+
def _matmul_hadU(X, transpose=False):
102+
n = X.shape[-1]
103+
# Check if we have the determined hadamard matrix
104+
hadK, K = _get_hadK(n, transpose)
105+
# Reshape diag matrix with randomized -1/+1
106+
input = X.clone().view(-1, n, 1)
107+
output = input.clone()
108+
109+
# for cases when hadK is not predetermined, determine hadamard matrix
110+
while input.shape[1] > K:
111+
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
112+
output = output.view(input.shape)
113+
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
114+
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
115+
output = output.view(input.shape[0], input.shape[1], -1)
116+
(input, output) = (output, input)
117+
del output
118+
119+
# K == 1 when hadK is None; this happens when the size dim (n)
120+
# is not comaptible with any of the maintained hadamard matrices
121+
122+
if K > 1:
123+
# Do not explicitly repeat - OOM
124+
# input = torch.bmm(
125+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
126+
# Use bcast instead
127+
128+
# for cases when hadK is pre-determined
129+
input = hadK.view(1, K, K).to(input) @ input
130+
131+
# normalize
132+
return input.view(X.shape) / torch.tensor(n).sqrt()
133+
134+
135+
def _is_pow2(n):
136+
return (n & (n - 1) == 0) and (n > 0)
137+
138+
139+
def _reshape_bits(packed_bits, original_size):
140+
had_unpacked = numpy.unpackbits(packed_bits)
141+
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
142+
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
143+
return had_unpacked
144+
145+
146+
# http://www.neilsloane.com/hadamard/index.html
147+
def _get_had12():
148+
# fmt: off
149+
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
150+
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
151+
# fmt: on
152+
# TODO: just unpack during apply
153+
had_12_unpacked = _reshape_bits(had_12, original_size=12)
154+
return torch.FloatTensor(had_12_unpacked)
155+
156+
157+
def _get_had20():
158+
# fmt: off
159+
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
160+
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
161+
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
162+
205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
163+
# fmt: on
164+
# TODO: just unpack during apply
165+
had_20_unpacked = _reshape_bits(had_20, original_size=20)
166+
return torch.FloatTensor(had_20_unpacked)

0 commit comments

Comments
 (0)