|
41 | 41 | iter_named_leaf_modules,
|
42 | 42 | iter_named_quantizable_modules,
|
43 | 43 | )
|
| 44 | +from compressed_tensors.transforms import Transforms |
| 45 | +from compressed_tensors.transforms.transform_config import TransformationConfig |
| 46 | +from compressed_tensors.transforms.transform_data import TransformData |
44 | 47 | from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
|
45 | 48 | from compressed_tensors.utils.offload import update_parameter_data
|
46 | 49 | from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
|
49 | 52 |
|
50 | 53 | __all__ = [
|
51 | 54 | "load_pretrained_quantization",
|
| 55 | + "load_transforms", |
52 | 56 | "apply_quantization_config",
|
53 | 57 | "apply_quantization_status",
|
54 | 58 | "find_name_or_class_matches",
|
55 | 59 | "expand_target_names",
|
56 | 60 | "is_target",
|
| 61 | + "process_transforms_config", |
57 | 62 | ]
|
58 | 63 |
|
59 | 64 | from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
60 |
| -from compressed_tensors.utils.safetensors_load import get_quantization_state_dict |
| 65 | +from compressed_tensors.utils.safetensors_load import ( |
| 66 | + get_quantization_state_dict, |
| 67 | + get_weight_mappings, |
| 68 | +) |
| 69 | +from safetensors import safe_open |
61 | 70 |
|
62 | 71 |
|
63 | 72 | _LOGGER = logging.getLogger(__name__)
|
64 | 73 |
|
65 | 74 |
|
| 75 | +def load_transforms(model: Module, model_name_or_path: str): |
| 76 | + model_path = get_safetensors_folder(model_name_or_path) |
| 77 | + weight_mappings = get_weight_mappings(model_path) |
| 78 | + |
| 79 | + state_dict = {} |
| 80 | + for weight_name, safe_path in weight_mappings.items(): |
| 81 | + if "transform" in weight_name: |
| 82 | + with safe_open(safe_path, framework="pt", device="cpu") as f: |
| 83 | + state_dict[weight_name] = f.get_tensor(weight_name) |
| 84 | + |
| 85 | + for name, submodule in iter_named_leaf_modules(model): |
| 86 | + transform_data = getattr(submodule, "transform_data", None) |
| 87 | + if transform_data: |
| 88 | + for transform_name, transform_data in transform_data.data.items(): |
| 89 | + full_name = f"{name}.{transform_name}" |
| 90 | + transform_data = state_dict.get(full_name, None) |
| 91 | + update_parameter_data(submodule, transform_data, transform_name) |
| 92 | + |
| 93 | + |
66 | 94 | def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
67 | 95 | """
|
68 | 96 | Loads the quantization parameters (scale and zero point) from model_name_or_path to
|
@@ -104,8 +132,92 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
104 | 132 | )
|
105 | 133 |
|
106 | 134 |
|
| 135 | +def process_transforms_config( |
| 136 | + transforms_config: TransformationConfig, |
| 137 | + model: torch.nn.Module, |
| 138 | + quantization_status: Optional[QuantizationStatus] = QuantizationStatus.INITIALIZED, |
| 139 | +): |
| 140 | + for _, group in transforms_config.transform_groups.items(): |
| 141 | + # Each group/scheme targets one type of transform |
| 142 | + transform_type = group.transform_type |
| 143 | + transform_creation_args = group.transform_creation_args |
| 144 | + |
| 145 | + # Need a better name - too many groups |
| 146 | + for transform_arg in group.groups: |
| 147 | + module_targets = transform_arg.module_targets |
| 148 | + |
| 149 | + for name, submodule in model.named_modules(): |
| 150 | + if len(transform_arg.ignore) > 0: |
| 151 | + if matches := find_name_or_class_matches( |
| 152 | + name, submodule, transform_arg.ignore |
| 153 | + ): |
| 154 | + for match in matches: |
| 155 | + print("ignoring", match, name) |
| 156 | + continue # layer matches ignore list, continue |
| 157 | + |
| 158 | + targets = find_name_or_class_matches( |
| 159 | + name, submodule, transform_arg.targets |
| 160 | + ) |
| 161 | + |
| 162 | + if targets: |
| 163 | + # Every layer which matches gets its own transform |
| 164 | + # Same transform type and args are used however |
| 165 | + |
| 166 | + # attach the transform to the submodule |
| 167 | + # because we can have more than one transform, need to attach some |
| 168 | + # form of key to fetch |
| 169 | + # OR we store it in the dictionary, handle cpu-offloading separatly |
| 170 | + |
| 171 | + if hasattr(submodule, "transform_data"): |
| 172 | + idx = submodule.transform_data.idx + 1 |
| 173 | + else: |
| 174 | + idx = 0 |
| 175 | + # only support weight parameters for now, assume one value in |
| 176 | + # module targets |
| 177 | + transform_name = f"{module_targets[0]}_transform_{idx}" |
| 178 | + |
| 179 | + # create an empty tensor OR create a new transform |
| 180 | + dtype = getattr(submodule, module_targets[0]).dtype |
| 181 | + if quantization_status in [ |
| 182 | + QuantizationStatus.COMPRESSED, |
| 183 | + QuantizationStatus.FROZEN, |
| 184 | + ]: |
| 185 | + transform = Transforms.load_from_registry( |
| 186 | + transform_type, |
| 187 | + dtype=dtype, |
| 188 | + empty=True, |
| 189 | + **transform_creation_args, |
| 190 | + ) |
| 191 | + else: |
| 192 | + transform = Transforms.load_from_registry( |
| 193 | + transform_type, |
| 194 | + dtype=dtype, |
| 195 | + **transform_creation_args, |
| 196 | + ) |
| 197 | + setattr(submodule, transform_name, transform) |
| 198 | + |
| 199 | + # add relevant transform data to the submodule as well |
| 200 | + data = { |
| 201 | + transform_name: { |
| 202 | + "type": transform_type, |
| 203 | + "call_args": transform_arg.call_args, |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + if hasattr(submodule, "transform_data"): |
| 208 | + submodule.transform_data.data.update(data) |
| 209 | + submodule.transform_data.idx = idx |
| 210 | + else: |
| 211 | + transform_data = TransformData(data=OrderedDict(data)) |
| 212 | + submodule.transform_data = transform_data |
| 213 | + return model |
| 214 | + |
| 215 | + |
107 | 216 | def apply_quantization_config(
|
108 |
| - model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False |
| 217 | + model: Module, |
| 218 | + config: Union[QuantizationConfig, None], |
| 219 | + run_compressed: bool = False, |
| 220 | + transforms_config=None, |
109 | 221 | ) -> OrderedDict:
|
110 | 222 | """
|
111 | 223 | Initializes the model for quantization in-place based on the given config.
|
@@ -184,6 +296,12 @@ def apply_quantization_config(
|
184 | 296 | f"{set(config.ignore) - set(ignored_submodules)}"
|
185 | 297 | )
|
186 | 298 |
|
| 299 | + if transforms_config: |
| 300 | + model.transforms_config = transforms_config |
| 301 | + model = process_transforms_config( |
| 302 | + transforms_config, model, config.quantization_status |
| 303 | + ) |
| 304 | + |
187 | 305 | # apply current quantization status across all targeted layers
|
188 | 306 | apply_quantization_status(model, config.quantization_status)
|
189 | 307 | return names_to_scheme
|
|
0 commit comments