-
-
Notifications
You must be signed in to change notification settings - Fork 962
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into cli_main_reorg
- Loading branch information
Showing
22 changed files
with
3,102 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
--- | ||
title: "LoRA Optimizations" | ||
description: "Custom autograd functions and Triton kernels in Axolotl for optimized | ||
LoRA fine-tuning" | ||
--- | ||
|
||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two | ||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU | ||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function | ||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was | ||
to leverage operator fusion and tensor re-use in order to improve speed and reduce | ||
memory usage during the forward and backward passes of these calculations. | ||
|
||
We currently support several common model architectures, including (but not limited to): | ||
- `llama` | ||
- `mistral` | ||
- `qwen2` | ||
- `gemma` | ||
- `gemma2` | ||
|
||
<details> | ||
|
||
The set of models we support is currently limited by our attention patching strategy, | ||
which assumes (and replaces) specific code blocks for query / key / value and output | ||
projections: | ||
|
||
```python | ||
ORIGINAL_QKV_CODE = """ | ||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) | ||
""".lstrip( | ||
"\n" | ||
) | ||
|
||
ORIGINAL_O_CODE = """ | ||
attn_output = self.o_proj(attn_output) | ||
""".lstrip( | ||
"\n" | ||
) | ||
``` | ||
|
||
Is replaced with: | ||
|
||
```python | ||
PATCHED_QKV_CODE = """ | ||
query_states, key_states, value_states = self.apply_qkv(hidden_states) | ||
query_states = query_states.view(hidden_shape).transpose(1, 2) | ||
key_states = key_states.view(hidden_shape).transpose(1, 2) | ||
value_states = value_states.view(hidden_shape).transpose(1, 2) | ||
""".lstrip( | ||
"\n" | ||
) | ||
|
||
PATCHED_O_CODE = """ | ||
attn_output = self.apply_o(attn_output) | ||
""".lstrip( | ||
"\n" | ||
) | ||
``` | ||
|
||
Where `apply_qkv` and `apply_o` are defined in the `axolotl.kernels.lora` module. | ||
|
||
We welcome testing of other model architectures and / or PRs to expand our patching | ||
logic to be compatible with more of them. | ||
|
||
</details> | ||
|
||
## Usage | ||
|
||
These optimizations can be enabled in your Axolotl config YAML file. The | ||
`lora_mlp_kernel` option enables the optimized MLP path, while `lora_qkv_kernel` and | ||
`lora_o_kernel` enable the fused query-key-value projection and optimized output | ||
projection, respectively. | ||
|
||
```yaml | ||
lora_mlp_kernel: true | ||
lora_qkv_kernel: true | ||
lora_o_kernel: true | ||
``` | ||
## Requirements | ||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels) | ||
- AMD can be used with experimental Triton support by setting the environment variable `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` | ||
- Targeted LoRA adapters cannot use Dropout | ||
- This may limit model expressivity / cause overfitting | ||
- Targeted LoRA adapters cannot have bias terms | ||
- This may limit model expressivity | ||
|
||
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to | ||
be re-finetuned without these features in order to be useful. | ||
|
||
## Implementation details | ||
|
||
### Custom autograd functions | ||
|
||
The LoRA MLP autograd function optimizes the entire MLP computation path. It fuses the | ||
LoRA and base weight computations together and provides a single, efficient backward | ||
pass for the entire MLP block. | ||
|
||
For attention components, similar optimizations are provided through a function that | ||
handles the query, key, and value projections, and a function that handles the output | ||
projection. They are designed to work with the existing `transformers` attention | ||
implementation via some monkey-patching logic. | ||
|
||
### Triton kernels | ||
|
||
Two activation functions (SwiGLU and GeGLU) are implemented with Triton kernels for | ||
improved speed and memory performance. These kernels handle both the forward and | ||
backward passes. | ||
|
||
### Integration | ||
|
||
The custom autograd functions and Triton kernels are designed to work together. The | ||
autograd function manages the high-level computation flow and gradient tracking, while | ||
calling the Triton kernels for the activation function computation. During the backward | ||
pass, the kernel computes both the activation output and the required gradients, which | ||
the autograd function then uses to compute the final gradients for the entire | ||
computation path. | ||
|
||
## Future Work | ||
|
||
- Support for additional model architectures | ||
- Support for the FSDP setting | ||
- Support for dropout and bias | ||
- Additional operator fusions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
base_model: NousResearch/Llama-3.2-1B | ||
# Automatically upload checkpoint and final model to HF | ||
# hub_model_id: username/custom_model_name | ||
|
||
load_in_8bit: false | ||
load_in_4bit: false | ||
strict: false | ||
|
||
datasets: | ||
- path: teknium/GPT4-LLM-Cleaned | ||
type: alpaca | ||
dataset_prepared_path: last_run_prepared | ||
val_set_size: 0.1 | ||
output_dir: ./outputs/lora-out | ||
|
||
adapter: lora | ||
lora_model_dir: | ||
|
||
sequence_len: 2048 | ||
sample_packing: true | ||
pad_to_sequence_len: true | ||
|
||
lora_r: 16 | ||
lora_alpha: 32 | ||
# Currently, we don't support dropout with our custom Triton kernels | ||
# lora_dropout: 0.05 | ||
lora_fan_in_fan_out: | ||
lora_target_modules: | ||
- gate_proj | ||
- down_proj | ||
- up_proj | ||
- q_proj | ||
- v_proj | ||
- k_proj | ||
- o_proj | ||
|
||
# These options enable our custom Triton kernels / autograd | ||
# functions for MLP and attention calculations | ||
lora_mlp_kernel: true | ||
lora_qkv_kernel: true | ||
lora_o_kernel: true | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
gradient_accumulation_steps: 2 | ||
micro_batch_size: 2 | ||
num_epochs: 1 | ||
optimizer: adamw_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 0.0002 | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: auto | ||
fp16: | ||
tf32: false | ||
|
||
gradient_checkpointing: true | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
|
||
loss_watchdog_threshold: 5.0 | ||
loss_watchdog_patience: 3 | ||
|
||
warmup_steps: 10 | ||
evals_per_epoch: 4 | ||
saves_per_epoch: 1 | ||
debug: | ||
deepspeed: | ||
weight_decay: 0.0 | ||
fsdp: | ||
fsdp_config: | ||
special_tokens: | ||
pad_token: "<|end_of_text|>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.