Skip to content

Commit 39d9bd5

Browse files
authored
Apply chat template in mergekit-evolve (#508)
Minor update to allow passing `apply_chat_template` and `fewshot_as_multiturn` when running `mergekit-evolve`.
1 parent 30b67a2 commit 39d9bd5

File tree

4 files changed

+28
-10
lines changed

4 files changed

+28
-10
lines changed

Diff for: mergekit/evo/actors.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from mergekit.options import MergeOptions
4040
from mergekit.plan import MergePlanner
4141

42+
logger = logging.getLogger(__name__)
43+
4244

4345
class MergeActorBase:
4446
def __init__(
@@ -87,18 +89,18 @@ def evaluate_genotype(
8789
) -> dict:
8890
gc.collect()
8991
torch.cuda.empty_cache()
90-
logging.info("Merging model")
92+
logger.info("Merging model")
9193
merged_path = merge_model(
9294
genotype, self.genome, self.model_storage_path, self.merge_options
9395
)
9496
if not merged_path:
95-
logging.error("Model merge failed")
97+
logger.error("Model merge failed")
9698
return {"score": None, "results": None}
9799

98100
kwargs = {}
99101
if self.quantization_config is not None:
100102
kwargs["quantization_config"] = self.quantization_config
101-
logging.info(f"Model merged to {merged_path}")
103+
logger.info(f"Model merged to {merged_path}")
102104
return evaluate_model(
103105
merged_path,
104106
self.config.tasks,
@@ -107,6 +109,8 @@ def evaluate_genotype(
107109
vllm=self.vllm,
108110
batch_size=self.batch_size,
109111
task_manager=self.task_manager,
112+
apply_chat_template=self.config.apply_chat_template,
113+
fewshot_as_multiturn=self.config.fewshot_as_multiturn,
110114
**kwargs,
111115
)
112116

@@ -163,7 +167,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
163167
continue
164168

165169
if getattr(cfg_out, key) != getattr(self.arch_info.config, key, None):
166-
logging.warn(f"Config key {key} changed, reinitializing model")
170+
logger.warn(f"Config key {key} changed, reinitializing model")
167171
different = True
168172
break
169173

@@ -202,7 +206,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
202206
del inner_model
203207
tokenizer_donor = self.genome.definition.base_model
204208
if tokenizer_donor is None:
205-
logging.warning(
209+
logger.warning(
206210
"Base model not set, using tokenizer from first model in genome"
207211
)
208212
tokenizer_donor = self.genome.definition.models[0]
@@ -220,7 +224,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
220224
max_model_len = min(max_model_len or 1024, window_sz)
221225
if max_model_len and max_model_len > 8192:
222226
max_model_len = 8192
223-
logging.warn(f"Clipping sequence length to {max_model_len}")
227+
logger.warning(f"Clipping sequence length to {max_model_len}")
224228

225229
mem_util = (
226230
0.7 if self.merge_options.cuda else 0.9
@@ -237,13 +241,13 @@ def _maybe_init_model(self, config: MergeConfiguration):
237241
else:
238242
self.model = lm_eval.models.huggingface.HFLM(pretrained=inner_model)
239243
self.arch_info = ConfiguredArchitectureInfo(info=ai, config=cfg_out)
240-
logging.info("Model initialized")
244+
logger.info("Model initialized")
241245

242246
def evaluate(self, genotype: torch.Tensor) -> dict:
243247
try:
244248
config = self.genome.genotype_merge_config(genotype)
245249
except InvalidGenotypeError as e:
246-
logging.error("Invalid genotype", exc_info=e)
250+
logger.error("Invalid genotype", exc_info=e)
247251
return {"score": None, "results": None}
248252

249253
self._maybe_init_model(config)
@@ -262,7 +266,13 @@ def evaluate(self, genotype: torch.Tensor) -> dict:
262266
assert (
263267
model.llm_engine.parallel_config.world_size == 1
264268
), "Must be single GPU"
265-
worker = model.llm_engine.driver_worker
269+
engine = model.llm_engine
270+
if hasattr(engine, "model_executor"):
271+
worker = engine.model_executor.worker
272+
elif hasattr(engine, "driver_worker"):
273+
worker = engine.driver_worker
274+
else:
275+
raise ValueError("Unknown LLM engine type")
266276
model = worker.model_runner.model
267277
param_dict = dict(model.named_parameters())
268278

@@ -311,6 +321,8 @@ def evaluate(self, genotype: torch.Tensor) -> dict:
311321
limit=self.config.limit,
312322
task_manager=self.task_manager,
313323
batch_size=self.batch_size,
324+
apply_chat_template=self.config.apply_chat_template,
325+
fewshot_as_multiturn=self.config.fewshot_as_multiturn,
314326
)
315327

316328
def evaluate_genotype(

Diff for: mergekit/evo/config.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class EvolMergeConfiguration(BaseModel, frozen=True):
2828
num_fewshot: Optional[int] = None
2929
shuffle: bool = False
3030
random_init: bool = False
31+
apply_chat_template: bool = True
32+
fewshot_as_multiturn: bool = True
3133

3234

3335
NAUGHTY_PREFIXES = [

Diff for: mergekit/evo/strategy.py

+4
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ async def process_queue(self):
165165
vllm=self.vllm,
166166
batch_size=self.batch_size,
167167
task_manager=self.task_manager,
168+
apply_chat_template=self.config.apply_chat_template,
169+
fewshot_as_multiturn=self.config.fewshot_as_multiturn,
168170
**kwargs,
169171
)
170172
] = future_result
@@ -265,6 +267,8 @@ def evaluate_genotype_serial(
265267
vllm=vllm,
266268
batch_size=batch_size,
267269
task_manager=task_manager,
270+
apply_chat_template=config.apply_chat_template,
271+
fewshot_as_multiturn=config.fewshot_as_multiturn,
268272
**kwargs,
269273
)
270274
)

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
dev = ["black~=24.10.0", "isort~=5.13.2", "pre-commit~=4.1.0"]
3333
test = ["pytest~=8.3.4"]
3434
evolve = ["ray", "cma", "lm_eval", "wandb"]
35-
vllm = ["vllm==0.3.2", "lm_eval[vllm]"]
35+
vllm = ["vllm==0.7.2", "lm_eval[vllm]"]
3636

3737
[project.urls]
3838
repository = "https://github.com/cg123/mergekit"

0 commit comments

Comments
 (0)