Skip to content

Commit 95ba68d

Browse files
authored
[BugFix]: AttributeError in CompressedLinear (#273)
* Bugfix: CompressedLinear has no attribute _is_compressed * Raise warning for User
1 parent b762e56 commit 95ba68d

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

src/compressed_tensors/linear/compressed_linear.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
from typing import Dict, Tuple
1617

1718
import torch
@@ -33,14 +34,15 @@ class CompressedLinear(Linear):
3334
Wrapper module for running a compressed forward pass of a quantized Linear module.
3435
The wrapped layer will decompressed on each forward call.
3536
36-
:param module: dense linear module to replace
37-
:param quantization_scheme: quantization config for the module to wrap
38-
:param quantization_format: compression format module is stored as
3937
"""
4038

4139
def __init__(self, *args, **kwargs) -> None:
4240
super().__init__(*args, **kwargs)
43-
self._is_compressed = True
41+
warnings.warn(
42+
"CompressedLinear should not be initialized directly. "
43+
"Use the from_linear method instead.",
44+
UserWarning,
45+
)
4446

4547
@classmethod
4648
@torch.no_grad()
@@ -50,6 +52,12 @@ def from_linear(
5052
quantization_scheme: QuantizationScheme,
5153
quantization_format: str,
5254
):
55+
"""
56+
:param module: dense linear module to replace
57+
:param quantization_scheme: quantization config for the module to wrap
58+
:param quantization_format: compression format module is stored as
59+
:return: CompressedLinear module wrapping the input module
60+
"""
5361
module.__class__ = CompressedLinear
5462
module.compressor = BaseCompressor.load_from_registry(quantization_format)
5563
device = next(module.parameters()).device
@@ -90,8 +98,9 @@ def forward(self, input: Tensor) -> Tensor:
9098
"""
9199
Decompresses the weight, then runs the wrapped forward pass
92100
"""
93-
if self._is_compressed:
94-
self.weight = self.compressor.decompress_module(self)
95-
self._is_compressed = False
101+
if self.quantization_status == QuantizationStatus.COMPRESSED:
102+
decompressed_weight = self.compressor.decompress_module(self)
103+
self.weight.data = decompressed_weight
104+
self.quantization_status = QuantizationStatus.FROZEN
96105

97106
return linear(input, self.weight, self.bias)

0 commit comments

Comments
 (0)