Skip to content

Commit ff939d8

Browse files
authored
fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path (axolotl-ai-cloud#1298)
* fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path * fix: normalize config
1 parent 324d59e commit ff939d8

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

src/axolotl/utils/config/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def normalize_config(cfg):
119119
model_config = load_model_config(cfg)
120120
cfg.model_config_type = model_config.model_type
121121

122+
cfg.tokenizer_config = (
123+
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
124+
)
125+
122126
# figure out if the model is llama
123127
cfg.is_llama_derived_model = (
124128
(hasattr(model_config, "model_type") and model_config.model_type == "llama")

src/axolotl/utils/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def load_tokenized_prepared_datasets(
134134
split="train",
135135
) -> Tuple[DatasetDict, List[Prompter]]:
136136
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
137-
tokenizer_name = tokenizer.__class__.__name__
137+
tokenizer_name = cfg.tokenizer_config
138138
ds_hash = str(
139139
md5(
140140
(

src/axolotl/utils/models.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,8 @@ def load_tokenizer(cfg):
134134
if cfg.tokenizer_type:
135135
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
136136

137-
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
138137
tokenizer = tokenizer_cls.from_pretrained(
139-
tokenizer_config,
138+
cfg.tokenizer_config,
140139
trust_remote_code=cfg.trust_remote_code or False,
141140
use_fast=use_fast,
142141
**tokenizer_kwargs,

tests/core/test_trainer_builder.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""
22
unit tests for axolotl.core.trainer_builder
33
"""
4+
45
import pytest
56

67
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
8+
from axolotl.utils.config import normalize_config
79
from axolotl.utils.dict import DictDefault
810
from axolotl.utils.models import load_model, load_tokenizer
911

1012

1113
@pytest.fixture(name="cfg")
1214
def fixture_cfg():
13-
return DictDefault(
15+
cfg = DictDefault(
1416
{
1517
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
1618
"model_type": "AutoModelForCausalLM",
@@ -34,6 +36,10 @@ def fixture_cfg():
3436
}
3537
)
3638

39+
normalize_config(cfg)
40+
41+
return cfg
42+
3743

3844
@pytest.fixture(name="tokenizer")
3945
def fixture_tokenizer(cfg):

0 commit comments

Comments
 (0)