Skip to content

Commit

Permalink
rename references to dpo dataset prep to pref data (#2258)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Jan 15, 2025
1 parent 1ed4de7 commit 19cd83d
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
10 changes: 5 additions & 5 deletions src/axolotl/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
Expand Down Expand Up @@ -103,9 +103,9 @@ def load_preference_datasets(
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for DPO training, calling
`axolotl.utils.data.rl.load_prepare_dpo_datasets`. Optionally, logs out debug
information.
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Expand All @@ -115,7 +115,7 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper,
load_prepare_datasets,
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/data/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type")


def load_prepare_dpo_datasets(cfg):
def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from transformers import AutoTokenizer

from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault


Expand Down Expand Up @@ -280,7 +280,7 @@ def test_load_hub_with_dpo(self):
}
)

train_dataset, _ = load_prepare_dpo_datasets(cfg)
train_dataset, _ = load_prepare_preference_datasets(cfg)

assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_load_hub_with_revision_with_dpo(self):
}
)

train_dataset, _ = load_prepare_dpo_datasets(cfg)
train_dataset, _ = load_prepare_preference_datasets(cfg)

assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
Expand Down
6 changes: 3 additions & 3 deletions tests/test_exact_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from transformers import AutoTokenizer

from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_dpo_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_load_with_deduplication(self):
"""Verify that loading with deduplication removes duplicates."""

# Load the dataset using the deduplication setting
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
train_dataset, _ = load_prepare_preference_datasets(self.cfg)

# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
Expand All @@ -245,7 +245,7 @@ def test_load_without_deduplication(self):
"""Verify that loading without deduplication retains duplicates."""
self.cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_dpo_datasets(self.cfg)
train_dataset, _ = load_prepare_preference_datasets(self.cfg)

# Verify that the dataset retains duplicates
assert (
Expand Down

0 comments on commit 19cd83d

Please sign in to comment.