Skip to content

Commit 47cff55

Browse files
committed
Merge branch 'main' into support-ov-models-via-genai
2 parents dc60929 + 928e8bb commit 47cff55

30 files changed

+278
-88
lines changed

lm_eval/api/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,11 @@ def fn(requests):
283283
eval_logger.info(
284284
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
285285
)
286-
# actually run the LM on the requests that do not have cached results
287-
rem_res = getattr(self.lm, attr)(remaining_reqs)
286+
if remaining_reqs:
287+
# actually run the LM on the requests that do not have cached results
288+
rem_res = getattr(self.lm, attr)(remaining_reqs)
289+
else:
290+
rem_res = []
288291

289292
# stick the new ones back into the list and also cache any of the new ones
290293
resptr = 0

lm_eval/api/samplers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ def get_context(self, doc, num_fewshot):
8282
if self.config.doc_to_choice is None or isinstance(doc_content, str)
8383
else self.doc_to_choice(doc)[doc_content]
8484
)
85-
labeled_examples += self.target_delimiter
85+
8686
if doc_target != "":
87+
labeled_examples += self.target_delimiter
8788
labeled_examples += (
8889
str(doc_target[0])
8990
if isinstance(doc_target, list)

lm_eval/evaluator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ def simple_evaluate(
208208
)
209209
else:
210210
if not isinstance(model, lm_eval.api.model.LM):
211-
raise TypeError
211+
raise TypeError(
212+
f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
213+
)
212214
eval_logger.info("Using pre-initialized model")
213215
lm = model
214216

@@ -287,12 +289,18 @@ def _adjust_config(task_dict):
287289
if check_integrity:
288290
run_task_tests(task_list=tasks)
289291

292+
# hotfix: delete when chat_template fixed
293+
try:
294+
chat = lm.chat_template(apply_chat_template)
295+
except: # noqa: E722
296+
chat = None
297+
290298
if evaluation_tracker is not None:
291299
evaluation_tracker.general_config_tracker.log_experiment_args(
292300
model_source=model,
293301
model_args=model_args,
294302
system_instruction=system_instruction,
295-
chat_template=lm.chat_template(apply_chat_template),
303+
chat_template=chat,
296304
fewshot_as_multiturn=fewshot_as_multiturn,
297305
)
298306

lm_eval/models/api_models.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def __init__(
104104
self._truncate = truncate
105105
self._max_gen_toks = int(max_gen_toks)
106106
self._seed = int(seed)
107-
self.max_length = max_length
107+
# max_length - 1 as we always have 1 token for generation
108+
eval_logger.info(f"Using max length {max_length} - 1")
109+
self.max_length = max_length - 1
108110
if int(num_concurrent) <= 1:
109111
eval_logger.info(
110112
"Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1."
@@ -417,6 +419,7 @@ def batch_logliklehood_requests(
417419
cache_keys = []
418420
for chunk in chunks:
419421
for cache_key, context_enc, continuation_enc in chunk:
422+
# max_length - 1 as we always have 1 token for generation
420423
inp = (context_enc + continuation_enc)[-(self.max_length) :]
421424
ctxlen = len(context_enc) - max(
422425
0, len(context_enc) + len(continuation_enc) - (self.max_length)
@@ -510,7 +513,7 @@ def _collate(req: LogLikelihoodInputs):
510513
):
511514
if answer_ is not None:
512515
res.append(answer_)
513-
# partial caching
516+
# cache requests that aren't from a loglikelihood_rolling request
514517
if cache_key is not None:
515518
self.cache_hook.add_partial(
516519
"loglikelihood", cache_key, answer_
@@ -619,7 +622,8 @@ def loglikelihood_rolling(
619622
utils.get_rolling_token_windows(
620623
token_list=self.tok_encode(string),
621624
prefix_token=self.prefix_token_id,
622-
max_seq_len=self.max_length,
625+
# max_seq_len - (1 for context)
626+
max_seq_len=self.max_length - 1,
623627
context_len=1,
624628
),
625629
)
@@ -638,4 +642,7 @@ def loglikelihood_rolling(
638642

639643
string_nll = sum(string_nll)
640644
loglikelihoods.append(string_nll)
645+
646+
# cache this loglikelihood_rolling request
647+
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
641648
return loglikelihoods

lm_eval/models/huggingface.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,10 +688,10 @@ def _create_model(
688688
raise AssertionError("load_in_4bit requires peft >= 0.4.0")
689689
if self._model.config.vocab_size != len(self.tokenizer):
690690
# resize model for LoRAs with added tokens
691-
self._model.resize_token_embeddings(len(self.tokenizer))
692691
eval_logger.info(
693692
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
694693
)
694+
self._model.resize_token_embeddings(len(self.tokenizer))
695695
self._model = PeftModel.from_pretrained(
696696
self._model, peft, revision=revision
697697
)
@@ -1018,6 +1018,9 @@ def loglikelihood_rolling(
10181018
string_nll = sum(string_nll)
10191019
loglikelihoods.append(string_nll)
10201020

1021+
# cache this loglikelihood_rolling request
1022+
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
1023+
10211024
return loglikelihoods
10221025

10231026
def _batch_scheduler(self, pos, n_reordered_requests):
@@ -1246,7 +1249,13 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
12461249

12471250
res.append(answer)
12481251

1249-
self.cache_hook.add_partial("loglikelihood", request_str, answer)
1252+
if request_str is not None:
1253+
# special case: loglikelihood_rolling produces a number of loglikelihood requests
1254+
# all with cache key None. instead do add_partial on the per-example level
1255+
# in the loglikelihood_rolling() function for those.
1256+
self.cache_hook.add_partial(
1257+
"loglikelihood", request_str, answer
1258+
)
12501259
pbar.update(1)
12511260

12521261
pbar.close()

lm_eval/models/nemo_lm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,9 @@ def loglikelihood_rolling(
386386

387387
string_nll = sum(string_nll)
388388
loglikelihoods.append(string_nll)
389+
390+
# cache this loglikelihood_rolling request
391+
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
389392
return loglikelihoods
390393

391394
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
@@ -468,6 +471,9 @@ def _collate(x):
468471
answer = (logprob, is_greedy)
469472

470473
if cache_key is not None:
474+
# special case: loglikelihood_rolling produces a number of loglikelihood requests
475+
# all with cache key None. instead do add_partial on the per-example level
476+
# in the loglikelihood_rolling() function for those.
471477
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
472478

473479
res.append(answer)

lm_eval/models/neuralmagic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ def _collate(x):
321321
res.append(answer)
322322

323323
if cache_key is not None:
324+
# special case: loglikelihood_rolling produces a number of loglikelihood requests
325+
# all with cache key None. instead do add_partial on the per-example level
326+
# in the loglikelihood_rolling() function for those.
324327
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
325328

326329
return re_ord.get_original(res)

lm_eval/models/neuron_optimum.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,8 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
502502

503503
string_nll = sum(string_nll)
504504
loglikelihoods.append(string_nll)
505-
505+
# cache this loglikelihood_rolling request
506+
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
506507
return loglikelihoods
507508

508509
def _loglikelihood_tokens(
@@ -620,7 +621,11 @@ def _collate(x):
620621

621622
res.append(answer)
622623

623-
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
624+
if cache_key is not None:
625+
# special case: loglikelihood_rolling produces a number of loglikelihood requests
626+
# all with cache key None. instead do add_partial on the per-example level
627+
# in the loglikelihood_rolling() function for those.
628+
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
624629

625630
return re_ord.get_original(res)
626631

lm_eval/models/vllm_causallms.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,8 @@ def loglikelihood_rolling(
289289
make_disjoint_window,
290290
get_rolling_token_windows(
291291
token_list=self.tok_encode(string),
292-
prefix_token=self.eot_token_id,
292+
prefix_token=self.prefix_token_id,
293+
# max_seq_len - (1 for context)
293294
max_seq_len=self.max_length - 1,
294295
context_len=1,
295296
),
@@ -307,6 +308,10 @@ def loglikelihood_rolling(
307308

308309
string_nll = sum(string_nll)
309310
loglikelihoods.append(string_nll)
311+
312+
# cache this loglikelihood_rolling request
313+
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
314+
310315
return loglikelihoods
311316

312317
def generate_until(
@@ -453,8 +458,10 @@ def _collate(x):
453458

454459
res.append(answer)
455460

456-
# partial caching
457461
if cache_key is not None:
462+
# special case: loglikelihood_rolling produces a number of loglikelihood requests
463+
# all with cache key None. instead do add_partial on the per-example level
464+
# in the loglikelihood_rolling() function for those.
458465
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
459466
pbar.update(1)
460467
pbar.close()

lm_eval/tasks/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
| [mgsm](mgsm/README.md) | Benchmark of multilingual grade-school math problems. | Spanish, French, German, Russian, Chinese, Japanese, Thai, Swahili, Bengali, Telugu |
7070
| [minerva_math](minerva_math/README.md) | Mathematics-focused tasks requiring numerical reasoning and problem-solving skills. | English |
7171
| mmlu | Massive Multitask Language Understanding benchmark for broad domain language evaluation. Several variants are supported. | English |
72-
| [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigourous. | English |
72+
| [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigorous. | English |
7373
| model_written_evals | Evaluation tasks auto-generated for evaluating a collection of AI Safety concerns. | |
7474
| [mutual](mutual/README.md) | A retrieval-based dataset for multi-turn dialogue reasoning. | English |
7575
| [nq_open](nq_open/README.md) | Open domain question answering tasks based on the Natural Questions dataset. | English |

lm_eval/tasks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def _get_task_and_group(self, task_dir: str):
492492
"`group` and `group_alias` keys in tasks' configs will no longer be used in the next release of lm-eval. "
493493
"`tag` will be used to allow to call a collection of tasks just like `group`. "
494494
"`group` will be removed in order to not cause confusion with the new ConfigurableGroup "
495-
"which will be the offical way to create groups with addition of group-wide configuations."
495+
"which will be the official way to create groups with addition of group-wide configurations."
496496
)
497497
print_info = False
498498
# attr = "tag"

lm_eval/tasks/aclue/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Homepage: https://github.com/isen-zhang/ACLUE
1414

1515
```bibtex
1616
@inproceedings{zhang-li-2023-large,
17-
title = "Can Large Langauge Model Comprehend {A}ncient {C}hinese? A Preliminary Test on {ACLUE}",
17+
title = "Can Large Language Model Comprehend {A}ncient {C}hinese? A Preliminary Test on {ACLUE}",
1818
author = "Zhang, Yixuan and Li, Haonan",
1919
booktitle = "Proceedings of the Ancient Language Processing Workshop",
2020
month = sep,

lm_eval/tasks/asdiv/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ Homepage: https://github.com/chaochun/nlu-asdiv-dataset
4141
#### Tasks
4242

4343
* `asdiv`
44+
* `asdiv_cot_llama`: ASDIV with prompt formatting modified to conform to the evaluation settings described by Meta here: https://huggingface.co/datasets/meta-llama/Meta-Llama-3.1-8B-Instruct-evals/viewer/Meta-Llama-3.1-8B-Instruct-evals__gsm8k__details?row=0
45+
- Note that the CoT prompt from (https://arxiv.org/pdf/2201.11903) is used exactly as in GSM8k-CoT
46+
- This file is setup to run identically to the task `gsm8k_cot_llama` but for asdiv.
47+
- Use this task with --fewshot_as_multiturn and --apply_chat_template to run correctly with Llama Instruct models.
48+
4449

4550
### Checklist
4651

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
dataset_path: EleutherAI/asdiv
2+
doc_to_target: "{{answer.split(' (')[0] if answer is defined else target}}"
3+
doc_to_text: "Given the following problem, reason and give a final answer to the problem.\nProblem: {{body if body is defined}} {{question}}\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n"
4+
fewshot_config:
5+
sampler: first_n
6+
samples:
7+
- question: There are 15 trees in the grove. Grove workers will plant trees in the
8+
grove today. After they are done, there will be 21 trees. How many trees did
9+
the grove workers plant today?
10+
target: There are 15 trees originally. Then there were 21 trees after some more
11+
were planted. So there must have been 21 - 15 = 6. The final answer is 6
12+
- question: If there are 3 cars in the parking lot and 2 more cars arrive, how many
13+
cars are in the parking lot?
14+
target: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The final answer
15+
is 5
16+
- question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many
17+
pieces do they have left in total?
18+
target: Originally, Leah had 32 chocolates. Her sister had 42. So in total they
19+
had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The final answer is 39
20+
- question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12
21+
lollipops. How many lollipops did Jason give to Denny?
22+
target: Jason started with 20 lollipops. Then he had 12 after giving some to Denny.
23+
So he gave Denny 20 - 12 = 8. The final answer is 8
24+
- question: Shawn has five toys. For Christmas, he got two toys each from his mom and
25+
dad. How many toys does he have now?
26+
target: Shawn started with 5 toys. If he got 2 toys each from his mom and dad,
27+
then that is 4 more toys. 5 + 4 = 9. The final answer is 9
28+
- question: There were nine computers in the server room. Five more computers were
29+
installed each day, from monday to thursday. How many computers are now in the
30+
server room?
31+
target: There were originally 9 computers. For each of 4 days, 5 more computers
32+
were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The final answer is
33+
29
34+
- question: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday,
35+
he lost 2 more. How many golf balls did he have at the end of wednesday?
36+
target: Michael started with 58 golf balls. After losing 23 on tuesday, he had
37+
58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The final answer
38+
is 33
39+
- question: Olivia has $23. She bought five bagels for $3 each. How much money does
40+
she have left?
41+
target: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15
42+
dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The final answer is 8
43+
filter_list:
44+
- filter:
45+
- function: regex
46+
group_select: -1
47+
regex_pattern: The final answer is ((-?[$0-9.,]{2,})|(-?[0-9]+))
48+
- function: take_first
49+
name: strict-match
50+
- filter:
51+
- function: regex
52+
group_select: -1
53+
regex_pattern: (-?[$0-9.,]{2,})|(-?[0-9]+)
54+
- function: take_first
55+
name: flexible-extract
56+
generation_kwargs:
57+
do_sample: false
58+
until:
59+
- '<|eot_id|>'
60+
- '<|start_header_id|>user<|end_header_id|>'
61+
- 'Q:'
62+
- </s>
63+
- <|im_end|>
64+
tag:
65+
- chain_of_thought
66+
metadata:
67+
version: 1.0
68+
metric_list:
69+
- aggregation: mean
70+
higher_is_better: true
71+
ignore_case: true
72+
ignore_punctuation: false
73+
metric: exact_match
74+
regexes_to_ignore:
75+
- ','
76+
- \$
77+
- '(?s).*#### '
78+
- \.$
79+
num_fewshot: 8
80+
output_type: generate_until
81+
repeats: 1
82+
task: asdiv_cot_llama
83+
validation_split: validation
84+
test_split: validation
85+
should_decontaminate: true
86+
doc_to_decontamination_query: "{{body}} {{question}}"
87+
dataset_kwargs:
88+
trust_remote_code: true

lm_eval/tasks/eq_bench/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ Homepage: https://eqbench.com/
1616
NOTE: There are some key differences between the lm-evaluation-harness version and the implementation described in the EQ-Bench paper (These have been OK'd by the author):
1717

1818
- The lm-eval version uses the EQ-Bench v2 test set (171 questions) and score calculation. It does not incorporate the revision part of the prompt, as per v2.1 (https://github.com/EQ-bench/EQ-Bench)
19-
- No retries in lm-eval version (EQ-Bench pipeline retries with successively higher temps if it encounters unparseable answers)
20-
- In the original implementation, unparseable answers are excluded from the final score, and 83% of answers have to be parseable or a fail is returned. The lm-eval version instead assigns 0 to unparsable answers and has no fail criteria. So for lower performing models, there may be differences with the EQ-Bench leaderboard.
19+
- No retries in lm-eval version (EQ-Bench pipeline retries with successively higher temps if it encounters unparsable answers)
20+
- In the original implementation, unparsable answers are excluded from the final score, and 83% of answers have to be parseable or a fail is returned. The lm-eval version instead assigns 0 to unparsable answers and has no fail criteria. So for lower performing models, there may be differences with the EQ-Bench leaderboard.
2121

2222

2323
### Citation

lm_eval/tasks/ifeval/ifeval.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ metric_list:
2626
aggregation: !function utils.agg_inst_level_acc
2727
higher_is_better: true
2828
metadata:
29-
version: 3.0
29+
version: 4.0

0 commit comments

Comments
 (0)