Skip to content

Commit 6cdcb8d

Browse files
authored
Set the pytorch_cuda_alloc_conf env in the train module (axolotl-ai-cloud#2447)
1 parent a7811ad commit 6cdcb8d

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/axolotl/cli/train.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from axolotl.common.datasets import load_datasets, load_preference_datasets
1818
from axolotl.integrations.base import PluginManager
1919
from axolotl.train import train
20+
from axolotl.utils import set_pytorch_cuda_alloc_conf
2021
from axolotl.utils.config import normalize_config, resolve_dtype
2122
from axolotl.utils.dict import DictDefault
2223

@@ -33,6 +34,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
3334
cfg: Dictionary mapping `axolotl` config keys to values.
3435
cli_args: Training-specific CLI arguments.
3536
"""
37+
# Enable expandable segments for cuda allocation to improve VRAM usage
38+
set_pytorch_cuda_alloc_conf()
39+
3640
print_axolotl_text_art()
3741
check_accelerate_default_config()
3842
if int(os.getenv("LOCAL_RANK", "0")) == 0:

0 commit comments

Comments
 (0)