Skip to content

Commit 6125c51

Browse files
committed
Fix scripts
1 parent e5996f1 commit 6125c51

File tree

5 files changed

+213
-14
lines changed

5 files changed

+213
-14
lines changed

mergekit/moe/mixtral.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import tqdm
99
import transformers
1010

11-
from mergekit.architecture import MISTRAL_INFO, WeightInfo
11+
from mergekit.architecture import WeightInfo
12+
from mergekit.architecture.mixtral import MISTRAL_INFO
1213
from mergekit.moe.arch import MoEOutputArchitecture
1314
from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
1415
from mergekit.moe.config import MoEMergeConfig

mergekit/moe/qwen.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# if the transformers version installed is too old
1313
from transformers.models.qwen2_moe import Qwen2MoeConfig
1414

15-
from mergekit.architecture import QWEN2_INFO
15+
from mergekit.architecture.json_definitions import NAME_TO_ARCH
1616
from mergekit.moe.arch import MoEOutputArchitecture
1717
from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
1818
from mergekit.moe.config import MoEMergeConfig
1919
from mergekit.options import MergeOptions
2020

21+
QWEN2_INFO = NAME_TO_ARCH["Qwen2ForCausalLM"][0]
22+
2123

2224
class QwenMoE(MoEOutputArchitecture):
2325
def name(self) -> str:

mergekit/scripts/fill_missing_params.py

+198-1
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import logging
44
import shutil
55
from pathlib import Path
6+
from typing import List, Optional, Tuple
67

78
import click
89
import torch
10+
from huggingface_hub import snapshot_download
911
from safetensors import safe_open
1012
from tqdm import tqdm
1113

12-
from mergekit.architecture import ParameterNamesUtils
1314
from mergekit.io.lazy_tensor_loader import ShardedTensorIndex
1415
from mergekit.io.tensor_writer import TensorWriter
1516

@@ -197,3 +198,199 @@ def main(
197198

198199
if __name__ == "__main__":
199200
main()
201+
202+
203+
class ParameterNamesUtils:
204+
"""Utility functions for handling parameter names."""
205+
206+
@staticmethod
207+
def resolve_model_directory(repo_id: str) -> Path:
208+
"""Resolve the model directory (local or Hugging Face Hub)."""
209+
if Path(repo_id).is_dir():
210+
return Path(repo_id)
211+
212+
return Path(snapshot_download(repo_id))
213+
214+
@staticmethod
215+
def get_model_parameter_names(repo_id: str) -> List[str]:
216+
"""Get parameter names of a model from a Hugging Face repo or local directory."""
217+
model_dir = ParameterNamesUtils.resolve_model_directory(repo_id)
218+
return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys())
219+
220+
@staticmethod
221+
def strip_prefix(name: str, prefix: str) -> str:
222+
"""Remove a single prefix from the start of a name."""
223+
if prefix != "" and name.startswith(prefix + "."):
224+
return name[len(prefix) + 1 :]
225+
return name
226+
227+
@staticmethod
228+
def find_prefix(list1: List[str], list2: List[str]) -> Optional[str]:
229+
"""
230+
Find a prefix in list1 that, after removal, makes list2 an ordered sublist.
231+
"""
232+
assert len(list1) >= len(list2), "params name list1 can't be shorter than list2"
233+
234+
possible_prefixes = {item.split(".")[0] for item in list1 if "." in item}
235+
possible_prefixes = [""] + list(possible_prefixes)
236+
237+
prefix_matches = {}
238+
best_prefix = "" # Default to no prefix
239+
for prefix in possible_prefixes:
240+
stripped_list1 = [
241+
ParameterNamesUtils.strip_prefix(item, prefix) for item in list1
242+
]
243+
prefix_matches[prefix] = len(
244+
[item for item in list2 if item in stripped_list1]
245+
)
246+
247+
if max(prefix_matches.values()) > prefix_matches[""]:
248+
best_prefix = max(prefix_matches, key=prefix_matches.get)
249+
250+
return best_prefix
251+
252+
@staticmethod
253+
def find_common_ordered_names(
254+
param_names: List[List[str]], prefixes: List[str]
255+
) -> List[str]:
256+
"""Identify and return common parameter names across all models, ensuring correct order. Also account for prefix."""
257+
common_names = set(param_names[0])
258+
for i in range(1, len(param_names)):
259+
prefix = f"{prefixes[i]}." if prefixes[i] else ""
260+
common_names.intersection_update({prefix + name for name in param_names[i]})
261+
return [name for name in param_names[0] if name in common_names]
262+
263+
@staticmethod
264+
def remove_size_conflicts(common_names, referenced_models, prefixes):
265+
model_dirs = [
266+
ParameterNamesUtils.resolve_model_directory(m.model.path)
267+
for m in referenced_models
268+
]
269+
model_indices = [ShardedTensorIndex.from_disk(str(dir)) for dir in model_dirs]
270+
271+
common_name_and_shape = common_names.copy()
272+
removed_names = []
273+
274+
for name in common_names:
275+
base_shape = ParameterNamesUtils.tensor_shape(name, model_indices[0])
276+
277+
for i in range(1, len(referenced_models)):
278+
other_name = name
279+
prefix = f"{prefixes[i]}." if prefixes[i] else ""
280+
if name.startswith(prefix) and prefix != "":
281+
other_name = name[len(prefix) :]
282+
shape = ParameterNamesUtils.tensor_shape(other_name, model_indices[i])
283+
284+
if base_shape != shape:
285+
common_name_and_shape.remove(name)
286+
removed_names.append((name, base_shape, shape, i))
287+
break
288+
289+
size_mismatch_count = len(removed_names)
290+
if size_mismatch_count > 0:
291+
logging.warning(
292+
f"Size mismatch detected for {size_mismatch_count}/{size_mismatch_count + len(common_names)} tensors. "
293+
"These names were removed from the merge list."
294+
)
295+
logging.info(
296+
"The following tensors have different shapes across models and were removed from the merge list:"
297+
)
298+
for name, base_shape, shape, i in removed_names:
299+
logging.info(
300+
f"Tensor name: {name}, Base model shape: {base_shape}, Mismatched shape: {shape} in model {referenced_models[i].model.path}"
301+
)
302+
303+
return common_name_and_shape
304+
305+
@staticmethod
306+
def are_common_params_ordered(list1: List[str], list2: List[str]) -> bool:
307+
"""
308+
Check if common elements of list2 maintain their relative order in list1.
309+
"""
310+
common_params = set(list1).intersection(set(list2))
311+
last_index = -1
312+
313+
for param in list2:
314+
if param in common_params:
315+
current_index = list1.index(param)
316+
if current_index < last_index:
317+
return False
318+
last_index = current_index
319+
return True
320+
321+
@staticmethod
322+
def ordered_sublist(list1: List[str], list2: List[str]) -> bool:
323+
"""
324+
Check if list2 is a contiguous ordered sublist of list1.
325+
"""
326+
n, m = len(list1), len(list2)
327+
328+
for i in range(n - m + 1):
329+
if list1[i : i + m] == list2:
330+
return True
331+
return False
332+
333+
@staticmethod
334+
def report_names_similarity(
335+
base_names: List[str], other_names: List[str]
336+
) -> Tuple[Optional[str], str]:
337+
"""
338+
Analyze similarity between parameter names of two models and identify shared prefixes.
339+
Returns:
340+
best_prefix (str): Best matching prefix for parameter names.
341+
case_message (str): Explanation of the structural relationship.
342+
"""
343+
possible_prefixes = {""}
344+
possible_prefixes.update(
345+
{item.split(".")[0] for item in base_names if "." in item}
346+
)
347+
348+
prefixes_subset_overlap = {}
349+
best_prefix = None
350+
case_message = "No common parameter names found for any prefix"
351+
352+
for prefix in possible_prefixes:
353+
base_names_stripped = [
354+
ParameterNamesUtils.strip_prefix(name, prefix) for name in base_names
355+
]
356+
357+
if ParameterNamesUtils.ordered_sublist(base_names_stripped, other_names):
358+
return prefix, "All params in model have exact match in base model."
359+
360+
intersection = set(base_names_stripped).intersection(set(other_names))
361+
prefixes_subset_overlap[prefix] = intersection
362+
363+
if prefixes_subset_overlap:
364+
best_prefix = max(
365+
prefixes_subset_overlap, key=lambda x: len(prefixes_subset_overlap[x])
366+
)
367+
base_names_stripped = [
368+
ParameterNamesUtils.strip_prefix(name, best_prefix)
369+
for name in base_names
370+
]
371+
372+
overlap = len(prefixes_subset_overlap[best_prefix])
373+
ordered = ParameterNamesUtils.are_common_params_ordered(
374+
base_names_stripped, other_names
375+
)
376+
mismatched = [
377+
item for item in other_names if item not in base_names_stripped
378+
]
379+
mismatched = "\n ".join(mismatched)
380+
case_message = (
381+
f"{overlap}/{len(other_names)} ({100 * overlap / len(other_names):.2f}%) "
382+
f"of model parameters are in the base model. \n"
383+
f" Name ordering is {'preserved' if ordered else 'not preserved'}.\n"
384+
f" Missing parameters:\n {mismatched}"
385+
)
386+
387+
return best_prefix, case_message
388+
389+
@staticmethod
390+
def tensor_shape(name, index) -> Tuple[int]:
391+
from safetensors import safe_open
392+
393+
with safe_open(
394+
Path(index.base_path) / index.tensor_paths[name], framework="pt"
395+
) as f:
396+
return f.get_slice(name).get_shape()

mergekit/scripts/moe.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,6 @@ def select_output_arch(
163163
help="Device to use to compute embeddings",
164164
show_default=True,
165165
)
166-
@click.option(
167-
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
168-
)
169166
@click.option(
170167
"--i-understand-this-is-not-useful-without-training",
171168
type=bool,
@@ -180,7 +177,6 @@ def main(
180177
load_in_4bit: bool,
181178
load_in_8bit: bool,
182179
device: str,
183-
verbose: bool,
184180
i_understand_this_is_not_useful_without_training: bool,
185181
merge_options: MergeOptions,
186182
):
@@ -204,7 +200,7 @@ def main(
204200
load_in_8bit=load_in_8bit,
205201
device=device,
206202
allow_all_same=i_understand_this_is_not_useful_without_training,
207-
verbose=verbose,
203+
verbose=merge_options.verbose,
208204
)
209205

210206
if merge_options.write_model_card:

mergekit/scripts/tokensurgeon.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def main(
132132
barycentric=barycentric,
133133
cosine_similarity=cosine_similarity,
134134
name=embed_info.name,
135+
log_reconstruction_error=verbosity > 0,
135136
)
136137

137138
if lm_head_info:
@@ -469,12 +470,14 @@ def get_embeddings(
469470

470471
if log_reconstruction_error:
471472
# compute reconstruction error in donor_embed space
472-
knn_reconstruction_error.append(
473-
torch.nn.functional.mse_loss(
474-
(knn_embeddings.T.to(weights.dtype) @ weights).squeeze(),
475-
token_embedding,
476-
).item()
473+
reconstructed = (
474+
(knn_embeddings.T.to(weights.dtype) @ weights)
475+
.squeeze()
476+
.to(token_embedding.dtype)
477477
)
478+
diff = token_embedding - reconstructed
479+
mse = diff.square().mean().item()
480+
knn_reconstruction_error.append(mse)
478481

479482
# Reconstruct the embedding in original_embed space
480483
res[idx_1] = (e_c_0[indices].T @ weights).squeeze()
@@ -591,7 +594,7 @@ def validate_architecture(
591594
donor_arch_info = arch_info_for_config(donor_cfg)
592595
if donor_arch_info != model_arch_info:
593596
report_issue(
594-
f"Model architectures do not match: {model_arch_info.name()} vs {donor_arch_info.name()}",
597+
f"Model architectures do not match: {model_arch_info.expected_model_type} vs {donor_arch_info.expected_model_type}",
595598
error=not options.allow_crimes,
596599
)
597600

0 commit comments

Comments
 (0)