diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 61dc4403ad..4cdd2b6c30 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -3,11 +3,8 @@ import logging import os -import random import subprocess # nosec B404 import tempfile -from copy import deepcopy -from itertools import product from pathlib import Path from typing import Optional @@ -17,6 +14,7 @@ import axolotl from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs +from axolotl.cli.sweeps import generate_sweep_configs from axolotl.cli.utils import ( add_options_from_config, add_options_from_dataclass, @@ -29,76 +27,6 @@ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig -def generate_sweep_configs(base_config, sweeps_config): - """ - Recursively generates all possible configurations by applying sweeps to the base config. - - Args: - base_config (dict): The original configuration dictionary - sweeps_config (dict): Dictionary where keys are parameters and values are either: - - lists of values to sweep independently - - or for paired values, a list of dicts under the '_' key - - Returns: - list: List of all possible configuration dictionaries - - Example: - sweeps_config = { - 'learning_rate': [0.1, 0.01], - '_': [ - {'load_in_8bit': True, 'adapter': 'lora'}, - {'load_in_4bit': True, 'adapter': 'qlora'} - ] - } - """ - # Separate paired values from regular sweeps - paired_values = sweeps_config.get("_", []) - regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"} - - # Process regular sweeps - param_names = list(regular_sweeps.keys()) - param_values = list(regular_sweeps.values()) - - # Generate combinations for regular sweeps - regular_combinations = list(product(*param_values)) if param_values else [()] - - # Combine regular sweeps with paired values - all_combinations = [] - for reg_combo in regular_combinations: - if paired_values: - for paired_set in paired_values: - new_config = {} - # new_config = deepcopy(base_config) - # Combine regular parameters with paired parameters - full_combo = {**dict(zip(param_names, reg_combo)), **paired_set} - for param_name, param_value in full_combo.items(): - new_config[param_name] = param_value - print(new_config) - all_combinations.append(new_config) - else: - # If no paired values, just use regular combinations - # new_config = deepcopy(base_config) - new_config = {} - for param_name, param_value in zip(param_names, reg_combo): - new_config[param_name] = param_value - print(new_config) - all_combinations.append(new_config) - - # randomize the order of trials - random.seed(42) - random.shuffle(all_combinations) - - # Generate a new config for each combination - result_configs = [] - for combination in all_combinations: - new_config = deepcopy(base_config) - for param_name, param_value in combination.items(): - new_config[param_name] = param_value - result_configs.append(new_config) - - return result_configs - - @click.group() @click.version_option(version=axolotl.__version__, prog_name="axolotl") def cli(): diff --git a/src/axolotl/cli/sweeps.py b/src/axolotl/cli/sweeps.py new file mode 100644 index 0000000000..d21664964c --- /dev/null +++ b/src/axolotl/cli/sweeps.py @@ -0,0 +1,77 @@ +"""Utilities for handling sweeps over configs for axolotl train CLI command""" + +import random +from copy import deepcopy +from itertools import product + + +def generate_sweep_configs( + base_config: dict[str, list], sweeps_config: dict[str, list] +) -> list[dict[str, list]]: + """ + Recursively generates all possible configurations by applying sweeps to the base config. + + Args: + base_config (dict): The original configuration dictionary + sweeps_config (dict): Dictionary where keys are parameters and values are either: + - lists of values to sweep independently + - or for paired values, a list of dicts under the '_' key + + Returns: + list: List of all possible configuration dictionaries + + Example: + sweeps_config = { + 'learning_rate': [0.1, 0.01], + '_': [ + {'load_in_8bit': True, 'adapter': 'lora'}, + {'load_in_4bit': True, 'adapter': 'qlora'} + ] + } + """ + # Separate paired values from regular sweeps + paired_values = sweeps_config.get("_", []) + regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"} + + # Process regular sweeps + param_names = list(regular_sweeps.keys()) + param_values = list(regular_sweeps.values()) + + # Generate combinations for regular sweeps + regular_combinations = list(product(*param_values)) if param_values else [()] + + # Combine regular sweeps with paired values + all_combinations = [] + for reg_combo in regular_combinations: + if paired_values: + for paired_set in paired_values: + new_config = {} + # new_config = deepcopy(base_config) + # Combine regular parameters with paired parameters + full_combo = {**dict(zip(param_names, reg_combo)), **paired_set} + for param_name, param_value in full_combo.items(): + new_config[param_name] = param_value + print(new_config) + all_combinations.append(new_config) + else: + # If no paired values, just use regular combinations + # new_config = deepcopy(base_config) + new_config = {} + for param_name, param_value in zip(param_names, reg_combo): + new_config[param_name] = param_value + print(new_config) + all_combinations.append(new_config) + + # randomize the order of trials + random.seed(42) + random.shuffle(all_combinations) + + # Generate a new config for each combination + result_configs = [] + for combination in all_combinations: + new_config = deepcopy(base_config) + for param_name, param_value in combination.items(): + new_config[param_name] = param_value + result_configs.append(new_config) + + return result_configs