Skip to content

Commit 063d62d

Browse files
committed
add apply, serialize, deserialize support
1 parent fadaaf8 commit 063d62d

File tree

7 files changed

+447
-5
lines changed

7 files changed

+447
-5
lines changed

src/compressed_tensors/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
1919
COMPRESSION_VERSION_NAME = "version"
2020
QUANTIZATION_METHOD_NAME = "quant_method"
21+
TRANSFORMS_CONFIG = "transforms_config"

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QUANTIZATION_CONFIG_NAME,
3030
QUANTIZATION_METHOD_NAME,
3131
SPARSITY_CONFIG_NAME,
32+
TRANSFORMS_CONFIG,
3233
)
3334
from compressed_tensors.compressors.base import BaseCompressor
3435
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
@@ -38,13 +39,15 @@
3839
QuantizationStatus,
3940
apply_quantization_config,
4041
load_pretrained_quantization,
42+
load_transforms,
4143
)
4244
from compressed_tensors.quantization.lifecycle import expand_target_names
4345
from compressed_tensors.quantization.quant_args import QuantizationArgs
4446
from compressed_tensors.quantization.utils import (
4547
is_module_quantized,
4648
iter_named_leaf_modules,
4749
)
50+
from compressed_tensors.transforms.transform_config import TransformationConfig
4851
from compressed_tensors.utils import (
4952
get_safetensors_folder,
5053
merge_names,
@@ -133,6 +136,8 @@ def from_compression_config(
133136

134137
sparsity_config = cls.parse_sparsity_config(compression_config)
135138
quantization_config = cls.parse_quantization_config(compression_config)
139+
transforms_config = cls.parse_transforms_config(compression_config)
140+
136141
if sparsity_config is None and quantization_config is None:
137142
return None
138143

@@ -144,8 +149,13 @@ def from_compression_config(
144149
if quantization_config is not None:
145150
quantization_config = QuantizationConfig.model_validate(quantization_config)
146151

152+
if transforms_config is not None:
153+
transforms_config = TransformationConfig.model_validate(transforms_config)
154+
147155
return cls(
148-
sparsity_config=sparsity_config, quantization_config=quantization_config
156+
sparsity_config=sparsity_config,
157+
quantization_config=quantization_config,
158+
transforms_config=transforms_config,
149159
)
150160

151161
@classmethod
@@ -170,6 +180,10 @@ def from_pretrained_model(
170180
model, format=quantization_format
171181
)
172182

183+
# TODO: update to fetch from the pretrained model
184+
# using the attached config for now
185+
transforms_config = getattr(model, "transforms_config", None)
186+
173187
if isinstance(sparsity_config, str): # we passed in a sparsity format
174188
sparsity_config = SparsityCompressionConfig.load_from_registry(
175189
sparsity_config
@@ -179,9 +193,25 @@ def from_pretrained_model(
179193
return None
180194

181195
return cls(
182-
sparsity_config=sparsity_config, quantization_config=quantization_config
196+
sparsity_config=sparsity_config,
197+
quantization_config=quantization_config,
198+
transforms_config=transforms_config,
183199
)
184200

201+
@staticmethod
202+
def parse_transforms_config(
203+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
204+
) -> Union[Dict[str, Any], None]:
205+
206+
if compression_config is None:
207+
return None
208+
209+
if is_compressed_tensors_config(compression_config):
210+
t_config = compression_config.transforms_config
211+
return t_config.model_dump() if t_config is not None else None
212+
213+
return compression_config.get(TRANSFORMS_CONFIG, None)
214+
185215
@staticmethod
186216
def parse_sparsity_config(
187217
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
@@ -243,9 +273,11 @@ def __init__(
243273
self,
244274
sparsity_config: Optional[SparsityCompressionConfig] = None,
245275
quantization_config: Optional[QuantizationConfig] = None,
276+
transforms_config: Optional[TransformationConfig] = None,
246277
):
247278
self.sparsity_config = sparsity_config
248279
self.quantization_config = quantization_config
280+
self.transforms_config = transforms_config
249281
self.sparsity_compressor = None
250282
self.quantization_compressor = None
251283

@@ -434,10 +466,14 @@ def decompress(self, model_path: str, model: Module):
434466
self.quantization_config, QuantizationStatus.FROZEN
435467
):
436468
names_to_scheme = apply_quantization_config(
437-
model, self.quantization_config
469+
model,
470+
self.quantization_config,
471+
transforms_config=self.transforms_config,
438472
)
439473
load_pretrained_quantization(model, model_path)
440474

475+
load_transforms(model, model_path)
476+
441477
model_path_or_state_dict = (
442478
model.state_dict() if sparse_decompressed else model_path
443479
)
@@ -497,6 +533,12 @@ def update_config(self, save_directory: str):
497533
SPARSITY_CONFIG_NAME
498534
] = sparsity_config_data
499535

536+
if self.transforms_config is not None:
537+
transforms_config_data = self.transforms_config.to_dict()
538+
config_data[QUANTIZATION_CONFIG_NAME][
539+
TRANSFORMS_CONFIG
540+
] = transforms_config_data
541+
500542
with open(config_file_path, "w") as config_file:
501543
json.dump(config_data, config_file, indent=2, sort_keys=True)
502544

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def decompress_weight(
126126
:param quantization_args: quantization parameters for the weight
127127
:return: tensor of the decompressed weight
128128
"""
129+
129130
weight = compressed_data["weight_packed"]
130131
scale = compressed_data["weight_scale"]
131132
zero_point = compressed_data.get("weight_zero_point", None)

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
iter_named_leaf_modules,
4242
iter_named_quantizable_modules,
4343
)
44+
from compressed_tensors.transforms import Transforms
45+
from compressed_tensors.transforms.transform_config import TransformationConfig
46+
from compressed_tensors.transforms.transform_data import TransformData
4447
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
4548
from compressed_tensors.utils.offload import update_parameter_data
4649
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
@@ -49,20 +52,45 @@
4952

5053
__all__ = [
5154
"load_pretrained_quantization",
55+
"load_transforms",
5256
"apply_quantization_config",
5357
"apply_quantization_status",
5458
"find_name_or_class_matches",
5559
"expand_target_names",
5660
"is_target",
61+
"process_transforms_config",
5762
]
5863

5964
from compressed_tensors.quantization.utils.helpers import is_module_quantized
60-
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
65+
from compressed_tensors.utils.safetensors_load import (
66+
get_quantization_state_dict,
67+
get_weight_mappings,
68+
)
69+
from safetensors import safe_open
6170

6271

6372
_LOGGER = logging.getLogger(__name__)
6473

6574

75+
def load_transforms(model: Module, model_name_or_path: str):
76+
model_path = get_safetensors_folder(model_name_or_path)
77+
weight_mappings = get_weight_mappings(model_path)
78+
79+
state_dict = {}
80+
for weight_name, safe_path in weight_mappings.items():
81+
if "transform" in weight_name:
82+
with safe_open(safe_path, framework="pt", device="cpu") as f:
83+
state_dict[weight_name] = f.get_tensor(weight_name)
84+
85+
for name, submodule in iter_named_leaf_modules(model):
86+
transform_data = getattr(submodule, "transform_data", None)
87+
if transform_data:
88+
for transform_name, transform_data in transform_data.data.items():
89+
full_name = f"{name}.{transform_name}"
90+
transform_data = state_dict.get(full_name, None)
91+
update_parameter_data(submodule, transform_data, transform_name)
92+
93+
6694
def load_pretrained_quantization(model: Module, model_name_or_path: str):
6795
"""
6896
Loads the quantization parameters (scale and zero point) from model_name_or_path to
@@ -104,8 +132,92 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
104132
)
105133

106134

135+
def process_transforms_config(
136+
transforms_config: TransformationConfig,
137+
model: torch.nn.Module,
138+
quantization_status: Optional[QuantizationStatus] = QuantizationStatus.INITIALIZED,
139+
):
140+
for _, group in transforms_config.transform_groups.items():
141+
# Each group/scheme targets one type of transform
142+
transform_type = group.transform_type
143+
transform_creation_args = group.transform_creation_args
144+
145+
# Need a better name - too many groups
146+
for transform_arg in group.groups:
147+
module_targets = transform_arg.module_targets
148+
149+
for name, submodule in model.named_modules():
150+
if len(transform_arg.ignore) > 0:
151+
if matches := find_name_or_class_matches(
152+
name, submodule, transform_arg.ignore
153+
):
154+
for match in matches:
155+
print("ignoring", match, name)
156+
continue # layer matches ignore list, continue
157+
158+
targets = find_name_or_class_matches(
159+
name, submodule, transform_arg.targets
160+
)
161+
162+
if targets:
163+
# Every layer which matches gets its own transform
164+
# Same transform type and args are used however
165+
166+
# attach the transform to the submodule
167+
# because we can have more than one transform, need to attach some
168+
# form of key to fetch
169+
# OR we store it in the dictionary, handle cpu-offloading separatly
170+
171+
if hasattr(submodule, "transform_data"):
172+
idx = submodule.transform_data.idx + 1
173+
else:
174+
idx = 0
175+
# only support weight parameters for now, assume one value in
176+
# module targets
177+
transform_name = f"{module_targets[0]}_transform_{idx}"
178+
179+
# create an empty tensor OR create a new transform
180+
dtype = getattr(submodule, module_targets[0]).dtype
181+
if quantization_status in [
182+
QuantizationStatus.COMPRESSED,
183+
QuantizationStatus.FROZEN,
184+
]:
185+
transform = Transforms.load_from_registry(
186+
transform_type,
187+
dtype=dtype,
188+
empty=True,
189+
**transform_creation_args,
190+
)
191+
else:
192+
transform = Transforms.load_from_registry(
193+
transform_type,
194+
dtype=dtype,
195+
**transform_creation_args,
196+
)
197+
setattr(submodule, transform_name, transform)
198+
199+
# add relevant transform data to the submodule as well
200+
data = {
201+
transform_name: {
202+
"type": transform_type,
203+
"call_args": transform_arg.call_args,
204+
}
205+
}
206+
207+
if hasattr(submodule, "transform_data"):
208+
submodule.transform_data.data.update(data)
209+
submodule.transform_data.idx = idx
210+
else:
211+
transform_data = TransformData(data=OrderedDict(data))
212+
submodule.transform_data = transform_data
213+
return model
214+
215+
107216
def apply_quantization_config(
108-
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
217+
model: Module,
218+
config: Union[QuantizationConfig, None],
219+
run_compressed: bool = False,
220+
transforms_config=None,
109221
) -> OrderedDict:
110222
"""
111223
Initializes the model for quantization in-place based on the given config.
@@ -184,6 +296,12 @@ def apply_quantization_config(
184296
f"{set(config.ignore) - set(ignored_submodules)}"
185297
)
186298

299+
if transforms_config:
300+
model.transforms_config = transforms_config
301+
model = process_transforms_config(
302+
transforms_config, model, config.quantization_status
303+
)
304+
187305
# apply current quantization status across all targeted layers
188306
apply_quantization_status(model, config.quantization_status)
189307
return names_to_scheme

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
calculate_range,
2929
compute_dynamic_scales_and_zp,
3030
)
31+
from compressed_tensors.transforms.apply import (
32+
apply_inverse_transforms_to_parameter,
33+
apply_transforms_to_parameter,
34+
)
3135
from compressed_tensors.utils import safe_permute
3236
from torch.nn import Module
3337

@@ -280,10 +284,25 @@ def wrapped_forward(self, *args, **kwargs):
280284
if scheme.weights is not None and not compressed:
281285
# calibrate and (fake) quantize weights when applicable
282286
unquantized_weight = self.weight.data.clone()
287+
transform_data = getattr(module, "transform_data", None)
288+
if transform_data is not None:
289+
apply_transforms_to_parameter(
290+
module=module,
291+
module_parameter=self.weight,
292+
transform_data=transform_data,
293+
)
294+
283295
self.weight.data = forward_quantize(
284296
module, self.weight, "weight", scheme.weights
285297
)
286298

299+
if transform_data is not None:
300+
apply_inverse_transforms_to_parameter(
301+
module=module,
302+
module_parameter=self.weight,
303+
transform_data=transform_data,
304+
)
305+
287306
# perform wrapped forward call
288307
output = forward_func_orig.__get__(module, module.__class__)(
289308
input_, *args[1:], **kwargs

0 commit comments

Comments
 (0)