Skip to content

Commit b8c633a

Browse files
authored
batch api HF adapter for ring-flash-attn; cleanup and improvements (axolotl-ai-cloud#2520)
* batch api HF adapter for ring-flash-attn; cleanup and improvements * update * adding all batch ring-flash-attn methods via single adapter * removing pad_to_sequence_len=False for now * fix * updating docs to include batch SP * review comments * fixes for batch API funcs, simplify * fixes * fix * updates * add batch_zigzag smoke test
1 parent 682a9cf commit b8c633a

File tree

13 files changed

+397
-49
lines changed

13 files changed

+397
-49
lines changed

docs/config.qmd

+3
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,9 @@ sequence_parallel_degree:
693693
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
694694
# Must evenly divide the number of KV heads in your model.
695695
heads_k_stride: 1
696+
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
697+
# in the sample packing case, and "batch_ring" in the non-sample packing case.
698+
ring_attn_func:
696699

697700
# Path to torch distx for optim 'adamw_anyprecision'
698701
torchdistx_path:

docs/sequence_parallelism.qmd

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ To enable sequence parallelism, add the following to your configuration file:
2727
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
2828
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
2929
heads_k_stride: 1
30+
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
31+
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
32+
ring_attn_func:
3033
```
3134
3235
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:

src/axolotl/core/trainer_builder.py

+2
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,7 @@ def build(self, total_num_steps):
776776
training_arguments_kwargs["sequence_parallel_degree"] = (
777777
self.cfg.sequence_parallel_degree
778778
)
779+
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
779780

780781
if self.cfg.reward_model:
781782
training_args_cls = AxolotlRewardConfig
@@ -933,6 +934,7 @@ def build_collator(
933934
kwargs["return_tensors"] = "pt"
934935
if issubclass(collator, DataCollatorForSeq2Seq):
935936
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
937+
kwargs["ring_attn_func"] = training_args.ring_attn_func
936938

937939
return collator(
938940
*collator_args,

src/axolotl/core/training_args.py

+8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from transformers import TrainingArguments
1010
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
1111

12+
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
13+
1214

1315
@dataclass
1416
class AxolotlTrainingMixins:
@@ -218,6 +220,12 @@ class AxolotlTrainingMixins:
218220
default=1,
219221
metadata={"help": "The number of workers to use in sequence parallelism"},
220222
)
223+
ring_attn_func: Optional[RingAttnFunc] = field(
224+
default=None,
225+
metadata={
226+
"help": "The ring-flash-attn function to use in sequence parallelism"
227+
},
228+
)
221229

222230
# multi-modal section
223231

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Init for ring attention monkeypatch module"""
2+
3+
# pylint: disable=unused-import
4+
# flake8: noqa
5+
6+
from .patch import (
7+
RingAttnFunc,
8+
get_ring_attn_group,
9+
register_ring_attn,
10+
set_ring_attn_group,
11+
update_ring_attn_params,
12+
)

src/axolotl/monkeypatch/attention/ring_attn/adapters/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""
2+
HuggingFace flash attention adapter for basic ring attention (batch API).
3+
4+
Inspired by
5+
https://github.com/zhuzilin/ring-flash-attention/blob/ce9fd3935ca0e5f0592bb0826cbed18ec69da729/ring_flash_attn/adapters/hf_adapter.py.
6+
Our implementation closely follows the structure of that module, but we've minified it
7+
somewhat to support only the latest versions of transformers.
8+
"""
9+
10+
# pylint: disable=protected-access,cyclic-import
11+
12+
import os
13+
from typing import Callable
14+
15+
import torch
16+
import torch.distributed as dist
17+
import transformers
18+
import transformers.modeling_flash_attention_utils
19+
from ring_flash_attn import (
20+
ring_flash_attn_func,
21+
stripe_flash_attn_func,
22+
zigzag_ring_flash_attn_func,
23+
)
24+
from ring_flash_attn.adapters.hf_adapter import check_params
25+
from transformers.modeling_flash_attention_utils import (
26+
_flash_supports_window_size,
27+
is_flash_attn_greater_or_equal,
28+
)
29+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
30+
31+
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
32+
33+
RING_ATTN_FUNC_MAPPING = {
34+
RingAttnFunc.BATCH_RING: ring_flash_attn_func,
35+
RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func,
36+
RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func,
37+
}
38+
39+
40+
def create_flash_attn_forward(
41+
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
42+
) -> Callable:
43+
"""
44+
Create a ring flash attention forward function compatible with HuggingFace's
45+
interface.
46+
47+
Args:
48+
process_group: A PyTorch distributed process group.
49+
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
50+
attention with.
51+
52+
Returns:
53+
A function that implements the ring flash attention forward pass with the
54+
signature expected by HuggingFace Transformers.
55+
"""
56+
57+
# transformers 4.48+
58+
# pylint: disable=unused-argument
59+
def _flash_attention_forward(
60+
query_states: torch.Tensor,
61+
key_states: torch.Tensor,
62+
value_states: torch.Tensor,
63+
attention_mask: torch.Tensor,
64+
query_length: int,
65+
is_causal: bool,
66+
dropout: float = 0.0,
67+
position_ids: torch.Tensor | None = None,
68+
softmax_scale: float | None = None,
69+
sliding_window: int | None = None,
70+
use_top_left_mask: bool = False,
71+
softcap: float | None = None,
72+
deterministic: bool = None,
73+
cu_seq_lens_q: torch.LongTensor | None = None,
74+
cu_seq_lens_k: torch.LongTensor | None = None,
75+
max_length_q: int | None = None,
76+
max_length_k: int | None = None,
77+
target_dtype: torch.dtype | None = None,
78+
**kwargs,
79+
):
80+
"""
81+
Calls the forward method of Ring Flash Attention.
82+
83+
Args:
84+
query_states: Tensor containing the query vectors.
85+
key_states: Tensor containing the key vectors.
86+
value_states: Tensor containing the value vectors.
87+
attention_mask: Not used in this implementation.
88+
query_length: Integer representing the length of the query sequence.
89+
is_causal: Boolean indicating whether to apply a causal mask to the attention.
90+
dropout: Float representing the dropout probability. Default is 0.0.
91+
position_ids: Not used in this implementation.
92+
softmax_scale: Optional float value for the softmax scaling factor. Default is None.
93+
sliding_window: Optional integer defining the size of the sliding attention window.
94+
Default is None.
95+
use_top_left_mask: Boolean indicating whether to use a top-left mask for the attention.
96+
Default is False.
97+
softcap: Not used in this implementation.
98+
deterministic: Optional boolean to enforce deterministic computation. Default is None.
99+
cu_seq_lens_q: Not used in this implementation.
100+
cu_seq_lens_k: Not used in this implementation.
101+
max_length_q: Not used in this implementation.
102+
max_length_k: Not used in this implementation.
103+
target_dtype: Not used in this implementation.
104+
**kwargs: Additional keyword arguments. Not used in this implementation.
105+
106+
Returns:
107+
torch.Tensor: The output of the attention mechanism, with shape
108+
`[batch_size, query_length, num_heads, head_dim]`.
109+
"""
110+
if not use_top_left_mask:
111+
causal = is_causal
112+
else:
113+
causal = is_causal and query_length != 1
114+
115+
# Handle sliding window
116+
use_sliding_windows = (
117+
_flash_supports_window_size
118+
and sliding_window is not None
119+
and key_states.shape[1] > sliding_window
120+
)
121+
window_size = (
122+
(sliding_window, sliding_window) if use_sliding_windows else (-1, -1)
123+
)
124+
125+
# Handle deterministic mode
126+
if is_flash_attn_greater_or_equal("2.4.1"):
127+
if deterministic is None:
128+
deterministic = (
129+
os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
130+
)
131+
132+
# Call ring flash attention function
133+
attn_output = RING_ATTN_FUNC_MAPPING[ring_attn_func](
134+
query_states,
135+
key_states,
136+
value_states,
137+
dropout_p=dropout,
138+
softmax_scale=softmax_scale,
139+
causal=causal,
140+
window_size=window_size,
141+
alibi_slopes=None,
142+
deterministic=deterministic,
143+
return_attn_probs=False,
144+
group=process_group,
145+
)
146+
147+
return attn_output
148+
149+
return _flash_attention_forward
150+
151+
152+
def substitute_hf_flash_attn(
153+
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
154+
):
155+
"""
156+
Substitute HuggingFace's flash attention implementation with ring-based implementation.
157+
158+
Args:
159+
process_group: PyTorch distributed process group for communication.
160+
ring_attn_func: Function from `ring_flash_attention` to replace HF flash
161+
attention with.
162+
"""
163+
try:
164+
# Substitute flash attention
165+
old_flash_attention_forward = (
166+
transformers.modeling_flash_attention_utils._flash_attention_forward
167+
)
168+
new_flash_attention_forward = create_flash_attn_forward(
169+
process_group=process_group, ring_attn_func=ring_attn_func
170+
)
171+
172+
if check_params(old_flash_attention_forward, new_flash_attention_forward):
173+
transformers.modeling_flash_attention_utils._flash_attention_forward = (
174+
new_flash_attention_forward
175+
)
176+
else:
177+
raise ValueError(
178+
"The signature of the new flash attention forward function does not match the old one."
179+
)
180+
except Exception as exception:
181+
raise ValueError(
182+
f"The current transformer version {transformers.__version__} is not supported. "
183+
"Please use pip install -U transformers to upgrade to the latest version. "
184+
"If the code failed with the latest version, "
185+
f"please file an issue."
186+
) from exception
187+
188+
# Register with ALL_ATTENTION_FUNCTIONS if available
189+
if ALL_ATTENTION_FUNCTIONS is not None:
190+
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
191+
192+
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward

src/axolotl/monkeypatch/attention/ring_attn.py renamed to src/axolotl/monkeypatch/attention/ring_attn/patch.py

+44-21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
their sequence parallel version of Flash Attention 2.
77
"""
88

9+
from enum import Enum
10+
911
import torch
1012
import torch.distributed as dist
1113
from accelerate.logging import get_logger
@@ -16,6 +18,7 @@
1618
configure_logging()
1719
LOG = get_logger(__name__)
1820

21+
1922
RING_ATTN_GROUP = None
2023

2124

@@ -40,14 +43,32 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
4043
RING_ATTN_GROUP = ring_attn_group
4144

4245

43-
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
46+
class RingAttnFunc(str, Enum):
47+
"""Enum class for supported `ring-flash-attn` implementations"""
48+
49+
# VARLEN_RING = "varlen_ring"
50+
# VARLEN_ZIGZAG = "varlen_zigzag"
51+
VARLEN_LLAMA3 = "varlen_llama3"
52+
BATCH_RING = "batch_ring"
53+
BATCH_ZIGZAG = "batch_zigzag"
54+
BATCH_STRIPE = "batch_stripe"
55+
56+
57+
def register_ring_attn(
58+
sequence_parallel_degree: int,
59+
heads_k_stride: int | None,
60+
ring_attn_func: RingAttnFunc | None,
61+
):
4462
"""
4563
Create ring attention group and substitute flash attn with ring flash attn.
4664
4765
Args:
4866
sequence_parallel_degree: Sequence parallelism factor.
4967
heads_k_stride: Sequence parallelism K head stride size. Passed
5068
through to `ring_flash_attn.substitute_hf_flash_attn`.
69+
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
70+
packing is enabled, it must be a `varlen` function; otherwise, it must be a
71+
`batch` function.
5172
"""
5273
if get_ring_attn_group() is not None:
5374
LOG.info("Ring attention already registered, exiting early...")
@@ -58,7 +79,9 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
5879
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
5980
)
6081

82+
rank = dist.get_rank()
6183
world_size = dist.get_world_size()
84+
6285
assert sequence_parallel_degree <= world_size, (
6386
f"sequence_parallel_degree ({sequence_parallel_degree}) "
6487
f"must be less than or equal to world_size ({world_size})"
@@ -68,10 +91,8 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
6891
f"must evenly divide world_size ({world_size})"
6992
)
7093

71-
# Detailed logging of group formation
72-
rank = dist.get_rank()
94+
# Assign ranks to sequence parallel groups
7395
group_assignments = {}
74-
7596
for i in range(world_size // sequence_parallel_degree):
7697
ring_attn_ranks = list(
7798
range(
@@ -92,35 +113,37 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None
92113
if rank == 0:
93114
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
94115

95-
if heads_k_stride is None:
96-
heads_k_stride = 1
116+
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
117+
from ring_flash_attn import substitute_hf_flash_attn
97118

98-
from ring_flash_attn import substitute_hf_flash_attn
119+
substitute_hf_flash_attn(
120+
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
121+
)
122+
elif ring_attn_func in [
123+
RingAttnFunc.BATCH_RING,
124+
RingAttnFunc.BATCH_ZIGZAG,
125+
RingAttnFunc.BATCH_STRIPE,
126+
]:
127+
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
128+
substitute_hf_flash_attn,
129+
)
99130

100-
substitute_hf_flash_attn(
101-
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
102-
)
131+
substitute_hf_flash_attn(
132+
process_group=get_ring_attn_group(),
133+
ring_attn_func=ring_attn_func,
134+
)
103135

104136

105-
def update_ring_attn_params(batch: dict[str, torch.Tensor]):
137+
def update_ring_attn_params(position_ids: torch.Tensor | None):
106138
"""
107139
Calculate the cumulative sequence lengths for the current forward pass and pass the
108140
value to the substituted `ring_flash_attn`.
109141
110142
Args:
111-
batch: A dictionary with a batch of data. May or may not contain `position_ids`
112-
data; if not, we compute it.
143+
position_ids: Optional tensor of position IDs (for sample packed data).
113144
"""
114145
from ring_flash_attn import update_ring_flash_attn_params
115146

116-
input_ids = batch["input_ids"]
117-
position_ids = batch.get("position_ids")
118-
if position_ids is None:
119-
seq_len = input_ids.shape[1]
120-
position_ids = torch.arange(
121-
0, seq_len, dtype=torch.long, device=input_ids.device
122-
).unsqueeze(0)
123-
124147
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
125148
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
126149
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

0 commit comments

Comments
 (0)