Skip to content

Commit cad7848

Browse files
committed
HumanEval: Rename new args to match other scripts
1 parent ef7cdda commit cad7848

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

eval/humaneval.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from __future__ import annotations
22

3-
import os
4-
import sys
5-
3+
import os, sys
64
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
75
from human_eval.data import write_jsonl, read_problems
86
from exllamav2 import model_init
@@ -25,9 +23,9 @@
2523
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
2624
parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
2725
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6", default = 0.6)
28-
parser.add_argument("--top_k", type = int, help = "Top-k sampling, default: 50", default = 50)
29-
parser.add_argument("--top_p", type = float, help = "Top-p sampling, default: 0.6", default = 0.6)
30-
parser.add_argument("-trp", "--token_repetition_penalty", type = float, help = "Token repetition penalty, default: 1.0", default = 1.0)
26+
parser.add_argument("-topk", "--top_k", type = int, help = "Top-k sampling, default: 50", default = 50)
27+
parser.add_argument("-topp", "--top_p", type = float, help = "Top-p sampling, default: 0.6", default = 0.6)
28+
parser.add_argument("-repp", "--repetition_penalty", type = float, help = "Token repetition penalty, default: 1.0", default = 1.0)
3129
model_init.add_args(parser)
3230
args = parser.parse_args()
3331

@@ -124,10 +122,10 @@
124122
)
125123

126124
gen_settings = ExLlamaV2Sampler.Settings(
127-
token_repetition_penalty=args.token_repetition_penalty,
128-
temperature=args.temperature,
129-
top_k=args.top_k,
130-
top_p=args.top_p
125+
token_repetition_penalty = args.repetition_penalty,
126+
temperature = args.temperature,
127+
top_k = args.top_k,
128+
top_p = args.top_p
131129
)
132130

133131
# Get problems

0 commit comments

Comments
 (0)