diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 34a30db448..8097e9c56b 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -4,8 +4,8 @@ set -e python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ -# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ -pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ +pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure +pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/cicd/tests.py b/cicd/tests.py index b934d5316a..41ae2306fa 100644 --- a/cicd/tests.py +++ b/cicd/tests.py @@ -1,6 +1,4 @@ -""" - modal application to run axolotl gpu tests in Modal - """ +"""Modal app to run axolotl GPU tests""" # pylint: disable=duplicate-code import os diff --git a/docs/config.qmd b/docs/config.qmd index a7a1508625..19e33e7f0b 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -305,6 +305,13 @@ lora_modules_to_save: lora_fan_in_fan_out: false +# Apply custom LoRA autograd functions and activation function Triton kernels for +# speed and memory savings +# See: https://axolotl-ai-cloud.github.io/axolotl/docs/lora_optims.html +lora_mlp_kernel: true +lora_qkv_kernel: true +lora_o_kernel: true + # LoRA+ hyperparameters # For more details about the following options, see: # https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py` diff --git a/docs/lora_optims.qmd b/docs/lora_optims.qmd new file mode 100644 index 0000000000..a57aa854b1 --- /dev/null +++ b/docs/lora_optims.qmd @@ -0,0 +1,127 @@ +--- +title: "LoRA Optimizations" +description: "Custom autograd functions and Triton kernels in Axolotl for optimized +LoRA fine-tuning" +--- + +Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two +optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU +(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function +Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was +to leverage operator fusion and tensor re-use in order to improve speed and reduce +memory usage during the forward and backward passes of these calculations. + +We currently support several common model architectures, including (but not limited to): +- `llama` +- `mistral` +- `qwen2` +- `gemma` +- `gemma2` + +
+ +The set of models we support is currently limited by our attention patching strategy, +which assumes (and replaces) specific code blocks for query / key / value and output +projections: + +```python +ORIGINAL_QKV_CODE = """ + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) +""".lstrip( + "\n" +) + +ORIGINAL_O_CODE = """ + attn_output = self.o_proj(attn_output) +""".lstrip( + "\n" +) +``` + +Is replaced with: + +```python +PATCHED_QKV_CODE = """ + query_states, key_states, value_states = self.apply_qkv(hidden_states) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) +""".lstrip( + "\n" +) + +PATCHED_O_CODE = """ + attn_output = self.apply_o(attn_output) +""".lstrip( + "\n" +) +``` + +Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module. + +We welcome testing of other model architectures and / or PRs to expand our patching +logic to be compatible with more of them. + +
+ +## Usage + +These optimizations can be enabled in your Axolotl config YAML file. The +`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and +`lora_o_kernel` enable the fused query-key-value projection and optimized output +projection, respectively. + +```yaml +lora_mlp_kernel: true +lora_qkv_kernel: true +lora_o_kernel: true +``` + +## Requirements + +- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) + - AMD can be used with experimental Triton support by setting the environment variable `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` +- Targeted LoRA adapters cannot use Dropout + - This may limit model expressivity / cause overfitting +- Targeted LoRA adapters cannot have bias terms + - This may limit model expressivity + +Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to +be re-finetuned without these features in order to be useful. + +## Implementation details + +### Custom autograd functions + +The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the +LoRA and base weight computations together and provides a single, efficient backward +pass for the entire MLP block. + +For attention components, similar optimizations are provided through a function that +handles the query, key, and value projections, and a function that handles the output +projection. They are designed to work with the existing `transformers` attention +implementation via some monkey-patching logic. + +### Triton kernels + +Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for +improved speed and memory performance. These kernels handle both the forward and +backward passes. + +### Integration + +The custom autograd functions and Triton kernels are designed to work together. The +autograd function manages the high-level computation flow and gradient tracking, while +calling the Triton kernels for the activation function computation. During the backward +pass, the kernel computes both the activation output and the required gradients, which +the autograd function then uses to compute the final gradients for the entire +computation path. + +## Future Work + +- Support for additional model architectures +- Support for the FSDP setting +- Support for dropout and bias +- Additional operator fusions diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml new file mode 100644 index 0000000000..9c47f266f1 --- /dev/null +++ b/examples/llama-3/lora-1b-kernels.yml @@ -0,0 +1,82 @@ +base_model: NousResearch/Llama-3.2-1B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: false +strict: false + +datasets: + - path: teknium/GPT4-LLM-Cleaned + type: alpaca +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: lora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 16 +lora_alpha: 32 +# Currently, we don't support dropout with our custom Triton kernels +# lora_dropout: 0.05 +lora_fan_in_fan_out: +lora_target_modules: + - gate_proj + - down_proj + - up_proj + - q_proj + - v_proj + - k_proj + - o_proj + +# These options enable our custom Triton kernels / autograd +# functions for MLP and attention calculations +lora_mlp_kernel: true +lora_qkv_kernel: true +lora_o_kernel: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_steps: 10 +evals_per_epoch: 4 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: "<|end_of_text|>" diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 4cdd2b6c30..0f132c133c 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -95,7 +95,6 @@ def train( """ # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() - from axolotl.cli.cloud import do_cli_train if "use_ray" in kwargs and kwargs["use_ray"]: accelerate = False @@ -129,6 +128,8 @@ def iter_configs(): try: if accelerate: if cloud: + from axolotl.cli.cloud import do_cli_train + cwd = os.getcwd() do_cli_train( cloud_config=cloud, @@ -157,6 +158,8 @@ def iter_configs(): subprocess.run(cmd, check=True) # nosec B603 else: if cloud: + from axolotl.cli.cloud import do_cli_train + do_cli_train( cloud_config=cloud, config=config, accelerate=False, **kwargs ) diff --git a/src/axolotl/kernels/__init__.py b/src/axolotl/kernels/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/axolotl/kernels/geglu.py b/src/axolotl/kernels/geglu.py new file mode 100644 index 0000000000..4dd70f4cc9 --- /dev/null +++ b/src/axolotl/kernels/geglu.py @@ -0,0 +1,159 @@ +""" +Module for definition of GEGLU Triton kernels. + +See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). + +Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. +""" +# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code + +import torch +import triton +import triton.language as tl + +SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π) + + +@triton.jit +def _geglu_fwd_kernel( + gate_ptr, + up_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """GEGLU forward kernel. + + Args: + gate_ptr: Pointer to gate tensor [*, hidden_dim]. + up_ptr: Pointer to up-projection tensor [*, hidden_dim]. + out_ptr: Pointer to output tensor [*, hidden_dim]. + n_elements: Total number of elements in the input tensors. + BLOCK_SIZE: Size of thread blocks for parallel computation. + """ + block_idx = tl.program_id(0) + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32) + up = tl.load(up_ptr + offsets, mask=mask, other=0) + + # Compute activation in fp32 then convert back + gelu_gate = 0.5 * gate * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0) + gelu_gate = gelu_gate.to(up.dtype) + result = gelu_gate * up + + tl.store(out_ptr + offsets, result, mask=mask) + + +def geglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """GEGLU forward pass. + + Args: + gate: Input gate tensor of shape [batch, seq_len, hidden_dim]. + up: Up-projection tensor of shape [batch, seq_len, hidden_dim]. + + Returns: + torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim]. + """ + batch, seq_len, hidden_dim = gate.shape + n_elements = gate.numel() + out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda") + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731 + _geglu_fwd_kernel[grid]( + gate_ptr=gate, + up_ptr=up, + out_ptr=out, + n_elements=n_elements, + BLOCK_SIZE=1024, + ) + return out + + +@triton.jit +def _geglu_bwd_kernel( + grad_out_ptr, + gate_ptr, + up_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """GEGLU backward kernel. Stores gradient results in-place. + + Args: + grad_out_ptr: Pointer to gradient output tensor [*, hidden_dim]. + gate_ptr: Pointer to gate tensor [*, hidden_dim]. + up_ptr: Pointer to up-projection tensor [*, hidden_dim]. + n_elements: Total number of elements in the input tensors. + BLOCK_SIZE: Size of thread blocks for parallel computation. + + Note: + After kernel execution, tensors are modified in-place: + - `grad_out_ptr` contains GEGLU activation output (`h`) + - `gate_ptr` contains gradient w.r.t gate (`grad_gate`) + - `up_ptr` contains gradient w.r.t up (`grad_up`) + """ + block_idx = tl.program_id(0) + offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0) + gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32) + up = tl.load(up_ptr + offsets, mask=mask, other=0) + + # Forward pass + gelu_partial = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * gate) + 1.0) + gelu_gate = gelu_partial * gate + gelu_gate = gelu_gate.to(grad_out.dtype) + + # Forward output + h = gelu_gate * up + + # Compute gradients + grad_up = grad_out * gelu_gate + + # Compute gate gradient using GELU derivative + temp = grad_out * up + t = 0.3989422804014327 # 1/sqrt(2*pi) + dgelu_dgate = gelu_partial + t * gate * tl.exp(-0.5 * gate * gate) + grad_gate = temp.to(tl.float32) * dgelu_dgate + grad_gate = grad_gate.to(grad_out.dtype) + + # Store results + tl.store(grad_out_ptr + offsets, h, mask=mask) + tl.store(gate_ptr + offsets, grad_gate, mask=mask) + tl.store(up_ptr + offsets, grad_up, mask=mask) + + +def geglu_backward( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """GEGLU backward pass using in-place operations. + + Args: + grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`. + gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`. + up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`. + + Returns: + Tuple containing: + - GEGLU activation output (`h`) + - Gradient with respect to gate (`grad_gate`) + - Gradient with respect to up (`grad_up`) + + Note: + This function modifies its input tensors in-place to store results. + """ + n_elements = grad_output.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) # noqa: E731 + _geglu_bwd_kernel[grid]( + grad_out_ptr=grad_output, + gate_ptr=gate, + up_ptr=up, + n_elements=n_elements, + BLOCK_SIZE=1024, + ) + + return grad_output, gate, up diff --git a/src/axolotl/kernels/lora.py b/src/axolotl/kernels/lora.py new file mode 100644 index 0000000000..1f8a8e787b --- /dev/null +++ b/src/axolotl/kernels/lora.py @@ -0,0 +1,779 @@ +""" +Module for definition of Low-Rank Adaptation (LoRA) Triton kernels. + +See "LoRA: Low-Rank Adaptation of Large Language Models" +(https://arxiv.org/abs/2106.09685). + +Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. +""" +# pylint: disable=invalid-name + +from typing import Callable + +import torch +from bitsandbytes.functional import QuantState +from torch import nn + +from .geglu import geglu_backward, geglu_forward +from .quantize import dequantize +from .swiglu import swiglu_backward, swiglu_forward +from .utils import torch_amp_custom_bwd, torch_amp_custom_fwd + + +def get_lora_parameters( + proj: nn.Module, +) -> tuple[ + torch.Tensor, + QuantState | None, + torch.Tensor | None, + torch.Tensor | None, + float | None, +]: + """ + Gets LoRA parameters from a projection module. + + Args: + proj: The projection module to extract parameters from. + + Returns: + A tuple containing the base weight matrix, quantization state, LoRA A matrix, + LoRA B matrix, and scaling factor. States and matrices may be None if not + available. + """ + # For DPO or disabled adapters + base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj + W = base_layer.weight + + if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged: + quant_state = getattr(W, "quant_state", None) + return W, quant_state, None, None, None + + active_adapter = ( + proj.active_adapters[0] + if hasattr(proj, "active_adapters") + else proj.active_adapter + ) + A = proj.lora_A[active_adapter].weight + B = proj.lora_B[active_adapter].weight + s = proj.scaling[active_adapter] + + quant_state = getattr(W, "quant_state", None) + + return W, quant_state, A, B, s + + +def matmul_lora( + X: torch.Tensor, + W: torch.Tensor, + W_quant: QuantState, + A: torch.Tensor, + B: torch.Tensor, + s: float, + out: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Efficient fused matmul + LoRA computation. + + Args: + X: Input tensor [*, in_features] + W: Base weight matrix [out_features, in_features] + W_quant: Quantization state for W + A: LoRA A matrix [rank, in_features] + B: LoRA B matrix [out_features, rank] + s: LoRA scaling factor + out: Optional output tensor for inplace operations + + Returns: + Result of X @ W + X @ A @ B + """ + dtype = X.dtype + W = dequantize(W.t(), W_quant) + + if X.dim() == 3: + batch, seq_len, _ = X.shape + X = X.view(-1, X.shape[-1]) + reshape = True + else: + reshape = False + + out = torch.matmul(X, W, out=out) + if W_quant is not None: + del W + + if A is not None: + A, B = A.t(), B.t() + out += (X @ A.to(dtype)) @ (s * B.to(dtype)) + + return out.view(batch, seq_len, -1) if reshape else out + + +class LoRA_MLP(torch.autograd.Function): + """Optimized LoRA MLP implementation.""" + + @staticmethod + @torch_amp_custom_fwd + def forward( + ctx, + X: torch.Tensor, + gate_weight: torch.Tensor, + gate_quant: object | None, + gate_A: torch.Tensor | None, + gate_B: torch.Tensor | None, + gate_scale: float, + up_weight: torch.Tensor, + up_quant: object | None, + up_A: torch.Tensor | None, + up_B: torch.Tensor | None, + up_scale: float, + down_weight: torch.Tensor, + down_quant: object | None, + down_A: torch.Tensor | None, + down_B: torch.Tensor | None, + down_scale: float, + activation_fn: Callable, + activation_fn_backward: Callable, + inplace: bool | None = True, + ) -> torch.Tensor: + """ + Forward pass for LoRA MLP. + + Args: + ctx: Autograd context + X: Input features + gate_weight: Gate projection weight + gate_quant: Gate quantization state + gate_A: Gate LoRA A matrix + gate_B: Gate LoRA B matrix + gate_scale: Gate LoRA scale + up_weight: Up-projection weight + up_quant: Up-projection quantization state + up_A: Up-projection LoRA A matrix + up_B: Up-projection LoRA B matrix + up_scale: Up-projection LoRA scale + down_weight: Down-projection weight + down_quant: Down-projection quantization state + down_A: Down-projection LoRA A matrix + down_B: Down-projection LoRA B matrix + down_scale: Down-projection LoRA scale + activation_fn: Forward activation function + activation_fn_backward: Backward activation function + inplace: Whether to perform operations in-place + + Returns: + Output transformed by multi-layer perceptron and activation function + """ + # Compute projections + gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale) + up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale) + + # Activation + hidden = activation_fn(gate, up) + + # Down projection + output = matmul_lora( + hidden, down_weight, down_quant, down_A, down_B, down_scale + ) + + # Save for backward + ctx.save_for_backward(X, gate, up, gate_A, gate_B, up_A, up_B, down_A, down_B) + ctx.scales = (gate_scale, up_scale, down_scale) + ctx.quants = (gate_quant, up_quant, down_quant) + ctx.weights = (gate_weight, up_weight, down_weight) + ctx.activation_fn = activation_fn + ctx.activation_fn_backward = activation_fn_backward + ctx.inplace = inplace + + return output + + @staticmethod + @torch_amp_custom_bwd + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, + ) -> tuple[ + torch.Tensor | None, + None, + None, + torch.Tensor | None, + torch.Tensor | None, + None, + None, + None, + torch.Tensor | None, + torch.Tensor | None, + None, + None, + None, + torch.Tensor | None, + torch.Tensor | None, + None, + None, + None, + None, + ]: + """ + Performs backward pass computation for LoRA MLP. + + Args: + ctx: Context object storing tensors saved during forward pass + grad_output: Gradient of loss with respect to layer output + + Returns: + Tuple containing gradients for all inputs from forward pass: + - Input gradient tensor (or `None`) + - `None` for weights/quantization states + - LoRA A/B matrix gradients (or `None`) + - `None` for scaling factors + - `None` for activation functions and flags + """ + ( + X, + gate, + up, + gate_A, + gate_B, + up_A, + up_B, + down_A, + down_B, + ) = ctx.saved_tensors + gate_scale, up_scale, down_scale = ctx.scales + gate_quant, up_quant, down_quant = ctx.quants + gate_weight, up_weight, down_weight = ctx.weights + + # Transpose all LoRA matrices + gate_A, gate_B = ( + gate_A.t() if gate_A is not None else None, + gate_B.t() if gate_B is not None else None, + ) + up_A, up_B = ( + up_A.t() if up_A is not None else None, + up_B.t() if up_B is not None else None, + ) + down_A, down_B = ( + down_A.t() if down_A is not None else None, + down_B.t() if down_B is not None else None, + ) + + # Reshape inputs + batch, seq_len, hd = X.shape + grad_output = grad_output.view(-1, grad_output.shape[-1]) + X = X.view(-1, X.shape[-1]) + gate = gate.view(-1, gate.shape[-1]) + up = up.view(-1, up.shape[-1]) + dtype = X.dtype + + # Down projection + DW = matmul_lora( + grad_output, + down_weight.t(), + down_quant, + down_B, + down_A, + down_scale, + ) + + # Activation backward + h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up) + + # Initialize and compute LoRA gradients + d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None + + if down_A is not None: + d_down_A = h.t() @ (grad_output @ down_B.t()) + d_down_B = (down_A.t() @ h.t()) @ grad_output + d_down_A *= down_scale + d_down_B *= down_scale + + if up_A is not None: + d_up_A = X.t() @ (grad_up @ up_B.t()) + d_up_B = (up_A.t() @ X.t()) @ grad_up + d_up_A *= up_scale + d_up_B *= up_scale + + if gate_A is not None: + d_gate_A = X.t() @ (grad_gate @ gate_B.t()) + d_gate_B = (gate_A.t() @ X.t()) @ grad_gate + d_gate_A *= gate_scale + d_gate_B *= gate_scale + + # Compute input gradients + dX = torch.zeros_like(X) if ctx.needs_input_grad[0] else None + + if dX is not None: + # Up projection gradients + up_weight = dequantize(up_weight.t(), up_quant) + if ctx.inplace: + dX = torch.matmul(grad_up, up_weight.t(), out=X) + else: + dX = torch.matmul(grad_up, up_weight.t()) + del up_weight + + # Note the .to(dtype) only where mixing LoRA with base weights + if up_A is not None: + dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t()) + + # Gate projection gradients + gate_weight = dequantize(gate_weight.t(), gate_quant) + dX += grad_gate @ gate_weight.t() + del gate_weight + + if gate_A is not None: + dX += ( + grad_gate + @ gate_B.to(dtype).t() + @ (gate_scale * gate_A.to(dtype).t()) + ) + + # Reshape back + dX = dX.view(batch, seq_len, hd) + + # Return gradients in correct order matching forward inputs + return ( + dX, + None, + None, + d_gate_A.t() if d_gate_A is not None else None, + d_gate_B.t() if d_gate_B is not None else None, + None, + None, + None, + d_up_A.t() if d_up_A is not None else None, + d_up_B.t() if d_up_B is not None else None, + None, + None, + None, + d_down_A.t() if d_down_A is not None else None, + d_down_B.t() if d_down_B is not None else None, + None, + None, + None, + None, + ) + + +def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor: + """ + Applies LoRA to MLP layer with SwiGLU activation. + + Args: + X: Input tensor for the MLP layer + inplace: Whether to perform operations in-place to save memory + + Returns: + Output tensor after applying LoRA-adapted MLP with SwiGLU activation + """ + gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) + downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + + out = LoRA_MLP.apply( + X, + gateW, + gateW_quant, + gateA, + gateB, + gateS, + upW, + upW_quant, + upA, + upB, + upS, + downW, + downW_quant, + downA, + downB, + downS, + swiglu_forward, + swiglu_backward, + inplace, + ) + + return out + + +def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.Tensor: + """ + Applies LoRA to MLP layer with GEGLU activation. + + Args: + X: Input tensor for the MLP layer + inplace: Whether to perform operations in-place to save memory + + Returns: + Output tensor after applying LoRA-adapted MLP with GEGLU activation + """ + gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj) + upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj) + downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj) + out = LoRA_MLP.apply( + X, + gateW, + gateW_quant, + gateA, + gateB, + gateS, + upW, + upW_quant, + upA, + upB, + upS, + downW, + downW_quant, + downA, + downB, + downS, + geglu_forward, + geglu_backward, + inplace, + ) + + return out + + +class LoRA_QKV(torch.autograd.Function): + """ + Optimized LoRA QKV implementation with quantization support. + + Implements efficient computation of query, key, value projections with LoRA, + supporting quantization and memory optimization. + """ + + @staticmethod + @torch_amp_custom_fwd + def forward( + ctx: torch.autograd.function.FunctionCtx, + X: torch.Tensor, + q_weight: torch.Tensor, + q_quant: QuantState | None, + q_A: torch.Tensor | None, + q_B: torch.Tensor | None, + q_scale: float, + k_weight: torch.Tensor, + k_quant: QuantState | None, + k_A: torch.Tensor | None, + k_B: torch.Tensor | None, + k_scale: float, + v_weight: torch.Tensor, + v_quant: QuantState | None, + v_A: torch.Tensor | None, + v_B: torch.Tensor | None, + v_scale: float, + inplace: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass computing Q, K, V projections with LoRA. + + Args: + ctx: Autograd context + X: Input tensor + q_weight: Query projection weight + q_quant: Query quantization state + q_A: Query LoRA A matrix + q_B: Query LoRA B matrix + q_scale: Query LoRA scale + k_weight: Key projection weight + k_quant: Key quantization state + k_A: Key LoRA A matrix + k_B: Key LoRA B matrix + k_scale: Key LoRA scale + v_weight: Value projection weight + v_quant: Value quantization state + v_A: Value LoRA A matrix + v_B: Value LoRA B matrix + v_scale: Value LoRA scale + inplace: Whether to perform operations in-place + + Returns: + Tuple of (Query, Key, Value) projection tensors + """ + Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale) + K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale) + V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale) + + ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B) + ctx.scales = (q_scale, k_scale, v_scale) + ctx.quants = (q_quant, k_quant, v_quant) + ctx.weights = (q_weight, k_weight, v_weight) + ctx.inplace = inplace + + return Q, K, V + + @staticmethod + @torch_amp_custom_fwd + def backward( + ctx: torch.autograd.function.FunctionCtx, + q_grad: torch.Tensor, + k_grad: torch.Tensor, + v_grad: torch.Tensor, + ) -> tuple[ + torch.Tensor, + None, + None, + torch.Tensor | None, + torch.Tensor | None, + None, + None, + None, + torch.Tensor | None, + torch.Tensor | None, + None, + None, + None, + torch.Tensor | None, + torch.Tensor | None, + None, + None, + ]: + """ + Backward pass computing gradients for LoRA QKV. + + Args: + ctx: Autograd context + q_grad: Gradient for query projection + k_grad: Gradient for key projection + v_grad: Gradient for value projection + + Returns: + Tuple containing gradients for all forward inputs + """ + X, A_q, B_q, A_k, B_k, A_v, B_v = ctx.saved_tensors + q_weight, k_weight, v_weight = ctx.weights + q_quant, k_quant, v_quant = ctx.quants + q_scale, k_scale, v_scale = ctx.scales + dtype = X.dtype + + # Reshape gradients + batch, seq_len = X.shape[:2] + q_grad = q_grad.view(-1, q_grad.shape[-1]) + k_grad = k_grad.reshape(-1, k_grad.shape[-1]) + v_grad = v_grad.view(-1, v_grad.shape[-1]) + X = X.view(-1, X.shape[-1]) + + # Pre-transpose X once + X_t = X.t() + + # Initialize LoRA gradients as None + d_A_q = d_B_q = d_A_k = d_B_k = d_A_v = d_B_v = None + + # Compute q path LoRA gradients if adapters exist + if A_q is not None and B_q is not None: + A_q_scaled = (q_scale * A_q).to(dtype) + B_q_scaled = B_q.to(dtype) + d_A_q = torch.mm(X_t, torch.mm(q_grad, B_q_scaled)) + d_B_q = torch.mm(torch.mm(A_q_scaled, X_t), q_grad) + + # Compute k path LoRA gradients if adapters exist + if A_k is not None and B_k is not None: + A_k_scaled = (k_scale * A_k).to(dtype) + B_k_scaled = B_k.to(dtype) + d_A_k = torch.mm(X_t, torch.mm(k_grad, B_k_scaled)) + d_B_k = torch.mm(torch.mm(A_k_scaled, X_t), k_grad) + + # Compute v path LoRA gradients if adapters exist + if A_v is not None and B_v is not None: + A_v_scaled = (v_scale * A_v).to(dtype) + B_v_scaled = B_v.to(dtype) + d_A_v = torch.mm(X_t, torch.mm(v_grad, B_v_scaled)) + d_B_v = torch.mm(torch.mm(A_v_scaled, X_t), v_grad) + + # Compute input gradient, reusing X memory if possible + out_buffer = X if ctx.inplace else None + + # Q path + q_weight_t = dequantize(q_weight, q_quant) + grad_X = torch.mm(q_grad, q_weight_t, out=out_buffer) + del q_weight + del q_weight_t + if A_q is not None and B_q is not None: + grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled)) + + # K path + k_weight_t = dequantize(k_weight, k_quant) + grad_X.addmm_(k_grad, k_weight_t) + del k_weight + del k_weight_t + if A_k is not None and B_k is not None: + grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled)) + + # V path + v_weight_t = dequantize(v_weight, v_quant) + grad_X.addmm_(v_grad, v_weight_t) + del v_weight + del v_weight_t + if A_v is not None and B_v is not None: + grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled)) + + # Transpose gradients if needed + if d_A_q is not None: + d_A_q = d_A_q.t() + if d_B_q is not None: + d_B_q = d_B_q.t() + if d_A_k is not None: + d_A_k = d_A_k.t() + if d_B_k is not None: + d_B_k = d_B_k.t() + if d_A_v is not None: + d_A_v = d_A_v.t() + if d_B_v is not None: + d_B_v = d_B_v.t() + + return ( + grad_X.view(batch, seq_len, -1), + None, + None, + d_A_q, + d_B_q, + None, + None, + None, + d_A_k, + d_B_k, + None, + None, + None, + d_A_v, + d_B_v, + None, + None, + ) + + +def apply_lora_qkv( + self, X: torch.Tensor, inplace: bool = True +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Applies LoRA to compute Query, Key, Value projections. + + Args: + X: Input tensor + inplace: Whether to perform operations in-place + + Returns: + Tuple of (Query, Key, Value) projection tensors + """ + QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj) + KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj) + VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj) + Q, K, V = LoRA_QKV.apply( + X, + QW, + QW_quant, + QA, + QB, + QS, + KW, + KW_quant, + KA, + KB, + KS, + VW, + VW_quant, + VA, + VB, + VS, + inplace, + ) + + return Q, K, V + + +class LoRA_O(torch.autograd.Function): + """Optimized LoRA implementation for output projection.""" + + @staticmethod + @torch_amp_custom_fwd + def forward( + ctx: torch.autograd.function.FunctionCtx, + X: torch.Tensor, + W: torch.Tensor, + W_quant: QuantState | None, + A: torch.Tensor | None, + B: torch.Tensor | None, + S: float, + ) -> torch.Tensor: + """ + Forward pass for output projection with LoRA. + + Args: + ctx: Autograd context + X: Input tensor + W: Output projection weight + W_quant: Weight quantization state + A: LoRA A matrix + B: LoRA B matrix + S: LoRA scaling factor + + Returns: + Output projection tensor + """ + XW = matmul_lora(X, W, W_quant, A, B, S) + ctx.custom_saved_tensors = ( + W, + W_quant, + S, + ) + ctx.save_for_backward(A, B, X) + + return XW + + @staticmethod + @torch_amp_custom_bwd + def backward( + ctx: torch.autograd.function.FunctionCtx, + dY: torch.Tensor, + ) -> tuple[ + torch.Tensor, + None, + None, + torch.Tensor | None, + torch.Tensor | None, + None, + ]: + """ + Backward pass computing gradients for LoRA output projection. + + Args: + ctx: Autograd context + dY: Gradient of loss with respect to output + + Returns: + Tuple containing gradients for all forward inputs + """ + W, W_quant, S = ctx.custom_saved_tensors + A, B, X = ctx.saved_tensors + + batch, seq_len, hd = X.shape + dY = dY.reshape(-1, dY.shape[-1]) + X = X.reshape(-1, X.shape[-1]) + dtype = X.dtype + + # Weight projection + dY_X = X.t() @ dY + d_A = S * dY_X @ B + d_B = S * A @ dY_X + + # Get derivative for dX + W = dequantize(W.t(), W_quant) + dX = dY @ W.t() + del W + dX += dY @ B.to(dtype) @ (S * A.to(dtype)) + + # W, W_quant, A, B, S + return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None + + +def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor: + """ + Applies LoRA to output projection layer. + + Args: + X: Input tensor + + Returns: + Transformed output tensor + """ + OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj) + output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS) + + return output diff --git a/src/axolotl/kernels/quantize.py b/src/axolotl/kernels/quantize.py new file mode 100644 index 0000000000..ea5ecf8e84 --- /dev/null +++ b/src/axolotl/kernels/quantize.py @@ -0,0 +1,149 @@ +"""Dequantization utilities for `bitsandbytes` integration.""" +# pylint: disable=invalid-name,global-statement + +import ctypes + +import bitsandbytes as bnb +import torch +from bitsandbytes.functional import QuantState, get_ptr +from packaging.version import Version + +cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 +cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 +cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 + +CUDA_STREAM: torch.cuda.Stream | None = None +HAS_CUDA_STREAM: bool = Version(bnb.__version__) > Version("0.43.3") + + +def dequantize( + W: torch.Tensor, + quant_state: QuantState | list | None = None, + out: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Fast NF4 dequantization using `bitsandbytes` CUDA kernels. + + Performs efficient dequantization of weights from NF4 format using `bitsandbytes`' + optimized CUDA implementations. Supports both legacy list and new `QuantState` + formats. + + Args: + W: Quantized weight tensor to dequantize + quant_state: Quantization state containing metadata needed for + dequantization. Can be either a `QuantState` object or legacy list format. + If None, returns `W` unchanged. + out: Optional output tensor for storing dequantized results. Must match + expected shape and dtype if provided. + + Returns: + Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if + input `W` was transposed. + + Raises: + AssertionError: If provided output tensor doesn't match expected shape / dtype. + + Note: + Uses CUDA streams for better performance when available in newer `bitsandbytes` + versions (>0.43.3). + """ + if quant_state is None: + return W + + # Get the target device from input tensor W + target_device = W.device + + # Extract quantization state + if not isinstance(quant_state, list): + # New style quant_state class + absmax = quant_state.absmax.to(target_device) + shape = quant_state.shape + dtype = quant_state.dtype + blocksize = quant_state.blocksize + offset = quant_state.offset.to(target_device) + state2 = quant_state.state2 + absmax2 = state2.absmax.to(target_device) + code2 = state2.code.to(target_device) + blocksize2 = state2.blocksize + else: + # Legacy list format + absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state + absmax = absmax.to(target_device) + offset, state2 = compressed_stats + offset = offset.to(target_device) + absmax2, code2, blocksize2, _, _, _, _ = state2 + absmax2 = absmax2.to(target_device) + code2 = code2.to(target_device) + + # Setup output tensor on the same device as input + if out is None: + out = torch.empty(shape, dtype=dtype, device=target_device) + else: + assert out.shape == shape and out.dtype == dtype + out = out.to(target_device) + + # Dequantize statistics on the target device + n_elements_absmax: int = absmax.numel() + out_absmax: torch.Tensor = torch.empty( + n_elements_absmax, dtype=torch.float32, device=target_device + ) + ptr_out_absmax: int = get_ptr(out_absmax) + + # Use CUDA stream if available + if HAS_CUDA_STREAM: + global CUDA_STREAM + if CUDA_STREAM is None: + CUDA_STREAM = torch.cuda.current_stream(target_device) + + cdequantize_blockwise_fp32( + get_ptr(code2), + get_ptr(absmax), + get_ptr(absmax2), + ptr_out_absmax, + ctypes.c_int(blocksize2), + ctypes.c_int(n_elements_absmax), + CUDA_STREAM, + ) + else: + cdequantize_blockwise_fp32( + get_ptr(code2), + get_ptr(absmax), + get_ptr(absmax2), + ptr_out_absmax, + ctypes.c_int(blocksize2), + ctypes.c_int(n_elements_absmax), + ) + + out_absmax += offset + + # Choose appropriate dequantization function + fx = ( + cdequantize_blockwise_fp16_nf4 + if dtype == torch.float16 + else cdequantize_blockwise_bf16_nf4 + ) + + # Dequantize weights + if HAS_CUDA_STREAM: + fx( + get_ptr(None), + get_ptr(W), + ptr_out_absmax, + get_ptr(out), + ctypes.c_int(blocksize), + ctypes.c_int(out.numel()), + CUDA_STREAM, + ) + else: + fx( + get_ptr(None), + get_ptr(W), + ptr_out_absmax, + get_ptr(out), + ctypes.c_int(blocksize), + ctypes.c_int(out.numel()), + ) + + # Handle transposed data + is_transposed: bool = W.shape[0] == 1 + return out.t() if is_transposed else out diff --git a/src/axolotl/kernels/swiglu.py b/src/axolotl/kernels/swiglu.py new file mode 100644 index 0000000000..20c6e87a0f --- /dev/null +++ b/src/axolotl/kernels/swiglu.py @@ -0,0 +1,163 @@ +""" +Module for definition of SwiGLU Triton kernels. + +See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). + +Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation. +""" +import torch +import triton +import triton.language as tl + + +@triton.jit +def _swiglu_fwd_kernel( + gate_ptr, + up_ptr, + out_ptr, + n_elements, + block_size: tl.constexpr, +): + """ + SwiGLU forward kernel. The kernel computes activation in fp32 precision for better + numerical stability, then converts back to original dtype for the final result. + + Args: + gate_ptr: Pointer to gate tensor `[*, hidden_dim]`. + up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`. + out_ptr: Pointer to output tensor `[*, hidden_dim]`. + n_elements: Total number of elements in the input tensors. + block_size: Size of thread blocks for parallel computation. + """ + block_idx = tl.program_id(0) + offsets = block_idx * block_size + tl.arange(0, block_size) + mask = offsets < n_elements + + # Load gate in fp32, keep up in original dtype + gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32) + up = tl.load(up_ptr + offsets, mask=mask, other=0) + + # Compute activation in fp32 then convert back + f = gate * tl.sigmoid(gate) + f = f.to(up.dtype) + result = f * up + + tl.store(out_ptr + offsets, result, mask=mask) + + +@triton.jit +def _swiglu_bwd_kernel( + grad_out_ptr, + gate_ptr, + up_ptr, + n_elements, + block_size: tl.constexpr, +): + """ + SwiGLU backward kernel. Stores gradient results in-place. + + Args: + grad_out_ptr: Pointer to gradient output tensor `[*, hidden_dim]`. + gate_ptr: Pointer to gate tensor `[*, hidden_dim]`. + up_ptr: Pointer to up-projection tensor `[*, hidden_dim]`. + n_elements: Total number of elements in the input tensors. + block_size: Size of thread blocks for parallel computation. + + Note: + After kernel execution, tensors are modified in-place: + - `grad_out_ptr` contains forward output (`h`) + - `gate_ptr` contains gradient w.r.t gate (`grad_gate`) + - `up_ptr` contains gradient w.r.t up (`grad_up`) + """ + block_idx = tl.program_id(0) + offsets = block_idx * block_size + tl.arange(0, block_size) + mask = offsets < n_elements + + # Load values - only convert gate to fp32 + grad_out = tl.load(grad_out_ptr + offsets, mask=mask, other=0) + gate = tl.load(gate_ptr + offsets, mask=mask, other=0).to(tl.float32) + up = tl.load(up_ptr + offsets, mask=mask, other=0) + + # Compute SiLU and forward output + sigmoid_gate = tl.sigmoid(gate) + silu_gate = sigmoid_gate * gate + silu_gate = silu_gate.to(grad_out.dtype) + h = silu_gate * up + + # Compute gradients + grad_up = grad_out * silu_gate # gradient for up is grad_out * SiLU(gate) + + # Compute gate gradient + temp = grad_out * up + grad_gate = temp.to(tl.float32) * sigmoid_gate * (1.0 + gate * (1.0 - sigmoid_gate)) + grad_gate = grad_gate.to(grad_out.dtype) + + # Store results with correct gradient ordering + tl.store(grad_out_ptr + offsets, h, mask=mask) + tl.store(gate_ptr + offsets, grad_gate, mask=mask) # grad wrt gate + tl.store(up_ptr + offsets, grad_up, mask=mask) # grad wrt up + + +# pylint: disable=unnecessary-lambda-assignment +def swiglu_forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor: + """ + SwiGLU forward pass. Computes SwiGLU activation: `x * sigmoid(x) * up`, where + `x` is the gate tensor. + + Args: + gate: Input gate tensor of shape `[batch, seq_len, hidden_dim]`. + up: Up-projection tensor of shape `[batch, seq_len, hidden_dim]`. + + Returns: + Output tensor of shape `[batch, seq_len, hidden_dim]`. + """ + batch, seq_len, hidden_dim = gate.shape + n_elements = gate.numel() + out = torch.empty((batch, seq_len, hidden_dim), dtype=gate.dtype, device="cuda") + + grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731 + _swiglu_fwd_kernel[grid]( + gate_ptr=gate, + up_ptr=up, + out_ptr=out, + n_elements=n_elements, + block_size=1024, + ) + + return out + + +# pylint: disable=unnecessary-lambda-assignment +def swiglu_backward( + grad_output: torch.Tensor, gate: torch.Tensor, up: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + SwiGLU backward pass using in-place operations. + + Args: + grad_output: Gradient of loss with respect to output, shape `[batch, seq_len, hidden_dim]`. + gate: Gate tensor from forward pass, shape `[batch, seq_len, hidden_dim]`. + up: Up-projection tensor from forward pass, shape `[batch, seq_len, hidden_dim]`. + + Returns: + Tuple containing: + - Forward pass output (`h`) + - Gradient with respect to gate (`df`) + - Gradient with respect to up-projection (`de`) + """ + n_elements = grad_output.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta["block_size"]),) # noqa: E731 + _swiglu_bwd_kernel[grid]( + grad_out_ptr=grad_output, + gate_ptr=gate, + up_ptr=up, + n_elements=n_elements, + block_size=1024, + ) + + # After kernel execution, tensors contain: + # grad_output: h (forward output) + # gate: grad_gate (grad wrt gate) + # up: grad_up (grad wrt up) + return grad_output, gate, up diff --git a/src/axolotl/kernels/utils.py b/src/axolotl/kernels/utils.py new file mode 100644 index 0000000000..59a7e00127 --- /dev/null +++ b/src/axolotl/kernels/utils.py @@ -0,0 +1,11 @@ +"""Utilities for `axolotl.kernels` submodules.""" + +import torch +from packaging.version import Version + +if Version(torch.__version__) < Version("2.4.0"): + torch_amp_custom_fwd = torch.cuda.amp.custom_fwd + torch_amp_custom_bwd = torch.cuda.amp.custom_bwd +else: + torch_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") + torch_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py new file mode 100644 index 0000000000..d59fd22c93 --- /dev/null +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -0,0 +1,333 @@ +"""Module for patching custom LoRA Triton kernels and `torch.autograd` functions.""" + +import importlib +import inspect +import logging +import types +from typing import Type + +import torch +from accelerate.logging import get_logger +from peft import PeftModelForCausalLM +from torch import nn +from transformers import AutoConfig + +from axolotl.kernels.lora import ( + apply_lora_mlp_geglu, + apply_lora_mlp_swiglu, + apply_lora_o, + apply_lora_qkv, +) +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.dict import DictDefault + +LOG = get_logger(__name__) + +ORIGINAL_QKV_CODE = """ + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) +""".lstrip( + "\n" +) + +PATCHED_QKV_CODE = """ + query_states, key_states, value_states = self.apply_qkv(hidden_states) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) +""".lstrip( + "\n" +) + +ORIGINAL_O_CODE = """ + attn_output = self.o_proj(attn_output) +""".lstrip( + "\n" +) + +PATCHED_O_CODE = """ + attn_output = self.apply_o(attn_output) +""".lstrip( + "\n" +) + +SUPPORTED_ACTIVATIONS = ["silu", "gelu"] +APPLY_FN_MAPPING = { + "silu": apply_lora_mlp_swiglu, + "gelu": apply_lora_mlp_geglu, +} + + +def original_apply_qkv( + self: nn.Module, hidden_states: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Original implementation of QKV projection without optimizations. + + Args: + self: The attention module instance. + hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]. + + Returns: + A tuple `(query_states, key_states, value_states)` containing the projected + states for query, key, and value. + """ + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + return query_states, key_states, value_states + + +def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Original implementation of output projection without optimizations. + + Args: + self: The attention module instance. + hidden_states: Input tensor of shape `[`batch_size, seq_len, hidden_dim]`. + + Returns: + The output projection result. + """ + attn_output = self.o_proj(hidden_states) + + return attn_output + + +def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: + """ + Get the appropriate attention class by inspecting the model config. + Uses dynamic import to support any model architecture that follows + the standard transformers naming convention. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + + Returns: + The appropriate attention class for the model. + + Raises: + ValueError: If `base_model` not specified or attention class cannot be imported + ImportError: If the model module or attention class doesn't exist + """ + if "base_model" not in cfg: + raise ValueError("base_model must be specified in config") + + # Get model config without loading the model + model_config = AutoConfig.from_pretrained(cfg["base_model"]) + model_type = model_config.model_type + + # Special case for model_type = "qwen2" + if model_type == "qwen2": + from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention + + return Qwen2Attention + + try: + # Dynamically import the module and attention class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + module = __import__( + module_path, fromlist=[f"{model_type.capitalize()}Attention"] + ) + attention_cls = getattr(module, f"{model_type.capitalize()}Attention") + + return attention_cls + except (ImportError, AttributeError) as e: + raise ValueError( + f"Could not import attention class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e + + +# pylint: disable=protected-access +def patch_self_attn_lora(cfg: DictDefault): + """ + Given an `axolotl` config, this method patches the inferred attention class forward + pass with optimized LoRA implementations. + + It modifies the attention class to use optimized QKV and output projections. The + original implementation is preserved and can be restored if needed. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + + Raises: + AssertionError: If the required code blocks are not found in the attention + implementation. + """ + attention_cls = get_attention_cls_from_config(cfg) + + # Check if already patched + if hasattr(attention_cls, "_original_forward"): + LOG.info(f"{attention_cls.__name__} already patched") + return + + self_attn_forward = inspect.getsource(attention_cls.forward) + attention_cls._original_forward = self_attn_forward + self_attn_forward, _ = detab_code(self_attn_forward) + + assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found" + assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" + + self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE) + self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) + self_attn_forward = self_attn_forward.replace( + "def forward(", + "def axolotl_attn_forward(", + 1, + ) + + # Load necessary imports + module_name = attention_cls.__module__ + module = importlib.import_module(module_name) + + items_to_import = [] + for item in dir(module): + if item in self_attn_forward: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + f"from {module_name} import ({', '.join(items_to_import)})", + globals(), + ) + exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 + + LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}") + attention_cls.forward = ( + axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821 + ) + + +def apply_lora_kernel_patches( + model: PeftModelForCausalLM, cfg: DictDefault +) -> PeftModelForCausalLM: + """ + Applies optimized Triton kernel patches to a PEFT model. + + Patches a PEFT model with optimized implementations for MLP and attention + computations. The optimizations include custom Triton kernels for activation + functions and specialized autograd functions for LoRA computations. + + Args: + model: A PEFT model to be patched with optimized kernels. + cfg: Dictionary mapping `axolotl` config keys to values. + + Returns: + PeftModelForCausalLM: The patched model with optimized kernels. + + Raises: + TypeError: If the provided model is not a `PeftModelForCausalLM`. + NotImplementedError: If the model type is not supported. + AssertionError: If multiple adapters are active (currently unsupported). + + Note: + The optimizations require LoRA adapters with no dropout and no bias terms. The + function will skip patching if these conditions aren't met. + """ + if not isinstance(model, PeftModelForCausalLM): + raise TypeError("Model must be a PeftModelForCausalLM") + + # Get active LoRA adapter config + if hasattr(model, "active_adapters"): + assert ( + len(model.active_adapters) == 1 + ), "Axolotl currently does not support LoRA Triton kernels for multiple adapters" + active_adapter = model.active_adapters[0] + else: + active_adapter = model.active_adapter + lora_config = model.model.peft_config[active_adapter] + + # Only patch if conditions are met + can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none" + + if not can_patch: + LOG.warning("Cannot patch layers - requires no dropout and no bias") + LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file") + return model + + # This needs to be reset after patching + original_level = LOG.getEffectiveLevel() + LOG.setLevel(logging.INFO) + + # Choose activation based on model type + activation = model.config.hidden_act + if activation not in SUPPORTED_ACTIVATIONS: + raise NotImplementedError(f"Activation {activation} is not supported") + + # Patch each layer + for layer in model.model.model.layers: + # Add QKV, O fallback implementations to start + # These will be overwritten later (if some conditions apply) + layer.self_attn.apply_qkv = types.MethodType( + original_apply_qkv, layer.self_attn + ) + layer.self_attn.apply_o = types.MethodType(original_apply_o, layer.self_attn) + + if cfg.lora_mlp_kernel: + # MLP patching + gate_proj = layer.mlp.gate_proj + up_proj = layer.mlp.up_proj + down_proj = layer.mlp.down_proj + + can_patch_mlp = all( + hasattr(proj, "lora_A") + and getattr(proj, "base_layer", proj).bias is None + and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 + for proj in (gate_proj, up_proj, down_proj) + ) + + if can_patch_mlp: + apply_fn = APPLY_FN_MAPPING[activation] + layer.mlp.forward = types.MethodType(apply_fn, layer.mlp) + else: + LOG.warning_once( + "Cannot patch some MLP layers - requires LoRA adapters with no bias" + ) + if cfg.lora_qkv_kernel: + # Query, key, value patching + layer_modules = [ + getattr(layer.self_attn, linear_proj) + for linear_proj in ["q_proj", "k_proj", "v_proj"] + ] + can_patch_qkv = all( + hasattr(module, "lora_A") + and getattr(module, "base_layer", module).bias is None + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + for module in layer_modules + ) + + if can_patch_qkv: + # Add optimized implementation + layer.self_attn.apply_qkv = types.MethodType( + apply_lora_qkv, layer.self_attn + ) + else: + LOG.warning_once( + "Cannot patch some attention QKV projections - requires LoRA adapters with no bias" + ) + if cfg.lora_o_kernel: + # Output patching + layer_modules = [ + getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] + ] + can_patch_o = all( + hasattr(module, "lora_A") + and getattr(module, "base_layer", module).bias is None + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + for module in layer_modules + ) + + if can_patch_o: + layer.self_attn.apply_o = types.MethodType( + apply_lora_o, layer.self_attn + ) + else: + LOG.warning_once( + "Cannot patch some attention output projection - requires LoRA adapters with no bias" + ) + + LOG.setLevel(original_level) + + return model diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 8b5e0074c0..515248fced 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -175,6 +175,7 @@ def terminate_handler(_, __, model_weakref): LOG.info("hang tight... sorting dataset for group_by_length") pretrain_hooks(cfg, trainer) + if cfg.flash_optimum: with torch.backends.cuda.sdp_kernel( # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... @@ -185,6 +186,7 @@ def terminate_handler(_, __, model_weakref): trainer.train(resume_from_checkpoint=resume_from_checkpoint) else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) + post_train_hooks(cfg, trainer) LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1e7a6aa8b7..38204d02d2 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1,7 +1,4 @@ -""" -Module for pydantic models for configuration -""" - +"""Module with Pydantic models for configuration.""" # pylint: disable=too-many-lines import logging @@ -810,6 +807,10 @@ class Config: unsloth_rms_norm: Optional[bool] = None unsloth_rope: Optional[bool] = None + lora_mlp_kernel: Optional[bool] = None + lora_qkv_kernel: Optional[bool] = None + lora_o_kernel: Optional[bool] = None + deepspeed: Optional[Union[str, Dict[str, Any]]] = None fsdp: Optional[List[str]] = None fsdp_config: Optional[Dict[str, Any]] = None @@ -1534,12 +1535,42 @@ def check_qlora_unsloth(cls, data): or data.get("unsloth_lora_qkv") or data.get("unsloth_lora_o") ): - if data.get("adapter") == "lora" or data.get("load_in_8bit"): + if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" ) return data + @model_validator(mode="before") + @classmethod + def check_lora_8bit(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ): + if data.get("adapter") == "lora" and data.get("load_in_8bit"): + raise ValueError( + "lora_mlp_kernel, lora_mlp_kernel, and lora_mlp_kernel are not compatible with 8-bit LoRA" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_axolotl_unsloth(cls, data): + is_lora_kernel = any( + data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] + ) + is_unsloth_lora = any( + data.get(k) + for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] + ) + if is_lora_kernel and is_unsloth_lora: + raise ValueError( + "both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)" + ) + return data + @model_validator(mode="before") @classmethod def check_torch_compile_deepspeed(cls, data): @@ -1672,6 +1703,29 @@ def check_multigpu_unsloth(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_multigpu_lora_kernels(cls, data): + if ( + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") + ): + capabilities = data.get("capabilities") + is_fsdp = data.get("fsdp") is not None + is_deepspeed = data.get("deepspeed") is not None + + if capabilities and capabilities.get("n_gpu", 0) > 1: + if is_fsdp: + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP." + ) + if is_deepspeed: + raise ValueError( + "lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with DeepSpeed." + ) + return data + @model_validator(mode="before") @classmethod def check_adopt_torch_version(cls, data): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index a96ecb0cf6..377f086052 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -414,6 +414,7 @@ def apply_patches(self) -> None: has_remote_code = "AutoModelForCausalLM" in auto_map_config else: has_remote_code = False + if has_remote_code and self.cfg.trust_remote_code is False: # if explicitly set in the YAML, we should prefer that, for example if explicitly disabled has_remote_code = self.cfg.trust_remote_code @@ -425,10 +426,6 @@ def apply_patches(self) -> None: if self.cfg.is_llama_derived_model: self.patch_loss_llama() - if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: - from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora - - patch_self_attn_lora() elif self.cfg.is_llama_derived_model: self.patch_llama_derived_model() @@ -442,6 +439,11 @@ def apply_patches(self) -> None: patch_mistral_cross_entropy() + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: + from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora + + patch_self_attn_lora(self.cfg) + def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: @@ -472,9 +474,7 @@ def has_flash_attn(self) -> bool: return importlib.util.find_spec("flash_attn") is not None def patch_loss_llama(self) -> None: - """ - Patch loss functions - """ + """Patch loss functions and other optimizations""" if self.has_flash_attn: from axolotl.monkeypatch.llama_attn_hijack_flash import ( patch_fa_llama_cross_entropy, @@ -494,15 +494,14 @@ def patch_loss_llama(self) -> None: from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm patch_unsloth_layernorm() + if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o: from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora patch_self_attn_lora() def patch_llama_derived_model(self) -> None: - """ - Modify all llama derived models in one block - """ + """Modify all llama derived models in one block""" self.patch_loss_llama() if self.cfg.flash_attention: @@ -1013,7 +1012,8 @@ def convert_embedding_modules_dtype( if hasattr(module, "weight"): module.to(dist_dtype) - def apply_lora_patch(self) -> None: + # TODO: Deprecate this. + def apply_unsloth_lora_patch(self) -> None: if self.cfg.unsloth_lora_mlp: from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch @@ -1027,6 +1027,16 @@ def apply_lora_patch(self) -> None: integrate_rope_embeddings() + def apply_lora_patch(self) -> None: + if ( + self.cfg.lora_mlp_kernel + or self.cfg.lora_qkv_kernel + or self.cfg.lora_o_kernel + ): + from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches + + apply_lora_kernel_patches(self.model, self.cfg) + def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: self.apply_patches() self.set_auto_model_loader() @@ -1171,6 +1181,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: if self.cfg.adapter is not None: log_gpu_memory_usage(LOG, "after adapters", self.model.device) + self.apply_unsloth_lora_patch() self.apply_lora_patch() for _ in range(3): diff --git a/tests/e2e/kernels/test_geglu.py b/tests/e2e/kernels/test_geglu.py new file mode 100644 index 0000000000..c720bbce7b --- /dev/null +++ b/tests/e2e/kernels/test_geglu.py @@ -0,0 +1,76 @@ +"""Tests for GEGLU activation function Triton kernels.""" +# pylint: disable=duplicate-code + +import torch +import torch.nn.functional as F + +from axolotl.kernels.geglu import geglu_backward, geglu_forward + + +def test_geglu_forward_shape(): + """Test that GEGLU forward pass preserves expected shapes.""" + batch, seq_len, hidden_dim = 2, 3, 64 + gate = torch.randn(batch, seq_len, hidden_dim, device="cuda") + up = torch.randn(batch, seq_len, hidden_dim, device="cuda") + + out = geglu_forward(gate, up) + assert out.shape == (batch, seq_len, hidden_dim) + assert out.dtype == gate.dtype + assert out.device == gate.device + + +def test_geglu_forward_values(): + """Test GEGLU forward pass matches PyTorch reference implementation.""" + gate = torch.randn(2, 3, 64, device="cuda") + up = torch.randn(2, 3, 64, device="cuda") + + # Custom implementation + triton_out = geglu_forward(gate.clone(), up.clone()) + + # PyTorch reference + torch_out = F.gelu(gate) * up + + assert torch.allclose(triton_out, torch_out, rtol=1e-3) + + +def test_geglu_backward(): + """Test GEGLU backward pass matches PyTorch autograd.""" + gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True) + up = torch.randn(2, 3, 64, device="cuda", requires_grad=True) + grad_output = torch.randn(2, 3, 64, device="cuda") + + # PyTorch reference - compute intermediates + gelu_gate = F.gelu(gate) + torch_out = gelu_gate * up + torch_out.backward(grad_output) + + # Custom backward pass + gate_clone = gate.clone().detach() + up_clone = up.clone().detach() + grad_output_clone = grad_output.clone() + + h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone) + + # Compare outputs and gradients + assert torch.allclose(h, torch_out, rtol=1e-3) + assert torch.allclose(grad_gate, gate.grad, rtol=1e-3) + assert torch.allclose(grad_up, up.grad, rtol=1e-3) + + +def test_geglu_inplace_preservation(): + """Test that GEGLU backward doesn't modify original tensors unexpectedly.""" + gate = torch.randn(2, 3, 64, device="cuda") + up = torch.randn(2, 3, 64, device="cuda") + grad_output = torch.randn(2, 3, 64, device="cuda") + + gate_copy = gate.clone() + up_copy = up.clone() + grad_copy = grad_output.clone() + + geglu_backward(grad_output, gate, up) + + assert not torch.equal(gate, gate_copy), "Gate should be modified in-place" + assert not torch.equal(up, up_copy), "Up should be modified in-place" + assert not torch.equal( + grad_output, grad_copy + ), "Grad output should be modified in-place" diff --git a/tests/e2e/kernels/test_lora.py b/tests/e2e/kernels/test_lora.py new file mode 100644 index 0000000000..c8becf2da8 --- /dev/null +++ b/tests/e2e/kernels/test_lora.py @@ -0,0 +1,531 @@ +"""Tests for LoRA custom autograd.""" +# pylint: disable=invalid-name,redefined-outer-name + +import pytest +import torch +from bitsandbytes.functional import QuantState +from torch import nn + +from axolotl.kernels.geglu import geglu_backward, geglu_forward +from axolotl.kernels.lora import ( + LoRA_MLP, + LoRA_O, + LoRA_QKV, + apply_lora_mlp_geglu, + apply_lora_mlp_swiglu, + get_lora_parameters, + matmul_lora, +) +from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward + + +@pytest.fixture +def mock_quantstate(): + """Creates a mock QuantState for testing""" + shape = (64, 64) + n_blocks = shape[0] # Assuming blockwise quantization along first dimension + + # Create nested state first + nested_state = QuantState( + absmax=torch.ones(n_blocks, device="cuda"), # One value per block + shape=shape, + code=torch.randint(0, 15, shape, device="cuda"), # NF4 range is 0-15 + dtype=torch.float16, + blocksize=64, + quant_type="nf4", + offset=None, + state2=None, + ) + + # Create main state with nested state + return QuantState( + absmax=torch.ones(n_blocks, device="cuda"), + shape=shape, + code=torch.randint(0, 15, shape, device="cuda"), + dtype=torch.float16, + blocksize=64, + quant_type="nf4", + offset=torch.zeros(n_blocks, dtype=torch.int32, device="cuda"), + state2=nested_state, + ) + + +@pytest.fixture +def sample_tensors(): + """Creates sample tensors for testing""" + torch.manual_seed(42) + batch_size, seq_len, hidden_dim = 2, 3, 64 + rank = 8 + out_dim = hidden_dim + + return { + "X": torch.randn( + batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16 + ), + "W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16), + "scale": 0.5, + "shapes": { + "batch": batch_size, + "seq": seq_len, + "hidden": hidden_dim, + "out": out_dim, + "rank": rank, + }, + } + + +@pytest.fixture +def mock_proj(): + """Creates a mock projection module for testing.""" + + class MockProj(nn.Module): + """Mock projection class.""" + + def __init__(self, in_features=64, out_features=128, rank=8): + super().__init__() + self.base_layer = nn.Linear(in_features, out_features) + self.base_layer.to("cuda") + self.lora_A = nn.ModuleDict( + {"default": nn.Linear(in_features, rank, bias=False).to("cuda")} + ) + self.lora_B = nn.ModuleDict( + {"default": nn.Linear(rank, out_features, bias=False).to("cuda")} + ) + self.scaling = {"default": 0.5} + self.active_adapter = "default" + self.disable_adapters = False + self.merged = False + + return MockProj() + + +def test_get_lora_parameters(mock_proj): + """Tests get_lora_parameters function""" + # Test with LoRA enabled + W, _, A, B, s = get_lora_parameters(mock_proj) + + assert isinstance(W, torch.Tensor) + assert W.shape == (128, 64) + assert A.shape == (8, 64) + assert B.shape == (128, 8) + assert s == 0.5 + + # Test with LoRA disabled + mock_proj.disable_adapters = True + W, _, A, B, s = get_lora_parameters(mock_proj) + assert A is None and B is None and s is None + + # Test with merged state + mock_proj.disable_adapters = False + mock_proj.merged = True + W, _, A, B, s = get_lora_parameters(mock_proj) + assert A is None and B is None and s is None + + +def test_matmul_lora(sample_tensors): + """Tests matmul_lora function""" + X = sample_tensors["X"] + W = sample_tensors["W"] + scale = sample_tensors["scale"] + + shapes = sample_tensors["shapes"] + hidden_dim = shapes["hidden"] + out_dim = shapes["out"] + rank = shapes["rank"] + + A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16) + B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) + + # Test base matmul + out1 = matmul_lora(X, W, None, None, None, None) + expected1 = torch.matmul(X, W.t()) + assert torch.allclose(out1, expected1, rtol=1e-3) + + # Test with LoRA + out2 = matmul_lora(X, W, None, A, B, scale) + lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t()) + expected2 = expected1 + lora_term + assert torch.allclose(out2, expected2, rtol=1e-3) + + # Test 3D input reshaping + X_3d = X.clone() + out3 = matmul_lora(X_3d, W, None, A, B, scale) + assert out3.shape == (X.shape[0], X.shape[1], W.shape[0]) + + +@pytest.mark.parametrize( + "activation_forward,activation_backward", + [(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)], +) +def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward): + """Tests LoRA_MLP directly with different activation functions""" + X = sample_tensors["X"] + shapes = sample_tensors["shapes"] + hidden_dim = shapes["hidden"] + out_dim = shapes["out"] + + # Create linear layers + gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16) + up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16) + down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16) + + # Test SwiGLU path + X.requires_grad = True + output = LoRA_MLP.apply( + X, + gate_proj.weight, + None, # gate_quant + None, # gate_A + None, # gate_B + None, # gate_scale + up_proj.weight, + None, # up_quant + None, # up_A + None, # up_B + None, # up_scale + down_proj.weight, + None, # down_quant + None, # down_A + None, # down_B + None, # down_scale + activation_forward, + activation_backward, + True, # inplace + ) + + assert output.shape == X.shape + assert not torch.isnan(output).any() + + # Test backward pass + loss = output.sum() + loss.backward() + assert X.grad is not None + assert not torch.isnan(X.grad).any() + + +@pytest.mark.parametrize( + "activation_forward,activation_backward", + [(swiglu_forward, swiglu_backward), (geglu_forward, geglu_backward)], +) +def test_lora_mlp_with_adapters( + sample_tensors, activation_forward, activation_backward +): + """Tests LoRA_MLP with LoRA adapters""" + X = sample_tensors["X"] + shapes = sample_tensors["shapes"] + hidden_dim = shapes["hidden"] + out_dim = shapes["out"] + rank = shapes["rank"] + + # Create LoRA components + gate_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16) + gate_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) + up_A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16) + up_B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) + down_A = torch.randn(rank, out_dim, device="cuda", dtype=torch.float16) + down_B = torch.randn(hidden_dim, rank, device="cuda", dtype=torch.float16) + scale = 0.5 + + gate_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16) + up_proj = nn.Linear(hidden_dim, out_dim).to(device="cuda", dtype=torch.float16) + down_proj = nn.Linear(out_dim, hidden_dim).to(device="cuda", dtype=torch.float16) + + X.requires_grad = True + gate_A.requires_grad = True + gate_B.requires_grad = True + up_A.requires_grad = True + up_B.requires_grad = True + down_A.requires_grad = True + down_B.requires_grad = True + + # Forward pass with adapters + output = LoRA_MLP.apply( + X, + gate_proj.weight, + None, + gate_A, + gate_B, + scale, + up_proj.weight, + None, + up_A, + up_B, + scale, + down_proj.weight, + None, + down_A, + down_B, + scale, + activation_forward, + activation_backward, + True, + ) + + assert output.shape == X.shape + assert not torch.isnan(output).any() + + # Test backward pass + loss = output.sum() + loss.backward() + + # Check all gradients + assert X.grad is not None + assert gate_A.grad is not None + assert gate_B.grad is not None + assert up_A.grad is not None + assert up_B.grad is not None + assert down_A.grad is not None + assert down_B.grad is not None + + assert not torch.isnan(X.grad).any() + assert not torch.isnan(gate_A.grad).any() + assert not torch.isnan(gate_B.grad).any() + assert not torch.isnan(up_A.grad).any() + assert not torch.isnan(up_B.grad).any() + assert not torch.isnan(down_A.grad).any() + assert not torch.isnan(down_B.grad).any() + + +def test_lora_qkv(sample_tensors): + """Tests LoRA QKV implementation with and without adapters""" + X = sample_tensors["X"] + shapes = sample_tensors["shapes"] + hidden_dim = shapes["hidden"] + rank = shapes["rank"] + + # Create base weights + q_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16) + k_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16) + v_weight = torch.randn(hidden_dim, hidden_dim, device="cuda", dtype=torch.float16) + + # Create LoRA matrices + q_A = torch.randn( + rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True + ) + q_B = torch.randn( + hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True + ) + k_A = torch.randn( + rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True + ) + k_B = torch.randn( + hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True + ) + v_A = torch.randn( + rank, hidden_dim, device="cuda", dtype=torch.float16, requires_grad=True + ) + v_B = torch.randn( + hidden_dim, rank, device="cuda", dtype=torch.float16, requires_grad=True + ) + scale = 0.5 + + X.requires_grad = True + + # Test without LoRA adapters + Q1, K1, V1 = LoRA_QKV.apply( + X, + q_weight, + None, + None, + None, + None, + k_weight, + None, + None, + None, + None, + v_weight, + None, + None, + None, + None, + True, + ) + + assert Q1.shape == K1.shape == V1.shape == X.shape + loss1 = (Q1 + K1 + V1).sum() + loss1.backward() + assert X.grad is not None + + # Clear gradients + X.grad = None + + # Test with LoRA adapters + Q2, K2, V2 = LoRA_QKV.apply( + X, + q_weight, + None, + q_A, + q_B, + scale, + k_weight, + None, + k_A, + k_B, + scale, + v_weight, + None, + v_A, + v_B, + scale, + True, + ) + + assert Q2.shape == K2.shape == V2.shape == X.shape + loss2 = (Q2 + K2 + V2).sum() + loss2.backward() + + # Check gradients + assert X.grad is not None + assert q_A.grad is not None + assert q_B.grad is not None + assert k_A.grad is not None + assert k_B.grad is not None + assert v_A.grad is not None + assert v_B.grad is not None + + # Check for NaN values + assert not torch.isnan(X.grad).any() + assert not torch.isnan(q_A.grad).any() + assert not torch.isnan(q_B.grad).any() + assert not torch.isnan(k_A.grad).any() + assert not torch.isnan(k_B.grad).any() + assert not torch.isnan(v_A.grad).any() + assert not torch.isnan(v_B.grad).any() + + +def test_lora_o(sample_tensors): + """Tests LoRA output projection""" + X = sample_tensors["X"] + W = sample_tensors["W"] + scale = sample_tensors["scale"] + + shapes = sample_tensors["shapes"] + hidden_dim = shapes["hidden"] + out_dim = shapes["out"] + rank = shapes["rank"] + + A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16) + B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) + + # Test forward pass + X.requires_grad = True + output = LoRA_O.apply(X, W, None, A, B, scale) + + assert output.shape == (X.shape[0], X.shape[1], W.shape[0]) + + # Test backward pass + loss = output.sum() + loss.backward() + assert X.grad is not None + + +def test_with_quantization(sample_tensors, mock_quantstate): + """Tests LoRA with quantized weights""" + X = sample_tensors["X"] # [batch, seq, hidden] + W = sample_tensors["W"] # [out, hidden] + scale = 0.5 + + shapes = sample_tensors["shapes"] + hidden_dim = shapes["hidden"] + out_dim = shapes["out"] + rank = shapes["rank"] + + A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16) + B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) + + # Test matmul with quantization + out = matmul_lora(X, W, mock_quantstate, A, B, scale) + assert out.shape == (X.shape[0], X.shape[1], W.shape[0]) + assert not torch.isnan(out).any() + + # Test with different batch sizes + X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16) + out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale) + assert out2.shape == (4, 6, W.shape[0]) + assert not torch.isnan(out2).any() + + +@pytest.mark.parametrize( + "batch,seq,hidden,rank,out", + [ + (1, 1, 32, 4, 64), + (2, 3, 64, 8, 128), + (4, 5, 128, 16, 256), + ], +) +def test_shapes_and_dimensions(batch, seq, hidden, rank, out): + """Tests various input shapes and dimensions""" + X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16) + W = torch.randn(out, hidden, device="cuda", dtype=torch.float16) + A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16) + B = torch.randn(out, rank, device="cuda", dtype=torch.float16) + scale = 0.5 + + result = matmul_lora(X, W, None, A, B, scale) + assert result.shape == (batch, seq, out) + + +def test_gradient_flow(sample_tensors): + """Tests gradient flow through LoRA layers""" + X = sample_tensors["X"].clone() + W = sample_tensors["W"].clone() + scale = sample_tensors["scale"] + + shapes = sample_tensors["shapes"] + hidden_dim = shapes["hidden"] + out_dim = shapes["out"] + rank = shapes["rank"] + + A = torch.randn(rank, hidden_dim, device="cuda", dtype=torch.float16) + B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16) + + X.requires_grad = True + A.requires_grad = True + B.requires_grad = True + + # Forward pass + out = matmul_lora(X, W, None, A, B, scale) + loss = out.sum() + + # Backward pass + loss.backward() + + assert X.grad is not None + assert A.grad is not None + assert B.grad is not None + assert not torch.isnan(X.grad).any() + assert not torch.isnan(A.grad).any() + assert not torch.isnan(B.grad).any() + + +@pytest.mark.parametrize( + "apply_function", + [apply_lora_mlp_swiglu, apply_lora_mlp_geglu], +) +def test_inplace_operations(sample_tensors, apply_function): + """Tests inplace operation behavior""" + X = sample_tensors["X"] + shapes = sample_tensors["shapes"] + + # Create MLP with both inplace=True and inplace=False + mlp = type( + "MLPModule", + (), + { + "gate_proj": nn.Linear(shapes["hidden"], shapes["out"]).to( + device="cuda", dtype=torch.float16 + ), + "up_proj": nn.Linear(shapes["hidden"], shapes["out"]).to( + device="cuda", dtype=torch.float16 + ), + "down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to( + device="cuda", dtype=torch.float16 + ), + }, + ) + + out1 = apply_function(mlp, X.clone(), inplace=True) + out2 = apply_function(mlp, X.clone(), inplace=False) + + assert torch.allclose(out1, out2, rtol=1e-3) diff --git a/tests/e2e/kernels/test_quantize.py b/tests/e2e/kernels/test_quantize.py new file mode 100644 index 0000000000..e4beb846e9 --- /dev/null +++ b/tests/e2e/kernels/test_quantize.py @@ -0,0 +1,103 @@ +"""Tests for quantization utility functions.""" +# pylint: disable=invalid-name + +import torch +from bitsandbytes.functional import QuantState + +from axolotl.kernels.quantize import dequantize + + +def test_dequantize_null_state(): + """Test that dequantize returns input unchanged when quant_state is None""" + W = torch.randn(32, 32) + assert torch.equal(dequantize(W, None), W) + + +def test_dequantize_shape_preservation(): + """Test that dequantization preserves expected shapes""" + shape = (32, 32) + W = torch.randn(shape, device="cuda") + + quant_state = QuantState( + absmax=torch.ones(shape[0], device="cuda"), + shape=shape, + code=torch.randint(0, 15, shape, device="cuda"), + dtype=torch.float16, + blocksize=32, + quant_type="nf4", + offset=torch.zeros(shape[0], dtype=torch.int32, device="cuda"), + state2=QuantState( + absmax=torch.ones(shape[0], device="cuda"), + shape=shape, + code=torch.randint(0, 15, shape, device="cuda"), + dtype=torch.float16, + blocksize=32, + quant_type="nf4", + offset=None, + state2=None, + ), + ) + + result = dequantize(W, quant_state) + assert result.shape == shape + assert result.dtype == torch.float16 + assert result.device == W.device + + +def test_dequantize_transposed(): + """Test that transposed input produces transposed output""" + shape = (32, 32) + W = torch.randn(1, shape[1], device="cuda") # Transposed input + + quant_state = QuantState( + absmax=torch.ones(1), + shape=shape, + code=torch.randint(0, 15, shape), + dtype=torch.float16, + blocksize=32, + quant_type="nf4", + offset=torch.zeros(1, dtype=torch.int32), + state2=QuantState( + absmax=torch.ones(1), + shape=shape, + code=torch.randint(0, 15, shape), + dtype=torch.float16, + blocksize=32, + quant_type="nf4", + offset=None, + state2=None, + ), + ) + + result = dequantize(W, quant_state) + assert result.shape[0] == shape[0] + + +def test_dequantize_output_tensor(): + """Test dequantization with provided output tensor""" + shape = (32, 32) + W = torch.randn(shape, device="cuda") + out = torch.empty(shape, dtype=torch.float16, device="cuda") + + quant_state = QuantState( + absmax=torch.ones(shape[0]), + shape=shape, + code=torch.randint(0, 15, shape), + dtype=torch.float16, + blocksize=32, + quant_type="nf4", + offset=torch.zeros(shape[0], dtype=torch.int32), + state2=QuantState( + absmax=torch.ones(shape[0]), + shape=shape, + code=torch.randint(0, 15, shape), + dtype=torch.float16, + blocksize=32, + quant_type="nf4", + offset=None, + state2=None, + ), + ) + + result = dequantize(W, quant_state, out=out) + assert result is out diff --git a/tests/e2e/kernels/test_swiglu.py b/tests/e2e/kernels/test_swiglu.py new file mode 100644 index 0000000000..3717402de1 --- /dev/null +++ b/tests/e2e/kernels/test_swiglu.py @@ -0,0 +1,78 @@ +"""Tests for SwiGLU activation function Triton kernels.""" +# pylint: disable=duplicate-code + +import torch +import torch.nn.functional as F + +from axolotl.kernels.swiglu import swiglu_backward, swiglu_forward + + +def test_swiglu_forward_shape(): + """Test that SwiGLU forward pass preserves expected shapes""" + batch, seq_len, hidden_dim = 2, 3, 64 + gate = torch.randn(batch, seq_len, hidden_dim, device="cuda") + up = torch.randn(batch, seq_len, hidden_dim, device="cuda") + + out = swiglu_forward(gate, up) + assert out.shape == (batch, seq_len, hidden_dim) + assert out.dtype == gate.dtype + assert out.device == gate.device + + +def test_swiglu_forward_values(): + """Test SwiGLU forward pass matches PyTorch reference implementation""" + gate = torch.randn(2, 3, 64, device="cuda") + up = torch.randn(2, 3, 64, device="cuda") + + # Custom implementation + triton_out = swiglu_forward(gate.clone(), up.clone()) + + # PyTorch reference + torch_out = F.silu(gate) * up + + assert torch.allclose(triton_out, torch_out, rtol=1e-3) + + +def test_swiglu_backward(): + """Test SwiGLU backward pass matches PyTorch autograd""" + gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True) + up = torch.randn(2, 3, 64, device="cuda", requires_grad=True) + grad_output = torch.randn(2, 3, 64, device="cuda") + + # PyTorch reference - compute intermediates + silu_gate = F.silu(gate) + torch_out = silu_gate * up + torch_out.backward(grad_output) + + # Custom backward pass + gate_clone = gate.clone().detach() + up_clone = up.clone().detach() + grad_output_clone = grad_output.clone() + + h, our_grad_gate, our_grad_up = swiglu_backward( + grad_output_clone, gate_clone, up_clone + ) + + # Compare outputs and gradients + assert torch.allclose(h, torch_out, rtol=1e-3) + assert torch.allclose(our_grad_gate, gate.grad, rtol=1e-3) + assert torch.allclose(our_grad_up, up.grad, rtol=1e-3) + + +def test_swiglu_inplace_preservation(): + """Test that SwiGLU backward doesn't modify original tensors unexpectedly""" + gate = torch.randn(2, 3, 64, device="cuda") + up = torch.randn(2, 3, 64, device="cuda") + grad_output = torch.randn(2, 3, 64, device="cuda") + + gate_copy = gate.clone() + up_copy = up.clone() + grad_copy = grad_output.clone() + + swiglu_backward(grad_output, gate, up) + + assert not torch.equal(gate, gate_copy), "Gate should be modified in-place" + assert not torch.equal(up, up_copy), "Up should be modified in-place" + assert not torch.equal( + grad_output, grad_copy + ), "Grad output should be modified in-place" diff --git a/tests/e2e/patched/lora_kernels/__init__.py b/tests/e2e/patched/lora_kernels/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py new file mode 100644 index 0000000000..4e33733673 --- /dev/null +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -0,0 +1,414 @@ +"""Integration tests for LoRA activation and attention kernels.""" +# pylint: disable=redefined-outer-name + +import pytest +import torch +from accelerate.state import PartialState +from peft import PeftModelForCausalLM, get_peft_config +from transformers import AutoModelForCausalLM, LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention + +from axolotl.kernels.lora import ( + apply_lora_mlp_geglu, + apply_lora_mlp_swiglu, + apply_lora_o, + apply_lora_qkv, +) +from axolotl.monkeypatch.lora_kernels import ( + apply_lora_kernel_patches, + patch_self_attn_lora, +) +from axolotl.utils.dict import DictDefault + +MODEL_CONFIGS = [ + { + "name": "openaccess-ai-collective/tiny-mistral", + "expected_activation": apply_lora_mlp_swiglu, + "dtype": torch.float16, + }, + { + "name": "Qwen/Qwen2-7B", + "expected_activation": apply_lora_mlp_swiglu, + "dtype": torch.float16, + }, + { + "name": "HuggingFaceTB/SmolLM2-135M", + "expected_activation": apply_lora_mlp_swiglu, + "dtype": torch.float32, + }, + { + "name": "mhenrichsen/gemma-2b", + "expected_activation": apply_lora_mlp_geglu, + "dtype": torch.float16, + }, +] + + +@pytest.fixture(autouse=True) +def init_accelerate(): + """Initialize Accelerate state before tests.""" + _ = PartialState() + + +@pytest.fixture +def small_llama_model(): + """Create a small LLaMA model for testing.""" + config = { + "vocab_size": 100, + "hidden_size": 128, + "intermediate_size": 256, + "num_hidden_layers": 2, + "num_attention_heads": 4, + } + + return LlamaForCausalLM(LlamaConfig(**config)) + + +def test_attention_patching_integration(): + """Test attention patching in integration context.""" + cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"} + + # Store the original implementation + original_forward = getattr(LlamaAttention, "forward") + + # Apply patch + patch_self_attn_lora(cfg) + + # Get the new forward method + patched_forward = LlamaAttention.forward + + # Check the forward method was replaced + assert original_forward is not patched_forward + assert patched_forward.__name__ == "axolotl_attn_forward" + + # Check original implementation was stored + assert hasattr(LlamaAttention, "_original_forward") + + # Clean up + setattr(LlamaAttention, "forward", original_forward) + delattr(LlamaAttention, "_original_forward") + + +def test_swiglu_mlp_integration(small_llama_model): + """Test SwiGLU activation in LoRA MLP context.""" + peft_config = get_peft_config( + { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 16, + "target_modules": ["gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0, + "bias": "none", + } + ) + model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda") + cfg = DictDefault({"lora_mlp_kernel": True}) + + # Apply patches + patched_model = apply_lora_kernel_patches(model, cfg) + + # Verify patches + layer = patched_model.model.model.layers[0] + assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu + + # Test forward pass + batch_size, seq_len = 2, 10 + hidden_states = torch.randn( + batch_size, seq_len, model.config.hidden_size, device=model.device + ) + position_ids = ( + torch.arange(seq_len, device=model.device).unsqueeze(0).expand(batch_size, -1) + ) + cos, sin = model.model.model.rotary_emb(hidden_states, position_ids) + + inputs = { + "hidden_states": hidden_states, + "attention_mask": None, + "position_embeddings": (cos, sin), + "output_attentions": False, + "use_cache": False, + "past_key_value": None, + } + + # Compare outputs + with torch.no_grad(): + original_output = model.model.model.layers[0](**inputs)[0] + patched_output = layer(**inputs)[0] + + assert torch.allclose(original_output, patched_output, rtol=1e-4) + + +def test_geglu_model_integration(): + """Test GeGLU activation with Gemma model.""" + model = AutoModelForCausalLM.from_pretrained( + "mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda" + ) + peft_config = get_peft_config( + { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 16, + "target_modules": ["gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0, + "bias": "none", + } + ) + model = PeftModelForCausalLM(model, peft_config) + + cfg = DictDefault({"lora_mlp_kernel": True}) + patched_model = apply_lora_kernel_patches(model, cfg) + + # Verify patches + layer = patched_model.model.model.layers[0] + assert layer.mlp.forward.__func__ is apply_lora_mlp_geglu + + # Test end-to-end + inputs = torch.randint(0, 100, (1, 20), device=model.device, dtype=torch.long) + with torch.no_grad(): + original_output = model(inputs).logits + patched_output = patched_model(inputs).logits + + assert torch.allclose(original_output, patched_output, rtol=1e-4) + + +@pytest.mark.parametrize( + "model_name,expected_activation", + [ + ("HuggingFaceTB/SmolLM2-135M", apply_lora_mlp_swiglu), + ("mhenrichsen/gemma-2b", apply_lora_mlp_geglu), + ], +) +def test_model_specific_activation(model_name, expected_activation): + """Test that each model type gets the correct activation function.""" + model = AutoModelForCausalLM.from_pretrained(model_name) + peft_config = get_peft_config( + { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 16, + "target_modules": ["gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0, + "bias": "none", + } + ) + model = PeftModelForCausalLM(model, peft_config) + cfg = DictDefault({"lora_mlp_kernel": True}) + + patched_model = apply_lora_kernel_patches(model, cfg) + layer = patched_model.model.model.layers[0] + assert layer.mlp.forward.__func__ is expected_activation + + +def test_kernel_patch_conditions(): + """Test various conditions that should prevent kernel patching.""" + test_configs = [ + # Dropout prevents patching + { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 16, + "target_modules": ["gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0.1, + "bias": "none", + }, + # Bias prevents patching + { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 16, + "target_modules": ["gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0, + "bias": "lora_only", + }, + ] + + for config in test_configs: + model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M") + peft_config = get_peft_config(config) + model = PeftModelForCausalLM(model, peft_config) + cfg = DictDefault({"lora_mlp_kernel": True}) + + # Should not patch + patched_model = apply_lora_kernel_patches(model, cfg) + layer = patched_model.model.model.layers[0].mlp + + # Verify no patches applied + assert layer.forward.__func__ is not apply_lora_mlp_swiglu + assert layer.forward.__func__ is not apply_lora_mlp_geglu + + +def test_kernel_config_options(): + """Test that kernel configuration options are respected.""" + # Test different configurations + test_configs = [ + ( + {"lora_mlp_kernel": True, "lora_qkv_kernel": False, "lora_o_kernel": False}, + lambda layer: ( + layer.mlp.forward.__func__ is apply_lora_mlp_swiglu + and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv + and layer.self_attn.apply_o.__func__ is not apply_lora_o + ), + ), + ( + {"lora_mlp_kernel": False, "lora_qkv_kernel": True, "lora_o_kernel": False}, + lambda layer: ( + layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu + and layer.self_attn.apply_qkv.__func__ is apply_lora_qkv + and layer.self_attn.apply_o.__func__ is not apply_lora_o + ), + ), + ( + {"lora_mlp_kernel": False, "lora_qkv_kernel": False, "lora_o_kernel": True}, + lambda layer: ( + layer.mlp.forward.__func__ is not apply_lora_mlp_swiglu + and layer.self_attn.apply_qkv.__func__ is not apply_lora_qkv + and layer.self_attn.apply_o.__func__ is apply_lora_o + ), + ), + ] + + for config_dict, check_fn in test_configs: + # Create fresh model for each test + config = { + "vocab_size": 100, + "hidden_size": 128, + "intermediate_size": 256, + "num_hidden_layers": 2, + "num_attention_heads": 4, + } + small_llama_model = LlamaForCausalLM(LlamaConfig(**config)) + + peft_config = get_peft_config( + { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 16, + "target_modules": [ + "gate_proj", + "up_proj", + "down_proj", + "q_proj", + "k_proj", + "v_proj", + "o_proj", + ], + "lora_dropout": 0, + "bias": "none", + } + ) + model = PeftModelForCausalLM(small_llama_model, peft_config).to("cuda") + cfg = DictDefault(config_dict) + patched_model = apply_lora_kernel_patches(model, cfg) + + # Verify only requested optimizations were applied + for layer in patched_model.model.model.layers: + assert check_fn(layer), f"Failed for config: {config_dict}" + + # Clean up + del model + del small_llama_model + del patched_model + + +def get_lora_config(): + """Get standard LoRA configuration for testing.""" + return { + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + "r": 8, + "lora_alpha": 16, + "target_modules": ["gate_proj", "up_proj", "down_proj"], + "lora_dropout": 0, + "bias": "none", + } + + +def get_test_inputs(model, seq_length=20): + """Generate test inputs for model evaluation.""" + return torch.randint( + 0, + model.config.vocab_size, + (1, seq_length), + device=model.device, + dtype=torch.long, + ) + + +@pytest.mark.parametrize("model_config", MODEL_CONFIGS) +def test_model_architecture(model_config): + """Test LoRA kernel patches across different model architectures.""" + # Load model with appropriate dtype + model = AutoModelForCausalLM.from_pretrained( + model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda" + ) + + # Apply LoRA configuration + peft_config = get_peft_config(get_lora_config()) + model = PeftModelForCausalLM(model, peft_config) + + # Apply kernel patches + cfg = DictDefault({"lora_mlp_kernel": True}) + patched_model = apply_lora_kernel_patches(model, cfg) + + # Verify correct activation function + layer = patched_model.model.model.layers[0] + assert ( + layer.mlp.forward.__func__ is model_config["expected_activation"] + ), f"Wrong activation for {model_config['name']}" + + # Test forward pass + inputs = get_test_inputs(model) + with torch.no_grad(): + original_output = model(inputs).logits + patched_output = patched_model(inputs).logits + + # Check outputs match + assert torch.allclose( + original_output, patched_output, rtol=1e-4 + ), f"Outputs don't match for {model_config['name']}" + + +# pylint: disable=duplicate-code +def test_kernel_training_integration(): + """Test model loading with kernel patches enabled.""" + from axolotl.cli.utils import load_model_and_tokenizer + + # Create minimal config + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_config": "HuggingFaceTB/SmolLM2-135M", + "learning_rate": 0.000001, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.0, + "lora_target_linear": True, + "sequence_len": 1024, + "lora_mlp_kernel": True, + "lora_qkv_kernel": True, + "lora_o_kernel": True, + } + ) + + # Load model + model, _ = load_model_and_tokenizer(cfg=cfg) + + # Verify correct activation function + layer = model.model.model.layers[0] + assert layer.mlp.forward.__func__ is apply_lora_mlp_swiglu