File tree 4 files changed +13
-4
lines changed
4 files changed +13
-4
lines changed Original file line number Diff line number Diff line change @@ -119,6 +119,10 @@ def normalize_config(cfg):
119
119
model_config = load_model_config (cfg )
120
120
cfg .model_config_type = model_config .model_type
121
121
122
+ cfg .tokenizer_config = (
123
+ cfg .tokenizer_config or cfg .base_model_config or cfg .base_model
124
+ )
125
+
122
126
# figure out if the model is llama
123
127
cfg .is_llama_derived_model = (
124
128
(hasattr (model_config , "model_type" ) and model_config .model_type == "llama" )
Original file line number Diff line number Diff line change @@ -134,7 +134,7 @@ def load_tokenized_prepared_datasets(
134
134
split = "train" ,
135
135
) -> Tuple [DatasetDict , List [Prompter ]]:
136
136
cfg_datasets = cfg .test_datasets if split == "test" else cfg .datasets
137
- tokenizer_name = tokenizer . __class__ . __name__
137
+ tokenizer_name = cfg . tokenizer_config
138
138
ds_hash = str (
139
139
md5 (
140
140
(
Original file line number Diff line number Diff line change @@ -134,9 +134,8 @@ def load_tokenizer(cfg):
134
134
if cfg .tokenizer_type :
135
135
tokenizer_cls = getattr (transformers , cfg .tokenizer_type )
136
136
137
- tokenizer_config = cfg .tokenizer_config or cfg .base_model_config or cfg .base_model
138
137
tokenizer = tokenizer_cls .from_pretrained (
139
- tokenizer_config ,
138
+ cfg . tokenizer_config ,
140
139
trust_remote_code = cfg .trust_remote_code or False ,
141
140
use_fast = use_fast ,
142
141
** tokenizer_kwargs ,
Original file line number Diff line number Diff line change 1
1
"""
2
2
unit tests for axolotl.core.trainer_builder
3
3
"""
4
+
4
5
import pytest
5
6
6
7
from axolotl .core .trainer_builder import HFDPOTrainerBuilder
8
+ from axolotl .utils .config import normalize_config
7
9
from axolotl .utils .dict import DictDefault
8
10
from axolotl .utils .models import load_model , load_tokenizer
9
11
10
12
11
13
@pytest .fixture (name = "cfg" )
12
14
def fixture_cfg ():
13
- return DictDefault (
15
+ cfg = DictDefault (
14
16
{
15
17
"base_model" : "TinyLlama/TinyLlama-1.1B-Chat-v0.6" ,
16
18
"model_type" : "AutoModelForCausalLM" ,
@@ -34,6 +36,10 @@ def fixture_cfg():
34
36
}
35
37
)
36
38
39
+ normalize_config (cfg )
40
+
41
+ return cfg
42
+
37
43
38
44
@pytest .fixture (name = "tokenizer" )
39
45
def fixture_tokenizer (cfg ):
You can’t perform that action at this time.
0 commit comments