Skip to content

Commit

Permalink
Take split param from config in all load_dataset instances (#2281)
Browse files Browse the repository at this point in the history
  • Loading branch information
mashdragon authored Jan 24, 2025
1 parent 74f9782 commit b2774af
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/axolotl/utils/data/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token):
except (FileNotFoundError, ConnectionError):
pass

# gather extra args from the config
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
else:
load_ds_kwargs["split"] = None

# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
Expand All @@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
**load_ds_kwargs,
)
else:
try:
Expand All @@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token):
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
**load_ds_kwargs,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
Expand All @@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
**load_ds_kwargs,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
Expand All @@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
Expand All @@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
else:
if isinstance(config_dataset.data_files, str):
Expand Down Expand Up @@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
**load_ds_kwargs,
)
if not ds:
raise ValueError("unhandled dataset load")
Expand Down

0 comments on commit b2774af

Please sign in to comment.