Skip to content

Commit 068fc48

Browse files
authoredJun 13, 2023
Merge pull request axolotl-ai-cloud#199 from NanoCode012/chore/prompter-arg
chore: Refactor inf_kwargs out
2 parents aaadacf + dc77c8e commit 068fc48

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed
 

‎scripts/finetune.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_multi_line_input() -> Optional[str]:
6363
return instruction
6464

6565

66-
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
66+
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
6767
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
6868

6969
for token, symbol in default_tokens.items():
@@ -257,13 +257,13 @@ def train(
257257

258258
if cfg.inference:
259259
logging.info("calling do_inference function")
260-
inf_kwargs: Dict[str, Any] = {}
260+
prompter: Optional[str] = "AlpacaPrompter"
261261
if "prompter" in kwargs:
262262
if kwargs["prompter"] == "None":
263-
inf_kwargs["prompter"] = None
263+
prompter = None
264264
else:
265-
inf_kwargs["prompter"] = kwargs["prompter"]
266-
do_inference(cfg, model, tokenizer, **inf_kwargs)
265+
prompter = kwargs["prompter"]
266+
do_inference(cfg, model, tokenizer, prompter=prompter)
267267
return
268268

269269
if "shard" in kwargs:

0 commit comments

Comments
 (0)