|
3 | 3 | import logging
|
4 | 4 | import shutil
|
5 | 5 | from pathlib import Path
|
| 6 | +from typing import List, Optional, Tuple |
6 | 7 |
|
7 | 8 | import click
|
8 | 9 | import torch
|
| 10 | +from huggingface_hub import snapshot_download |
9 | 11 | from safetensors import safe_open
|
10 | 12 | from tqdm import tqdm
|
11 | 13 |
|
12 |
| -from mergekit.architecture import ParameterNamesUtils |
13 | 14 | from mergekit.io.lazy_tensor_loader import ShardedTensorIndex
|
14 | 15 | from mergekit.io.tensor_writer import TensorWriter
|
15 | 16 |
|
@@ -197,3 +198,199 @@ def main(
|
197 | 198 |
|
198 | 199 | if __name__ == "__main__":
|
199 | 200 | main()
|
| 201 | + |
| 202 | + |
| 203 | +class ParameterNamesUtils: |
| 204 | + """Utility functions for handling parameter names.""" |
| 205 | + |
| 206 | + @staticmethod |
| 207 | + def resolve_model_directory(repo_id: str) -> Path: |
| 208 | + """Resolve the model directory (local or Hugging Face Hub).""" |
| 209 | + if Path(repo_id).is_dir(): |
| 210 | + return Path(repo_id) |
| 211 | + |
| 212 | + return Path(snapshot_download(repo_id)) |
| 213 | + |
| 214 | + @staticmethod |
| 215 | + def get_model_parameter_names(repo_id: str) -> List[str]: |
| 216 | + """Get parameter names of a model from a Hugging Face repo or local directory.""" |
| 217 | + model_dir = ParameterNamesUtils.resolve_model_directory(repo_id) |
| 218 | + return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys()) |
| 219 | + |
| 220 | + @staticmethod |
| 221 | + def strip_prefix(name: str, prefix: str) -> str: |
| 222 | + """Remove a single prefix from the start of a name.""" |
| 223 | + if prefix != "" and name.startswith(prefix + "."): |
| 224 | + return name[len(prefix) + 1 :] |
| 225 | + return name |
| 226 | + |
| 227 | + @staticmethod |
| 228 | + def find_prefix(list1: List[str], list2: List[str]) -> Optional[str]: |
| 229 | + """ |
| 230 | + Find a prefix in list1 that, after removal, makes list2 an ordered sublist. |
| 231 | + """ |
| 232 | + assert len(list1) >= len(list2), "params name list1 can't be shorter than list2" |
| 233 | + |
| 234 | + possible_prefixes = {item.split(".")[0] for item in list1 if "." in item} |
| 235 | + possible_prefixes = [""] + list(possible_prefixes) |
| 236 | + |
| 237 | + prefix_matches = {} |
| 238 | + best_prefix = "" # Default to no prefix |
| 239 | + for prefix in possible_prefixes: |
| 240 | + stripped_list1 = [ |
| 241 | + ParameterNamesUtils.strip_prefix(item, prefix) for item in list1 |
| 242 | + ] |
| 243 | + prefix_matches[prefix] = len( |
| 244 | + [item for item in list2 if item in stripped_list1] |
| 245 | + ) |
| 246 | + |
| 247 | + if max(prefix_matches.values()) > prefix_matches[""]: |
| 248 | + best_prefix = max(prefix_matches, key=prefix_matches.get) |
| 249 | + |
| 250 | + return best_prefix |
| 251 | + |
| 252 | + @staticmethod |
| 253 | + def find_common_ordered_names( |
| 254 | + param_names: List[List[str]], prefixes: List[str] |
| 255 | + ) -> List[str]: |
| 256 | + """Identify and return common parameter names across all models, ensuring correct order. Also account for prefix.""" |
| 257 | + common_names = set(param_names[0]) |
| 258 | + for i in range(1, len(param_names)): |
| 259 | + prefix = f"{prefixes[i]}." if prefixes[i] else "" |
| 260 | + common_names.intersection_update({prefix + name for name in param_names[i]}) |
| 261 | + return [name for name in param_names[0] if name in common_names] |
| 262 | + |
| 263 | + @staticmethod |
| 264 | + def remove_size_conflicts(common_names, referenced_models, prefixes): |
| 265 | + model_dirs = [ |
| 266 | + ParameterNamesUtils.resolve_model_directory(m.model.path) |
| 267 | + for m in referenced_models |
| 268 | + ] |
| 269 | + model_indices = [ShardedTensorIndex.from_disk(str(dir)) for dir in model_dirs] |
| 270 | + |
| 271 | + common_name_and_shape = common_names.copy() |
| 272 | + removed_names = [] |
| 273 | + |
| 274 | + for name in common_names: |
| 275 | + base_shape = ParameterNamesUtils.tensor_shape(name, model_indices[0]) |
| 276 | + |
| 277 | + for i in range(1, len(referenced_models)): |
| 278 | + other_name = name |
| 279 | + prefix = f"{prefixes[i]}." if prefixes[i] else "" |
| 280 | + if name.startswith(prefix) and prefix != "": |
| 281 | + other_name = name[len(prefix) :] |
| 282 | + shape = ParameterNamesUtils.tensor_shape(other_name, model_indices[i]) |
| 283 | + |
| 284 | + if base_shape != shape: |
| 285 | + common_name_and_shape.remove(name) |
| 286 | + removed_names.append((name, base_shape, shape, i)) |
| 287 | + break |
| 288 | + |
| 289 | + size_mismatch_count = len(removed_names) |
| 290 | + if size_mismatch_count > 0: |
| 291 | + logging.warning( |
| 292 | + f"Size mismatch detected for {size_mismatch_count}/{size_mismatch_count + len(common_names)} tensors. " |
| 293 | + "These names were removed from the merge list." |
| 294 | + ) |
| 295 | + logging.info( |
| 296 | + "The following tensors have different shapes across models and were removed from the merge list:" |
| 297 | + ) |
| 298 | + for name, base_shape, shape, i in removed_names: |
| 299 | + logging.info( |
| 300 | + f"Tensor name: {name}, Base model shape: {base_shape}, Mismatched shape: {shape} in model {referenced_models[i].model.path}" |
| 301 | + ) |
| 302 | + |
| 303 | + return common_name_and_shape |
| 304 | + |
| 305 | + @staticmethod |
| 306 | + def are_common_params_ordered(list1: List[str], list2: List[str]) -> bool: |
| 307 | + """ |
| 308 | + Check if common elements of list2 maintain their relative order in list1. |
| 309 | + """ |
| 310 | + common_params = set(list1).intersection(set(list2)) |
| 311 | + last_index = -1 |
| 312 | + |
| 313 | + for param in list2: |
| 314 | + if param in common_params: |
| 315 | + current_index = list1.index(param) |
| 316 | + if current_index < last_index: |
| 317 | + return False |
| 318 | + last_index = current_index |
| 319 | + return True |
| 320 | + |
| 321 | + @staticmethod |
| 322 | + def ordered_sublist(list1: List[str], list2: List[str]) -> bool: |
| 323 | + """ |
| 324 | + Check if list2 is a contiguous ordered sublist of list1. |
| 325 | + """ |
| 326 | + n, m = len(list1), len(list2) |
| 327 | + |
| 328 | + for i in range(n - m + 1): |
| 329 | + if list1[i : i + m] == list2: |
| 330 | + return True |
| 331 | + return False |
| 332 | + |
| 333 | + @staticmethod |
| 334 | + def report_names_similarity( |
| 335 | + base_names: List[str], other_names: List[str] |
| 336 | + ) -> Tuple[Optional[str], str]: |
| 337 | + """ |
| 338 | + Analyze similarity between parameter names of two models and identify shared prefixes. |
| 339 | + Returns: |
| 340 | + best_prefix (str): Best matching prefix for parameter names. |
| 341 | + case_message (str): Explanation of the structural relationship. |
| 342 | + """ |
| 343 | + possible_prefixes = {""} |
| 344 | + possible_prefixes.update( |
| 345 | + {item.split(".")[0] for item in base_names if "." in item} |
| 346 | + ) |
| 347 | + |
| 348 | + prefixes_subset_overlap = {} |
| 349 | + best_prefix = None |
| 350 | + case_message = "No common parameter names found for any prefix" |
| 351 | + |
| 352 | + for prefix in possible_prefixes: |
| 353 | + base_names_stripped = [ |
| 354 | + ParameterNamesUtils.strip_prefix(name, prefix) for name in base_names |
| 355 | + ] |
| 356 | + |
| 357 | + if ParameterNamesUtils.ordered_sublist(base_names_stripped, other_names): |
| 358 | + return prefix, "All params in model have exact match in base model." |
| 359 | + |
| 360 | + intersection = set(base_names_stripped).intersection(set(other_names)) |
| 361 | + prefixes_subset_overlap[prefix] = intersection |
| 362 | + |
| 363 | + if prefixes_subset_overlap: |
| 364 | + best_prefix = max( |
| 365 | + prefixes_subset_overlap, key=lambda x: len(prefixes_subset_overlap[x]) |
| 366 | + ) |
| 367 | + base_names_stripped = [ |
| 368 | + ParameterNamesUtils.strip_prefix(name, best_prefix) |
| 369 | + for name in base_names |
| 370 | + ] |
| 371 | + |
| 372 | + overlap = len(prefixes_subset_overlap[best_prefix]) |
| 373 | + ordered = ParameterNamesUtils.are_common_params_ordered( |
| 374 | + base_names_stripped, other_names |
| 375 | + ) |
| 376 | + mismatched = [ |
| 377 | + item for item in other_names if item not in base_names_stripped |
| 378 | + ] |
| 379 | + mismatched = "\n ".join(mismatched) |
| 380 | + case_message = ( |
| 381 | + f"{overlap}/{len(other_names)} ({100 * overlap / len(other_names):.2f}%) " |
| 382 | + f"of model parameters are in the base model. \n" |
| 383 | + f" Name ordering is {'preserved' if ordered else 'not preserved'}.\n" |
| 384 | + f" Missing parameters:\n {mismatched}" |
| 385 | + ) |
| 386 | + |
| 387 | + return best_prefix, case_message |
| 388 | + |
| 389 | + @staticmethod |
| 390 | + def tensor_shape(name, index) -> Tuple[int]: |
| 391 | + from safetensors import safe_open |
| 392 | + |
| 393 | + with safe_open( |
| 394 | + Path(index.base_path) / index.tensor_paths[name], framework="pt" |
| 395 | + ) as f: |
| 396 | + return f.get_slice(name).get_shape() |
0 commit comments