|
| 1 | +""" |
| 2 | +HuggingFace flash attention adapter for basic ring attention (batch API). |
| 3 | +
|
| 4 | +Inspired by |
| 5 | +https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py. |
| 6 | +Our implementation closely follows the structure of that module, but we've minified it |
| 7 | +somewhat to support only the latest versions of transformers. |
| 8 | +""" |
| 9 | + |
| 10 | +# pylint: disable=protected-access,cyclic-import |
| 11 | + |
| 12 | +import os |
| 13 | +from typing import Callable |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.distributed as dist |
| 17 | +import transformers |
| 18 | +import transformers.modeling_flash_attention_utils |
| 19 | +from ring_flash_attn import ( |
| 20 | + ring_flash_attn_func, |
| 21 | + stripe_flash_attn_func, |
| 22 | + zigzag_ring_flash_attn_func, |
| 23 | +) |
| 24 | +from ring_flash_attn.adapters.hf_adapter import check_params |
| 25 | +from transformers.modeling_flash_attention_utils import ( |
| 26 | + _flash_supports_window_size, |
| 27 | + is_flash_attn_greater_or_equal, |
| 28 | +) |
| 29 | +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| 30 | + |
| 31 | +from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc |
| 32 | + |
| 33 | +RING_ATTN_FUNC_MAPPING = { |
| 34 | + RingAttnFunc.BATCH_RING: ring_flash_attn_func, |
| 35 | + RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func, |
| 36 | + RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func, |
| 37 | +} |
| 38 | + |
| 39 | + |
| 40 | +def create_flash_attn_forward( |
| 41 | + process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc |
| 42 | +) -> Callable: |
| 43 | + """ |
| 44 | + Create a ring flash attention forward function compatible with HuggingFace's |
| 45 | + interface. |
| 46 | +
|
| 47 | + Args: |
| 48 | + process_group: A PyTorch distributed process group. |
| 49 | + ring_attn_func: Function from `ring_flash_attention` to replace HF flash |
| 50 | + attention with. |
| 51 | +
|
| 52 | + Returns: |
| 53 | + A function that implements the ring flash attention forward pass with the |
| 54 | + signature expected by HuggingFace Transformers. |
| 55 | + """ |
| 56 | + |
| 57 | + # transformers 4.48+ |
| 58 | + # pylint: disable=unused-argument |
| 59 | + def _flash_attention_forward( |
| 60 | + query_states: torch.Tensor, |
| 61 | + key_states: torch.Tensor, |
| 62 | + value_states: torch.Tensor, |
| 63 | + attention_mask: torch.Tensor, |
| 64 | + query_length: int, |
| 65 | + is_causal: bool, |
| 66 | + dropout: float = 0.0, |
| 67 | + position_ids: torch.Tensor | None = None, |
| 68 | + softmax_scale: float | None = None, |
| 69 | + sliding_window: int | None = None, |
| 70 | + use_top_left_mask: bool = False, |
| 71 | + softcap: float | None = None, |
| 72 | + deterministic: bool = None, |
| 73 | + cu_seq_lens_q: torch.LongTensor | None = None, |
| 74 | + cu_seq_lens_k: torch.LongTensor | None = None, |
| 75 | + max_length_q: int | None = None, |
| 76 | + max_length_k: int | None = None, |
| 77 | + target_dtype: torch.dtype | None = None, |
| 78 | + **kwargs, |
| 79 | + ): |
| 80 | + """ |
| 81 | + Calls the forward method of Ring Flash Attention. |
| 82 | +
|
| 83 | + Args: |
| 84 | + query_states: Tensor containing the query vectors. |
| 85 | + key_states: Tensor containing the key vectors. |
| 86 | + value_states: Tensor containing the value vectors. |
| 87 | + attention_mask: Not used in this implementation. |
| 88 | + query_length: Integer representing the length of the query sequence. |
| 89 | + is_causal: Boolean indicating whether to apply a causal mask to the attention. |
| 90 | + dropout: Float representing the dropout probability. Default is 0.0. |
| 91 | + position_ids: Not used in this implementation. |
| 92 | + softmax_scale: Optional float value for the softmax scaling factor. Default is None. |
| 93 | + sliding_window: Optional integer defining the size of the sliding attention window. |
| 94 | + Default is None. |
| 95 | + use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention. |
| 96 | + Default is False. |
| 97 | + softcap: Not used in this implementation. |
| 98 | + deterministic: Optional boolean to enforce deterministic computation. Default is None. |
| 99 | + cu_seq_lens_q: Not used in this implementation. |
| 100 | + cu_seq_lens_k: Not used in this implementation. |
| 101 | + max_length_q: Not used in this implementation. |
| 102 | + max_length_k: Not used in this implementation. |
| 103 | + target_dtype: Not used in this implementation. |
| 104 | + **kwargs: Additional keyword arguments. Not used in this implementation. |
| 105 | +
|
| 106 | + Returns: |
| 107 | + torch.Tensor: The output of the attention mechanism, with shape |
| 108 | + `[batch_size, query_length, num_heads, head_dim]`. |
| 109 | + """ |
| 110 | + if not use_top_left_mask: |
| 111 | + causal = is_causal |
| 112 | + else: |
| 113 | + causal = is_causal and query_length != 1 |
| 114 | + |
| 115 | + # Handle sliding window |
| 116 | + use_sliding_windows = ( |
| 117 | + _flash_supports_window_size |
| 118 | + and sliding_window is not None |
| 119 | + and key_states.shape[1] > sliding_window |
| 120 | + ) |
| 121 | + window_size = ( |
| 122 | + (sliding_window, sliding_window) if use_sliding_windows else (-1, -1) |
| 123 | + ) |
| 124 | + |
| 125 | + # Handle deterministic mode |
| 126 | + if is_flash_attn_greater_or_equal("2.4.1"): |
| 127 | + if deterministic is None: |
| 128 | + deterministic = ( |
| 129 | + os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" |
| 130 | + ) |
| 131 | + |
| 132 | + # Call ring flash attention function |
| 133 | + attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func]( |
| 134 | + query_states, |
| 135 | + key_states, |
| 136 | + value_states, |
| 137 | + dropout_p=dropout, |
| 138 | + softmax_scale=softmax_scale, |
| 139 | + causal=causal, |
| 140 | + window_size=window_size, |
| 141 | + alibi_slopes=None, |
| 142 | + deterministic=deterministic, |
| 143 | + return_attn_probs=False, |
| 144 | + group=process_group, |
| 145 | + ) |
| 146 | + |
| 147 | + return attn_output |
| 148 | + |
| 149 | + return _flash_attention_forward |
| 150 | + |
| 151 | + |
| 152 | +def substitute_hf_flash_attn( |
| 153 | + process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc |
| 154 | +): |
| 155 | + """ |
| 156 | + Substitute HuggingFace's flash attention implementation with ring-based implementation. |
| 157 | +
|
| 158 | + Args: |
| 159 | + process_group: PyTorch distributed process group for communication. |
| 160 | + ring_attn_func: Function from `ring_flash_attention` to replace HF flash |
| 161 | + attention with. |
| 162 | + """ |
| 163 | + try: |
| 164 | + # Substitute flash attention |
| 165 | + old_flash_attention_forward = ( |
| 166 | + transformers.modeling_flash_attention_utils._flash_attention_forward |
| 167 | + ) |
| 168 | + new_flash_attention_forward = create_flash_attn_forward( |
| 169 | + process_group=process_group, ring_attn_func=ring_attn_func |
| 170 | + ) |
| 171 | + |
| 172 | + if check_params(old_flash_attention_forward, new_flash_attention_forward): |
| 173 | + transformers.modeling_flash_attention_utils._flash_attention_forward = ( |
| 174 | + new_flash_attention_forward |
| 175 | + ) |
| 176 | + else: |
| 177 | + raise ValueError( |
| 178 | + "The signature of the new flash attention forward function does not match the old one." |
| 179 | + ) |
| 180 | + except Exception as exception: |
| 181 | + raise ValueError( |
| 182 | + f"The current transformer version {transformers.__version__} is not supported. " |
| 183 | + "Please use pip install -U transformers to upgrade to the latest version. " |
| 184 | + "If the code failed with the latest version, " |
| 185 | + f"please file an issue." |
| 186 | + ) from exception |
| 187 | + |
| 188 | + # Register with ALL_ATTENTION_FUNCTIONS if available |
| 189 | + if ALL_ATTENTION_FUNCTIONS is not None: |
| 190 | + from ring_flash_attn.adapters.hf_adapter import flash_attention_forward |
| 191 | + |
| 192 | + ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward |
0 commit comments