forked from arcee-ai/mergekit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregistry.py
105 lines (101 loc) · 3.79 KB
/
registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: BUSL-1.1
from typing import Dict, List
from mergekit.merge_methods.arcee_fusion import ArceeFusionMerge
from mergekit.merge_methods.base import MergeMethod
from mergekit.merge_methods.generalized_task_arithmetic import (
ConsensusMethod,
GeneralizedTaskArithmeticMerge,
)
from mergekit.merge_methods.karcher import KarcherMerge
from mergekit.merge_methods.linear import LinearMerge
from mergekit.merge_methods.model_stock import ModelStockMerge
from mergekit.merge_methods.nuslerp import NuSlerpMerge
from mergekit.merge_methods.passthrough import PassthroughMerge
from mergekit.merge_methods.slerp import SlerpMerge
from mergekit.sparsify import SparsificationMethod
STATIC_MERGE_METHODS: List[MergeMethod] = [
LinearMerge(),
SlerpMerge(),
NuSlerpMerge(),
PassthroughMerge(),
ModelStockMerge(),
ArceeFusionMerge(),
KarcherMerge(),
# generalized task arithmetic methods
GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=None,
default_normalize=False,
default_rescale=False,
method_name="task_arithmetic",
method_pretty_name="Task Arithmetic",
method_reference_url="https://arxiv.org/abs/2212.04089",
),
GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.magnitude,
default_normalize=True,
default_rescale=False,
method_name="ties",
method_pretty_name="TIES",
method_reference_url="https://arxiv.org/abs/2306.01708",
),
GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.random,
default_normalize=False,
default_rescale=True,
method_name="dare_ties",
method_pretty_name="DARE TIES",
method_reference_url="https://arxiv.org/abs/2311.03099",
),
GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.random,
default_normalize=False,
default_rescale=True,
method_name="dare_linear",
method_pretty_name="Linear DARE",
method_reference_url="https://arxiv.org/abs/2311.03099",
),
GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.magnitude_outliers,
default_normalize=False,
default_rescale=False,
method_name="breadcrumbs",
method_pretty_name="Model Breadcrumbs",
method_reference_url="https://arxiv.org/abs/2312.06795",
),
GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.magnitude_outliers,
default_normalize=False,
default_rescale=False,
method_name="breadcrumbs_ties",
method_pretty_name="Model Breadcrumbs with TIES",
method_reference_url="https://arxiv.org/abs/2312.06795",
),
GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.della_magprune,
default_normalize=True,
default_rescale=True,
method_name="della",
method_pretty_name="DELLA",
method_reference_url="https://arxiv.org/abs/2406.11617",
),
GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.della_magprune,
default_normalize=False,
default_rescale=True,
method_name="della_linear",
method_pretty_name="Linear DELLA",
method_reference_url="https://arxiv.org/abs/2406.11617",
),
]
REGISTERED_MERGE_METHODS: Dict[str, MergeMethod] = {
method.name(): method for method in STATIC_MERGE_METHODS
}