From ff939d8a644c27cbe42889e772a1fc5502596759 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 25 Mar 2024 15:34:54 +0900 Subject: [PATCH] fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path (#1298) * fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path * fix: normalize config --- src/axolotl/utils/config/__init__.py | 4 ++++ src/axolotl/utils/data.py | 2 +- src/axolotl/utils/models.py | 3 +-- tests/core/test_trainer_builder.py | 8 +++++++- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 3e743bda9f..cd3f752c1a 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -119,6 +119,10 @@ def normalize_config(cfg): model_config = load_model_config(cfg) cfg.model_config_type = model_config.model_type + cfg.tokenizer_config = ( + cfg.tokenizer_config or cfg.base_model_config or cfg.base_model + ) + # figure out if the model is llama cfg.is_llama_derived_model = ( (hasattr(model_config, "model_type") and model_config.model_type == "llama") diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 9e0049e659..6c9bc68159 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -134,7 +134,7 @@ def load_tokenized_prepared_datasets( split="train", ) -> Tuple[DatasetDict, List[Prompter]]: cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets - tokenizer_name = tokenizer.__class__.__name__ + tokenizer_name = cfg.tokenizer_config ds_hash = str( md5( ( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 41fd471e65..0a59eb2a4d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -134,9 +134,8 @@ def load_tokenizer(cfg): if cfg.tokenizer_type: tokenizer_cls = getattr(transformers, cfg.tokenizer_type) - tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model tokenizer = tokenizer_cls.from_pretrained( - tokenizer_config, + cfg.tokenizer_config, trust_remote_code=cfg.trust_remote_code or False, use_fast=use_fast, **tokenizer_kwargs, diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index 19042639f1..541fdb343d 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -1,16 +1,18 @@ """ unit tests for axolotl.core.trainer_builder """ + import pytest from axolotl.core.trainer_builder import HFDPOTrainerBuilder +from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @pytest.fixture(name="cfg") def fixture_cfg(): - return DictDefault( + cfg = DictDefault( { "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", "model_type": "AutoModelForCausalLM", @@ -34,6 +36,10 @@ def fixture_cfg(): } ) + normalize_config(cfg) + + return cfg + @pytest.fixture(name="tokenizer") def fixture_tokenizer(cfg):