Skip to content

Commit

Permalink
review comments first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
djsaunde committed Feb 11, 2025
1 parent b2701dd commit e4021fb
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 41 deletions.
2 changes: 1 addition & 1 deletion docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ lora_fan_in_fan_out: false

# Apply custom LoRA autograd functions and activation function Triton kernels for
# speed and memory savings
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/kernels.html
# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
Expand Down
58 changes: 36 additions & 22 deletions docs/kernels.qmd → docs/lora_optims.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,39 @@ kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.

## Custom autograd functions
## Usage

These optimizations can be enabled in your Axolotl config YAML file. The
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
`lora_o_kernel` enable the fused query-key-value projection and optimized output
projection, respectively.

```yaml
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
```
## Requirements
- One or more NVIDIA GPUs (in order to use the Triton kernels)
- Targeted LoRA adapters cannot use Dropout
- This may limit model expressivity / cause overfitting
- Targeted LoRA adapters cannot have bias terms
- This may limit model expressivity
- One of the following model architectures (`model.config.model_type`):
- `llama`
- `mistral`
- `qwen2`
- `gemma`
- `gemma2`

Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
be re-finetuned without these features in order to be useful.

## Implementation details

### Custom autograd functions

The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the
LoRA and base weight computations together and provides a single, efficient backward
Expand All @@ -22,13 +54,13 @@ handles the query, key, and value projections, and a function that handles the o
projection. They are designed to work with the existing `transformers` attention
implementation via some monkey-patching logic.

## Triton kernels
### Triton kernels

Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for
improved speed and memory performance. These kernels handle both the forward and
backward passes.

## Integration
### Integration

The custom autograd functions and Triton kernels are designed to work together. The
autograd function manages the high-level computation flow and gradient tracking, while
Expand All @@ -37,27 +69,9 @@ pass, the kernel computes both the activation output and the required gradients,
the autograd function then uses to compute the final gradients for the entire
computation path.

## Axolotl config

These optimizations can be enabled in your Axolotl config YAML file. The
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and
`lora_o_kernel` enable the fused query-key-value projection and optimized output
projection, respectively.

```yaml
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
```
Several requirements must be met for these optimizations to be applied:
- A GPU (in order to use the Triton kernels)
- Targeted LoRA adapters cannot use Dropout
- Targeted LoRA adapters cannot have bias terms
## Future Work

- Support for more architectures
- Support for additional model architectures
- Support for FSDP and DeepSpeed multi-GPU settings
- Support for dropout and bias
- Use Triton autotune to improve kernel performance
Expand Down
1 change: 0 additions & 1 deletion examples/llama-3/lora-1b-kernels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ lora_model_dir:

sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true

lora_r: 16
Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/monkeypatch/lora_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ def apply_lora_kernel_patches(model: PeftModelForCausalLM, cfg: DictDefault):

if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model

# This needs to be reset after patching
original_level = LOG.getEffectiveLevel()
LOG.setLevel(logging.INFO)

# Choose activation based on model type
Expand Down Expand Up @@ -232,4 +234,6 @@ def apply_lora_kernel_patches(model: PeftModelForCausalLM, cfg: DictDefault):
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)

LOG.setLevel(original_level)

return model
4 changes: 2 additions & 2 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,7 @@ def check_qlora_unsloth(cls, data):
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
Expand All @@ -1538,7 +1538,7 @@ def check_lora_8bit(cls, data):
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
raise ValueError(
"lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA"
)
Expand Down
4 changes: 1 addition & 3 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,7 @@ def has_flash_attn(self) -> bool:
return importlib.util.find_spec("flash_attn") is not None

def patch_loss_llama(self) -> None:
"""
Patch loss functions
"""
"""Patch loss functions and other optimizations"""
if self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
Expand Down
10 changes: 4 additions & 6 deletions tests/e2e/kernels/test_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ def test_geglu_forward_shape():

def test_geglu_forward_values():
"""Test GEGLU forward pass matches PyTorch reference implementation."""
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")

# Custom implementation
triton_out = geglu_forward(gate.clone(), up.clone())

# PyTorch reference - GELU instead of SiLU
# PyTorch reference
torch_out = F.gelu(gate) * up

assert torch.allclose(triton_out, torch_out, rtol=1e-4)
Expand All @@ -49,9 +49,7 @@ def test_geglu_backward():
up_clone = up.clone().detach()
grad_output_clone = grad_output.clone()

h, grad_gate, grad_up = geglu_backward(
grad_output_clone.clone(), gate_clone.clone(), up_clone.clone()
)
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)

# Compare outputs and gradients
assert torch.allclose(h, torch_out, rtol=1e-4)
Expand Down
4 changes: 0 additions & 4 deletions tests/e2e/kernels/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def test_matmul_lora(sample_tensors):
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA")
@pytest.mark.parametrize(
"activation_forward,activation_backward",
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
Expand Down Expand Up @@ -204,7 +203,6 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
assert not torch.isnan(X.grad).any()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA")
@pytest.mark.parametrize(
"activation_forward,activation_backward",
[(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)],
Expand Down Expand Up @@ -288,7 +286,6 @@ def test_lora_mlp_with_adapters(
assert not torch.isnan(down_B.grad).any()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA")
def test_lora_qkv(sample_tensors):
"""Tests LoRA QKV implementation with and without adapters"""
X = sample_tensors["X"]
Expand Down Expand Up @@ -502,7 +499,6 @@ def test_gradient_flow(sample_tensors):
assert not torch.isnan(B.grad).any()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA")
@pytest.mark.parametrize(
"apply_function",
[apply_lora_mlp_swiglu, apply_lora_mlp_geglu],
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/kernels/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def test_swiglu_forward_shape():

def test_swiglu_forward_values():
"""Test SwiGLU forward pass matches PyTorch reference implementation"""
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")

# Custom implementation
triton_out = swiglu_forward(gate.clone(), up.clone())
Expand Down

0 comments on commit e4021fb

Please sign in to comment.