From d8a10ecad71d9fe577ac0e086eda9398a626512d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 30 May 2025 13:40:52 -0400 Subject: [PATCH 01/27] add utilities Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 98 ++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 7ee06444..5b60cc9d 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -28,15 +28,17 @@ import contextlib import warnings from functools import wraps -from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union import torch try: + from accelerate import dispatch_model from accelerate.hooks import ( AlignDevicesHook, add_hook_to_module, + named_module_tensors, remove_hook_from_module, ) from accelerate.utils import ( @@ -54,6 +56,8 @@ OffloadedWeightsLoader = None PrefixedDataset = None set_module_tensor_to_device = None + named_module_tensors = None + dispatch_model = None __all__ = [ @@ -70,6 +74,8 @@ "disable_offload", "align_modules", "align_module_device", + "register_offload_module", + "force_cpu_offload", ] @@ -77,6 +83,11 @@ def check_accelerate(fallback: Any): def decorator(func: Callable[[Any], Any]): if not _has_accelerate: + if fallback == "error": + raise ValueError( + "Please install `accelerate` in order to use this function" + ) + @wraps(func) def fallback_fn(*args, **kwargs): return fallback @@ -346,6 +357,7 @@ def delete_from_weights_map( ) +@check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager def disable_offload(module: torch.nn.Module): """ @@ -362,6 +374,7 @@ def disable_offload(module: torch.nn.Module): yield +@check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager def align_modules( modules: Union[torch.nn.Module, Iterable[torch.nn.Module]], @@ -383,6 +396,89 @@ def align_modules( yield +@check_accelerate(fallback=None) +def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module): + """ + Register a submodule with offloading if the parent module is offloaded + + :param base: module to attach submodule to + :param name: name of submodule + :param module: submodule to attach + """ + + if has_offloaded_params(base): + hook: AlignDevicesHook = base._hf_hook + assert hook.offload + assert hook.weights_map is not None + assert hook.tied_params_map is not None + + # offloading kwargs for submodule + place_submodules = False + offload_buffers = True + + # copy device offloading arguments from parent + current_device = next(base.parameters()).device # assume base has parameters + offload_device = get_offloaded_device(base) + + # offload parameters to weights map + for param_name, param in named_module_tensors( + module, include_buffers=offload_buffers, recurse=place_submodules + ): + offloaded = param.to(offload_device) + hook.tied_params_map[offloaded.data_ptr()] = {} # (1) + offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded) + + # if the parent places submodules, offload here + if hook.place_submodules: + set_module_tensor_to_device(module, param_name, current_device) + + # if the parent does not place submodules, then add a hook + # parameters are offloaded by `add_hook_to_module` + if not hook.place_submodules: + weights_map = PrefixedDataset( + hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}." + ) + + submodule_hook = AlignDevicesHook( + execution_device=hook.execution_device, + offload=hook.offload, + io_same_device=False, + weights_map=weights_map, + offload_buffers=offload_buffers, + place_submodules=place_submodules, + skip_keys=None, + tied_params_map=hook.tied_params_map, + ) + add_hook_to_module(module, submodule_hook) + + base.register_module(name, module) + + # (1): Since we cannot know which pointers are shared when we add parameters in an + # online way, assume that all pointers are shared. This comes at no runtime cost + + +@check_accelerate(fallback="error") +def force_cpu_offload(module: torch.nn.Module, execution_device: torch.device): + device_map = {} + + def dfs(name: List[str], module: torch.nn.Module): + if next(module.parameters(recurse=False), None) is not None: + device_map[".".join(name)] = "cpu" + return + + else: + for submodule_name, submodule in module.named_children(): + name.append(submodule_name) + dfs(name, submodule) + name.pop() + + dfs([], module) + + return dispatch_model( + module, device_map, main_device=execution_device, force_hooks=True + ) + + """ Upstreamed Functions """ From d2af05407b2f10a699d3b5637cea3e6491645fcf Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 30 May 2025 14:30:26 -0400 Subject: [PATCH 02/27] add tests Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 31 ++++++++-- tests/test_utils/test_offload.py | 75 +++++++++++++++++++++++-- 2 files changed, 97 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 5b60cc9d..4241479c 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -38,6 +38,7 @@ from accelerate.hooks import ( AlignDevicesHook, add_hook_to_module, + attach_align_device_hook, named_module_tensors, remove_hook_from_module, ) @@ -58,6 +59,7 @@ set_module_tensor_to_device = None named_module_tensors = None dispatch_model = None + attach_align_device_hook = None __all__ = [ @@ -458,10 +460,31 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M @check_accelerate(fallback="error") -def force_cpu_offload(module: torch.nn.Module, execution_device: torch.device): +def force_cpu_offload( + module: torch.nn.Module, execution_device: torch.device +) -> torch.nn.Module: + """ + Force cpu offloading a module, primarily used for testing + + :param module: module containing parameters to offload + :param execution_device: execution device submodules + :return: module with hooks to perform cpu offloading + """ + # edge case: there is a bug in `dispatch_model` which causes + # the function to only work if the model contains submodules + if next(module.children(), None) is None: + attach_align_device_hook( + module, + execution_device=execution_device, + offload=True, + weights_map=module.state_dict(), + tied_params_map={}, + ) + return module + device_map = {} - def dfs(name: List[str], module: torch.nn.Module): + def collect_device_map(name: List[str], module: torch.nn.Module): if next(module.parameters(recurse=False), None) is not None: device_map[".".join(name)] = "cpu" return @@ -469,10 +492,10 @@ def dfs(name: List[str], module: torch.nn.Module): else: for submodule_name, submodule in module.named_children(): name.append(submodule_name) - dfs(name, submodule) + collect_device_map(name, submodule) name.pop() - dfs([], module) + collect_device_map([], module) return dispatch_model( module, device_map, main_device=execution_device, force_hooks=True diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 4bc8587d..d2cea240 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -18,8 +18,10 @@ align_modules, delete_offload_parameter, disable_hf_hook, + force_cpu_offload, get_execution_device, has_offloaded_params, + register_offload_module, register_offload_parameter, update_offload_parameter, ) @@ -37,9 +39,17 @@ def forward(self, x): return x * self.a + self.b +class ExampleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(1, 2) + + def forward(self, x): + return self.linear(x) + + @requires_accelerate() def test_has_offloaded_params(): - from accelerate.big_modeling import cpu_offload_with_hook from accelerate.hooks import attach_align_device_hook, remove_hook_from_module module = ExampleModule() @@ -48,10 +58,6 @@ def test_has_offloaded_params(): attach_align_device_hook(module, offload=False) assert not has_offloaded_params(module) - remove_hook_from_module(module) - module, _ = cpu_offload_with_hook(module) - assert not has_offloaded_params(module) - remove_hook_from_module(module) attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) assert has_offloaded_params(module) @@ -334,3 +340,62 @@ def test_offload_to_weights_map(): weights_map = PrefixedDataset(OffloadedWeightsLoader({name: old_value}), prefix) offload_to_weights_map(weights_map, name, new_value) assert weights_map[name] == new_value + + +@requires_gpu +@requires_accelerate() +def test_register_offload_module(): + execution_device = torch.device("cuda") + + # no offloading + model = ExampleModel() + child = torch.nn.Linear(2, 3) + register_offload_module(model, "child", child) + register_offload_module(model.linear, "child", child) + assert child in model.children() + assert child in model.linear.children() + + # with offloading + model = ExampleModel() + child = torch.nn.Linear(2, 3) + force_cpu_offload(model, execution_device) + register_offload_module(model, "child", child) + register_offload_module(model.linear, "child", child) + assert child in model.children() + assert child in model.linear.children() + + # can run modules + model(torch.empty(1)) + child(torch.empty(2, device=execution_device)) + + +@requires_gpu +@requires_accelerate() +def test_force_cpu_offload(): + execution_device = torch.device("cuda") + + # single module + module = torch.nn.Linear(1, 2) + module = force_cpu_offload(module, execution_device) + assert has_offloaded_params(module) + assert module._hf_hook.offload + assert module.weight.device == torch.device("meta") + assert "weight" in module._hf_hook.weights_map + assert module._hf_hook.tied_params_map is not None + + # can run + module(torch.empty(1, device=execution_device)) + + # model + model = ExampleModel() + model = force_cpu_offload(model, execution_device) + assert not has_offloaded_params(model) + + assert has_offloaded_params(model.linear) + assert model.linear._hf_hook.offload + assert model.linear.weight.device == torch.device("meta") + assert "weight" in model.linear._hf_hook.weights_map + assert model.linear._hf_hook.tied_params_map is not None + + # can run + model(torch.empty(1, device=execution_device)) From e32d5b5ccb240785203a771d10b522d9df7f1a3f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 30 May 2025 15:39:35 -0400 Subject: [PATCH 03/27] add additional tests Signed-off-by: Kyle Sayers --- tests/test_utils/test_offload.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index d2cea240..a2c357a4 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -344,9 +344,8 @@ def test_offload_to_weights_map(): @requires_gpu @requires_accelerate() -def test_register_offload_module(): - execution_device = torch.device("cuda") - +@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")]) +def test_register_offload_module(exec_device): # no offloading model = ExampleModel() child = torch.nn.Linear(2, 3) @@ -358,7 +357,7 @@ def test_register_offload_module(): # with offloading model = ExampleModel() child = torch.nn.Linear(2, 3) - force_cpu_offload(model, execution_device) + force_cpu_offload(model, exec_device) register_offload_module(model, "child", child) register_offload_module(model.linear, "child", child) assert child in model.children() @@ -366,17 +365,16 @@ def test_register_offload_module(): # can run modules model(torch.empty(1)) - child(torch.empty(2, device=execution_device)) + child(torch.empty(2, device=exec_device)) @requires_gpu @requires_accelerate() -def test_force_cpu_offload(): - execution_device = torch.device("cuda") - +@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")]) +def test_force_cpu_offload(exec_device): # single module module = torch.nn.Linear(1, 2) - module = force_cpu_offload(module, execution_device) + module = force_cpu_offload(module, exec_device) assert has_offloaded_params(module) assert module._hf_hook.offload assert module.weight.device == torch.device("meta") @@ -384,11 +382,11 @@ def test_force_cpu_offload(): assert module._hf_hook.tied_params_map is not None # can run - module(torch.empty(1, device=execution_device)) + module(torch.empty(1, device=exec_device)) # model model = ExampleModel() - model = force_cpu_offload(model, execution_device) + model = force_cpu_offload(model, exec_device) assert not has_offloaded_params(model) assert has_offloaded_params(model.linear) @@ -398,4 +396,4 @@ def test_force_cpu_offload(): assert model.linear._hf_hook.tied_params_map is not None # can run - model(torch.empty(1, device=execution_device)) + model(torch.empty(1, device=exec_device)) From 9d0518b0a702a9f1808eca69fc867574a97115f5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 30 May 2025 15:53:41 -0400 Subject: [PATCH 04/27] add utils and tests Signed-off-by: Kyle Sayers --- .../transform/utils/__init__.py | 13 ++ .../transform/utils/hadamard.py | 165 ++++++++++++++++++ .../transform/utils/utils.py | 85 +++++++++ src/compressed_tensors/utils/helpers.py | 53 ++++++ tests/test_transform/utils/test_hadamards.py | 60 +++++++ tests/test_utils/test_helpers.py | 42 ++++- 6 files changed, 417 insertions(+), 1 deletion(-) create mode 100644 src/compressed_tensors/transform/utils/__init__.py create mode 100644 src/compressed_tensors/transform/utils/hadamard.py create mode 100644 src/compressed_tensors/transform/utils/utils.py create mode 100644 tests/test_transform/utils/test_hadamards.py diff --git a/src/compressed_tensors/transform/utils/__init__.py b/src/compressed_tensors/transform/utils/__init__.py new file mode 100644 index 00000000..0c44f887 --- /dev/null +++ b/src/compressed_tensors/transform/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py new file mode 100644 index 00000000..1f042941 --- /dev/null +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -0,0 +1,165 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +import numpy +import torch + + +__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"] + +# adapted from: +# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py +def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: + """ + Construct an Hadamard matrix. + + Constructs an n-by-n Hadamard matrix, using Sylvester's + construction. `n` must be a power of 2. + + :param size: order of the matrix; must be a power of 2 + + returns a (size, size) hadamard matrix + """ + if size <= 0: + raise ValueError("Cannot construct deterministic hadamard of size <= 0") + + log2 = int(math.log(size, 2)) + if size != 2**log2: + raise ValueError("Cannot construct deterministic hadamard of size != 2^n") + + H = numpy.array([[1]], dtype=int) + + # Sylvester's construction + for i in range(0, log2): + H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H)))) + + return H + + +# adapted from: +# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py + +# TODO: the following library exists for online rotations and should be considered +# in the future: +# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master + + +def random_hadamard_matrix(size: int) -> torch.Tensor: + """ + Produces a randomly generated Hadamard matrix. + See https://cornell-relaxml.github.io/quip-sharp/ , + Section "Randomized Hadamard Transformation" + + :param size: The dimension of the matrix. Matrix generated will have dimensions + (size, size) + + """ + # TODO: potentially update to add "seed" as an arugment, to allow + # the matrix generated to be reproducible + + # Benefits: support other shapes / non powers of 2, support randomization + Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + Q = Q * 2 - 1 + Q = torch.diag(Q) + return _matmul_hadU(Q) + + +def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]: + # NOTE: we can easily extend the list of supported shapes/sizes + # by adding to these methods + hadK, K = None, None + if n % 20 == 0: + assert _is_pow2(n // 20) + K = 20 + hadK = _get_had20().T if transpose else _get_had20() + elif n % 12 == 0: + assert _is_pow2(n // 12) + K = 12 + hadK = _get_had12().T if transpose else _get_had12() + else: + assert _is_pow2(n) + K = 1 + + return hadK, K + + +def _matmul_hadU(X, transpose=False) -> torch.Tensor: + n = X.shape[-1] + # Check if we have the determined hadamard matrix + hadK, K = _get_hadK(n, transpose) + # Reshape diag matrix with randomized -1/+1 + input = X.clone().view(-1, n, 1) + output = input.clone() + + # for cases when hadK is not predetermined, determine hadamard matrix + while input.shape[1] > K: + input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) + output = output.view(input.shape) + output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] + output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] + output = output.view(input.shape[0], input.shape[1], -1) + (input, output) = (output, input) + del output + + # K == 1 when hadK is None; this happens when the size dim (n) + # is not comaptible with any of the maintained hadamard matrices + + if K > 1: + # Do not explicitly repeat - OOM + # input = torch.bmm( + # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) + # Use bcast instead + + # for cases when hadK is pre-determined + input = hadK.view(1, K, K).to(input) @ input + + # normalize + return input.view(X.shape) / torch.tensor(n).sqrt() + + +def _is_pow2(n: int) -> bool: + return (n & (n - 1) == 0) and (n > 0) + + +def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray: + had_unpacked = numpy.unpackbits(packed_bits) + had_unpacked = [1 if x == 1 else -1 for x in had_unpacked] + had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size)) + return had_unpacked + + +# http://www.neilsloane.com/hadamard/index.html +def _get_had12() -> torch.Tensor: + # fmt: off + had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218, + 62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8) + # fmt: on + # TODO: just unpack during apply + had_12_unpacked = _reshape_bits(had_12, original_size=12) + return torch.tensor(had_12_unpacked) + + +def _get_had20() -> torch.Tensor: + # fmt: off + had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252, + 216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243, + 97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43, + 205, 225, 94, 107, 10, 243], dtype=numpy.uint8) + # fmt: on + # TODO: just unpack during apply + had_20_unpacked = _reshape_bits(had_20, original_size=20) + return torch.tensor(had_20_unpacked) diff --git a/src/compressed_tensors/transform/utils/utils.py b/src/compressed_tensors/transform/utils/utils.py new file mode 100644 index 00000000..eebe3663 --- /dev/null +++ b/src/compressed_tensors/transform/utils/utils.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transform import TransformLocation + + +__all__ = ["get_matrix_size", "apply_transform_weight"] + + +def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int: + """ + Determine the size of a matrix given its location on the module + + :param module: module that matrix will be applied to + :param location: location on module + :return: size of matrix + """ + assert isinstance(module, torch.nn.Linear) + if location in ("input", TransformLocation.WEIGHT_INPUT): + return module.in_features + else: + return module.out_features + + +def apply_transform_weight( + weight: torch.Tensor, + value: torch.Tensor, + location: TransformLocation, +) -> torch.Tensor: + """ + Using the transform location, determine how to apply the transform weight to the + given value + + let x be input activation + W be weight, + yh, xh, Wh be transformed output, input, weight + + note that + y = (x W.T) // torch.nn.Linear + yh = (xh) (Wh).T // transformed + + let V, Vi be transform matrices on input side + U, Ui be transform matrices on output side + + show that the following values for yh, xh, and Wh are consistent + + pick xh = (x V) + Wh = (U.T W Vi.T) + yh = (y U) + + (xh) (Wh).T = (x V) (U.T W Vi.T).T + = (x V) (Vi W.T U) // transpose matrix product identity + = (x W.T) U + = y U + = yh + + :param weight: transform weight to apply + :param value: value to apply weight to + :param location: determines how weight should be applied + :return: value after transform weight has been applied + """ + + if location == TransformLocation.INPUT: + return value @ weight + + elif location == TransformLocation.WEIGHT_INPUT: + return value @ weight.T + + elif location == TransformLocation.WEIGHT_OUTPUT: + return weight.T @ value + + elif location == TransformLocation.OUTPUT: + return value @ weight diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index a842d00e..d8898ae4 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import warnings from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional @@ -38,6 +39,8 @@ "shard_tensor", "pack_bitmasks", "unpack_bitmasks", + "patch_attr", + "ParameterizedDefaultDict", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -328,3 +331,53 @@ def unpack_bitmasks( ) return unpacked_bitmasks_torch + + +@contextlib.contextmanager +def patch_attr(base: object, attr: str, value: Any): + """ + Patch the value of an object attribute. Original value is restored upon exit + + :param base: object which has the attribute to patch + :param attr: name of the the attribute to patch + :param value: used to replace original value + + Usage: + >>> from types import SimpleNamespace + >>> obj = SimpleNamespace() + >>> with patch_attr(obj, "attribute", "value"): + ... assert obj.attribute == "value" + >>> assert not hasattr(obj, "attribute") + """ + _sentinel = object() + original_value = getattr(base, attr, _sentinel) + + setattr(base, attr, value) + try: + yield + finally: + if original_value is not _sentinel: + setattr(base, attr, original_value) + else: + delattr(base, attr) + + +class ParameterizedDefaultDict(dict): + """ + Similar to `collections.DefaultDict`, but upon fetching a key which is missing, + the key is passed as arguments to the `default_factory` + + :param default_factory: function which takes a key as input and returns the + corresponding default value + """ + + def __init__(self, default_factory: Callable[[Any], Any]): + self.default_factory = default_factory + + def __missing__(self, key): + if isinstance(key, tuple): + value = self.default_factory(*key) + else: + value = self.default_factory(key) + self[key] = value + return value diff --git a/tests/test_transform/utils/test_hadamards.py b/tests/test_transform/utils/test_hadamards.py new file mode 100644 index 00000000..ae8a0664 --- /dev/null +++ b/tests/test_transform/utils/test_hadamards.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy +import pytest +import torch +from compressed_tensors.transform.utils.hadamard import ( + _get_had12, + _get_had20, + deterministic_hadamard_matrix, + random_hadamard_matrix, +) + + +@pytest.mark.parametrize( + "had_func", + [ + _get_had12, + _get_had20, + ], +) +def test_packed_hadamard_compliant(had_func): + had_matrix = had_func() + size = had_matrix.shape[0] + # HH.T == nI + val_1 = had_matrix @ had_matrix.T + assert torch.equal(val_1 / size, torch.eye(size)) + + +@pytest.mark.parametrize( + "size", + [4096, 2048], +) +def test_random_hadamard_matrix_compliant(size): + had_matrix = random_hadamard_matrix(size) + val_1 = torch.round(had_matrix @ had_matrix.T) + assert torch.equal(val_1, torch.eye(size)) + + +@pytest.mark.parametrize( + "size", + [1024], +) +def test_deterministic_hadamard_compliant(size): + had_matrix = deterministic_hadamard_matrix(size) + # HH.T == nI + val_1 = had_matrix @ had_matrix.T + assert numpy.array_equal(val_1 / size, numpy.eye(size)) diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py index b2fab70d..e50f141e 100644 --- a/tests/test_utils/test_helpers.py +++ b/tests/test_utils/test_helpers.py @@ -13,10 +13,17 @@ # limitations under the License. import os +from types import SimpleNamespace import pytest import torch -from compressed_tensors import load_compressed, save_compressed, save_compressed_model +from compressed_tensors import ( + ParameterizedDefaultDict, + load_compressed, + patch_attr, + save_compressed, + save_compressed_model, +) from compressed_tensors.config import BitmaskConfig from safetensors.torch import save_model from transformers import AutoModelForCausalLM @@ -151,3 +158,36 @@ def test_save_compressed_model(tmp_path, llama_model): # make sure that compressed model is smaller # than uncompressed by roughly 1.14 (value established empirically) assert pytest.approx(size_uncompressed_kb / size_compressed_kb, 0.01) == 1.14 + + +def test_patch_attr(): + # patch, original value + obj = SimpleNamespace() + obj.attribute = "original" + with patch_attr(obj, "attribute", "patched"): + assert obj.attribute == "patched" + obj.attribute = "modified" + assert obj.attribute == "original" + + # patch, no original attribute + obj = SimpleNamespace() + with patch_attr(obj, "attribute", "patched"): + assert obj.attribute == "patched" + obj.attribute = "modified" + assert not hasattr(obj, "attribute") + + +def test_parameterized_default_dict(): + def add_one(value): + return value + 1 + + add_dict = ParameterizedDefaultDict(add_one) + assert add_dict[0] == 1 + assert add_dict[1] == 2 + + def sum_vals(a, b): + return a + b + + sum_dict = ParameterizedDefaultDict(sum_vals) + assert sum_dict[0, 1] == 1 + assert sum_dict[5, 7] == 12 From 8c5a2d96af1f5e1b11bcc1ea5476685cec27f816 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 30 May 2025 16:06:10 -0400 Subject: [PATCH 05/27] Implement transform factories Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/__init__.py | 5 + .../transform/factory/__init__.py | 13 ++ .../transform/factory/base.py | 161 +++++++++++++++++ .../transform/factory/hadamard.py | 79 +++++++++ .../transform/factory/matrix_multiply.py | 86 +++++++++ .../transform/factory/random_hadamard.py | 38 ++++ .../transform/transform_args.py | 2 +- .../transform/utils/__init__.py | 13 ++ .../transform/utils/hadamard.py | 165 ++++++++++++++++++ .../transform/utils/utils.py | 85 +++++++++ src/compressed_tensors/utils/helpers.py | 53 ++++++ src/compressed_tensors/utils/offload.py | 121 ++++++++++++- .../factory/test_correctness.py | 107 ++++++++++++ tests/test_transform/factory/test_memory.py | 112 ++++++++++++ tests/test_transform/utils/test_hadamards.py | 60 +++++++ tests/test_utils/test_helpers.py | 42 ++++- tests/test_utils/test_offload.py | 75 +++++++- 17 files changed, 1209 insertions(+), 8 deletions(-) create mode 100644 src/compressed_tensors/transform/factory/__init__.py create mode 100644 src/compressed_tensors/transform/factory/base.py create mode 100644 src/compressed_tensors/transform/factory/hadamard.py create mode 100644 src/compressed_tensors/transform/factory/matrix_multiply.py create mode 100644 src/compressed_tensors/transform/factory/random_hadamard.py create mode 100644 src/compressed_tensors/transform/utils/__init__.py create mode 100644 src/compressed_tensors/transform/utils/hadamard.py create mode 100644 src/compressed_tensors/transform/utils/utils.py create mode 100644 tests/test_transform/factory/test_correctness.py create mode 100644 tests/test_transform/factory/test_memory.py create mode 100644 tests/test_transform/utils/test_hadamards.py diff --git a/src/compressed_tensors/transform/__init__.py b/src/compressed_tensors/transform/__init__.py index 70a7bd49..f6d656dd 100644 --- a/src/compressed_tensors/transform/__init__.py +++ b/src/compressed_tensors/transform/__init__.py @@ -18,3 +18,8 @@ from .transform_args import * from .transform_scheme import * from .transform_config import * + +from .factory.base import * +from .factory.hadamard import * +from .factory.matrix_multiply import * +from .factory.random_hadamard import * diff --git a/src/compressed_tensors/transform/factory/__init__.py b/src/compressed_tensors/transform/factory/__init__.py new file mode 100644 index 00000000..0c44f887 --- /dev/null +++ b/src/compressed_tensors/transform/factory/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py new file mode 100644 index 00000000..30033447 --- /dev/null +++ b/src/compressed_tensors/transform/factory/base.py @@ -0,0 +1,161 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +import torch +import torch.nn.utils.parametrize as P +from compressed_tensors.quantization.lifecycle import is_target # TODO: move to utils +from compressed_tensors.registry.registry import RegistryMixin, T +from compressed_tensors.transform import ( + TransformArgs, + TransformLocation, + TransformScheme, +) +from compressed_tensors.utils import ( + align_module_device, + has_offloaded_params, + patch_attr, + register_offload_module, + update_offload_parameter, +) +from torch import Tensor +from torch.nn import Module, Parameter + + +__all__ = ["TransformFactory", "TransformBase"] + + +class TransformFactory(RegistryMixin, ABC): + """ + Abstract factory base used to create and apply transforms to a model + + :param name: name associated with transform scheme + :param scheme: transform scheme which defines how transforms should be created + :param seed: random seed used to transform weight randomization + """ + + def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): + self.name = name + self.scheme = scheme + self.seed = seed + + @classmethod + def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T: + """ + Create a transform factory from a scheme + + :param scheme: defines how transforms should be created + :param kwargs: TransformFactory constructor arguments + :return: subclass of `TransformFactory` corresponding to the scheme type + """ + constructor = cls.get_value_from_registry(name=scheme.type) + return constructor(scheme=scheme, **kwargs) + + @abstractmethod + def create_transform(self, module: Module, args: TransformArgs) -> "TransformBase": + """ + Abstract method which defines how a transform should be created. May utilize + caching to maximize shared memory + + :param module: parent module that transform will be applied to + :param args: defines how the transform will be applied to the module + :return: instance of TransformBase + """ + raise NotImplementedError() + + def apply_to_model(self, model: Module): + """ + Create transforms and apply them to the model + + :param model: module to apply transforms to + """ + for arg in self.scheme.apply: + for path, module in list(model.named_modules()): + if is_target(path, module, arg.targets, arg.ignore): + self._apply_to_module(module, arg) + + def _apply_to_module(self, module: Module, args: TransformArgs): + """ + Create transforms and apply them to the module + + :param module: target module to apply transforms to + :param args: defines how the transform will be applied to the target module + """ + # create transform as submodule + transform_name = f"{self.name}_{args.location}" + transform = self.create_transform(module, args) + register_offload_module(module, transform_name, transform) # (1) + + # register input transformation hook + if args.location == TransformLocation.INPUT: + + def input_hook(_, args): + input = args[0] + return transform(input) + + module.register_forward_pre_hook(input_hook) + + # eagerly apply transformation to weight + elif args.location in ( + TransformLocation.WEIGHT_INPUT, + TransformLocation.WEIGHT_OUTPUT, + ): + assert isinstance(module, torch.nn.Linear) + assert module.bias is None + + with torch.no_grad(), align_module_device(module): + update_offload_parameter(module, "weight", transform(module.weight)) + + if self.scheme.requires_grad: + # for training, the weight changes with every forward pass + # so we can leverage parametrization to propagate the gradient + if has_offloaded_params(module): + raise ValueError("Offloaded training is not supported") + P.register_parametrization(module, "weight", transform) + + # register output transformation hook + elif args.location == TransformLocation.OUTPUT: + + def output_hook(_, _input, output): + return transform(output) + + module.register_forward_hook(output_hook) + + # other locations such as q_attn and k_attn have not been implemented + else: + raise NotImplementedError() + + # (1) even in the `weight` cases, this submodule attachment is needed in order + # to support saving in the frozen state + + +class TransformBase(Module, ABC): + """ + Represents the application of a transform accord to TransformArgs + """ + + args: TransformArgs + weight: Parameter + + @abstractmethod + def forward(self, value: Tensor) -> Tensor: + raise NotImplementedError() + + def right_inverse(self, value: Tensor) -> Tensor: + with patch_attr(self.args, "inverse", not self.args.inverse): + return self.forward(value) + + def __repr__(self): + return f"{self.__class__.__name__}(inverse={self.args.inverse})" diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py new file mode 100644 index 00000000..b73e0687 --- /dev/null +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from compressed_tensors.transform import TransformArgs, TransformScheme +from compressed_tensors.transform.factory.base import TransformBase, TransformFactory +from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix +from compressed_tensors.transform.utils.utils import ( + apply_transform_weight, + get_matrix_size, +) +from compressed_tensors.utils import get_offloaded_device +from compressed_tensors.utils.helpers import ParameterizedDefaultDict +from torch import Tensor, device, dtype +from torch.nn import Linear, Module, Parameter + + +@TransformFactory.register("hadamard") +class HadamardFactory(TransformFactory): + """ + Factory used to apply hadamard transforms to a model + + :param name: name associated with transform scheme + :param scheme: transform scheme which defines how transforms should be created + :param seed: random seed used to transform weight randomization + """ + + def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): + super().__init__(name, scheme, seed) + self.weights = ParameterizedDefaultDict(self._create_weight) + + def create_transform(self, module: Module, args: TransformArgs): + """ + Create a HadamardTransform for applying to a module. Transforms with the same + size, dtype, and device are cached + + :param module: parent module that transform will be applied to + :param args: defines how the transform will be applied to the module + """ + assert isinstance(module, Linear) + size = get_matrix_size(module, args.location) + dtype = module.weight.dtype + device = get_offloaded_device(module) + + weight = self.weights[size, dtype, device] + return HadamardTransform(weight, args) + + def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: + data = torch.tensor(deterministic_hadamard_matrix(size)) # TODO: seed=self.seed + data = data.to(dtype=dtype, device=device) + return Parameter(data, requires_grad=self.scheme.requires_grad) + + +class HadamardTransform(TransformBase): + def __init__(self, weight: Parameter, args: TransformArgs): + super().__init__() + self.weight = weight + self.args = args + + def forward(self, value: Tensor) -> Tensor: + if not self.args.inverse: + weight = self.weight + else: + weight = self.weight.T / self.weight.size(0) + + return apply_transform_weight(weight, value, self.args.location) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py new file mode 100644 index 00000000..13a27f79 --- /dev/null +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -0,0 +1,86 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transform import TransformArgs, TransformScheme +from compressed_tensors.transform.factory.base import TransformBase, TransformFactory +from compressed_tensors.transform.utils.utils import ( + apply_transform_weight, + get_matrix_size, +) +from compressed_tensors.utils import get_offloaded_device +from compressed_tensors.utils.helpers import ParameterizedDefaultDict +from torch import Tensor, device, dtype +from torch.nn import Linear, Module, Parameter + + +@TransformFactory.register("matrix-mul") +class RandomMatrixFactory(TransformFactory): + """ + Factory used to apply random matrix transforms to a model + + :param name: name associated with transform scheme + :param scheme: transform scheme which defines how transforms should be created + :param seed: random seed used to transform weight randomization + """ + + def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): + super().__init__(name, scheme, seed) + self.weights = ParameterizedDefaultDict(self._create_weight) + self.inverses = ParameterizedDefaultDict(self._create_inverse) + + def create_transform(self, module: Module, args: TransformArgs): + """ + Create a RandomMatrixTransform for applying to a module. Transforms with the + same size, dtype, and device are cached + + :param module: parent module that transform will be applied to + :param args: defines how the transform will be applied to the module + """ + assert isinstance(module, Linear) + size = get_matrix_size(module, args.location) + dtype = module.weight.dtype + device = get_offloaded_device(module) + + if not args.inverse: + weight = self.weights[size, dtype, device] + else: + weight = self.inverses[size, dtype, device] + return RandomMatrixTransform(weight, args) + + def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: + data = torch.rand((size, size), dtype=dtype, device=device) + return Parameter(data, requires_grad=self.scheme.requires_grad) + + def _create_inverse(self, size: int, dtype: dtype, device: device) -> Parameter: + weight = self.weights[size, dtype, device] + return Parameter(high_precision_invert(weight.data), requires_grad=False) + + +class RandomMatrixTransform(TransformBase): + def __init__(self, weight: Tensor, args: TransformArgs): + super().__init__() + self.weight = weight # is an inverse if args.inverse + self.args = args + + def forward(self, value: Tensor) -> Parameter: + return apply_transform_weight(self.weight, value, self.args.location) + + def right_inverse(self, value: Tensor) -> Tensor: + inverse = high_precision_invert(self.weight) + return apply_transform_weight(inverse, value, self.args.location) + + +def high_precision_invert(weight: Tensor) -> Tensor: + return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype) diff --git a/src/compressed_tensors/transform/factory/random_hadamard.py b/src/compressed_tensors/transform/factory/random_hadamard.py new file mode 100644 index 00000000..e4f14186 --- /dev/null +++ b/src/compressed_tensors/transform/factory/random_hadamard.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from compressed_tensors.transform import HadamardFactory, TransformFactory +from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix +from torch import device, dtype +from torch.nn import Parameter + + +@TransformFactory.register("random-hadamard") +class RandomHadamardFactory(HadamardFactory): + """ + Factory used to apply random hadamard transforms to a model + + :param name: name associated with transform scheme + :param scheme: transform scheme which defines how transforms should be created + :param seed: random seed used to transform weight randomization + """ + + def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: + for key in self.weights.keys(): + if key[0] == size: + return self.weights[key].to(dtype=dtype, device=device) + + data = random_hadamard_matrix(size) # seed + data = data.to(dtype=dtype, device=device) + return Parameter(data, requires_grad=self.scheme.requires_grad) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index d0487678..a9ed36c9 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -18,7 +18,7 @@ from pydantic import BaseModel, Field, field_validator -__all__ = ["TransformArgs"] +__all__ = ["TransformLocation", "TransformArgs"] class TransformLocation(str, Enum): diff --git a/src/compressed_tensors/transform/utils/__init__.py b/src/compressed_tensors/transform/utils/__init__.py new file mode 100644 index 00000000..0c44f887 --- /dev/null +++ b/src/compressed_tensors/transform/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py new file mode 100644 index 00000000..1f042941 --- /dev/null +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -0,0 +1,165 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Tuple + +import numpy +import torch + + +__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"] + +# adapted from: +# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py +def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: + """ + Construct an Hadamard matrix. + + Constructs an n-by-n Hadamard matrix, using Sylvester's + construction. `n` must be a power of 2. + + :param size: order of the matrix; must be a power of 2 + + returns a (size, size) hadamard matrix + """ + if size <= 0: + raise ValueError("Cannot construct deterministic hadamard of size <= 0") + + log2 = int(math.log(size, 2)) + if size != 2**log2: + raise ValueError("Cannot construct deterministic hadamard of size != 2^n") + + H = numpy.array([[1]], dtype=int) + + # Sylvester's construction + for i in range(0, log2): + H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H)))) + + return H + + +# adapted from: +# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py + +# TODO: the following library exists for online rotations and should be considered +# in the future: +# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master + + +def random_hadamard_matrix(size: int) -> torch.Tensor: + """ + Produces a randomly generated Hadamard matrix. + See https://cornell-relaxml.github.io/quip-sharp/ , + Section "Randomized Hadamard Transformation" + + :param size: The dimension of the matrix. Matrix generated will have dimensions + (size, size) + + """ + # TODO: potentially update to add "seed" as an arugment, to allow + # the matrix generated to be reproducible + + # Benefits: support other shapes / non powers of 2, support randomization + Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + Q = Q * 2 - 1 + Q = torch.diag(Q) + return _matmul_hadU(Q) + + +def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]: + # NOTE: we can easily extend the list of supported shapes/sizes + # by adding to these methods + hadK, K = None, None + if n % 20 == 0: + assert _is_pow2(n // 20) + K = 20 + hadK = _get_had20().T if transpose else _get_had20() + elif n % 12 == 0: + assert _is_pow2(n // 12) + K = 12 + hadK = _get_had12().T if transpose else _get_had12() + else: + assert _is_pow2(n) + K = 1 + + return hadK, K + + +def _matmul_hadU(X, transpose=False) -> torch.Tensor: + n = X.shape[-1] + # Check if we have the determined hadamard matrix + hadK, K = _get_hadK(n, transpose) + # Reshape diag matrix with randomized -1/+1 + input = X.clone().view(-1, n, 1) + output = input.clone() + + # for cases when hadK is not predetermined, determine hadamard matrix + while input.shape[1] > K: + input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) + output = output.view(input.shape) + output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] + output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] + output = output.view(input.shape[0], input.shape[1], -1) + (input, output) = (output, input) + del output + + # K == 1 when hadK is None; this happens when the size dim (n) + # is not comaptible with any of the maintained hadamard matrices + + if K > 1: + # Do not explicitly repeat - OOM + # input = torch.bmm( + # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) + # Use bcast instead + + # for cases when hadK is pre-determined + input = hadK.view(1, K, K).to(input) @ input + + # normalize + return input.view(X.shape) / torch.tensor(n).sqrt() + + +def _is_pow2(n: int) -> bool: + return (n & (n - 1) == 0) and (n > 0) + + +def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray: + had_unpacked = numpy.unpackbits(packed_bits) + had_unpacked = [1 if x == 1 else -1 for x in had_unpacked] + had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size)) + return had_unpacked + + +# http://www.neilsloane.com/hadamard/index.html +def _get_had12() -> torch.Tensor: + # fmt: off + had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218, + 62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8) + # fmt: on + # TODO: just unpack during apply + had_12_unpacked = _reshape_bits(had_12, original_size=12) + return torch.tensor(had_12_unpacked) + + +def _get_had20() -> torch.Tensor: + # fmt: off + had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252, + 216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243, + 97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43, + 205, 225, 94, 107, 10, 243], dtype=numpy.uint8) + # fmt: on + # TODO: just unpack during apply + had_20_unpacked = _reshape_bits(had_20, original_size=20) + return torch.tensor(had_20_unpacked) diff --git a/src/compressed_tensors/transform/utils/utils.py b/src/compressed_tensors/transform/utils/utils.py new file mode 100644 index 00000000..eebe3663 --- /dev/null +++ b/src/compressed_tensors/transform/utils/utils.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transform import TransformLocation + + +__all__ = ["get_matrix_size", "apply_transform_weight"] + + +def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int: + """ + Determine the size of a matrix given its location on the module + + :param module: module that matrix will be applied to + :param location: location on module + :return: size of matrix + """ + assert isinstance(module, torch.nn.Linear) + if location in ("input", TransformLocation.WEIGHT_INPUT): + return module.in_features + else: + return module.out_features + + +def apply_transform_weight( + weight: torch.Tensor, + value: torch.Tensor, + location: TransformLocation, +) -> torch.Tensor: + """ + Using the transform location, determine how to apply the transform weight to the + given value + + let x be input activation + W be weight, + yh, xh, Wh be transformed output, input, weight + + note that + y = (x W.T) // torch.nn.Linear + yh = (xh) (Wh).T // transformed + + let V, Vi be transform matrices on input side + U, Ui be transform matrices on output side + + show that the following values for yh, xh, and Wh are consistent + + pick xh = (x V) + Wh = (U.T W Vi.T) + yh = (y U) + + (xh) (Wh).T = (x V) (U.T W Vi.T).T + = (x V) (Vi W.T U) // transpose matrix product identity + = (x W.T) U + = y U + = yh + + :param weight: transform weight to apply + :param value: value to apply weight to + :param location: determines how weight should be applied + :return: value after transform weight has been applied + """ + + if location == TransformLocation.INPUT: + return value @ weight + + elif location == TransformLocation.WEIGHT_INPUT: + return value @ weight.T + + elif location == TransformLocation.WEIGHT_OUTPUT: + return weight.T @ value + + elif location == TransformLocation.OUTPUT: + return value @ weight diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index a842d00e..d8898ae4 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import warnings from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional @@ -38,6 +39,8 @@ "shard_tensor", "pack_bitmasks", "unpack_bitmasks", + "patch_attr", + "ParameterizedDefaultDict", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -328,3 +331,53 @@ def unpack_bitmasks( ) return unpacked_bitmasks_torch + + +@contextlib.contextmanager +def patch_attr(base: object, attr: str, value: Any): + """ + Patch the value of an object attribute. Original value is restored upon exit + + :param base: object which has the attribute to patch + :param attr: name of the the attribute to patch + :param value: used to replace original value + + Usage: + >>> from types import SimpleNamespace + >>> obj = SimpleNamespace() + >>> with patch_attr(obj, "attribute", "value"): + ... assert obj.attribute == "value" + >>> assert not hasattr(obj, "attribute") + """ + _sentinel = object() + original_value = getattr(base, attr, _sentinel) + + setattr(base, attr, value) + try: + yield + finally: + if original_value is not _sentinel: + setattr(base, attr, original_value) + else: + delattr(base, attr) + + +class ParameterizedDefaultDict(dict): + """ + Similar to `collections.DefaultDict`, but upon fetching a key which is missing, + the key is passed as arguments to the `default_factory` + + :param default_factory: function which takes a key as input and returns the + corresponding default value + """ + + def __init__(self, default_factory: Callable[[Any], Any]): + self.default_factory = default_factory + + def __missing__(self, key): + if isinstance(key, tuple): + value = self.default_factory(*key) + else: + value = self.default_factory(key) + self[key] = value + return value diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 7ee06444..4241479c 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -28,15 +28,18 @@ import contextlib import warnings from functools import wraps -from typing import Any, Callable, Dict, Iterable, Literal, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Union import torch try: + from accelerate import dispatch_model from accelerate.hooks import ( AlignDevicesHook, add_hook_to_module, + attach_align_device_hook, + named_module_tensors, remove_hook_from_module, ) from accelerate.utils import ( @@ -54,6 +57,9 @@ OffloadedWeightsLoader = None PrefixedDataset = None set_module_tensor_to_device = None + named_module_tensors = None + dispatch_model = None + attach_align_device_hook = None __all__ = [ @@ -70,6 +76,8 @@ "disable_offload", "align_modules", "align_module_device", + "register_offload_module", + "force_cpu_offload", ] @@ -77,6 +85,11 @@ def check_accelerate(fallback: Any): def decorator(func: Callable[[Any], Any]): if not _has_accelerate: + if fallback == "error": + raise ValueError( + "Please install `accelerate` in order to use this function" + ) + @wraps(func) def fallback_fn(*args, **kwargs): return fallback @@ -346,6 +359,7 @@ def delete_from_weights_map( ) +@check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager def disable_offload(module: torch.nn.Module): """ @@ -362,6 +376,7 @@ def disable_offload(module: torch.nn.Module): yield +@check_accelerate(fallback=contextlib.nullcontext()) @contextlib.contextmanager def align_modules( modules: Union[torch.nn.Module, Iterable[torch.nn.Module]], @@ -383,6 +398,110 @@ def align_modules( yield +@check_accelerate(fallback=None) +def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module): + """ + Register a submodule with offloading if the parent module is offloaded + + :param base: module to attach submodule to + :param name: name of submodule + :param module: submodule to attach + """ + + if has_offloaded_params(base): + hook: AlignDevicesHook = base._hf_hook + assert hook.offload + assert hook.weights_map is not None + assert hook.tied_params_map is not None + + # offloading kwargs for submodule + place_submodules = False + offload_buffers = True + + # copy device offloading arguments from parent + current_device = next(base.parameters()).device # assume base has parameters + offload_device = get_offloaded_device(base) + + # offload parameters to weights map + for param_name, param in named_module_tensors( + module, include_buffers=offload_buffers, recurse=place_submodules + ): + offloaded = param.to(offload_device) + hook.tied_params_map[offloaded.data_ptr()] = {} # (1) + offload_to_weights_map(hook.weights_map, f"{name}.{param_name}", offloaded) + + # if the parent places submodules, offload here + if hook.place_submodules: + set_module_tensor_to_device(module, param_name, current_device) + + # if the parent does not place submodules, then add a hook + # parameters are offloaded by `add_hook_to_module` + if not hook.place_submodules: + weights_map = PrefixedDataset( + hook.weights_map.dataset, prefix=f"{hook.weights_map.prefix}{name}." + ) + + submodule_hook = AlignDevicesHook( + execution_device=hook.execution_device, + offload=hook.offload, + io_same_device=False, + weights_map=weights_map, + offload_buffers=offload_buffers, + place_submodules=place_submodules, + skip_keys=None, + tied_params_map=hook.tied_params_map, + ) + add_hook_to_module(module, submodule_hook) + + base.register_module(name, module) + + # (1): Since we cannot know which pointers are shared when we add parameters in an + # online way, assume that all pointers are shared. This comes at no runtime cost + + +@check_accelerate(fallback="error") +def force_cpu_offload( + module: torch.nn.Module, execution_device: torch.device +) -> torch.nn.Module: + """ + Force cpu offloading a module, primarily used for testing + + :param module: module containing parameters to offload + :param execution_device: execution device submodules + :return: module with hooks to perform cpu offloading + """ + # edge case: there is a bug in `dispatch_model` which causes + # the function to only work if the model contains submodules + if next(module.children(), None) is None: + attach_align_device_hook( + module, + execution_device=execution_device, + offload=True, + weights_map=module.state_dict(), + tied_params_map={}, + ) + return module + + device_map = {} + + def collect_device_map(name: List[str], module: torch.nn.Module): + if next(module.parameters(recurse=False), None) is not None: + device_map[".".join(name)] = "cpu" + return + + else: + for submodule_name, submodule in module.named_children(): + name.append(submodule_name) + collect_device_map(name, submodule) + name.pop() + + collect_device_map([], module) + + return dispatch_model( + module, device_map, main_device=execution_device, force_hooks=True + ) + + """ Upstreamed Functions """ diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py new file mode 100644 index 00000000..44cc89e5 --- /dev/null +++ b/tests/test_transform/factory/test_correctness.py @@ -0,0 +1,107 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.transform import ( + TransformArgs, + TransformFactory, + TransformScheme, +) +from compressed_tensors.utils import align_modules, force_cpu_offload +from tests.testing_utils import requires_accelerate, requires_gpu + + +class TransformableModel(torch.nn.Module): + def __init__(self, *sizes): + super().__init__() + self.fcs = torch.nn.ModuleList([]) + self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) + for index in range(1, len(sizes) - 1): + self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) + + def forward(self, x): + for layer in self.fcs: + x = layer(x) + return x + + +@pytest.mark.parametrize( + "scheme", + [TransformScheme(type=name) for name in TransformFactory.registered_names()], +) +def test_correctness_linear(scheme): + size = (4, 8) + module = torch.nn.Linear(*size, bias=True) + factory = TransformFactory.from_scheme(scheme, name="") + + input_tfm = factory.create_transform( + module, TransformArgs(targets="Linear", location="input", inverse=True) + ) + w_in_tfm = factory.create_transform( + module, TransformArgs(targets="Linear", location="weight_input") + ) + w_out_tfm = factory.create_transform( + module, TransformArgs(targets="Linear", location="weight_output") + ) + output_tfm = factory.create_transform( + module, TransformArgs(targets="Linear", location="output", inverse=True) + ) + + input = torch.rand((17, size[0])) + true_output = input @ module.weight.T + input_transformed = input_tfm(input) + weight_transformed = w_out_tfm(w_in_tfm(module.weight)) + output = output_tfm(input_transformed @ weight_transformed.T) + + torch.allclose(true_output, output, atol=1e-7, rtol=0.0) + + +@pytest.mark.parametrize( + "scheme", + [TransformScheme(type=name) for name in TransformFactory.registered_names()], +) +def test_correctness_model(scheme, offload=False): + # load model + model = TransformableModel(2, 4, 8, 16) + if offload: + model = force_cpu_offload(model, torch.device("cuda")) + + # create factory + scheme.apply = [ + TransformArgs(targets="fcs.0", location="input"), + TransformArgs(targets="fcs.2", location="output", inverse=True), + ] + factory = TransformFactory.from_scheme(scheme, name="") + + # create inputs + input = torch.rand((17, model.fcs[0].in_features)) + if offload: + input = input.to(torch.device("cuda")) + + # compare outputs + true_output = model(input) + factory.apply_to_model(model) + output = model(input) + torch.allclose(true_output, output, atol=1e-7, rtol=0.0) + + +@requires_gpu +@requires_accelerate() +@pytest.mark.parametrize( + "scheme", + [TransformScheme(type=name) for name in TransformFactory.registered_names()], +) +def test_correctness_model_offload(scheme): + test_correctness_model(scheme, offload=True) diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py new file mode 100644 index 00000000..49e882e4 --- /dev/null +++ b/tests/test_transform/factory/test_memory.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import Counter + +import pytest +import torch +from compressed_tensors.transform import ( + TransformArgs, + TransformBase, + TransformFactory, + TransformScheme, +) +from compressed_tensors.utils import align_modules, force_cpu_offload +from tests.testing_utils import requires_accelerate, requires_gpu + + +class TransformableModel(torch.nn.Module): + def __init__(self, *sizes): + super().__init__() + self.fcs = torch.nn.ModuleList([]) + self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) + for index in range(1, len(sizes) - 1): + self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) + + def forward(self, x): + for layer in self.fcs: + x = layer(x) + return x + + +@pytest.mark.parametrize( + "scheme", + [TransformScheme(type=name) for name in TransformFactory.registered_names()], +) +def test_memory_sharing(scheme, offload=False): + # load scheme and factory + scheme = TransformScheme( + type="hadamard", + apply=[ + TransformArgs(targets="Linear", location="input"), + TransformArgs(targets="Linear", location="output"), + ], + ) + factory = TransformFactory.from_scheme(scheme, name="") + + # load model (maybe with offloading) + model = TransformableModel(2, 2, 4, 4, 8, 8) + if offload: + force_cpu_offload(model, torch.device("cuda")) + + # add transforms to model + factory.apply_to_model(model) + + # check that memory is shared when onloaded + with align_modules(model.modules()): + weights = [m.weight for m in model.modules() if isinstance(m, TransformBase)] + weight_to_count = Counter(weights) + size_to_weight = {weight.size(0): weight for weight in weight_to_count} + + assert len(weight_to_count) == len(size_to_weight) == 3 + assert weight_to_count[size_to_weight[2]] == 3 + assert weight_to_count[size_to_weight[4]] == 4 + assert weight_to_count[size_to_weight[8]] == 3 + + # check that memory is shared in offloaded dict + if offload: + weights_map = dict(model.fcs[0]._hf_hook.weights_map.dataset) + offloaded_weights = [ + value + for name, value in weights_map.items() + if name.endswith("_input.weight") or name.endswith("_output.weight") + ] + weight_to_count = Counter(offloaded_weights) + size_to_weight = {weight.size(0): weight for weight in weight_to_count} + + assert len(weight_to_count) == len(size_to_weight) == 3 + assert weight_to_count[size_to_weight[2]] == 3 + assert weight_to_count[size_to_weight[4]] == 4 + assert weight_to_count[size_to_weight[8]] == 3 + + +@requires_gpu +@requires_accelerate() +@pytest.mark.parametrize( + "scheme", + [TransformScheme(type=name) for name in TransformFactory.registered_names()], +) +def test_memory_sharing_offload(scheme): + test_memory_sharing(scheme, offload=True) + + +@pytest.mark.parametrize( + "scheme", + [ + TransformScheme(type=name, requires_grad=True) + for name in TransformFactory.registered_names() + ], +) +def test_memory_sharing_training(scheme): + test_memory_sharing(scheme, offload=False) diff --git a/tests/test_transform/utils/test_hadamards.py b/tests/test_transform/utils/test_hadamards.py new file mode 100644 index 00000000..ae8a0664 --- /dev/null +++ b/tests/test_transform/utils/test_hadamards.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy +import pytest +import torch +from compressed_tensors.transform.utils.hadamard import ( + _get_had12, + _get_had20, + deterministic_hadamard_matrix, + random_hadamard_matrix, +) + + +@pytest.mark.parametrize( + "had_func", + [ + _get_had12, + _get_had20, + ], +) +def test_packed_hadamard_compliant(had_func): + had_matrix = had_func() + size = had_matrix.shape[0] + # HH.T == nI + val_1 = had_matrix @ had_matrix.T + assert torch.equal(val_1 / size, torch.eye(size)) + + +@pytest.mark.parametrize( + "size", + [4096, 2048], +) +def test_random_hadamard_matrix_compliant(size): + had_matrix = random_hadamard_matrix(size) + val_1 = torch.round(had_matrix @ had_matrix.T) + assert torch.equal(val_1, torch.eye(size)) + + +@pytest.mark.parametrize( + "size", + [1024], +) +def test_deterministic_hadamard_compliant(size): + had_matrix = deterministic_hadamard_matrix(size) + # HH.T == nI + val_1 = had_matrix @ had_matrix.T + assert numpy.array_equal(val_1 / size, numpy.eye(size)) diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py index b2fab70d..e50f141e 100644 --- a/tests/test_utils/test_helpers.py +++ b/tests/test_utils/test_helpers.py @@ -13,10 +13,17 @@ # limitations under the License. import os +from types import SimpleNamespace import pytest import torch -from compressed_tensors import load_compressed, save_compressed, save_compressed_model +from compressed_tensors import ( + ParameterizedDefaultDict, + load_compressed, + patch_attr, + save_compressed, + save_compressed_model, +) from compressed_tensors.config import BitmaskConfig from safetensors.torch import save_model from transformers import AutoModelForCausalLM @@ -151,3 +158,36 @@ def test_save_compressed_model(tmp_path, llama_model): # make sure that compressed model is smaller # than uncompressed by roughly 1.14 (value established empirically) assert pytest.approx(size_uncompressed_kb / size_compressed_kb, 0.01) == 1.14 + + +def test_patch_attr(): + # patch, original value + obj = SimpleNamespace() + obj.attribute = "original" + with patch_attr(obj, "attribute", "patched"): + assert obj.attribute == "patched" + obj.attribute = "modified" + assert obj.attribute == "original" + + # patch, no original attribute + obj = SimpleNamespace() + with patch_attr(obj, "attribute", "patched"): + assert obj.attribute == "patched" + obj.attribute = "modified" + assert not hasattr(obj, "attribute") + + +def test_parameterized_default_dict(): + def add_one(value): + return value + 1 + + add_dict = ParameterizedDefaultDict(add_one) + assert add_dict[0] == 1 + assert add_dict[1] == 2 + + def sum_vals(a, b): + return a + b + + sum_dict = ParameterizedDefaultDict(sum_vals) + assert sum_dict[0, 1] == 1 + assert sum_dict[5, 7] == 12 diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index 4bc8587d..d2cea240 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -18,8 +18,10 @@ align_modules, delete_offload_parameter, disable_hf_hook, + force_cpu_offload, get_execution_device, has_offloaded_params, + register_offload_module, register_offload_parameter, update_offload_parameter, ) @@ -37,9 +39,17 @@ def forward(self, x): return x * self.a + self.b +class ExampleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(1, 2) + + def forward(self, x): + return self.linear(x) + + @requires_accelerate() def test_has_offloaded_params(): - from accelerate.big_modeling import cpu_offload_with_hook from accelerate.hooks import attach_align_device_hook, remove_hook_from_module module = ExampleModule() @@ -48,10 +58,6 @@ def test_has_offloaded_params(): attach_align_device_hook(module, offload=False) assert not has_offloaded_params(module) - remove_hook_from_module(module) - module, _ = cpu_offload_with_hook(module) - assert not has_offloaded_params(module) - remove_hook_from_module(module) attach_align_device_hook(module, offload=True, weights_map=module.state_dict()) assert has_offloaded_params(module) @@ -334,3 +340,62 @@ def test_offload_to_weights_map(): weights_map = PrefixedDataset(OffloadedWeightsLoader({name: old_value}), prefix) offload_to_weights_map(weights_map, name, new_value) assert weights_map[name] == new_value + + +@requires_gpu +@requires_accelerate() +def test_register_offload_module(): + execution_device = torch.device("cuda") + + # no offloading + model = ExampleModel() + child = torch.nn.Linear(2, 3) + register_offload_module(model, "child", child) + register_offload_module(model.linear, "child", child) + assert child in model.children() + assert child in model.linear.children() + + # with offloading + model = ExampleModel() + child = torch.nn.Linear(2, 3) + force_cpu_offload(model, execution_device) + register_offload_module(model, "child", child) + register_offload_module(model.linear, "child", child) + assert child in model.children() + assert child in model.linear.children() + + # can run modules + model(torch.empty(1)) + child(torch.empty(2, device=execution_device)) + + +@requires_gpu +@requires_accelerate() +def test_force_cpu_offload(): + execution_device = torch.device("cuda") + + # single module + module = torch.nn.Linear(1, 2) + module = force_cpu_offload(module, execution_device) + assert has_offloaded_params(module) + assert module._hf_hook.offload + assert module.weight.device == torch.device("meta") + assert "weight" in module._hf_hook.weights_map + assert module._hf_hook.tied_params_map is not None + + # can run + module(torch.empty(1, device=execution_device)) + + # model + model = ExampleModel() + model = force_cpu_offload(model, execution_device) + assert not has_offloaded_params(model) + + assert has_offloaded_params(model.linear) + assert model.linear._hf_hook.offload + assert model.linear.weight.device == torch.device("meta") + assert "weight" in model.linear._hf_hook.weights_map + assert model.linear._hf_hook.tied_params_map is not None + + # can run + model(torch.empty(1, device=execution_device)) From 8d613b3c182978b672371039f5b6d7e4df9ef264 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 30 May 2025 22:14:18 -0400 Subject: [PATCH 06/27] add permutations Signed-off-by: Kyle Sayers --- .../transform/factory/hadamard.py | 27 ++++++++++++++----- .../transform/utils/utils.py | 9 ++++++- .../factory/test_correctness.py | 25 +++++++++-------- tests/test_transform/factory/test_memory.py | 27 +++++++++---------- 4 files changed, 52 insertions(+), 36 deletions(-) diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index b73e0687..a05accd4 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Union import torch from compressed_tensors.transform import TransformArgs, TransformScheme from compressed_tensors.transform.factory.base import TransformBase, TransformFactory from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix from compressed_tensors.transform.utils.utils import ( + apply_permutation, apply_transform_weight, get_matrix_size, ) @@ -41,6 +42,7 @@ class HadamardFactory(TransformFactory): def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): super().__init__(name, scheme, seed) self.weights = ParameterizedDefaultDict(self._create_weight) + self.perms = ParameterizedDefaultDict(self._create_permutation) def create_transform(self, module: Module, args: TransformArgs): """ @@ -56,24 +58,35 @@ def create_transform(self, module: Module, args: TransformArgs): device = get_offloaded_device(module) weight = self.weights[size, dtype, device] - return HadamardTransform(weight, args) + perm = self.perms[module, weight] if self.scheme.randomize_modules else None + return HadamardTransform(weight, perm, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: data = torch.tensor(deterministic_hadamard_matrix(size)) # TODO: seed=self.seed data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) + def _create_permutation(self, module: Module, weight: Parameter) -> Parameter: + data = torch.randperm(weight.size(0)) + return Parameter(data, requires_grad=False) + class HadamardTransform(TransformBase): - def __init__(self, weight: Parameter, args: TransformArgs): + def __init__( + self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs + ): super().__init__() self.weight = weight + self.perm = perm self.args = args def forward(self, value: Tensor) -> Tensor: - if not self.args.inverse: - weight = self.weight - else: - weight = self.weight.T / self.weight.size(0) + weight = self.weight + + if self.perm is not None: + weight = apply_permutation(weight, self.perm) + + if self.args.inverse: + weight = weight.T / weight.size(0) return apply_transform_weight(weight, value, self.args.location) diff --git a/src/compressed_tensors/transform/utils/utils.py b/src/compressed_tensors/transform/utils/utils.py index eebe3663..88cbf3aa 100644 --- a/src/compressed_tensors/transform/utils/utils.py +++ b/src/compressed_tensors/transform/utils/utils.py @@ -16,7 +16,7 @@ from compressed_tensors.transform import TransformLocation -__all__ = ["get_matrix_size", "apply_transform_weight"] +__all__ = ["get_matrix_size", "apply_transform_weight", "apply_permutation"] def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int: @@ -83,3 +83,10 @@ def apply_transform_weight( elif location == TransformLocation.OUTPUT: return value @ weight + + +def apply_permutation(weight: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + weight = weight.clone() + diag_indices = torch.arange(weight.size(0)) + weight[diag_indices, diag_indices] = weight.diagonal()[perm] + return weight diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index 44cc89e5..8ac4dbab 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -19,10 +19,18 @@ TransformFactory, TransformScheme, ) -from compressed_tensors.utils import align_modules, force_cpu_offload +from compressed_tensors.utils import force_cpu_offload from tests.testing_utils import requires_accelerate, requires_gpu +_test_schemes = [ + TransformScheme(type=name) for name in TransformFactory.registered_names() +] + [ + TransformScheme(type=name, randomize_modules=True) + for name in TransformFactory.registered_names() +] + + class TransformableModel(torch.nn.Module): def __init__(self, *sizes): super().__init__() @@ -37,10 +45,7 @@ def forward(self, x): return x -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", _test_schemes) def test_correctness_linear(scheme): size = (4, 8) module = torch.nn.Linear(*size, bias=True) @@ -68,10 +73,7 @@ def test_correctness_linear(scheme): torch.allclose(true_output, output, atol=1e-7, rtol=0.0) -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", _test_schemes) def test_correctness_model(scheme, offload=False): # load model model = TransformableModel(2, 4, 8, 16) @@ -99,9 +101,6 @@ def test_correctness_model(scheme, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", _test_schemes) def test_correctness_model_offload(scheme): test_correctness_model(scheme, offload=True) diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index 49e882e4..63e37561 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -26,6 +26,14 @@ from tests.testing_utils import requires_accelerate, requires_gpu +_test_schemes = [ + TransformScheme(type=name) for name in TransformFactory.registered_names() +] + [ + TransformScheme(type=name, randomize_modules=True) + for name in TransformFactory.registered_names() +] + + class TransformableModel(torch.nn.Module): def __init__(self, *sizes): super().__init__() @@ -40,10 +48,7 @@ def forward(self, x): return x -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", _test_schemes) def test_memory_sharing(scheme, offload=False): # load scheme and factory scheme = TransformScheme( @@ -93,20 +98,12 @@ def test_memory_sharing(scheme, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) +@pytest.mark.parametrize("scheme", _test_schemes) def test_memory_sharing_offload(scheme): test_memory_sharing(scheme, offload=True) -@pytest.mark.parametrize( - "scheme", - [ - TransformScheme(type=name, requires_grad=True) - for name in TransformFactory.registered_names() - ], -) +@pytest.mark.parametrize("scheme", _test_schemes) def test_memory_sharing_training(scheme): + scheme.requires_grad = True test_memory_sharing(scheme, offload=False) From 57d171afa55ceda0b3e16f9e0d9c2e5baed73147 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 31 May 2025 00:48:48 -0400 Subject: [PATCH 07/27] add delete_offload_module Signed-off-by: Kyle Sayers --- src/compressed_tensors/utils/offload.py | 16 ++++++++++++++- tests/test_utils/test_offload.py | 27 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 4241479c..10b0810e 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -77,6 +77,7 @@ "align_modules", "align_module_device", "register_offload_module", + "delete_offload_module", "force_cpu_offload", ] @@ -398,7 +399,6 @@ def align_modules( yield -@check_accelerate(fallback=None) def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.Module): """ Register a submodule with offloading if the parent module is offloaded @@ -459,6 +459,20 @@ def register_offload_module(base: torch.nn.Module, name: str, module: torch.nn.M # online way, assume that all pointers are shared. This comes at no runtime cost +def delete_offload_module(base: torch.nn.Module, name: str): + """ + Delete a submodule from a model which may contain offloading + :param base: parent module to delete submodule from + :param name: name of submodule on parent + """ + module: torch.nn.Module = getattr(base, name) + + for param_name, _ in list(module.named_parameters()): + delete_offload_parameter(module, param_name) + + delattr(base, name) + + @check_accelerate(fallback="error") def force_cpu_offload( module: torch.nn.Module, execution_device: torch.device diff --git a/tests/test_utils/test_offload.py b/tests/test_utils/test_offload.py index a2c357a4..0d3aa15d 100644 --- a/tests/test_utils/test_offload.py +++ b/tests/test_utils/test_offload.py @@ -16,6 +16,7 @@ from compressed_tensors.utils import ( align_module_device, align_modules, + delete_offload_module, delete_offload_parameter, disable_hf_hook, force_cpu_offload, @@ -368,6 +369,32 @@ def test_register_offload_module(exec_device): child(torch.empty(2, device=exec_device)) +@requires_gpu +@requires_accelerate() +@pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")]) +def test_delete_offload_module(exec_device): + # no offloading + model = ExampleModel() + child = torch.nn.Linear(2, 3) + register_offload_module(model, "child", child) + register_offload_module(model.linear, "child", child) + delete_offload_module(model, "child") + delete_offload_module(model.linear, "child") + assert not child in model.children() + assert not child in model.linear.children() + + # with offloading + model = ExampleModel() + child = torch.nn.Linear(2, 3) + force_cpu_offload(model, exec_device) + register_offload_module(model, "child", child) + register_offload_module(model.linear, "child", child) + delete_offload_module(model, "child") + delete_offload_module(model.linear, "child") + assert not child in model.children() + assert not child in model.linear.children() + + @requires_gpu @requires_accelerate() @pytest.mark.parametrize("exec_device", [torch.device("cpu"), torch.device("cuda")]) From aa7d21b611e709945f0c884c0044db0dfce287ac Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 31 May 2025 09:53:18 -0400 Subject: [PATCH 08/27] key inverses by weight Signed-off-by: Kyle Sayers --- .../transform/factory/matrix_multiply.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 13a27f79..75890beb 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -53,19 +53,19 @@ def create_transform(self, module: Module, args: TransformArgs): dtype = module.weight.dtype device = get_offloaded_device(module) - if not args.inverse: - weight = self.weights[size, dtype, device] - else: - weight = self.inverses[size, dtype, device] + weight = self.weights[size, dtype, device] + if args.inverse: + weight = self.inverses[weight] + return RandomMatrixTransform(weight, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: data = torch.rand((size, size), dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) - def _create_inverse(self, size: int, dtype: dtype, device: device) -> Parameter: - weight = self.weights[size, dtype, device] - return Parameter(high_precision_invert(weight.data), requires_grad=False) + def _create_inverse(self, weight: Parameter) -> Parameter: + data = high_precision_invert(weight.data) + return Parameter(data, requires_grad=False) class RandomMatrixTransform(TransformBase): From 6901e02eaa8252514dd5d97c06a2d6e1b6b7b526 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 31 May 2025 10:56:34 -0400 Subject: [PATCH 09/27] fix tests Signed-off-by: Kyle Sayers --- .../transform/factory/random_hadamard.py | 4 ---- .../transform/utils/hadamard.py | 2 +- .../factory/test_correctness.py | 21 +++++++++++++------ 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/transform/factory/random_hadamard.py b/src/compressed_tensors/transform/factory/random_hadamard.py index e4f14186..cc12d8c4 100644 --- a/src/compressed_tensors/transform/factory/random_hadamard.py +++ b/src/compressed_tensors/transform/factory/random_hadamard.py @@ -29,10 +29,6 @@ class RandomHadamardFactory(HadamardFactory): """ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - for key in self.weights.keys(): - if key[0] == size: - return self.weights[key].to(dtype=dtype, device=device) - data = random_hadamard_matrix(size) # seed data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index 1f042941..1a71b116 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -128,7 +128,7 @@ def _matmul_hadU(X, transpose=False) -> torch.Tensor: input = hadK.view(1, K, K).to(input) @ input # normalize - return input.view(X.shape) / torch.tensor(n).sqrt() + return input.view(X.shape) def _is_pow2(n: int) -> bool: diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index 44cc89e5..bab1117e 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -64,8 +64,7 @@ def test_correctness_linear(scheme): input_transformed = input_tfm(input) weight_transformed = w_out_tfm(w_in_tfm(module.weight)) output = output_tfm(input_transformed @ weight_transformed.T) - - torch.allclose(true_output, output, atol=1e-7, rtol=0.0) + assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) @pytest.mark.parametrize( @@ -74,14 +73,24 @@ def test_correctness_linear(scheme): ) def test_correctness_model(scheme, offload=False): # load model - model = TransformableModel(2, 4, 8, 16) + model = TransformableModel(2, 4, 8, 16, 32, 64) if offload: model = force_cpu_offload(model, torch.device("cuda")) # create factory scheme.apply = [ - TransformArgs(targets="fcs.0", location="input"), - TransformArgs(targets="fcs.2", location="output", inverse=True), + # weight output -> input + TransformArgs(targets="fcs.0", location="weight_output"), + TransformArgs(targets="fcs.1", location="input", inverse=True), + # output -> weight input + TransformArgs(targets="fcs.1", location="output"), + TransformArgs(targets="fcs.2", location="weight_input", inverse=True), + # output -> input + TransformArgs(targets="fcs.2", location="output"), + TransformArgs(targets="fcs.3", location="input", inverse=True), + # weight output -> weight input + TransformArgs(targets="fcs.3", location="weight_output"), + TransformArgs(targets="fcs.4", location="weight_input", inverse=True), ] factory = TransformFactory.from_scheme(scheme, name="") @@ -94,7 +103,7 @@ def test_correctness_model(scheme, offload=False): true_output = model(input) factory.apply_to_model(model) output = model(input) - torch.allclose(true_output, output, atol=1e-7, rtol=0.0) + assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) @requires_gpu From 47ae9fe14d598a9cbb07475b8ca438f6b7e7405e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 31 May 2025 11:00:21 -0400 Subject: [PATCH 10/27] standardize random hadamard Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/utils/hadamard.py | 2 +- tests/test_transform/utils/test_hadamards.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index 1f042941..1a71b116 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -128,7 +128,7 @@ def _matmul_hadU(X, transpose=False) -> torch.Tensor: input = hadK.view(1, K, K).to(input) @ input # normalize - return input.view(X.shape) / torch.tensor(n).sqrt() + return input.view(X.shape) def _is_pow2(n: int) -> bool: diff --git a/tests/test_transform/utils/test_hadamards.py b/tests/test_transform/utils/test_hadamards.py index ae8a0664..1faa839d 100644 --- a/tests/test_transform/utils/test_hadamards.py +++ b/tests/test_transform/utils/test_hadamards.py @@ -46,7 +46,7 @@ def test_packed_hadamard_compliant(had_func): def test_random_hadamard_matrix_compliant(size): had_matrix = random_hadamard_matrix(size) val_1 = torch.round(had_matrix @ had_matrix.T) - assert torch.equal(val_1, torch.eye(size)) + assert torch.equal(val_1 / size, torch.eye(size)) @pytest.mark.parametrize( From 10391001b88d6bce3e9de976cfe543dd22be2b67 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 31 May 2025 11:20:36 -0400 Subject: [PATCH 11/27] prepend input hooks Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/factory/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 30033447..33cca603 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -105,7 +105,7 @@ def input_hook(_, args): input = args[0] return transform(input) - module.register_forward_pre_hook(input_hook) + module.register_forward_pre_hook(input_hook, prepend=True) # eagerly apply transformation to weight elif args.location in ( From 68ec14e2c71fefa6aa2aecacc4567cdc6362cc41 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 13:13:59 -0400 Subject: [PATCH 12/27] apply sqrt division first Signed-off-by: Kyle Sayers --- .../transform/utils/hadamard.py | 6 +++--- .../{test_hadamards.py => test_hadamard.py} | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) rename tests/test_transform/utils/{test_hadamards.py => test_hadamard.py} (78%) diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index 1a71b116..f56c4a48 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -23,7 +23,7 @@ # adapted from: # https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py -def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: +def deterministic_hadamard_matrix(size: int) -> torch.Tensor: """ Construct an Hadamard matrix. @@ -47,7 +47,7 @@ def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: for i in range(0, log2): H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H)))) - return H + return torch.from_numpy(H / math.sqrt(size)) # adapted from: @@ -75,7 +75,7 @@ def random_hadamard_matrix(size: int) -> torch.Tensor: Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) Q = Q * 2 - 1 Q = torch.diag(Q) - return _matmul_hadU(Q) + return _matmul_hadU(Q) / math.sqrt(size) def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]: diff --git a/tests/test_transform/utils/test_hadamards.py b/tests/test_transform/utils/test_hadamard.py similarity index 78% rename from tests/test_transform/utils/test_hadamards.py rename to tests/test_transform/utils/test_hadamard.py index 1faa839d..e8f7e359 100644 --- a/tests/test_transform/utils/test_hadamards.py +++ b/tests/test_transform/utils/test_hadamard.py @@ -33,10 +33,10 @@ ) def test_packed_hadamard_compliant(had_func): had_matrix = had_func() - size = had_matrix.shape[0] + size = had_matrix.size(0) # HH.T == nI - val_1 = had_matrix @ had_matrix.T - assert torch.equal(val_1 / size, torch.eye(size)) + product = had_matrix @ had_matrix.T + assert torch.equal(product, size * torch.eye(size)) @pytest.mark.parametrize( @@ -45,8 +45,8 @@ def test_packed_hadamard_compliant(had_func): ) def test_random_hadamard_matrix_compliant(size): had_matrix = random_hadamard_matrix(size) - val_1 = torch.round(had_matrix @ had_matrix.T) - assert torch.equal(val_1 / size, torch.eye(size)) + product = torch.round(had_matrix @ had_matrix.T) + assert torch.equal(product, torch.eye(size)) @pytest.mark.parametrize( @@ -55,6 +55,6 @@ def test_random_hadamard_matrix_compliant(size): ) def test_deterministic_hadamard_compliant(size): had_matrix = deterministic_hadamard_matrix(size) - # HH.T == nI - val_1 = had_matrix @ had_matrix.T - assert numpy.array_equal(val_1 / size, numpy.eye(size)) + # (H / sqrt(n))(H.T / sqrt(n)) == I + product = had_matrix @ had_matrix.T + assert numpy.array_equal(product, numpy.eye(size)) From b117523704f4a71038de890b77bdabc30aac3cc5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 13:15:57 -0400 Subject: [PATCH 13/27] use divided hadamards Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/factory/hadamard.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index b73e0687..c161d9bd 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -59,7 +59,7 @@ def create_transform(self, module: Module, args: TransformArgs): return HadamardTransform(weight, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = torch.tensor(deterministic_hadamard_matrix(size)) # TODO: seed=self.seed + data = deterministic_hadamard_matrix(size) # TODO: seed=self.seed data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) @@ -74,6 +74,6 @@ def forward(self, value: Tensor) -> Tensor: if not self.args.inverse: weight = self.weight else: - weight = self.weight.T / self.weight.size(0) + weight = self.weight.T return apply_transform_weight(weight, value, self.args.location) From a46f7541b6864771bd758046a3fe581790979cb6 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 13:19:23 -0400 Subject: [PATCH 14/27] fix typo Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/utils/hadamard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index f56c4a48..622e7537 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -68,7 +68,7 @@ def random_hadamard_matrix(size: int) -> torch.Tensor: (size, size) """ - # TODO: potentially update to add "seed" as an arugment, to allow + # TODO: potentially update to add "seed" as an argument, to allow # the matrix generated to be reproducible # Benefits: support other shapes / non powers of 2, support randomization From cb1cb5293055a98aea655cd87fd26e8512afa5b8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 14:09:24 -0400 Subject: [PATCH 15/27] add random option Signed-off-by: Kyle Sayers --- .../transform/utils/hadamard.py | 17 +++++++------ tests/test_transform/utils/test_hadamard.py | 24 +++++++++++++++++++ 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index 622e7537..b0cfc1e9 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Tuple +from typing import Optional, Tuple import numpy import torch @@ -58,21 +58,20 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor: # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master -def random_hadamard_matrix(size: int) -> torch.Tensor: +def random_hadamard_matrix( + size: int, gen: Optional[torch.Generator] = None +) -> torch.Tensor: """ Produces a randomly generated Hadamard matrix. See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation" - :param size: The dimension of the matrix. Matrix generated will have dimensions - (size, size) - + :param size: The dimension of the hamadard matrix + :param gen: Optional generator random values + :return: randomly generated hadamard matrix """ - # TODO: potentially update to add "seed" as an argument, to allow - # the matrix generated to be reproducible - # Benefits: support other shapes / non powers of 2, support randomization - Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64) Q = Q * 2 - 1 Q = torch.diag(Q) return _matmul_hadU(Q) / math.sqrt(size) diff --git a/tests/test_transform/utils/test_hadamard.py b/tests/test_transform/utils/test_hadamard.py index e8f7e359..41532990 100644 --- a/tests/test_transform/utils/test_hadamard.py +++ b/tests/test_transform/utils/test_hadamard.py @@ -49,6 +49,30 @@ def test_random_hadamard_matrix_compliant(size): assert torch.equal(product, torch.eye(size)) +def test_random_hadamard_generator(): + generator = torch.Generator().manual_seed(42) + one = random_hadamard_matrix(2048, generator) + two = random_hadamard_matrix(2048, generator) + + one_true = torch.tensor( + [ + [-1, -1, -1], + [+1, -1, +1], + [-1, -1, +1], + ] + ) + two_true = torch.tensor( + [ + [-1, -1, -1], + [-1, +1, -1], + [+1, +1, -1], + ] + ) + + assert torch.all(one[:3, :3].sign() == one_true.sign()) + assert torch.all(two[:3, :3].sign() == two_true.sign()) + + @pytest.mark.parametrize( "size", [1024], From 02af1e9dffd9f3d896fbec1a1120899c9a887f92 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 14:20:51 -0400 Subject: [PATCH 16/27] use random seeds, rename matrix multiply Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/factory/base.py | 4 ++-- src/compressed_tensors/transform/factory/hadamard.py | 3 ++- src/compressed_tensors/transform/factory/matrix_multiply.py | 2 +- src/compressed_tensors/transform/factory/random_hadamard.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 33cca603..fb15247f 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -82,8 +82,8 @@ def apply_to_model(self, model: Module): :param model: module to apply transforms to """ for arg in self.scheme.apply: - for path, module in list(model.named_modules()): - if is_target(path, module, arg.targets, arg.ignore): + for name, module in list(model.named_modules()): + if is_target(name, module, arg.targets, arg.ignore): self._apply_to_module(module, arg) def _apply_to_module(self, module: Module, args: TransformArgs): diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index c161d9bd..5e4f094a 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -40,6 +40,7 @@ class HadamardFactory(TransformFactory): def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): super().__init__(name, scheme, seed) + self.generator = torch.Generator(device="cpu").manual_seed(seed) self.weights = ParameterizedDefaultDict(self._create_weight) def create_transform(self, module: Module, args: TransformArgs): @@ -59,7 +60,7 @@ def create_transform(self, module: Module, args: TransformArgs): return HadamardTransform(weight, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = deterministic_hadamard_matrix(size) # TODO: seed=self.seed + data = deterministic_hadamard_matrix(size) data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 75890beb..8e20b466 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -25,7 +25,7 @@ from torch.nn import Linear, Module, Parameter -@TransformFactory.register("matrix-mul") +@TransformFactory.register("random-matrix") class RandomMatrixFactory(TransformFactory): """ Factory used to apply random matrix transforms to a model diff --git a/src/compressed_tensors/transform/factory/random_hadamard.py b/src/compressed_tensors/transform/factory/random_hadamard.py index cc12d8c4..98113afe 100644 --- a/src/compressed_tensors/transform/factory/random_hadamard.py +++ b/src/compressed_tensors/transform/factory/random_hadamard.py @@ -29,6 +29,6 @@ class RandomHadamardFactory(HadamardFactory): """ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = random_hadamard_matrix(size) # seed + data = random_hadamard_matrix(size, self.generator) data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) From f45f3e928e40411c69d6f4eeae2e4b80b9383187 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 14:24:14 -0400 Subject: [PATCH 17/27] add deterministic generation to random matrix Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/factory/base.py | 1 + src/compressed_tensors/transform/factory/hadamard.py | 1 - src/compressed_tensors/transform/factory/matrix_multiply.py | 4 +++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index fb15247f..cdda6ce8 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -49,6 +49,7 @@ class TransformFactory(RegistryMixin, ABC): def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): self.name = name self.scheme = scheme + self.generator = torch.Generator().manual_seed(seed) self.seed = seed @classmethod diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index 5e4f094a..c77d3bfe 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -40,7 +40,6 @@ class HadamardFactory(TransformFactory): def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): super().__init__(name, scheme, seed) - self.generator = torch.Generator(device="cpu").manual_seed(seed) self.weights = ParameterizedDefaultDict(self._create_weight) def create_transform(self, module: Module, args: TransformArgs): diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 8e20b466..15d4e65d 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -60,7 +60,9 @@ def create_transform(self, module: Module, args: TransformArgs): return RandomMatrixTransform(weight, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = torch.rand((size, size), dtype=dtype, device=device) + data = torch.rand( + (size, size), generator=self.generator, dtype=dtype, device=device + ) return Parameter(data, requires_grad=self.scheme.requires_grad) def _create_inverse(self, weight: Parameter) -> Parameter: From 7a7abdf6bd7813337120aa79ec12c3c200f3a409 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 15:45:39 -0400 Subject: [PATCH 18/27] fix perm math Signed-off-by: Kyle Sayers --- .../model_compressors/model_compressor.py | 34 +++-- .../compressors/quantized_compressors/base.py | 33 ++++- .../quantization/lifecycle/apply.py | 11 +- .../quantization/lifecycle/forward.py | 64 ++++++---- .../quantization/lifecycle/initialize.py | 120 +----------------- .../quantization/quant_args.py | 1 + .../quantization/utils/helpers.py | 19 +-- .../transform/factory/base.py | 13 +- .../transform/factory/hadamard.py | 17 ++- .../transform/factory/matrix_multiply.py | 24 ++-- .../transform/factory/random_hadamard.py | 6 +- .../transform/transform_config.py | 4 +- .../transform/transform_scheme.py | 7 +- .../transform/utils/hadamard.py | 25 ++-- .../transform/utils/utils.py | 9 +- .../quantized_compressors/test_fp8_quant.py | 6 +- .../quantized_compressors/test_int_quant.py | 4 +- .../lifecycle/test_forward.py | 24 ++-- .../test_utils/test_helpers.py | 10 +- .../factory/test_correctness.py | 23 +++- tests/test_transform/factory/test_memory.py | 2 +- tests/test_transform/test_transform_scheme.py | 8 +- .../{test_hadamards.py => test_hadamard.py} | 40 ++++-- 23 files changed, 226 insertions(+), 278 deletions(-) rename tests/test_transform/utils/{test_hadamards.py => test_hadamard.py} (58%) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index e81ab8e4..4059ae8d 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -50,6 +50,7 @@ align_module_device, delete_offload_parameter, get_execution_device, + get_offloaded_device, get_safetensors_folder, has_offloaded_params, merge_names, @@ -408,16 +409,17 @@ def compress_model(self, model: Module): ) # remove any existing parameters - device = get_execution_device(module) + exec_device = get_execution_device(module) + offload_device = get_offloaded_device(module) for name, _ in list(module.named_parameters()): - delattr(module, name) + delete_offload_parameter(module, name) # replace with compressed parameters for name, value in state_dict.items(): name = name.removeprefix(f"{prefix}.") - value = value.to(device) + value = value.to(exec_device) param = torch.nn.Parameter(value, requires_grad=False) - register_offload_parameter(module, name, param) + register_offload_parameter(module, name, param, offload_device) module.quantization_status = QuantizationStatus.COMPRESSED @@ -460,30 +462,26 @@ def decompress_model(self, model: Module): # quantization second if prefix in module_to_scheme: - generator = self.quantization_compressor.decompress_from_state_dict( - state_dict, - names_to_scheme=module_to_scheme, + state_dict = ( + self.quantization_compressor.decompress_module_from_state_dict( + prefix, + state_dict, + scheme=module_to_scheme[prefix], + ) ) - # generates (mod_path, {param_name, param_val}) - # of compressed params and used params, but not unused params - # some used params are removed by get_unexpected_file_keys - state_dict = { - merge_names(module_path, param_name): param_value - for module_path, compressed_data in generator - for param_name, param_value in compressed_data.items() - } # remove any existing parameters - device = get_execution_device(module) + exec_device = get_execution_device(module) + offload_device = get_offloaded_device(module) for name, _ in list(module.named_parameters()): delete_offload_parameter(module, name) # replace with decompressed parameters for name, value in state_dict.items(): name = name.removeprefix(f"{prefix}.") - value = value.to(device) + value = value.to(exec_device) param = torch.nn.Parameter(value, requires_grad=False) - register_offload_parameter(module, name, param) + register_offload_parameter(module, name, param, offload_device) module.quantization_status = QuantizationStatus.FROZEN diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 1102721c..d0a07302 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -24,6 +24,7 @@ get_nested_weight_mappings, merge_names, ) +from compressed_tensors.utils.safetensors_load import match_param_name from safetensors import safe_open from torch import Tensor from tqdm import tqdm @@ -223,9 +224,7 @@ def decompress_from_state_dict( state_dict, self.compression_param_names ) for module_path in weight_mappings.keys(): - weight_data = {} - for param_name, param_value in weight_mappings[module_path].items(): - weight_data[param_name] = param_value + weight_data = weight_mappings[module_path].copy() if "weight_scale" in weight_data: quant_args = names_to_scheme[module_path].weights @@ -234,3 +233,31 @@ def decompress_from_state_dict( ) weight_data["weight"] = decompressed yield module_path, weight_data + + def decompress_module_from_state_dict( + self, + prefix: str, + state_dict: Dict[str, torch.Tensor], + scheme: QuantizationScheme, + ) -> Dict[str, torch.Tensor]: + """ + Only used by in-memory decompression pathways to decompress the parameters of + one module + + :param prefix: prefix of state_dict, typically the path to the module + :param state_dict: state dict containing module parameter values + :param scheme: quantization scheme of module to decompress + :return: state dict with weight decompressed if applicable + """ + state_dict = { + key.removeprefix(f"{prefix}."): value for key, value in state_dict.items() + } + + if "weight_scale" in state_dict: + state_dict["weight"] = self.decompress_weight( + compressed_data=state_dict, quantization_args=scheme.weights + ) + + state_dict = {f"{prefix}.{key}": value for key, value in state_dict.items()} + + return state_dict diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 6bcc940f..566a1963 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -27,14 +27,8 @@ ) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, - update_fused_layer_weight_global_scales, -) -from compressed_tensors.quantization.quant_args import ( - FP4_E2M1_DATA, - FP8_E4M3_DATA, - QuantizationArgs, - QuantizationType, ) +from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import ( QuantizationConfig, QuantizationStatus, @@ -272,9 +266,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): ) ) - if status == QuantizationStatus.INITIALIZED: - update_fused_layer_weight_global_scales(model) - if current_status < status >= QuantizationStatus.COMPRESSED > current_status: model.apply(compress_quantized_weights) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 6662cd2b..b4ca3a82 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -21,7 +21,6 @@ DynamicType, QuantizationArgs, QuantizationStrategy, - QuantizationType, round_to_quantized_type, ) from compressed_tensors.quantization.quant_config import QuantizationStatus @@ -227,31 +226,42 @@ def _process_quantization( perm = torch.argsort(g_idx) x = safe_permute(x, perm, dim=1) - # TODO: experiment with vectorizing for loop for performance - end = 0 - for index, group_count in enumerate(group_sizes): - sc = scale[:, index].view(-1, 1) - zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None - - start = end - end = start + group_count - if do_quantize: - output[:, start:end] = _quantize( - x=x[:, start:end], - scale=sc, - zero_point=zp, - q_min=q_min, - q_max=q_max, - args=args, - dtype=dtype, - global_scale=global_scale, - ) + x = torch.reshape( + x, + ( + x.shape[0], + ceil(x.shape[1] / group_size), + group_size, + ), + ) - if do_dequantize: - input = output[:, start:end] if do_quantize else x[:, start:end] - output[:, start:end] = _dequantize( - x_q=input, scale=sc, zero_point=zp, global_scale=global_scale - ) + if do_quantize: + output = _quantize( + x=x, + scale=scale.unsqueeze(-1), + zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, + dtype=dtype, + global_scale=global_scale, + q_min=q_min, + q_max=q_max, + args=args, + ) + + if do_dequantize: + input = output if do_quantize else x + output = _dequantize( + x_q=input, + scale=scale.unsqueeze(-1), + zero_point=zero_point.unsqueeze(-1) if zero_point is not None else None, + global_scale=global_scale, + ) + + output = torch.reshape( + output, + (output.shape[0], output.shape[1] * output.shape[2]), + ) + + output = output.to(output_dtype) if not is_column_order: output = safe_permute(output, torch.argsort(perm), dim=1) @@ -394,7 +404,7 @@ def _quantize( # if a global scale is optionally provided, use it # to further scale the local `scale` parameter - if global_scale: + if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale scaled = x / scale @@ -427,7 +437,7 @@ def _dequantize( # if a global scale is optionally provided, use it # to further scale the local `scale` parameter - if global_scale: + if global_scale is not None: scale = scale.to(global_scale.dtype) / global_scale dequant_value = x_q.to(scale.dtype) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index af58f810..806a98f0 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -23,26 +23,18 @@ wrap_module_forward_quantized, ) from compressed_tensors.quantization.quant_args import ( - FP4_E2M1_DATA, FP8_E4M3_DATA, ActivationOrdering, QuantizationArgs, QuantizationStrategy, - QuantizationType, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import ( - generate_global_scale, - is_fp4, - is_kv_cache_quant_scheme, - iter_named_quantizable_modules, -) +from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme from compressed_tensors.utils import ( disable_hf_hook, get_execution_device, register_offload_parameter, - update_parameter_data, ) from torch.nn import Module, Parameter @@ -51,7 +43,6 @@ "initialize_module_for_quantization", "is_attention_module", "KVCacheScaleType", - "update_fused_layer_weight_global_scales", ] @@ -162,22 +153,13 @@ def _initialize_scale_zero_point( # initialize on execution device to avoid performing quantized ops on cpu device = get_execution_device(module) - # 1. Create global_scales for tensor_group + # 1. Create global_scales for tensor_group - generates + # a per tensor scale if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: - # TODO: should move to llmcompressor - if base_name == "weight": - # When applying weight-only FP4 quantization, generate a global_scale - # This scale is applied during runtime to ensure that the generated - # local scale falls properly within the FP8 range (i.e max value is FP8_max) - # which is the expected dtype of NVFP4A16 scales - value = generate_global_scale(input_tensor=module.weight) - value = value.to(device) - init_global_scale = Parameter(value, requires_grad=False) - else: - init_global_scale = Parameter( - torch.empty(1, dtype=torch.float32, device=device), - requires_grad=False, - ) + init_global_scale = Parameter( + torch.empty(1, dtype=torch.float32, device=device), + requires_grad=False, + ) register_offload_parameter( module, f"{base_name}_global_scale", init_global_scale ) @@ -258,91 +240,3 @@ def _initialize_attn_scales(module: Module) -> None: requires_grad=False, ) register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale) - - -# TODO: Potentially introduce an argument to turn this off -# Only relevant for NVFP4A16 currently -def update_fused_layer_weight_global_scales(model: torch.nn.Module): - """ - When running NVFP4A16 quantization, update the global scale - such that q,k,v layers are treated as one tensor with the same - global_scale and gate_proj/up_proj layers are treated as one tensor - with the same global scale. This is requirement currently being set - by vLLM and may be removed in the future OR potentially make it - an optional step. - - :param model: model to quantize - """ - - def _is_attention_module(module: Module): - return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") - or hasattr(module, "v_proj") - or hasattr(module, "qkv_proj") - ) - - def _is_mlp_module(module: Module): - return "mlp" in module.__class__.__name__.lower() and ( - hasattr(module, "gate_proj") or hasattr(module, "up_proj") - ) - - def _valid_fp4_quant(layer_list: List[torch.nn.Linear]): - """ - Return True if all the linear layers in the layer_list are - NVFP4A16 quantized. - """ - for layer in layer_list: - scheme = getattr(layer, "quantization_scheme", None) - if scheme is None: - return False - - weight_quant_args = scheme.weights - - if weight_quant_args is None: - return False - - if not is_fp4(quantization_args=weight_quant_args): - return False - return True - - for name, submodule in iter_named_quantizable_modules( - model, - include_attn=True, - include_mlp=True, - ): - - if _is_attention_module(submodule): - # already fused/treated as one layer - if hasattr(submodule, "qkv_proj"): - continue - - if not _valid_fp4_quant( - [submodule.q_proj, submodule.v_proj, submodule.k_proj] - ): - continue - - q_weight = submodule.q_proj.weight.data - v_weight = submodule.v_proj.weight.data - k_weight = submodule.k_proj.weight.data - - value = generate_global_scale( - input_tensor=torch.cat((q_weight, v_weight, k_weight), dim=0) - ) - - update_parameter_data(submodule.q_proj, value, "weight_global_scale") - update_parameter_data(submodule.k_proj, value, "weight_global_scale") - update_parameter_data(submodule.v_proj, value, "weight_global_scale") - - if _is_mlp_module(submodule): - if not _valid_fp4_quant([submodule.gate_proj, submodule.up_proj]): - continue - - gate_data = submodule.gate_proj.weight.data - up_data = submodule.up_proj.weight.data - - value = generate_global_scale( - input_tensor=torch.cat((gate_data, up_data), dim=0) - ) - - update_parameter_data(submodule.gate_proj, value, "weight_global_scale") - update_parameter_data(submodule.up_proj, value, "weight_global_scale") diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 1fadc847..fdf34a28 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -53,6 +53,7 @@ class FP4_E2M1_DATA(FloatArgs): min = -6.0 @staticmethod + @torch.compile def cast_to_fp4(x): sign = torch.sign(x) x = torch.abs(x) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 4d33bea3..5d855f75 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -47,7 +47,7 @@ "compute_dynamic_scales_and_zp", "calculate_range", "calculate_qparams", - "generate_global_scale", + "generate_gparam", "is_fp4", ] @@ -81,7 +81,7 @@ def calculate_qparams( currently only applied/supported for Fp4 :return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated - scale if of dtype FP8 + scale is of dtype FP8 """ # based on the implementations for consuming quantized values, # 0.0 must always be representable within the quantized range @@ -110,6 +110,7 @@ def calculate_qparams( else: scales = max_val_pos / (float(bit_range) / 2) + # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped if scales.dtype == FP8_E4M3_DATA.dtype: # torch.clamp not supported for FP8 # use the next largest fp8 value from 0 @@ -475,8 +476,9 @@ def parse_out_kv_cache_args( return kv_cache_args, quant_scheme_to_layers -def generate_global_scale( - input_tensor: torch.Tensor, +def generate_gparam( + updated_min_val: torch.Tensor, + updated_max_val: torch.Tensor, scale_data: Optional[FloatArgs] = FP8_E4M3_DATA, quant_data: Optional[FloatArgs] = FP4_E2M1_DATA, dtype: Optional[torch.dtype] = torch.float32, @@ -490,7 +492,8 @@ def generate_global_scale( attempts to use the entire FP8 dtype range while mapping a per-group max to the FP4 max. """ - scale_dtype = scale_data.dtype - tensor_amax = torch.abs(input_tensor.data).max().to(dtype) - global_scale = scale_data.max * quant_data.max / tensor_amax - return global_scale.to(dtype) + min_vals = torch.min(updated_min_val, torch.zeros_like(updated_min_val)) + max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val)) + max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) + global_scale = scale_data.max * quant_data.max / max_val_pos + return global_scale.to(dtype).reshape([1]) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 30033447..4c7b6c91 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Optional import torch import torch.nn.utils.parametrize as P @@ -46,10 +47,12 @@ class TransformFactory(RegistryMixin, ABC): :param seed: random seed used to transform weight randomization """ - def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): + def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): self.name = name self.scheme = scheme - self.seed = seed + self.generator = torch.Generator() + if seed is not None: + self.generator.manual_seed(seed) @classmethod def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T: @@ -82,8 +85,8 @@ def apply_to_model(self, model: Module): :param model: module to apply transforms to """ for arg in self.scheme.apply: - for path, module in list(model.named_modules()): - if is_target(path, module, arg.targets, arg.ignore): + for name, module in list(model.named_modules()): + if is_target(name, module, arg.targets, arg.ignore): self._apply_to_module(module, arg) def _apply_to_module(self, module: Module, args: TransformArgs): @@ -105,7 +108,7 @@ def input_hook(_, args): input = args[0] return transform(input) - module.register_forward_pre_hook(input_hook) + module.register_forward_pre_hook(input_hook, prepend=True) # eagerly apply transformation to weight elif args.location in ( diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index a05accd4..b4d5f7de 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Union import torch from compressed_tensors.transform import TransformArgs, TransformScheme from compressed_tensors.transform.factory.base import TransformBase, TransformFactory from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix from compressed_tensors.transform.utils.utils import ( - apply_permutation, apply_transform_weight, get_matrix_size, ) @@ -39,7 +38,7 @@ class HadamardFactory(TransformFactory): :param seed: random seed used to transform weight randomization """ - def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): + def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): super().__init__(name, scheme, seed) self.weights = ParameterizedDefaultDict(self._create_weight) self.perms = ParameterizedDefaultDict(self._create_permutation) @@ -58,16 +57,16 @@ def create_transform(self, module: Module, args: TransformArgs): device = get_offloaded_device(module) weight = self.weights[size, dtype, device] - perm = self.perms[module, weight] if self.scheme.randomize_modules else None + perm = self.perms[weight] if self.scheme.randomize else None return HadamardTransform(weight, perm, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = torch.tensor(deterministic_hadamard_matrix(size)) # TODO: seed=self.seed + data = deterministic_hadamard_matrix(size) data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) - def _create_permutation(self, module: Module, weight: Parameter) -> Parameter: - data = torch.randperm(weight.size(0)) + def _create_permutation(self, weight: Parameter) -> Parameter: + data = torch.randperm(weight.size(0), generator=self.generator) return Parameter(data, requires_grad=False) @@ -84,9 +83,9 @@ def forward(self, value: Tensor) -> Tensor: weight = self.weight if self.perm is not None: - weight = apply_permutation(weight, self.perm) + weight = weight[self.perm][:, self.perm] if self.args.inverse: - weight = weight.T / weight.size(0) + weight = weight.T return apply_transform_weight(weight, value, self.args.location) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 13a27f79..e551fc5f 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch from compressed_tensors.transform import TransformArgs, TransformScheme from compressed_tensors.transform.factory.base import TransformBase, TransformFactory @@ -25,7 +27,7 @@ from torch.nn import Linear, Module, Parameter -@TransformFactory.register("matrix-mul") +@TransformFactory.register("random-matrix") class RandomMatrixFactory(TransformFactory): """ Factory used to apply random matrix transforms to a model @@ -35,7 +37,7 @@ class RandomMatrixFactory(TransformFactory): :param seed: random seed used to transform weight randomization """ - def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): + def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): super().__init__(name, scheme, seed) self.weights = ParameterizedDefaultDict(self._create_weight) self.inverses = ParameterizedDefaultDict(self._create_inverse) @@ -53,19 +55,21 @@ def create_transform(self, module: Module, args: TransformArgs): dtype = module.weight.dtype device = get_offloaded_device(module) - if not args.inverse: - weight = self.weights[size, dtype, device] - else: - weight = self.inverses[size, dtype, device] + weight = self.weights[size, dtype, device] + if args.inverse: + weight = self.inverses[weight] + return RandomMatrixTransform(weight, args) def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = torch.rand((size, size), dtype=dtype, device=device) + data = torch.rand( + (size, size), generator=self.generator, dtype=dtype, device=device + ) return Parameter(data, requires_grad=self.scheme.requires_grad) - def _create_inverse(self, size: int, dtype: dtype, device: device) -> Parameter: - weight = self.weights[size, dtype, device] - return Parameter(high_precision_invert(weight.data), requires_grad=False) + def _create_inverse(self, weight: Parameter) -> Parameter: + data = high_precision_invert(weight.data) + return Parameter(data, requires_grad=False) class RandomMatrixTransform(TransformBase): diff --git a/src/compressed_tensors/transform/factory/random_hadamard.py b/src/compressed_tensors/transform/factory/random_hadamard.py index e4f14186..98113afe 100644 --- a/src/compressed_tensors/transform/factory/random_hadamard.py +++ b/src/compressed_tensors/transform/factory/random_hadamard.py @@ -29,10 +29,6 @@ class RandomHadamardFactory(HadamardFactory): """ def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - for key in self.weights.keys(): - if key[0] == size: - return self.weights[key].to(dtype=dtype, device=device) - - data = random_hadamard_matrix(size) # seed + data = random_hadamard_matrix(size, self.generator) data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) diff --git a/src/compressed_tensors/transform/transform_config.py b/src/compressed_tensors/transform/transform_config.py index 414c21e0..df178c42 100644 --- a/src/compressed_tensors/transform/transform_config.py +++ b/src/compressed_tensors/transform/transform_config.py @@ -49,7 +49,7 @@ class TransformConfig(BaseModel): inverse=True, ), ], - randomize_modules=True, + randomize=True, ), "u": TransformScheme( type="hadamard", @@ -62,7 +62,7 @@ class TransformConfig(BaseModel): targets=["Linear"], location="output", inverse=True # non-mergable ), ], - randomize_modules=True, + randomize=True, ), } ) diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 1335063c..64d646e0 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -31,13 +31,12 @@ class TransformScheme(BaseModel): (see `Transforms.registered_names()`) :param apply: list of TransformationArgs containing the information about the modules that should be targeted by the specified transform - :param randomize_modules: True if unique transforms should be applied to each - unique module targeted by `apply`, otherwise reuse transform weights where - applicable + :param randomize: True if uniquely randomized transform weights should be used, + otherwise use identical transform weights where applicable :param requires_grad: True if weights include gradients for training """ type: str apply: List[TransformArgs] = Field(default_factory=list) - randomize_modules: bool = Field(default=False) + randomize: bool = Field(default=False) requires_grad: bool = Field(default=False) diff --git a/src/compressed_tensors/transform/utils/hadamard.py b/src/compressed_tensors/transform/utils/hadamard.py index 1f042941..b0cfc1e9 100644 --- a/src/compressed_tensors/transform/utils/hadamard.py +++ b/src/compressed_tensors/transform/utils/hadamard.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Tuple +from typing import Optional, Tuple import numpy import torch @@ -23,7 +23,7 @@ # adapted from: # https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py -def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: +def deterministic_hadamard_matrix(size: int) -> torch.Tensor: """ Construct an Hadamard matrix. @@ -47,7 +47,7 @@ def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: for i in range(0, log2): H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H)))) - return H + return torch.from_numpy(H / math.sqrt(size)) # adapted from: @@ -58,24 +58,23 @@ def deterministic_hadamard_matrix(size: int) -> numpy.ndarray: # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master -def random_hadamard_matrix(size: int) -> torch.Tensor: +def random_hadamard_matrix( + size: int, gen: Optional[torch.Generator] = None +) -> torch.Tensor: """ Produces a randomly generated Hadamard matrix. See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation" - :param size: The dimension of the matrix. Matrix generated will have dimensions - (size, size) - + :param size: The dimension of the hamadard matrix + :param gen: Optional generator random values + :return: randomly generated hadamard matrix """ - # TODO: potentially update to add "seed" as an arugment, to allow - # the matrix generated to be reproducible - # Benefits: support other shapes / non powers of 2, support randomization - Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64) Q = Q * 2 - 1 Q = torch.diag(Q) - return _matmul_hadU(Q) + return _matmul_hadU(Q) / math.sqrt(size) def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]: @@ -128,7 +127,7 @@ def _matmul_hadU(X, transpose=False) -> torch.Tensor: input = hadK.view(1, K, K).to(input) @ input # normalize - return input.view(X.shape) / torch.tensor(n).sqrt() + return input.view(X.shape) def _is_pow2(n: int) -> bool: diff --git a/src/compressed_tensors/transform/utils/utils.py b/src/compressed_tensors/transform/utils/utils.py index 88cbf3aa..eebe3663 100644 --- a/src/compressed_tensors/transform/utils/utils.py +++ b/src/compressed_tensors/transform/utils/utils.py @@ -16,7 +16,7 @@ from compressed_tensors.transform import TransformLocation -__all__ = ["get_matrix_size", "apply_transform_weight", "apply_permutation"] +__all__ = ["get_matrix_size", "apply_transform_weight"] def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int: @@ -83,10 +83,3 @@ def apply_transform_weight( elif location == TransformLocation.OUTPUT: return value @ weight - - -def apply_permutation(weight: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: - weight = weight.clone() - diag_indices = torch.arange(weight.size(0)) - weight[diag_indices, diag_indices] = weight.diagonal()[perm] - return weight diff --git a/tests/test_compressors/quantized_compressors/test_fp8_quant.py b/tests/test_compressors/quantized_compressors/test_fp8_quant.py index e8bbc0d2..2fb2d62d 100644 --- a/tests/test_compressors/quantized_compressors/test_fp8_quant.py +++ b/tests/test_compressors/quantized_compressors/test_fp8_quant.py @@ -61,8 +61,8 @@ def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: [ QuantizationStrategy.GROUP, 128, - torch.rand((512, 8, 1)) * 0.01, - torch.zeros((512, 8, 1), dtype=torch.int8), + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8), dtype=torch.int8), ], [ QuantizationStrategy.CHANNEL, @@ -79,7 +79,7 @@ def test_quant_format(strategy, group_size, sc, zp): "dummy.weight_zero_point": torch.tensor(zp, dtype=torch.float32), } if group_size is not None: - dense_state_dict["dummy.weight_g_idx"] = make_dummy_g_idx(512, group_size) + dense_state_dict["dummy.weight_g_idx"] = make_dummy_g_idx(1024, group_size) quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) diff --git a/tests/test_compressors/quantized_compressors/test_int_quant.py b/tests/test_compressors/quantized_compressors/test_int_quant.py index 991444cc..627af582 100644 --- a/tests/test_compressors/quantized_compressors/test_int_quant.py +++ b/tests/test_compressors/quantized_compressors/test_int_quant.py @@ -53,8 +53,8 @@ def get_dummy_quant_config(strategy, group_size=None, symmetric=True): QuantizationStrategy.GROUP, True, 128, - torch.rand((512, 8, 1)) * 0.01, - torch.zeros((512, 8, 1), dtype=torch.int8), + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8), dtype=torch.int8), ], [ QuantizationStrategy.CHANNEL, diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index 542cd8b9..deb6cc30 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -108,8 +108,8 @@ def test_forward_quantize( "int", QuantizationStrategy.GROUP, 128, - torch.rand((512, 8, 1)) * 0.01, - torch.zeros((512, 8, 1)), + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), None, ), ( @@ -117,8 +117,8 @@ def test_forward_quantize( "int", QuantizationStrategy.GROUP, 128, - torch.rand((512, 8, 1)) * 0.01, - torch.zeros((512, 8, 1)), + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), ), ( @@ -135,8 +135,8 @@ def test_forward_quantize( "float", QuantizationStrategy.GROUP, 128, - torch.rand((512, 8, 1)) * 0.01, - torch.zeros((512, 8, 1)), + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), None, ), ( @@ -144,8 +144,8 @@ def test_forward_quantize( "float", QuantizationStrategy.GROUP, 128, - torch.rand((512, 8, 1)) * 0.01, - torch.zeros((512, 8, 1)), + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), ), ], @@ -174,8 +174,8 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx "int", QuantizationStrategy.GROUP, 128, - torch.rand((512, 8, 1)) * 0.01, - torch.zeros((512, 8, 1)), + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), None, ), ( @@ -183,8 +183,8 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx "int", QuantizationStrategy.GROUP, 128, - torch.rand((512, 8, 1)) * 0.01, - torch.zeros((512, 8, 1)), + torch.rand((512, 8)) * 0.01, + torch.zeros((512, 8)), make_dummy_g_idx(1024, 128), ), ], diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index 294bef8d..4cdc0c8a 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -20,10 +20,7 @@ QuantizationArgs, QuantizationStrategy, ) -from compressed_tensors.quantization.utils import ( - calculate_qparams, - generate_global_scale, -) +from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam @pytest.mark.parametrize( @@ -70,8 +67,9 @@ def test_fused_global_scales(): layer = torch.nn.Linear(7, 8) max_tensor_value = torch.abs(layer.weight.data).max() # use defaults - global_scale = generate_global_scale(layer.weight) + min_val, max_val = torch.aminmax(layer.weight) + global_scale = generate_gparam(min_val.data, max_val.data) # max value should be = (448 * 6) / global_scale - assert max_tensor_value == pytest.approx( + assert max_tensor_value.item() == pytest.approx( FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001 ) diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index 8ac4dbab..45293c94 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -26,7 +26,7 @@ _test_schemes = [ TransformScheme(type=name) for name in TransformFactory.registered_names() ] + [ - TransformScheme(type=name, randomize_modules=True) + TransformScheme(type=name, randomize=True) for name in TransformFactory.registered_names() ] @@ -69,21 +69,30 @@ def test_correctness_linear(scheme): input_transformed = input_tfm(input) weight_transformed = w_out_tfm(w_in_tfm(module.weight)) output = output_tfm(input_transformed @ weight_transformed.T) - - torch.allclose(true_output, output, atol=1e-7, rtol=0.0) + assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) @pytest.mark.parametrize("scheme", _test_schemes) def test_correctness_model(scheme, offload=False): # load model - model = TransformableModel(2, 4, 8, 16) + model = TransformableModel(2, 4, 8, 16, 32, 64) if offload: model = force_cpu_offload(model, torch.device("cuda")) # create factory scheme.apply = [ - TransformArgs(targets="fcs.0", location="input"), - TransformArgs(targets="fcs.2", location="output", inverse=True), + # weight output -> input + TransformArgs(targets="fcs.0", location="weight_output"), + TransformArgs(targets="fcs.1", location="input", inverse=True), + # output -> weight input + TransformArgs(targets="fcs.1", location="output"), + TransformArgs(targets="fcs.2", location="weight_input", inverse=True), + # output -> input + TransformArgs(targets="fcs.2", location="output"), + TransformArgs(targets="fcs.3", location="input", inverse=True), + # weight output -> weight input + TransformArgs(targets="fcs.3", location="weight_output"), + TransformArgs(targets="fcs.4", location="weight_input", inverse=True), ] factory = TransformFactory.from_scheme(scheme, name="") @@ -96,7 +105,7 @@ def test_correctness_model(scheme, offload=False): true_output = model(input) factory.apply_to_model(model) output = model(input) - torch.allclose(true_output, output, atol=1e-7, rtol=0.0) + assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) @requires_gpu diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index 63e37561..256fcecc 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -29,7 +29,7 @@ _test_schemes = [ TransformScheme(type=name) for name in TransformFactory.registered_names() ] + [ - TransformScheme(type=name, randomize_modules=True) + TransformScheme(type=name, randomize=True) for name in TransformFactory.registered_names() ] diff --git a/tests/test_transform/test_transform_scheme.py b/tests/test_transform/test_transform_scheme.py index ad851762..839ab46a 100644 --- a/tests/test_transform/test_transform_scheme.py +++ b/tests/test_transform/test_transform_scheme.py @@ -24,7 +24,7 @@ def test_basic_scheme(): type="hadamard", apply=[basic_args], ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 1 assert isinstance(scheme.apply[0], TransformArgs) @@ -43,10 +43,10 @@ def test_multiple_groups_global(): scheme = TransformScheme( type="hadamard", apply=[embedding_args, linear_args], - randomize_modules=True, + randomize=True, ) - assert scheme.randomize_modules + assert scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 2 assert isinstance(scheme.apply[0], TransformArgs) @@ -69,6 +69,6 @@ def test_multiple_groups(): apply=apply, ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 20 diff --git a/tests/test_transform/utils/test_hadamards.py b/tests/test_transform/utils/test_hadamard.py similarity index 58% rename from tests/test_transform/utils/test_hadamards.py rename to tests/test_transform/utils/test_hadamard.py index ae8a0664..41532990 100644 --- a/tests/test_transform/utils/test_hadamards.py +++ b/tests/test_transform/utils/test_hadamard.py @@ -33,10 +33,10 @@ ) def test_packed_hadamard_compliant(had_func): had_matrix = had_func() - size = had_matrix.shape[0] + size = had_matrix.size(0) # HH.T == nI - val_1 = had_matrix @ had_matrix.T - assert torch.equal(val_1 / size, torch.eye(size)) + product = had_matrix @ had_matrix.T + assert torch.equal(product, size * torch.eye(size)) @pytest.mark.parametrize( @@ -45,8 +45,32 @@ def test_packed_hadamard_compliant(had_func): ) def test_random_hadamard_matrix_compliant(size): had_matrix = random_hadamard_matrix(size) - val_1 = torch.round(had_matrix @ had_matrix.T) - assert torch.equal(val_1, torch.eye(size)) + product = torch.round(had_matrix @ had_matrix.T) + assert torch.equal(product, torch.eye(size)) + + +def test_random_hadamard_generator(): + generator = torch.Generator().manual_seed(42) + one = random_hadamard_matrix(2048, generator) + two = random_hadamard_matrix(2048, generator) + + one_true = torch.tensor( + [ + [-1, -1, -1], + [+1, -1, +1], + [-1, -1, +1], + ] + ) + two_true = torch.tensor( + [ + [-1, -1, -1], + [-1, +1, -1], + [+1, +1, -1], + ] + ) + + assert torch.all(one[:3, :3].sign() == one_true.sign()) + assert torch.all(two[:3, :3].sign() == two_true.sign()) @pytest.mark.parametrize( @@ -55,6 +79,6 @@ def test_random_hadamard_matrix_compliant(size): ) def test_deterministic_hadamard_compliant(size): had_matrix = deterministic_hadamard_matrix(size) - # HH.T == nI - val_1 = had_matrix @ had_matrix.T - assert numpy.array_equal(val_1 / size, numpy.eye(size)) + # (H / sqrt(n))(H.T / sqrt(n)) == I + product = had_matrix @ had_matrix.T + assert numpy.array_equal(product, numpy.eye(size)) From 6e52894753ca76e9625d289d3b298b58b204f039 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 16:40:51 -0400 Subject: [PATCH 19/27] update docstrings Signed-off-by: Kyle Sayers --- .../transform/transform_args.py | 20 +++++++++++++++++-- .../transform/utils/utils.py | 14 +++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index d0487678..4aa89843 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -13,15 +13,31 @@ # limitations under the License. from enum import Enum -from typing import Any, List +from typing import List from pydantic import BaseModel, Field, field_validator -__all__ = ["TransformArgs"] +__all__ = ["TransformArgs", "TransformLocation"] class TransformLocation(str, Enum): + """ + Enum representing which parameters/activations a transform weight should be applied + to on a given module. + + | -------------------------------------------------------------------------------------------------------- | # noqa: E501 + | Name | Runtime | Values | Candidate Inverse Locations | # noqa: E501 + | --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501 + | `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501 + | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501 + | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 + | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501 + | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501 + | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501 + | -------------------------------------------------------------------------------------------------------- | # noqa: E501 + """ + INPUT = "input" WEIGHT_INPUT = "weight_input" WEIGHT_OUTPUT = "weight_output" diff --git a/src/compressed_tensors/transform/utils/utils.py b/src/compressed_tensors/transform/utils/utils.py index eebe3663..e60d24dc 100644 --- a/src/compressed_tensors/transform/utils/utils.py +++ b/src/compressed_tensors/transform/utils/utils.py @@ -41,7 +41,9 @@ def apply_transform_weight( ) -> torch.Tensor: """ Using the transform location, determine how to apply the transform weight to the - given value + given value. For more info on input and output transforms, see `TransformLocation` + + The following explains how weights should be applied to values according to location let x be input activation W be weight, @@ -49,17 +51,18 @@ def apply_transform_weight( note that y = (x W.T) // torch.nn.Linear - yh = (xh) (Wh).T // transformed + + Choose values for yh, xh, and Wh which incorporate matrix transforms let V, Vi be transform matrices on input side U, Ui be transform matrices on output side - show that the following values for yh, xh, and Wh are consistent - pick xh = (x V) Wh = (U.T W Vi.T) yh = (y U) + The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh + (xh) (Wh).T = (x V) (U.T W Vi.T).T = (x V) (Vi W.T U) // transpose matrix product identity = (x W.T) U @@ -83,3 +86,6 @@ def apply_transform_weight( elif location == TransformLocation.OUTPUT: return value @ weight + + else: + raise NotImplementedError(f"{location} has not been implemented yet") From 72309330ca6ad392bcc470a7426b5f20040d5ea9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 16:49:32 -0400 Subject: [PATCH 20/27] update docstrings Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/transform_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/transform/transform_args.py b/src/compressed_tensors/transform/transform_args.py index 4aa89843..16ab10b3 100644 --- a/src/compressed_tensors/transform/transform_args.py +++ b/src/compressed_tensors/transform/transform_args.py @@ -27,7 +27,7 @@ class TransformLocation(str, Enum): to on a given module. | -------------------------------------------------------------------------------------------------------- | # noqa: E501 - | Name | Runtime | Values | Candidate Inverse Locations | # noqa: E501 + | Name | Runtime | Values | Locations Where Inverse Could Be Applied | # noqa: E501 | --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501 | `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501 | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501 From 92ddea972b400167155bf8bc5c2681b72f64f17b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 16:57:21 -0400 Subject: [PATCH 21/27] cleanup Signed-off-by: Kyle Sayers --- .../factory/test_correctness.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index 45293c94..03617599 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -23,12 +23,13 @@ from tests.testing_utils import requires_accelerate, requires_gpu -_test_schemes = [ - TransformScheme(type=name) for name in TransformFactory.registered_names() -] + [ - TransformScheme(type=name, randomize=True) - for name in TransformFactory.registered_names() -] +def all_schemes(): + base = [TransformScheme(type=name) for name in TransformFactory.registered_names()] + randomized = [ + TransformScheme(type=name, randomize=True) + for name in TransformFactory.registered_names() + ] + return base + randomized class TransformableModel(torch.nn.Module): @@ -45,7 +46,7 @@ def forward(self, x): return x -@pytest.mark.parametrize("scheme", _test_schemes) +@pytest.mark.parametrize("scheme", all_schemes()) def test_correctness_linear(scheme): size = (4, 8) module = torch.nn.Linear(*size, bias=True) @@ -72,7 +73,7 @@ def test_correctness_linear(scheme): assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) -@pytest.mark.parametrize("scheme", _test_schemes) +@pytest.mark.parametrize("scheme", all_schemes()) def test_correctness_model(scheme, offload=False): # load model model = TransformableModel(2, 4, 8, 16, 32, 64) @@ -110,6 +111,6 @@ def test_correctness_model(scheme, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize("scheme", _test_schemes) +@pytest.mark.parametrize("scheme", all_schemes()) def test_correctness_model_offload(scheme): test_correctness_model(scheme, offload=True) From 779956fc31bd259fba82d0d9b9b66b85227ac3ab Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 16:58:35 -0400 Subject: [PATCH 22/27] cleanup 2 Signed-off-by: Kyle Sayers --- tests/test_transform/factory/test_memory.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index 256fcecc..52ac879d 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -26,12 +26,13 @@ from tests.testing_utils import requires_accelerate, requires_gpu -_test_schemes = [ - TransformScheme(type=name) for name in TransformFactory.registered_names() -] + [ - TransformScheme(type=name, randomize=True) - for name in TransformFactory.registered_names() -] +def all_schemes(): + base = [TransformScheme(type=name) for name in TransformFactory.registered_names()] + randomized = [ + TransformScheme(type=name, randomize=True) + for name in TransformFactory.registered_names() + ] + return base + randomized class TransformableModel(torch.nn.Module): @@ -48,7 +49,7 @@ def forward(self, x): return x -@pytest.mark.parametrize("scheme", _test_schemes) +@pytest.mark.parametrize("scheme", all_schemes()) def test_memory_sharing(scheme, offload=False): # load scheme and factory scheme = TransformScheme( @@ -98,12 +99,12 @@ def test_memory_sharing(scheme, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize("scheme", _test_schemes) +@pytest.mark.parametrize("scheme", all_schemes()) def test_memory_sharing_offload(scheme): test_memory_sharing(scheme, offload=True) -@pytest.mark.parametrize("scheme", _test_schemes) +@pytest.mark.parametrize("scheme", all_schemes()) def test_memory_sharing_training(scheme): scheme.requires_grad = True test_memory_sharing(scheme, offload=False) From dd72b6aab1247e7933c2162bbfd19da52180adbc Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 5 Jun 2025 17:02:57 -0400 Subject: [PATCH 23/27] make seed optional Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/factory/base.py | 8 +++++--- src/compressed_tensors/transform/factory/hadamard.py | 2 +- .../transform/factory/matrix_multiply.py | 4 +++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index cdda6ce8..4c7b6c91 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Optional import torch import torch.nn.utils.parametrize as P @@ -46,11 +47,12 @@ class TransformFactory(RegistryMixin, ABC): :param seed: random seed used to transform weight randomization """ - def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): + def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): self.name = name self.scheme = scheme - self.generator = torch.Generator().manual_seed(seed) - self.seed = seed + self.generator = torch.Generator() + if seed is not None: + self.generator.manual_seed(seed) @classmethod def from_scheme(cls: type[T], scheme: TransformScheme, **kwargs) -> T: diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index c77d3bfe..b1da88a3 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -38,7 +38,7 @@ class HadamardFactory(TransformFactory): :param seed: random seed used to transform weight randomization """ - def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): + def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): super().__init__(name, scheme, seed) self.weights = ParameterizedDefaultDict(self._create_weight) diff --git a/src/compressed_tensors/transform/factory/matrix_multiply.py b/src/compressed_tensors/transform/factory/matrix_multiply.py index 15d4e65d..e551fc5f 100644 --- a/src/compressed_tensors/transform/factory/matrix_multiply.py +++ b/src/compressed_tensors/transform/factory/matrix_multiply.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch from compressed_tensors.transform import TransformArgs, TransformScheme from compressed_tensors.transform.factory.base import TransformBase, TransformFactory @@ -35,7 +37,7 @@ class RandomMatrixFactory(TransformFactory): :param seed: random seed used to transform weight randomization """ - def __init__(self, name: str, scheme: TransformScheme, seed: int = 42): + def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): super().__init__(name, scheme, seed) self.weights = ParameterizedDefaultDict(self._create_weight) self.inverses = ParameterizedDefaultDict(self._create_inverse) From da19b0f8e5e634f10828d4ebedec031f51e1d7ea Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 9 Jun 2025 17:02:14 -0400 Subject: [PATCH 24/27] remove iterable check and missing return value Signed-off-by: Kyle Sayers --- src/compressed_tensors/quantization/lifecycle/apply.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 566a1963..8bac1a0a 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -331,12 +331,10 @@ def find_name_or_class_matches( 3. matches on module names """ targets = sorted(targets, key=lambda x: ("re:" in x, x)) - if isinstance(targets, Iterable): - matches = _find_matches(name, targets) + _find_matches( - module.__class__.__name__, targets, check_contains - ) - matches = [match for match in matches if match is not None] - return matches + matches = _find_matches(name, targets) + _find_matches( + module.__class__.__name__, targets, check_contains + ) + return [match for match in matches if match is not None] def _find_matches( From 6e1ec396fd773992b2b83d0faa2fcd3f1357b026 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 10 Jun 2025 11:30:19 -0400 Subject: [PATCH 25/27] Remove unrelated changes --- src/compressed_tensors/quantization/lifecycle/apply.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 8bac1a0a..566a1963 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -331,10 +331,12 @@ def find_name_or_class_matches( 3. matches on module names """ targets = sorted(targets, key=lambda x: ("re:" in x, x)) - matches = _find_matches(name, targets) + _find_matches( - module.__class__.__name__, targets, check_contains - ) - return [match for match in matches if match is not None] + if isinstance(targets, Iterable): + matches = _find_matches(name, targets) + _find_matches( + module.__class__.__name__, targets, check_contains + ) + matches = [match for match in matches if match is not None] + return matches def _find_matches( From 938e7025cf4599c246183ed6b7f5894d2b2d747a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 10 Jun 2025 11:35:03 -0400 Subject: [PATCH 26/27] simplify code Signed-off-by: Kyle Sayers --- tests/test_transform/factory/test_correctness.py | 8 +++----- tests/test_transform/factory/test_memory.py | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index 03617599..1745281f 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -24,11 +24,9 @@ def all_schemes(): - base = [TransformScheme(type=name) for name in TransformFactory.registered_names()] - randomized = [ - TransformScheme(type=name, randomize=True) - for name in TransformFactory.registered_names() - ] + all_types = TransformFactory.registered_names() + base = [TransformScheme(type=type) for type in all_types] + randomized = [TransformScheme(type=type, randomize=True) for type in all_types] return base + randomized diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index 52ac879d..15d72b9b 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -27,11 +27,9 @@ def all_schemes(): - base = [TransformScheme(type=name) for name in TransformFactory.registered_names()] - randomized = [ - TransformScheme(type=name, randomize=True) - for name in TransformFactory.registered_names() - ] + all_types = TransformFactory.registered_names() + base = [TransformScheme(type=type) for type in all_types] + randomized = [TransformScheme(type=type, randomize=True) for type in all_types] return base + randomized From 27bc0b34857eedff3339be6eec91908f6bd21f57 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 10 Jun 2025 17:53:37 -0400 Subject: [PATCH 27/27] implement apply, use in tests Signed-off-by: Kyle Sayers --- src/compressed_tensors/transform/__init__.py | 1 + src/compressed_tensors/transform/apply.py | 25 +++++++ tests/test_transform/conftest.py | 52 ++++++++++++++ .../factory/test_correctness.py | 72 +++++++------------ tests/test_transform/factory/test_memory.py | 64 +++++++---------- 5 files changed, 132 insertions(+), 82 deletions(-) create mode 100644 src/compressed_tensors/transform/apply.py create mode 100644 tests/test_transform/conftest.py diff --git a/src/compressed_tensors/transform/__init__.py b/src/compressed_tensors/transform/__init__.py index f6d656dd..e7546d62 100644 --- a/src/compressed_tensors/transform/__init__.py +++ b/src/compressed_tensors/transform/__init__.py @@ -23,3 +23,4 @@ from .factory.hadamard import * from .factory.matrix_multiply import * from .factory.random_hadamard import * +from .apply import * diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py new file mode 100644 index 00000000..fb745b91 --- /dev/null +++ b/src/compressed_tensors/transform/apply.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transform import TransformConfig, TransformFactory + + +__all__ = ["apply_transform_config"] + + +def apply_transform_config(model: torch.nn.Module, config: TransformConfig): + for name, scheme in config.config_groups.items(): + factory = TransformFactory.from_scheme(scheme, name=name) + factory.apply_to_model(model) diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py new file mode 100644 index 00000000..8681b2f8 --- /dev/null +++ b/tests/test_transform/conftest.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.transform import TransformArgs + + +class TransformableModel(torch.nn.Module): + def __init__(self, *sizes): + super().__init__() + self.fcs = torch.nn.ModuleList([]) + self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) + for index in range(1, len(sizes) - 1): + self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) + + def forward(self, x): + for layer in self.fcs: + x = layer(x) + return x + + +@pytest.fixture(scope="function") +def model_apply(): + model = TransformableModel(2, 4, 8, 16, 32, 64) + apply = [ + # weight output -> input + TransformArgs(targets="fcs.0", location="weight_output"), + TransformArgs(targets="fcs.1", location="input", inverse=True), + # output -> weight input + TransformArgs(targets="fcs.1", location="output"), + TransformArgs(targets="fcs.2", location="weight_input", inverse=True), + # output -> input + TransformArgs(targets="fcs.2", location="output"), + TransformArgs(targets="fcs.3", location="input", inverse=True), + # weight output -> weight input + TransformArgs(targets="fcs.3", location="weight_output"), + TransformArgs(targets="fcs.4", location="weight_input", inverse=True), + ] + + return model, apply diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index 1745281f..c125d8f8 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -16,38 +16,27 @@ import torch from compressed_tensors.transform import ( TransformArgs, + TransformConfig, TransformFactory, TransformScheme, + apply_transform_config, ) from compressed_tensors.utils import force_cpu_offload from tests.testing_utils import requires_accelerate, requires_gpu -def all_schemes(): +def scheme_kwargs(): all_types = TransformFactory.registered_names() - base = [TransformScheme(type=type) for type in all_types] - randomized = [TransformScheme(type=type, randomize=True) for type in all_types] + base = [{"type": type} for type in all_types] + randomized = [{"type": type, "randomize": True} for type in all_types] return base + randomized -class TransformableModel(torch.nn.Module): - def __init__(self, *sizes): - super().__init__() - self.fcs = torch.nn.ModuleList([]) - self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) - for index in range(1, len(sizes) - 1): - self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) - - def forward(self, x): - for layer in self.fcs: - x = layer(x) - return x - - -@pytest.mark.parametrize("scheme", all_schemes()) -def test_correctness_linear(scheme): +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_correctness_linear(scheme_kwargs): size = (4, 8) module = torch.nn.Linear(*size, bias=True) + scheme = TransformScheme(**scheme_kwargs) factory = TransformFactory.from_scheme(scheme, name="") input_tfm = factory.create_transform( @@ -71,44 +60,37 @@ def test_correctness_linear(scheme): assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) -@pytest.mark.parametrize("scheme", all_schemes()) -def test_correctness_model(scheme, offload=False): +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_correctness_model(scheme_kwargs, model_apply, offload=False): # load model - model = TransformableModel(2, 4, 8, 16, 32, 64) + model = model_apply[0] if offload: model = force_cpu_offload(model, torch.device("cuda")) - # create factory - scheme.apply = [ - # weight output -> input - TransformArgs(targets="fcs.0", location="weight_output"), - TransformArgs(targets="fcs.1", location="input", inverse=True), - # output -> weight input - TransformArgs(targets="fcs.1", location="output"), - TransformArgs(targets="fcs.2", location="weight_input", inverse=True), - # output -> input - TransformArgs(targets="fcs.2", location="output"), - TransformArgs(targets="fcs.3", location="input", inverse=True), - # weight output -> weight input - TransformArgs(targets="fcs.3", location="weight_output"), - TransformArgs(targets="fcs.4", location="weight_input", inverse=True), - ] - factory = TransformFactory.from_scheme(scheme, name="") - - # create inputs + # get output input = torch.rand((17, model.fcs[0].in_features)) if offload: input = input.to(torch.device("cuda")) + true_output = model(input) + + # apply transforms + config = TransformConfig( + config_groups={ + "": TransformScheme( + **scheme_kwargs, + apply=model_apply[1], + ) + } + ) + apply_transform_config(model, config) # compare outputs - true_output = model(input) - factory.apply_to_model(model) output = model(input) assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) @requires_gpu @requires_accelerate() -@pytest.mark.parametrize("scheme", all_schemes()) -def test_correctness_model_offload(scheme): - test_correctness_model(scheme, offload=True) +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_correctness_model_offload(scheme_kwargs, model_apply): + test_correctness_model(scheme_kwargs, model_apply, offload=True) diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index 15d72b9b..8ef84ddb 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -19,53 +19,43 @@ from compressed_tensors.transform import ( TransformArgs, TransformBase, + TransformConfig, TransformFactory, TransformScheme, + apply_transform_config, ) from compressed_tensors.utils import align_modules, force_cpu_offload +from tests.test_transform.conftest import TransformableModel from tests.testing_utils import requires_accelerate, requires_gpu -def all_schemes(): +def scheme_kwargs(): all_types = TransformFactory.registered_names() - base = [TransformScheme(type=type) for type in all_types] - randomized = [TransformScheme(type=type, randomize=True) for type in all_types] + base = [{"type": type} for type in all_types] + randomized = [{"type": type, "randomize": True} for type in all_types] return base + randomized -class TransformableModel(torch.nn.Module): - def __init__(self, *sizes): - super().__init__() - self.fcs = torch.nn.ModuleList([]) - self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) - for index in range(1, len(sizes) - 1): - self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) - - def forward(self, x): - for layer in self.fcs: - x = layer(x) - return x - - -@pytest.mark.parametrize("scheme", all_schemes()) -def test_memory_sharing(scheme, offload=False): - # load scheme and factory - scheme = TransformScheme( - type="hadamard", - apply=[ - TransformArgs(targets="Linear", location="input"), - TransformArgs(targets="Linear", location="output"), - ], - ) - factory = TransformFactory.from_scheme(scheme, name="") - +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_memory_sharing(scheme_kwargs, offload=False): # load model (maybe with offloading) model = TransformableModel(2, 2, 4, 4, 8, 8) if offload: force_cpu_offload(model, torch.device("cuda")) # add transforms to model - factory.apply_to_model(model) + config = TransformConfig( + config_groups={ + "": TransformScheme( + **scheme_kwargs, + apply=[ + TransformArgs(targets="Linear", location="input"), + TransformArgs(targets="Linear", location="output"), + ], + ) + } + ) + apply_transform_config(model, config) # check that memory is shared when onloaded with align_modules(model.modules()): @@ -97,12 +87,12 @@ def test_memory_sharing(scheme, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize("scheme", all_schemes()) -def test_memory_sharing_offload(scheme): - test_memory_sharing(scheme, offload=True) +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_memory_sharing_offload(scheme_kwargs): + test_memory_sharing(scheme_kwargs, offload=True) -@pytest.mark.parametrize("scheme", all_schemes()) -def test_memory_sharing_training(scheme): - scheme.requires_grad = True - test_memory_sharing(scheme, offload=False) +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_memory_sharing_training(scheme_kwargs): + scheme_kwargs["requires_grad"] = True + test_memory_sharing(scheme_kwargs, offload=False)