diff --git a/docs/config.qmd b/docs/config.qmd index cc3a0232a5..b5fab25ba6 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -302,7 +302,7 @@ lora_fan_in_fan_out: false # Apply custom LoRA autograd functions and activation function Triton kernels for # speed and memory savings -# See: https://axolotl-ai-cloud.github.io/axolotl/docs/kernels.html +# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html lora_mlp_kernel: true lora_qkv_kernel: true lora_o_kernel: true diff --git a/docs/kernels.qmd b/docs/lora_optims.qmd similarity index 79% rename from docs/kernels.qmd rename to docs/lora_optims.qmd index 91aef0ee19..18587d663f 100644 --- a/docs/kernels.qmd +++ b/docs/lora_optims.qmd @@ -11,7 +11,39 @@ kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was leverage operator fusion and tensor re-use in order to improve speed and reduce memory usage during the forward and backward passes of these calculations. -## Custom autograd functions +## Usage + +These optimizations can be enabled in your Axolotl config YAML file. The +`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and +`lora_o_kernel` enable the fused query-key-value projection and optimized output +projection, respectively. + +```yaml +lora_mlp_kernel: true +lora_qkv_kernel: true +lora_o_kernel: true +``` + +## Requirements + +- One or more NVIDIA GPUs (in order to use the Triton kernels) +- Targeted LoRA adapters cannot use Dropout + - This may limit model expressivity / cause overfitting +- Targeted LoRA adapters cannot have bias terms + - This may limit model expressivity +- One of the following model architectures (`model.config.model_type`): + - `llama` + - `mistral` + - `qwen2` + - `gemma` + - `gemma2` + +Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to +be re-finetuned without these features in order to be useful. + +## Implementation details + +### Custom autograd functions The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the LoRA and base weight computations together and provides a single, efficient backward @@ -22,13 +54,13 @@ handles the query, key, and value projections, and a function that handles the o projection. They are designed to work with the existing `transformers` attention implementation via some monkey-patching logic. -## Triton kernels +### Triton kernels Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for improved speed and memory performance. These kernels handle both the forward and backward passes. -## Integration +### Integration The custom autograd functions and Triton kernels are designed to work together. The autograd function manages the high-level computation flow and gradient tracking, while @@ -37,27 +69,9 @@ pass, the kernel computes both the activation output and the required gradients, the autograd function then uses to compute the final gradients for the entire computation path. -## Axolotl config - -These optimizations can be enabled in your Axolotl config YAML file. The -`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and -`lora_o_kernel` enable the fused query-key-value projection and optimized output -projection, respectively. - -```yaml -lora_mlp_kernel: true -lora_qkv_kernel: true -lora_o_kernel: true -``` - -Several requirements must be met for these optimizations to be applied: -- A GPU (in order to use the Triton kernels) -- Targeted LoRA adapters cannot use Dropout -- Targeted LoRA adapters cannot have bias terms - ## Future Work -- Support for more architectures +- Support for additional model architectures - Support for FSDP and DeepSpeed multi-GPU settings - Support for dropout and bias - Use Triton autotune to improve kernel performance diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml index aa3981a269..9c47f266f1 100644 --- a/examples/llama-3/lora-1b-kernels.yml +++ b/examples/llama-3/lora-1b-kernels.yml @@ -18,7 +18,6 @@ lora_model_dir: sequence_len: 2048 sample_packing: true -eval_sample_packing: true pad_to_sequence_len: true lora_r: 16 diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 559f7f88a1..e76e828d4a 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -140,9 +140,11 @@ def apply_lora_kernel_patches(model: PeftModelForCausalLM, cfg: DictDefault): if not can_patch: LOG.warning("Cannot patch layers - requires no dropout and no bias") + LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file") return model # This needs to be reset after patching + original_level = LOG.getEffectiveLevel() LOG.setLevel(logging.INFO) # Choose activation based on model type @@ -232,4 +234,6 @@ def apply_lora_kernel_patches(model: PeftModelForCausalLM, cfg: DictDefault): "Cannot patch some attention output projection - requires LoRA adapters with no bias" ) + LOG.setLevel(original_level) + return model diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0e7afc2b53..4f4bc9fbc7 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1524,7 +1524,7 @@ def check_qlora_unsloth(cls, data): or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o") ): - if data.get("adapter") == "lora" or data.get("load_in_8bit"): + if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" ) @@ -1538,7 +1538,7 @@ def check_lora_8bit(cls, data): or data.get("lora_qkv_kernel") or data.get("lora_o_kernel") ): - if data.get("adapter") == "lora" or data.get("load_in_8bit"): + if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( "lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA" ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 26eb7dedc5..032e902a3a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -469,9 +469,7 @@ def has_flash_attn(self) -> bool: return importlib.util.find_spec("flash_attn") is not None def patch_loss_llama(self) -> None: - """ - Patch loss functions - """ + """Patch loss functions and other optimizations""" if self.has_flash_attn: from axolotl.monkeypatch.llama_attn_hijack_flash import ( patch_fa_llama_cross_entropy, diff --git a/tests/e2e/kernels/test_geglu.py b/tests/e2e/kernels/test_geglu.py index bd09a9765b..b6e4c14a8f 100644 --- a/tests/e2e/kernels/test_geglu.py +++ b/tests/e2e/kernels/test_geglu.py @@ -21,13 +21,13 @@ def test_geglu_forward_shape(): def test_geglu_forward_values(): """Test GEGLU forward pass matches PyTorch reference implementation.""" - gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True) - up = torch.randn(2, 3, 64, device="cuda", requires_grad=True) + gate = torch.randn(2, 3, 64, device="cuda") + up = torch.randn(2, 3, 64, device="cuda") # Custom implementation triton_out = geglu_forward(gate.clone(), up.clone()) - # PyTorch reference - GELU instead of SiLU + # PyTorch reference torch_out = F.gelu(gate) * up assert torch.allclose(triton_out, torch_out, rtol=1e-4) @@ -49,9 +49,7 @@ def test_geglu_backward(): up_clone = up.clone().detach() grad_output_clone = grad_output.clone() - h, grad_gate, grad_up = geglu_backward( - grad_output_clone.clone(), gate_clone.clone(), up_clone.clone() - ) + h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone) # Compare outputs and gradients assert torch.allclose(h, torch_out, rtol=1e-4) diff --git a/tests/e2e/kernels/test_lora.py b/tests/e2e/kernels/test_lora.py index 17dbd7c39d..195749863a 100644 --- a/tests/e2e/kernels/test_lora.py +++ b/tests/e2e/kernels/test_lora.py @@ -153,7 +153,6 @@ def test_matmul_lora(sample_tensors): assert out3.shape == (X.shape[0], X.shape[1], W.shape[0]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") @pytest.mark.parametrize( "activation_forward,activation_backward", [(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)], @@ -204,7 +203,6 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward assert not torch.isnan(X.grad).any() -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") @pytest.mark.parametrize( "activation_forward,activation_backward", [(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)], @@ -288,7 +286,6 @@ def test_lora_mlp_with_adapters( assert not torch.isnan(down_B.grad).any() -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") def test_lora_qkv(sample_tensors): """Tests LoRA QKV implementation with and without adapters""" X = sample_tensors["X"] @@ -502,7 +499,6 @@ def test_gradient_flow(sample_tensors): assert not torch.isnan(B.grad).any() -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires CUDA") @pytest.mark.parametrize( "apply_function", [apply_lora_mlp_swiglu, apply_lora_mlp_geglu], diff --git a/tests/e2e/kernels/test_swiglu.py b/tests/e2e/kernels/test_swiglu.py index cf7098ba17..dc5879989a 100644 --- a/tests/e2e/kernels/test_swiglu.py +++ b/tests/e2e/kernels/test_swiglu.py @@ -21,8 +21,8 @@ def test_swiglu_forward_shape(): def test_swiglu_forward_values(): """Test SwiGLU forward pass matches PyTorch reference implementation""" - gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True) - up = torch.randn(2, 3, 64, device="cuda", requires_grad=True) + gate = torch.randn(2, 3, 64, device="cuda") + up = torch.randn(2, 3, 64, device="cuda") # Custom implementation triton_out = swiglu_forward(gate.clone(), up.clone())