-
Notifications
You must be signed in to change notification settings - Fork 13
[Transform] apply_transform_config
, consolidate test fixtures
#348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kylesayrs
wants to merge
41
commits into
kylesayrs/transform_construct_cache_device
Choose a base branch
from
kylesayrs/transform_apply
base: kylesayrs/transform_construct_cache_device
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
d8a10ec
add utilities
kylesayrs d2af054
add tests
kylesayrs e32d5b5
add additional tests
kylesayrs 9d0518b
add utils and tests
kylesayrs 8c5a2d9
Implement transform factories
kylesayrs 809e367
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs 8d613b3
add permutations
kylesayrs 57d171a
add delete_offload_module
kylesayrs d77bcef
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs ab73b43
Merge branch 'kylesayrs/transform-accelerate-utilities' into kylesayr…
kylesayrs 4b55733
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs aa7d21b
key inverses by weight
kylesayrs 6901e02
fix tests
kylesayrs 47ae9fe
standardize random hadamard
kylesayrs 34f1343
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs 1039100
prepend input hooks
kylesayrs 5677553
Merge remote-tracking branch 'origin' into kylesayrs/transform_utils
kylesayrs 68ec14e
apply sqrt division first
kylesayrs a62418a
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs b117523
use divided hadamards
kylesayrs a46f754
fix typo
kylesayrs cb1cb52
add random option
kylesayrs 7c02bb2
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs 02af1e9
use random seeds, rename matrix multiply
kylesayrs f45f3e9
add deterministic generation to random matrix
kylesayrs 7a7abdf
fix perm math
kylesayrs 6e52894
update docstrings
kylesayrs 7230933
update docstrings
kylesayrs f74fe3e
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs 92ddea9
cleanup
kylesayrs 779956f
cleanup 2
kylesayrs fbd2939
Merge branch 'kylesayrs/transform_utils' into kylesayrs/transform_fac…
kylesayrs dd72b6a
make seed optional
kylesayrs 4ae491d
Merge branch 'kylesayrs/transform_factory' into kylesayrs/transform_p…
kylesayrs da19b0f
remove iterable check and missing return value
kylesayrs 7ab17ce
Merge branch 'main' into kylesayrs/transform_permutations
kylesayrs 33df50f
Merge remote-tracking branch 'origin' into kylesayrs/transform_permut…
kylesayrs 6e1ec39
Remove unrelated changes
kylesayrs 938e702
simplify code
kylesayrs 27bc0b3
implement apply, use in tests
kylesayrs e7f08e1
Merge branch 'kylesayrs/transform_construct_cache_device' into kylesa…
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
from compressed_tensors.transform import TransformConfig, TransformFactory | ||
|
||
|
||
__all__ = ["apply_transform_config"] | ||
|
||
|
||
def apply_transform_config(model: torch.nn.Module, config: TransformConfig): | ||
for name, scheme in config.config_groups.items(): | ||
factory = TransformFactory.from_scheme(scheme, name=name) | ||
factory.apply_to_model(model) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
import torch | ||
from compressed_tensors.transform import TransformArgs | ||
|
||
|
||
class TransformableModel(torch.nn.Module): | ||
def __init__(self, *sizes): | ||
super().__init__() | ||
self.fcs = torch.nn.ModuleList([]) | ||
self.fcs.append(torch.nn.Linear(sizes[0], sizes[1], bias=False)) | ||
for index in range(1, len(sizes) - 1): | ||
self.fcs.append(torch.nn.Linear(sizes[index], sizes[index + 1], bias=False)) | ||
|
||
def forward(self, x): | ||
for layer in self.fcs: | ||
x = layer(x) | ||
return x | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def model_apply(): | ||
model = TransformableModel(2, 4, 8, 16, 32, 64) | ||
apply = [ | ||
# weight output -> input | ||
TransformArgs(targets="fcs.0", location="weight_output"), | ||
TransformArgs(targets="fcs.1", location="input", inverse=True), | ||
# output -> weight input | ||
TransformArgs(targets="fcs.1", location="output"), | ||
TransformArgs(targets="fcs.2", location="weight_input", inverse=True), | ||
# output -> input | ||
TransformArgs(targets="fcs.2", location="output"), | ||
TransformArgs(targets="fcs.3", location="input", inverse=True), | ||
# weight output -> weight input | ||
TransformArgs(targets="fcs.3", location="weight_output"), | ||
TransformArgs(targets="fcs.4", location="weight_input", inverse=True), | ||
] | ||
|
||
return model, apply |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why doesn't this have to be
? Doesn't
.get
need to be called to instantiate the entry in the dict?