Skip to content

Commit

Permalink
Merge branch 'main' into cli_main_reorg
Browse files Browse the repository at this point in the history
  • Loading branch information
djsaunde authored Feb 17, 2025
2 parents 730ffca + 3d8425f commit e434b89
Show file tree
Hide file tree
Showing 22 changed files with 3,102 additions and 22 deletions.
4 changes: 2 additions & 2 deletions cicd/cicd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"

pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
4 changes: 1 addition & 3 deletions cicd/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
modal application to run axolotl gpu tests in Modal
"""
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code

import os
Expand Down
7 changes: 7 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ lora_modules_to_save:

lora_fan_in_fan_out: false

# Apply custom LoRA autograd functions and activation function Triton kernels for
# speed and memory savings
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true

# LoRA+ hyperparameters
# For more details about the following options, see:
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
Expand Down
127 changes: 127 additions & 0 deletions docs/lora_optims.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
---
title: "LoRA Optimizations"
description: "Custom autograd functions and Triton kernels in Axolotl for optimized
LoRA fine-tuning"
---

Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.

We currently support several common model architectures, including (but not limited to):
- `llama`
- `mistral`
- `qwen2`
- `gemma`
- `gemma2`

<details>

The set of models we support is currently limited by our attention patching strategy,
which assumes (and replaces) specific code blocks for query / key / value and output
projections:

```python
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)

ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
""".lstrip(
"\n"
)
```

Is replaced with:

```python
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)

PATCHED_O_CODE = """
attn_output = self.apply_o(attn_output)
""".lstrip(
"\n"
)
```

Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module.

We welcome testing of other model architectures and / or PRs to expand our patching
logic to be compatible with more of them.

</details>

## Usage

These optimizations can be enabled in your Axolotl config YAML file. The
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
`lora_o_kernel` enable the fused query-key-value projection and optimized output
projection, respectively.

```yaml
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
```
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
- AMD can be used with experimental Triton support by setting the environment variable `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`
- Targeted LoRA adapters cannot use Dropout
- This may limit model expressivity / cause overfitting
- Targeted LoRA adapters cannot have bias terms
- This may limit model expressivity

Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
be re-finetuned without these features in order to be useful.

## Implementation details

### Custom autograd functions

The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
LoRA and base weight computations together and provides a single, efficient backward
pass for the entire MLP block.

For attention components, similar optimizations are provided through a function that
handles the query, key, and value projections, and a function that handles the output
projection. They are designed to work with the existing `transformers` attention
implementation via some monkey-patching logic.

### Triton kernels

Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
improved speed and memory performance. These kernels handle both the forward and
backward passes.

### Integration

The custom autograd functions and Triton kernels are designed to work together. The
autograd function manages the high-level computation flow and gradient tracking, while
calling the Triton kernels for the activation function computation. During the backward
pass, the kernel computes both the activation output and the required gradients, which
the autograd function then uses to compute the final gradients for the entire
computation path.

## Future Work

- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Additional operator fusions
82 changes: 82 additions & 0 deletions examples/llama-3/lora-1b-kernels.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out

adapter: lora
lora_model_dir:

sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

lora_r: 16
lora_alpha: 32
# Currently, we don't support dropout with our custom Triton kernels
# lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj

# These options enable our custom Triton kernels / autograd
# functions for MLP and attention calculations
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"
5 changes: 4 additions & 1 deletion src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def train(
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train

if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
Expand Down Expand Up @@ -129,6 +128,8 @@ def iter_configs():
try:
if accelerate:
if cloud:
from axolotl.cli.cloud import do_cli_train

cwd = os.getcwd()
do_cli_train(
cloud_config=cloud,
Expand Down Expand Up @@ -157,6 +158,8 @@ def iter_configs():
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
from axolotl.cli.cloud import do_cli_train

do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
Expand Down
Empty file added src/axolotl/kernels/__init__.py
Empty file.
Loading

0 comments on commit e434b89

Please sign in to comment.