File tree 1 file changed +5
-5
lines changed
1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -63,7 +63,7 @@ def get_multi_line_input() -> Optional[str]:
63
63
return instruction
64
64
65
65
66
- def do_inference (cfg , model , tokenizer , prompter = "AlpacaPrompter" ):
66
+ def do_inference (cfg , model , tokenizer , prompter : Optional [ str ] ):
67
67
default_tokens = {"unk_token" : "<unk>" , "bos_token" : "<s>" , "eos_token" : "</s>" }
68
68
69
69
for token , symbol in default_tokens .items ():
@@ -257,13 +257,13 @@ def train(
257
257
258
258
if cfg .inference :
259
259
logging .info ("calling do_inference function" )
260
- inf_kwargs : Dict [str , Any ] = {}
260
+ prompter : Optional [str ] = "AlpacaPrompter"
261
261
if "prompter" in kwargs :
262
262
if kwargs ["prompter" ] == "None" :
263
- inf_kwargs [ " prompter" ] = None
263
+ prompter = None
264
264
else :
265
- inf_kwargs [ " prompter" ] = kwargs ["prompter" ]
266
- do_inference (cfg , model , tokenizer , ** inf_kwargs )
265
+ prompter = kwargs ["prompter" ]
266
+ do_inference (cfg , model , tokenizer , prompter = prompter )
267
267
return
268
268
269
269
if "shard" in kwargs :
You can’t perform that action at this time.
0 commit comments