Skip to content

[Transform] apply_transform_config, consolidate test fixtures #348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: kylesayrs/transform_construct_cache_device
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d8a10ec
add utilities
kylesayrs May 30, 2025
d2af054
add tests
kylesayrs May 30, 2025
e32d5b5
add additional tests
kylesayrs May 30, 2025
9d0518b
add utils and tests
kylesayrs May 30, 2025
8c5a2d9
Implement transform factories
kylesayrs May 30, 2025
809e367
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs May 30, 2025
8d613b3
add permutations
kylesayrs May 31, 2025
57d171a
add delete_offload_module
kylesayrs May 31, 2025
d77bcef
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs May 31, 2025
ab73b43
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs May 31, 2025
4b55733
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs May 31, 2025
aa7d21b
key inverses by weight
kylesayrs May 31, 2025
6901e02
fix tests
kylesayrs May 31, 2025
47ae9fe
standardize random hadamard
kylesayrs May 31, 2025
34f1343
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs May 31, 2025
1039100
prepend input hooks
kylesayrs May 31, 2025
5677553
Merge remote-tracking branch 'origin' into kylesayrs/transform_utils
kylesayrs Jun 5, 2025
68ec14e
apply sqrt division first
kylesayrs Jun 5, 2025
a62418a
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
b117523
use divided hadamards
kylesayrs Jun 5, 2025
a46f754
fix typo
kylesayrs Jun 5, 2025
cb1cb52
add random option
kylesayrs Jun 5, 2025
7c02bb2
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
02af1e9
use random seeds, rename matrix multiply
kylesayrs Jun 5, 2025
f45f3e9
add deterministic generation to random matrix
kylesayrs Jun 5, 2025
7a7abdf
fix perm math
kylesayrs Jun 5, 2025
6e52894
update docstrings
kylesayrs Jun 5, 2025
7230933
update docstrings
kylesayrs Jun 5, 2025
f74fe3e
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs Jun 5, 2025
92ddea9
cleanup
kylesayrs Jun 5, 2025
779956f
cleanup 2
kylesayrs Jun 5, 2025
fbd2939
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs Jun 5, 2025
dd72b6a
make seed optional
kylesayrs Jun 5, 2025
4ae491d
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs Jun 5, 2025
da19b0f
remove iterable check and missing return value
kylesayrs Jun 9, 2025
7ab17ce
Merge branch 'main' into kylesayrs/transform_permutations
kylesayrs Jun 10, 2025
33df50f
Merge remote-tracking branch 'origin' into kylesayrs/transform_permut…
kylesayrs Jun 10, 2025
6e1ec39
Remove unrelated changes
kylesayrs Jun 10, 2025
938e702
simplify code
kylesayrs Jun 10, 2025
27bc0b3
implement apply, use in tests
kylesayrs Jun 10, 2025
e7f08e1
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/compressed_tensors/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from .factory.hadamard import *
from .factory.matrix_multiply import *
from .factory.random_hadamard import *
from .apply import *
25 changes: 25 additions & 0 deletions src/compressed_tensors/transform/apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from compressed_tensors.transform import TransformConfig, TransformFactory


__all__ = ["apply_transform_config"]


def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
for name, scheme in config.config_groups.items():
factory = TransformFactory.from_scheme(scheme, name=name)
factory.apply_to_model(model)
26 changes: 19 additions & 7 deletions src/compressed_tensors/transform/factory/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Optional, Union

import torch
from compressed_tensors.transform import TransformArgs, TransformScheme
Expand Down Expand Up @@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory):
def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None):
super().__init__(name, scheme, seed)
self.weights = ParameterizedDefaultDict(self._create_weight)
self.perms = ParameterizedDefaultDict(self._create_permutation)

def create_transform(self, module: Module, args: TransformArgs):
"""
Expand All @@ -57,7 +58,8 @@ def create_transform(self, module: Module, args: TransformArgs):
exec_device = get_execution_device(module)

weight = self.weights.get(size, dtype, device, construct_device=exec_device)
return HadamardTransform(weight, args)
perm = self.perms[weight] if self.scheme.randomize else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why doesn't this have to be

perm = self.perms.get(weight) if self.scheme.randomize else None

? Doesn't .get need to be called to instantiate the entry in the dict?

return HadamardTransform(weight, perm, args)

def _create_weight(
self,
Expand All @@ -71,17 +73,27 @@ def _create_weight(
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)

def _create_permutation(self, weight: Parameter) -> Parameter:
data = torch.randperm(weight.size(0), generator=self.generator)
return Parameter(data, requires_grad=False)


class HadamardTransform(TransformBase):
def __init__(self, weight: Parameter, args: TransformArgs):
def __init__(
self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
):
super().__init__()
self.weight = weight
self.perm = perm
self.args = args

def forward(self, value: Tensor) -> Tensor:
if not self.args.inverse:
weight = self.weight
else:
weight = self.weight.T
weight = self.weight

if self.perm is not None:
weight = weight[self.perm][:, self.perm]

if self.args.inverse:
weight = weight.T

return apply_transform_weight(weight, value, self.args.location)
4 changes: 2 additions & 2 deletions src/compressed_tensors/transform/transform_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TransformConfig(BaseModel):
inverse=True,
),
],
randomize_modules=True,
randomize=True,
),
"u": TransformScheme(
type="hadamard",
Expand All @@ -62,7 +62,7 @@ class TransformConfig(BaseModel):
targets=["Linear"], location="output", inverse=True # non-mergable
),
],
randomize_modules=True,
randomize=True,
),
}
)
Expand Down
7 changes: 3 additions & 4 deletions src/compressed_tensors/transform/transform_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ class TransformScheme(BaseModel):
(see `Transforms.registered_names()`)
:param apply: list of TransformationArgs containing the information about the
modules that should be targeted by the specified transform
:param randomize_modules: True if unique transforms should be applied to each
unique module targeted by `apply`, otherwise reuse transform weights where
applicable
:param randomize: True if uniquely randomized transform weights should be used,
otherwise use identical transform weights where applicable
:param requires_grad: True if weights include gradients for training
"""

type: str
apply: List[TransformArgs] = Field(default_factory=list)
randomize_modules: bool = Field(default=False)
randomize: bool = Field(default=False)
requires_grad: bool = Field(default=False)
52 changes: 52 additions & 0 deletions tests/test_transform/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from compressed_tensors.transform import TransformArgs


class TransformableModel(torch.nn.Module):
def __init__(self, *sizes):
super().__init__()
self.fcs = torch.nn.ModuleList([])
self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False))
for index in range(1, len(sizes) - 1):
self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False))

def forward(self, x):
for layer in self.fcs:
x = layer(x)
return x


@pytest.fixture(scope="function")
def model_apply():
model = TransformableModel(2, 4, 8, 16, 32, 64)
apply = [
# weight output -> input
TransformArgs(targets="fcs.0", location="weight_output"),
TransformArgs(targets="fcs.1", location="input", inverse=True),
# output -> weight input
TransformArgs(targets="fcs.1", location="output"),
TransformArgs(targets="fcs.2", location="weight_input", inverse=True),
# output -> input
TransformArgs(targets="fcs.2", location="output"),
TransformArgs(targets="fcs.3", location="input", inverse=True),
# weight output -> weight input
TransformArgs(targets="fcs.3", location="weight_output"),
TransformArgs(targets="fcs.4", location="weight_input", inverse=True),
]

return model, apply
80 changes: 30 additions & 50 deletions tests/test_transform/factory/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,27 @@
import torch
from compressed_tensors.transform import (
TransformArgs,
TransformConfig,
TransformFactory,
TransformScheme,
apply_transform_config,
)
from compressed_tensors.utils import align_modules, force_cpu_offload
from compressed_tensors.utils import force_cpu_offload
from tests.testing_utils import requires_accelerate, requires_gpu


class TransformableModel(torch.nn.Module):
def __init__(self, *sizes):
super().__init__()
self.fcs = torch.nn.ModuleList([])
self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False))
for index in range(1, len(sizes) - 1):
self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False))
def scheme_kwargs():
all_types = TransformFactory.registered_names()
base = [{"type": type} for type in all_types]
randomized = [{"type": type, "randomize": True} for type in all_types]
return base + randomized

def forward(self, x):
for layer in self.fcs:
x = layer(x)
return x


@pytest.mark.parametrize(
"scheme",
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
)
def test_correctness_linear(scheme):
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
def test_correctness_linear(scheme_kwargs):
size = (4, 8)
module = torch.nn.Linear(*size, bias=True)
scheme = TransformScheme(**scheme_kwargs)
factory = TransformFactory.from_scheme(scheme, name="")

input_tfm = factory.create_transform(
Expand All @@ -67,50 +60,37 @@ def test_correctness_linear(scheme):
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)


@pytest.mark.parametrize(
"scheme",
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
)
def test_correctness_model(scheme, offload=False):
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
def test_correctness_model(scheme_kwargs, model_apply, offload=False):
# load model
model = TransformableModel(2, 4, 8, 16, 32, 64)
model = model_apply[0]
if offload:
model = force_cpu_offload(model, torch.device("cuda"))

# create factory
scheme.apply = [
# weight output -> input
TransformArgs(targets="fcs.0", location="weight_output"),
TransformArgs(targets="fcs.1", location="input", inverse=True),
# output -> weight input
TransformArgs(targets="fcs.1", location="output"),
TransformArgs(targets="fcs.2", location="weight_input", inverse=True),
# output -> input
TransformArgs(targets="fcs.2", location="output"),
TransformArgs(targets="fcs.3", location="input", inverse=True),
# weight output -> weight input
TransformArgs(targets="fcs.3", location="weight_output"),
TransformArgs(targets="fcs.4", location="weight_input", inverse=True),
]
factory = TransformFactory.from_scheme(scheme, name="")

# create inputs
# get output
input = torch.rand((17, model.fcs[0].in_features))
if offload:
input = input.to(torch.device("cuda"))
true_output = model(input)

# apply transforms
config = TransformConfig(
config_groups={
"": TransformScheme(
**scheme_kwargs,
apply=model_apply[1],
)
}
)
apply_transform_config(model, config)

# compare outputs
true_output = model(input)
factory.apply_to_model(model)
output = model(input)
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)


@requires_gpu
@requires_accelerate()
@pytest.mark.parametrize(
"scheme",
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
)
def test_correctness_model_offload(scheme):
test_correctness_model(scheme, offload=True)
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
def test_correctness_model_offload(scheme_kwargs, model_apply):
test_correctness_model(scheme_kwargs, model_apply, offload=True)
72 changes: 29 additions & 43 deletions tests/test_transform/factory/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,43 @@
from compressed_tensors.transform import (
TransformArgs,
TransformBase,
TransformConfig,
TransformFactory,
TransformScheme,
apply_transform_config,
)
from compressed_tensors.utils import align_modules, force_cpu_offload
from tests.test_transform.conftest import TransformableModel
from tests.testing_utils import requires_accelerate, requires_gpu


class TransformableModel(torch.nn.Module):
def __init__(self, *sizes):
super().__init__()
self.fcs = torch.nn.ModuleList([])
self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False))
for index in range(1, len(sizes) - 1):
self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False))
def scheme_kwargs():
all_types = TransformFactory.registered_names()
base = [{"type": type} for type in all_types]
randomized = [{"type": type, "randomize": True} for type in all_types]
return base + randomized

def forward(self, x):
for layer in self.fcs:
x = layer(x)
return x


@pytest.mark.parametrize(
"scheme",
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
)
def test_memory_sharing(scheme, offload=False):
# load scheme and factory
scheme = TransformScheme(
type="hadamard",
apply=[
TransformArgs(targets="Linear", location="input"),
TransformArgs(targets="Linear", location="output"),
],
)
factory = TransformFactory.from_scheme(scheme, name="")

@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
def test_memory_sharing(scheme_kwargs, offload=False):
# load model (maybe with offloading)
model = TransformableModel(2, 2, 4, 4, 8, 8)
if offload:
force_cpu_offload(model, torch.device("cuda"))

# add transforms to model
factory.apply_to_model(model)
config = TransformConfig(
config_groups={
"": TransformScheme(
**scheme_kwargs,
apply=[
TransformArgs(targets="Linear", location="input"),
TransformArgs(targets="Linear", location="output"),
],
)
}
)
apply_transform_config(model, config)

# check that memory is shared when onloaded
with align_modules(model.modules()):
Expand Down Expand Up @@ -93,20 +87,12 @@ def test_memory_sharing(scheme, offload=False):

@requires_gpu
@requires_accelerate()
@pytest.mark.parametrize(
"scheme",
[TransformScheme(type=name) for name in TransformFactory.registered_names()],
)
def test_memory_sharing_offload(scheme):
test_memory_sharing(scheme, offload=True)
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
def test_memory_sharing_offload(scheme_kwargs):
test_memory_sharing(scheme_kwargs, offload=True)


@pytest.mark.parametrize(
"scheme",
[
TransformScheme(type=name, requires_grad=True)
for name in TransformFactory.registered_names()
],
)
def test_memory_sharing_training(scheme):
test_memory_sharing(scheme, offload=False)
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
def test_memory_sharing_training(scheme_kwargs):
scheme_kwargs["requires_grad"] = True
test_memory_sharing(scheme_kwargs, offload=False)
Loading