12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import warnings
15
16
from typing import Dict , Tuple
16
17
17
18
import torch
@@ -33,14 +34,15 @@ class CompressedLinear(Linear):
33
34
Wrapper module for running a compressed forward pass of a quantized Linear module.
34
35
The wrapped layer will decompressed on each forward call.
35
36
36
- :param module: dense linear module to replace
37
- :param quantization_scheme: quantization config for the module to wrap
38
- :param quantization_format: compression format module is stored as
39
37
"""
40
38
41
39
def __init__ (self , * args , ** kwargs ) -> None :
42
40
super ().__init__ (* args , ** kwargs )
43
- self ._is_compressed = True
41
+ warnings .warn (
42
+ "CompressedLinear should not be initialized directly. "
43
+ "Use the from_linear method instead." ,
44
+ UserWarning ,
45
+ )
44
46
45
47
@classmethod
46
48
@torch .no_grad ()
@@ -50,6 +52,12 @@ def from_linear(
50
52
quantization_scheme : QuantizationScheme ,
51
53
quantization_format : str ,
52
54
):
55
+ """
56
+ :param module: dense linear module to replace
57
+ :param quantization_scheme: quantization config for the module to wrap
58
+ :param quantization_format: compression format module is stored as
59
+ :return: CompressedLinear module wrapping the input module
60
+ """
53
61
module .__class__ = CompressedLinear
54
62
module .compressor = BaseCompressor .load_from_registry (quantization_format )
55
63
device = next (module .parameters ()).device
@@ -90,8 +98,9 @@ def forward(self, input: Tensor) -> Tensor:
90
98
"""
91
99
Decompresses the weight, then runs the wrapped forward pass
92
100
"""
93
- if self ._is_compressed :
94
- self .weight = self .compressor .decompress_module (self )
95
- self ._is_compressed = False
101
+ if self .quantization_status == QuantizationStatus .COMPRESSED :
102
+ decompressed_weight = self .compressor .decompress_module (self )
103
+ self .weight .data = decompressed_weight
104
+ self .quantization_status = QuantizationStatus .FROZEN
96
105
97
106
return linear (input , self .weight , self .bias )
0 commit comments