diff --git a/src/compressed_tensors/transform/__init__.py b/src/compressed_tensors/transform/__init__.py index f6d656dd..e7546d62 100644 --- a/src/compressed_tensors/transform/__init__.py +++ b/src/compressed_tensors/transform/__init__.py @@ -23,3 +23,4 @@ from .factory.hadamard import * from .factory.matrix_multiply import * from .factory.random_hadamard import * +from .apply import * diff --git a/src/compressed_tensors/transform/apply.py b/src/compressed_tensors/transform/apply.py new file mode 100644 index 00000000..fb745b91 --- /dev/null +++ b/src/compressed_tensors/transform/apply.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from compressed_tensors.transform import TransformConfig, TransformFactory + + +__all__ = ["apply_transform_config"] + + +def apply_transform_config(model: torch.nn.Module, config: TransformConfig): + for name, scheme in config.config_groups.items(): + factory = TransformFactory.from_scheme(scheme, name=name) + factory.apply_to_model(model) diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index fa128fd1..277bf985 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union import torch from compressed_tensors.transform import TransformArgs, TransformScheme @@ -41,6 +41,7 @@ class HadamardFactory(TransformFactory): def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = None): super().__init__(name, scheme, seed) self.weights = ParameterizedDefaultDict(self._create_weight) + self.perms = ParameterizedDefaultDict(self._create_permutation) def create_transform(self, module: Module, args: TransformArgs): """ @@ -57,7 +58,8 @@ def create_transform(self, module: Module, args: TransformArgs): exec_device = get_execution_device(module) weight = self.weights.get(size, dtype, device, construct_device=exec_device) - return HadamardTransform(weight, args) + perm = self.perms[weight] if self.scheme.randomize else None + return HadamardTransform(weight, perm, args) def _create_weight( self, @@ -71,17 +73,27 @@ def _create_weight( data = data.to(device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) + def _create_permutation(self, weight: Parameter) -> Parameter: + data = torch.randperm(weight.size(0), generator=self.generator) + return Parameter(data, requires_grad=False) + class HadamardTransform(TransformBase): - def __init__(self, weight: Parameter, args: TransformArgs): + def __init__( + self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs + ): super().__init__() self.weight = weight + self.perm = perm self.args = args def forward(self, value: Tensor) -> Tensor: - if not self.args.inverse: - weight = self.weight - else: - weight = self.weight.T + weight = self.weight + + if self.perm is not None: + weight = weight[self.perm][:, self.perm] + + if self.args.inverse: + weight = weight.T return apply_transform_weight(weight, value, self.args.location) diff --git a/src/compressed_tensors/transform/transform_config.py b/src/compressed_tensors/transform/transform_config.py index 414c21e0..df178c42 100644 --- a/src/compressed_tensors/transform/transform_config.py +++ b/src/compressed_tensors/transform/transform_config.py @@ -49,7 +49,7 @@ class TransformConfig(BaseModel): inverse=True, ), ], - randomize_modules=True, + randomize=True, ), "u": TransformScheme( type="hadamard", @@ -62,7 +62,7 @@ class TransformConfig(BaseModel): targets=["Linear"], location="output", inverse=True # non-mergable ), ], - randomize_modules=True, + randomize=True, ), } ) diff --git a/src/compressed_tensors/transform/transform_scheme.py b/src/compressed_tensors/transform/transform_scheme.py index 1335063c..64d646e0 100644 --- a/src/compressed_tensors/transform/transform_scheme.py +++ b/src/compressed_tensors/transform/transform_scheme.py @@ -31,13 +31,12 @@ class TransformScheme(BaseModel): (see `Transforms.registered_names()`) :param apply: list of TransformationArgs containing the information about the modules that should be targeted by the specified transform - :param randomize_modules: True if unique transforms should be applied to each - unique module targeted by `apply`, otherwise reuse transform weights where - applicable + :param randomize: True if uniquely randomized transform weights should be used, + otherwise use identical transform weights where applicable :param requires_grad: True if weights include gradients for training """ type: str apply: List[TransformArgs] = Field(default_factory=list) - randomize_modules: bool = Field(default=False) + randomize: bool = Field(default=False) requires_grad: bool = Field(default=False) diff --git a/tests/test_transform/conftest.py b/tests/test_transform/conftest.py new file mode 100644 index 00000000..8681b2f8 --- /dev/null +++ b/tests/test_transform/conftest.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from compressed_tensors.transform import TransformArgs + + +class TransformableModel(torch.nn.Module): + def __init__(self, *sizes): + super().__init__() + self.fcs = torch.nn.ModuleList([]) + self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) + for index in range(1, len(sizes) - 1): + self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) + + def forward(self, x): + for layer in self.fcs: + x = layer(x) + return x + + +@pytest.fixture(scope="function") +def model_apply(): + model = TransformableModel(2, 4, 8, 16, 32, 64) + apply = [ + # weight output -> input + TransformArgs(targets="fcs.0", location="weight_output"), + TransformArgs(targets="fcs.1", location="input", inverse=True), + # output -> weight input + TransformArgs(targets="fcs.1", location="output"), + TransformArgs(targets="fcs.2", location="weight_input", inverse=True), + # output -> input + TransformArgs(targets="fcs.2", location="output"), + TransformArgs(targets="fcs.3", location="input", inverse=True), + # weight output -> weight input + TransformArgs(targets="fcs.3", location="weight_output"), + TransformArgs(targets="fcs.4", location="weight_input", inverse=True), + ] + + return model, apply diff --git a/tests/test_transform/factory/test_correctness.py b/tests/test_transform/factory/test_correctness.py index bab1117e..c125d8f8 100644 --- a/tests/test_transform/factory/test_correctness.py +++ b/tests/test_transform/factory/test_correctness.py @@ -16,34 +16,27 @@ import torch from compressed_tensors.transform import ( TransformArgs, + TransformConfig, TransformFactory, TransformScheme, + apply_transform_config, ) -from compressed_tensors.utils import align_modules, force_cpu_offload +from compressed_tensors.utils import force_cpu_offload from tests.testing_utils import requires_accelerate, requires_gpu -class TransformableModel(torch.nn.Module): - def __init__(self, *sizes): - super().__init__() - self.fcs = torch.nn.ModuleList([]) - self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) - for index in range(1, len(sizes) - 1): - self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) +def scheme_kwargs(): + all_types = TransformFactory.registered_names() + base = [{"type": type} for type in all_types] + randomized = [{"type": type, "randomize": True} for type in all_types] + return base + randomized - def forward(self, x): - for layer in self.fcs: - x = layer(x) - return x - -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_correctness_linear(scheme): +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_correctness_linear(scheme_kwargs): size = (4, 8) module = torch.nn.Linear(*size, bias=True) + scheme = TransformScheme(**scheme_kwargs) factory = TransformFactory.from_scheme(scheme, name="") input_tfm = factory.create_transform( @@ -67,50 +60,37 @@ def test_correctness_linear(scheme): assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_correctness_model(scheme, offload=False): +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_correctness_model(scheme_kwargs, model_apply, offload=False): # load model - model = TransformableModel(2, 4, 8, 16, 32, 64) + model = model_apply[0] if offload: model = force_cpu_offload(model, torch.device("cuda")) - # create factory - scheme.apply = [ - # weight output -> input - TransformArgs(targets="fcs.0", location="weight_output"), - TransformArgs(targets="fcs.1", location="input", inverse=True), - # output -> weight input - TransformArgs(targets="fcs.1", location="output"), - TransformArgs(targets="fcs.2", location="weight_input", inverse=True), - # output -> input - TransformArgs(targets="fcs.2", location="output"), - TransformArgs(targets="fcs.3", location="input", inverse=True), - # weight output -> weight input - TransformArgs(targets="fcs.3", location="weight_output"), - TransformArgs(targets="fcs.4", location="weight_input", inverse=True), - ] - factory = TransformFactory.from_scheme(scheme, name="") - - # create inputs + # get output input = torch.rand((17, model.fcs[0].in_features)) if offload: input = input.to(torch.device("cuda")) + true_output = model(input) + + # apply transforms + config = TransformConfig( + config_groups={ + "": TransformScheme( + **scheme_kwargs, + apply=model_apply[1], + ) + } + ) + apply_transform_config(model, config) # compare outputs - true_output = model(input) - factory.apply_to_model(model) output = model(input) assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0) @requires_gpu @requires_accelerate() -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_correctness_model_offload(scheme): - test_correctness_model(scheme, offload=True) +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_correctness_model_offload(scheme_kwargs, model_apply): + test_correctness_model(scheme_kwargs, model_apply, offload=True) diff --git a/tests/test_transform/factory/test_memory.py b/tests/test_transform/factory/test_memory.py index 49e882e4..8ef84ddb 100644 --- a/tests/test_transform/factory/test_memory.py +++ b/tests/test_transform/factory/test_memory.py @@ -19,49 +19,43 @@ from compressed_tensors.transform import ( TransformArgs, TransformBase, + TransformConfig, TransformFactory, TransformScheme, + apply_transform_config, ) from compressed_tensors.utils import align_modules, force_cpu_offload +from tests.test_transform.conftest import TransformableModel from tests.testing_utils import requires_accelerate, requires_gpu -class TransformableModel(torch.nn.Module): - def __init__(self, *sizes): - super().__init__() - self.fcs = torch.nn.ModuleList([]) - self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) - for index in range(1, len(sizes) - 1): - self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) +def scheme_kwargs(): + all_types = TransformFactory.registered_names() + base = [{"type": type} for type in all_types] + randomized = [{"type": type, "randomize": True} for type in all_types] + return base + randomized - def forward(self, x): - for layer in self.fcs: - x = layer(x) - return x - - -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_memory_sharing(scheme, offload=False): - # load scheme and factory - scheme = TransformScheme( - type="hadamard", - apply=[ - TransformArgs(targets="Linear", location="input"), - TransformArgs(targets="Linear", location="output"), - ], - ) - factory = TransformFactory.from_scheme(scheme, name="") +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_memory_sharing(scheme_kwargs, offload=False): # load model (maybe with offloading) model = TransformableModel(2, 2, 4, 4, 8, 8) if offload: force_cpu_offload(model, torch.device("cuda")) # add transforms to model - factory.apply_to_model(model) + config = TransformConfig( + config_groups={ + "": TransformScheme( + **scheme_kwargs, + apply=[ + TransformArgs(targets="Linear", location="input"), + TransformArgs(targets="Linear", location="output"), + ], + ) + } + ) + apply_transform_config(model, config) # check that memory is shared when onloaded with align_modules(model.modules()): @@ -93,20 +87,12 @@ def test_memory_sharing(scheme, offload=False): @requires_gpu @requires_accelerate() -@pytest.mark.parametrize( - "scheme", - [TransformScheme(type=name) for name in TransformFactory.registered_names()], -) -def test_memory_sharing_offload(scheme): - test_memory_sharing(scheme, offload=True) +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_memory_sharing_offload(scheme_kwargs): + test_memory_sharing(scheme_kwargs, offload=True) -@pytest.mark.parametrize( - "scheme", - [ - TransformScheme(type=name, requires_grad=True) - for name in TransformFactory.registered_names() - ], -) -def test_memory_sharing_training(scheme): - test_memory_sharing(scheme, offload=False) +@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs()) +def test_memory_sharing_training(scheme_kwargs): + scheme_kwargs["requires_grad"] = True + test_memory_sharing(scheme_kwargs, offload=False) diff --git a/tests/test_transform/test_transform_scheme.py b/tests/test_transform/test_transform_scheme.py index ad851762..839ab46a 100644 --- a/tests/test_transform/test_transform_scheme.py +++ b/tests/test_transform/test_transform_scheme.py @@ -24,7 +24,7 @@ def test_basic_scheme(): type="hadamard", apply=[basic_args], ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 1 assert isinstance(scheme.apply[0], TransformArgs) @@ -43,10 +43,10 @@ def test_multiple_groups_global(): scheme = TransformScheme( type="hadamard", apply=[embedding_args, linear_args], - randomize_modules=True, + randomize=True, ) - assert scheme.randomize_modules + assert scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 2 assert isinstance(scheme.apply[0], TransformArgs) @@ -69,6 +69,6 @@ def test_multiple_groups(): apply=apply, ) - assert not scheme.randomize_modules + assert not scheme.randomize assert scheme.type == "hadamard" assert len(scheme.apply) == 20