Skip to content

Graph + Multi-GPU optimizations #543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mergekit/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
if TYPE_CHECKING:
from mergekit.config import MergeConfiguration

logger = logging.getLogger(__name__)
LOG = logging.getLogger(__name__)


def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture]:
Expand All @@ -44,11 +44,11 @@ def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture
for c in candidates:
if c.expected_model_type == config.model_type:
return c
logger.warning(
LOG.warning(
f"Multiple architectures for {arch_name}, none match model type {config.model_type}"
)

logger.warning(f"No JSON architecture found for {arch_name}")
LOG.warning(f"No JSON architecture found for {arch_name}")
return None


Expand Down
2 changes: 1 addition & 1 deletion mergekit/architecture/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

RE_LAYER_INDEX = re.compile(r"\.(\d+)\.")

logger = logging.getLogger(__name__)
LOG = logging.getLogger(__name__)


def get_model_tensor_names(model: ModelReference, options: MergeOptions) -> List[str]:
Expand Down
11 changes: 5 additions & 6 deletions mergekit/architecture/json_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ def _template_substitution(
return TemplateWithArithmetic(template).substitute(substitutions)


def _load_architecture_json(name: str) -> ModelArchitecture:
with importlib.resources.open_text(mergekit._data.architectures, name) as f:
text = f.read()
def _load_architecture_json(text: str) -> ModelArchitecture:
data = json.loads(text)
kind = data.get("kind", "module")
if kind == "modular":
Expand Down Expand Up @@ -174,9 +172,10 @@ def _load_all_architectures() -> (
Tuple[List[ModelArchitecture], Dict[str, List[ModelArchitecture]]]
):
architectures: List[ModelArchitecture] = []
for f in importlib.resources.contents(mergekit._data.architectures):
if f.lower().endswith(".json"):
architectures.append(_load_architecture_json(f))
for f in importlib.resources.files(mergekit._data.architectures).iterdir():
if f.is_file() and f.name.lower().endswith(".json"):
text = f.read_text()
architectures.append(_load_architecture_json(text))

name_to_arch: Dict[str, List[JsonModuleArchitecture]] = {}
for arch_info in architectures:
Expand Down
18 changes: 9 additions & 9 deletions mergekit/evo/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from mergekit.options import MergeOptions
from mergekit.plan import MergePlanner

logger = logging.getLogger(__name__)
LOG = logging.getLogger(__name__)


class MergeActorBase:
Expand Down Expand Up @@ -91,18 +91,18 @@ def evaluate_genotype(
) -> dict:
gc.collect()
torch.cuda.empty_cache()
logger.info("Merging model")
LOG.info("Merging model")
merged_path = merge_model(
genotype, self.genome, self.model_storage_path, self.merge_options
)
if not merged_path:
logger.error("Model merge failed")
LOG.error("Model merge failed")
return {"score": None, "results": None}

model_kwargs = {}
if self.quantization_config is not None:
model_kwargs["quantization_config"] = self.quantization_config
logger.info(f"Model merged to {merged_path}")
LOG.info(f"Model merged to {merged_path}")
return evaluate_model(
merged_path,
self.config.tasks,
Expand Down Expand Up @@ -167,7 +167,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
continue

if getattr(cfg_out, key) != getattr(self.arch_info.config, key, None):
logger.warning(f"Config key {key} changed, reinitializing model")
LOG.warning(f"Config key {key} changed, reinitializing model")
different = True
break

Expand Down Expand Up @@ -206,7 +206,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
del inner_model
tokenizer_donor = self.genome.definition.base_model
if tokenizer_donor is None:
logger.warning(
LOG.warning(
"Base model not set, using tokenizer from first model in genome"
)
tokenizer_donor = self.genome.definition.models[0]
Expand All @@ -224,7 +224,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
max_model_len = min(max_model_len or 1024, window_sz)
if max_model_len and max_model_len > 8192:
max_model_len = 8192
logger.warning(f"Clipping sequence length to {max_model_len}")
LOG.warning(f"Clipping sequence length to {max_model_len}")

mem_util = (
0.7 if self.merge_options.cuda else 0.9
Expand All @@ -248,13 +248,13 @@ def _maybe_init_model(self, config: MergeConfiguration):
if ai
else None
)
logger.info("Model initialized")
LOG.info("Model initialized")

def evaluate(self, genotype: torch.Tensor) -> dict:
try:
config = self.genome.genotype_merge_config(genotype)
except InvalidGenotypeError as e:
logger.error("Invalid genotype", exc_info=e)
LOG.error("Invalid genotype", exc_info=e)
return {"score": None, "results": None}

self._maybe_init_model(config)
Expand Down
Loading
Loading