Skip to content

Apply chat template in mergekit-evolve #508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions mergekit/evo/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from mergekit.options import MergeOptions
from mergekit.plan import MergePlanner

logger = logging.getLogger(__name__)


class MergeActorBase:
def __init__(
Expand Down Expand Up @@ -87,18 +89,18 @@ def evaluate_genotype(
) -> dict:
gc.collect()
torch.cuda.empty_cache()
logging.info("Merging model")
logger.info("Merging model")
merged_path = merge_model(
genotype, self.genome, self.model_storage_path, self.merge_options
)
if not merged_path:
logging.error("Model merge failed")
logger.error("Model merge failed")
return {"score": None, "results": None}

kwargs = {}
if self.quantization_config is not None:
kwargs["quantization_config"] = self.quantization_config
logging.info(f"Model merged to {merged_path}")
logger.info(f"Model merged to {merged_path}")
return evaluate_model(
merged_path,
self.config.tasks,
Expand All @@ -107,6 +109,8 @@ def evaluate_genotype(
vllm=self.vllm,
batch_size=self.batch_size,
task_manager=self.task_manager,
apply_chat_template=self.config.apply_chat_template,
fewshot_as_multiturn=self.config.fewshot_as_multiturn,
**kwargs,
)

Expand Down Expand Up @@ -163,7 +167,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
continue

if getattr(cfg_out, key) != getattr(self.arch_info.config, key, None):
logging.warn(f"Config key {key} changed, reinitializing model")
logger.warn(f"Config key {key} changed, reinitializing model")
different = True
break

Expand Down Expand Up @@ -202,7 +206,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
del inner_model
tokenizer_donor = self.genome.definition.base_model
if tokenizer_donor is None:
logging.warning(
logger.warning(
"Base model not set, using tokenizer from first model in genome"
)
tokenizer_donor = self.genome.definition.models[0]
Expand All @@ -220,7 +224,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
max_model_len = min(max_model_len or 1024, window_sz)
if max_model_len and max_model_len > 8192:
max_model_len = 8192
logging.warn(f"Clipping sequence length to {max_model_len}")
logger.warning(f"Clipping sequence length to {max_model_len}")

mem_util = (
0.7 if self.merge_options.cuda else 0.9
Expand All @@ -237,13 +241,13 @@ def _maybe_init_model(self, config: MergeConfiguration):
else:
self.model = lm_eval.models.huggingface.HFLM(pretrained=inner_model)
self.arch_info = ConfiguredArchitectureInfo(info=ai, config=cfg_out)
logging.info("Model initialized")
logger.info("Model initialized")

def evaluate(self, genotype: torch.Tensor) -> dict:
try:
config = self.genome.genotype_merge_config(genotype)
except InvalidGenotypeError as e:
logging.error("Invalid genotype", exc_info=e)
logger.error("Invalid genotype", exc_info=e)
return {"score": None, "results": None}

self._maybe_init_model(config)
Expand All @@ -262,7 +266,13 @@ def evaluate(self, genotype: torch.Tensor) -> dict:
assert (
model.llm_engine.parallel_config.world_size == 1
), "Must be single GPU"
worker = model.llm_engine.driver_worker
engine = model.llm_engine
if hasattr(engine, "model_executor"):
worker = engine.model_executor.worker
elif hasattr(engine, "driver_worker"):
worker = engine.driver_worker
else:
raise ValueError("Unknown LLM engine type")
model = worker.model_runner.model
param_dict = dict(model.named_parameters())

Expand Down Expand Up @@ -311,6 +321,8 @@ def evaluate(self, genotype: torch.Tensor) -> dict:
limit=self.config.limit,
task_manager=self.task_manager,
batch_size=self.batch_size,
apply_chat_template=self.config.apply_chat_template,
fewshot_as_multiturn=self.config.fewshot_as_multiturn,
)

def evaluate_genotype(
Expand Down
2 changes: 2 additions & 0 deletions mergekit/evo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class EvolMergeConfiguration(BaseModel, frozen=True):
num_fewshot: Optional[int] = None
shuffle: bool = False
random_init: bool = False
apply_chat_template: bool = True
fewshot_as_multiturn: bool = True


NAUGHTY_PREFIXES = [
Expand Down
4 changes: 4 additions & 0 deletions mergekit/evo/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ async def process_queue(self):
vllm=self.vllm,
batch_size=self.batch_size,
task_manager=self.task_manager,
apply_chat_template=self.config.apply_chat_template,
fewshot_as_multiturn=self.config.fewshot_as_multiturn,
**kwargs,
)
] = future_result
Expand Down Expand Up @@ -265,6 +267,8 @@ def evaluate_genotype_serial(
vllm=vllm,
batch_size=batch_size,
task_manager=task_manager,
apply_chat_template=config.apply_chat_template,
fewshot_as_multiturn=config.fewshot_as_multiturn,
**kwargs,
)
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
dev = ["black~=24.10.0", "isort~=5.13.2", "pre-commit~=4.1.0"]
test = ["pytest~=8.3.4"]
evolve = ["ray", "cma", "lm_eval", "wandb"]
vllm = ["vllm==0.3.2", "lm_eval[vllm]"]
vllm = ["vllm==0.7.2", "lm_eval[vllm]"]

[project.urls]
repository = "https://github.com/cg123/mergekit"
Expand Down
Loading