We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a7811ad commit 6cdcb8dCopy full SHA for 6cdcb8d
src/axolotl/cli/train.py
@@ -17,6 +17,7 @@
17
from axolotl.common.datasets import load_datasets, load_preference_datasets
18
from axolotl.integrations.base import PluginManager
19
from axolotl.train import train
20
+from axolotl.utils import set_pytorch_cuda_alloc_conf
21
from axolotl.utils.config import normalize_config, resolve_dtype
22
from axolotl.utils.dict import DictDefault
23
@@ -33,6 +34,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
33
34
cfg: Dictionary mapping `axolotl` config keys to values.
35
cli_args: Training-specific CLI arguments.
36
"""
37
+ # Enable expandable segments for cuda allocation to improve VRAM usage
38
+ set_pytorch_cuda_alloc_conf()
39
+
40
print_axolotl_text_art()
41
check_accelerate_default_config()
42
if int(os.getenv("LOCAL_RANK", "0")) == 0:
0 commit comments