Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flex attention support #2255

Closed
wants to merge 76 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
61ad375
config validation for flex attention
bursteratom Jan 6, 2025
bcd9ad4
flex attention support
bursteratom Jan 7, 2025
543daaf
llama test
bursteratom Jan 9, 2025
0b47281
Fixing OSX installation (#2231)
SalmanMohammadi Jan 7, 2025
2346f21
Merge group queue (#2248)
winglian Jan 9, 2025
9c9ac1c
add hf cache caching for GHA (#2247)
winglian Jan 9, 2025
fac059a
fix: mistral nemo does not recognize token_type_ids in forward (#2233)
NanoCode012 Jan 9, 2025
e0d4b88
update modal version for ci (#2242)
winglian Jan 9, 2025
da97a21
Use SequentialSampler if curriculum_sampling is enabled with sample_p…
v-dicicco Jan 9, 2025
b7d27bd
update upstream HF deps (#2239)
winglian Jan 9, 2025
5eae134
feat: add support for data_files in pretraining (#2238)
NanoCode012 Jan 9, 2025
bd62d6e
rename liger test so it properly runs in ci (#2246)
winglian Jan 9, 2025
888cd94
use 2.5.1 docker images as latest tag as it seems stable (#2198)
winglian Jan 10, 2025
f99cae0
llama test
bursteratom Jan 12, 2025
2319ac7
Merge branch 'main' into flx_attn_support
bursteratom Jan 13, 2025
8b47e45
revert to transformers 4.47.1
bursteratom Jan 13, 2025
d3a0cb5
transformers version
bursteratom Jan 13, 2025
c06a6be
flex_attn sample packing WIP
bursteratom Jan 14, 2025
dbcd11e
revert seq len in multipack sampler
bursteratom Jan 14, 2025
a6f2c5d
flex sample packing WIP
bursteratom Jan 16, 2025
aad6242
not sure if this is necessary actually
bursteratom Jan 16, 2025
013a9b7
fix transformers version for testing
bursteratom Jan 16, 2025
a5360c1
llama hijacking
bursteratom Jan 17, 2025
80bfc50
get seqlens from position ids for foc masking
bursteratom Jan 17, 2025
b2a3438
sample packing doc mask creation WIP
bursteratom Jan 21, 2025
5f9f77f
llama patch
bursteratom Jan 22, 2025
d1be6e2
llama sdpa patching WIP
bursteratom Jan 23, 2025
cee310d
llama sdpa patching WIP
bursteratom Jan 23, 2025
b7deb52
llama sdpa patching WIP
bursteratom Jan 23, 2025
f3bec17
llama sdpa patching WIP - static class function import
bursteratom Jan 23, 2025
d7b133d
llama sdpa patching WIP - static class function import
bursteratom Jan 23, 2025
06f83a5
llama sdpa patching WIP - static class function import
bursteratom Jan 23, 2025
2753282
llama sdpa patching WIP - static class function import
bursteratom Jan 23, 2025
152e988
llama sdpa patching WIP - static class function import
bursteratom Jan 23, 2025
0dd18a3
llama sdpa patching WIP - static class function import
bursteratom Jan 23, 2025
bb9bea3
mask expansion
bursteratom Jan 23, 2025
8b3eec7
mask expansion
bursteratom Jan 23, 2025
f2f23c8
mask expansion
bursteratom Jan 23, 2025
85752cd
mask expansion
bursteratom Jan 23, 2025
e8b2789
revert mask expand
bursteratom Jan 23, 2025
555aa57
skip mask conversion if already 4d
bursteratom Jan 23, 2025
8c34c65
dummy
bursteratom Jan 23, 2025
0149de7
mask to bool
bursteratom Jan 23, 2025
5ca57cb
undo bool conversion
bursteratom Jan 23, 2025
b31796a
Merge branch 'main' into flx_attn_support
bursteratom Jan 28, 2025
ba88bc7
wip flex block mask creation
bursteratom Jan 29, 2025
96ad741
flex batching WIP
bursteratom Jan 30, 2025
065f6d4
flex batching WIP
bursteratom Jan 30, 2025
93a268e
--no-verify
bursteratom Jan 30, 2025
8496000
reset llama_patch_multipack.py
bursteratom Jan 30, 2025
3ed9c11
try vanilla mask
bursteratom Feb 1, 2025
48c3c47
vanills mask
bursteratom Feb 1, 2025
3f4fd3c
remove padding self attention
bursteratom Feb 2, 2025
907424a
stuff
bursteratom Feb 2, 2025
fa73554
test
bursteratom Feb 2, 2025
10de67e
more test
bursteratom Feb 2, 2025
9a43a09
more test
bursteratom Feb 2, 2025
2319e52
more test
bursteratom Feb 2, 2025
b692d39
more test
bursteratom Feb 2, 2025
b832b11
stuff
bursteratom Feb 2, 2025
e98581f
BLOCK SIZE
bursteratom Feb 2, 2025
0ebab63
test
bursteratom Feb 2, 2025
d3ea379
figure out slight diff from flash result
bursteratom Feb 2, 2025
b0871c8
attempt - mask padding
bursteratom Feb 3, 2025
9f6c89b
undo my stupidity
bursteratom Feb 3, 2025
e5b3690
misc
bursteratom Feb 3, 2025
8e1adc1
stuff
bursteratom Feb 3, 2025
470ba65
make doc mask instead of the whole block mask in collator
bursteratom Feb 5, 2025
adcbc74
misc
bursteratom Feb 5, 2025
3f6be51
stack
bursteratom Feb 5, 2025
d0e739d
attempt at getting around bf16 error
bursteratom Feb 5, 2025
c0a1d20
packed doc mask starts at 1, 0 means masked out
bursteratom Feb 7, 2025
0ef1f01
Merge branch 'main' into flx_attn_support
bursteratom Feb 12, 2025
82d04ea
test v2batch w/ flex attn
bursteratom Feb 13, 2025
e792b54
remove unnecessary components
bursteratom Feb 21, 2025
328bb04
Merge branch 'main' into flx_attn_support
bursteratom Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,9 @@ def build_collator(
if "max_length" in kwargs:
kwargs.pop("max_length")
elif use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
if self.cfg.flex_attention is True:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]
Expand Down
102 changes: 101 additions & 1 deletion src/axolotl/monkeypatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,103 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)


def get_packed_mask_from_pos_ids(position_ids):
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)

device = position_ids.device
results = []

for i, row in enumerate(position_ids):
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()

# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()

# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
doc_mask = torch.ones(len(row), dtype=torch.int32, device=device)
for i, seq_len in enumerate(seq_lengths):
start_id = start_indices[i]
doc_mask[start_id : start_id + seq_len] = (
(i+1) * doc_mask[start_id : start_id + seq_len]
)
if padding_length:
doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :]

results.append(doc_mask)

return torch.stack(results)


def get_seqlens_from_pos_ids(position_ids):
"""generate a sequence length set using pos ids for doc mask creation in flex attention"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
max_seq_len = position_ids.shape[1]

device = position_ids.device
results = []
totalseqlens = []

for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()

# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()

# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
torch.nonzero(seq_starts).unbind(dim=1)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Append the padding length to the sequence lengths
if padding_length:
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[len(row) - torch.sum(seq_lengths)],
dtype=torch.int32,
device=device,
),
]
)

results.append(seq_lengths)
totalseqlens.append(len(adjusted_row))

return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device)


def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
Expand Down Expand Up @@ -176,7 +273,10 @@ def mask_2d_to_4d(
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()

if len(mask.size()) == 4:
return mask
bsz, src_len = int(mask.size()[0]), int(mask.size()[1])
tgt_len = tgt_len if tgt_len is not None else src_len

mask = mask.unsqueeze(1).unsqueeze(2)
Expand Down
21 changes: 21 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ class AxolotlInputConfig(
xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None
flex_attention: Optional[bool] = None
flash_attention: Optional[bool] = None
flash_attn_cross_entropy: Optional[bool] = None
flash_attn_rms_norm: Optional[bool] = None
Expand Down Expand Up @@ -1789,6 +1790,26 @@ def check_adopt_torch_version(cls, data):
)
return data

@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (
data.get("flex_attention") is True
):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")

if torch_version is None:
import torch

torch_version = str(torch.__version__).split("+", maxsplit=1)[0]

if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"Flex attention is not supported on torch version < 2.5.1"
)
return data

@model_validator(mode="before")
@classmethod
def check_torch_compile_auto(cls, data):
Expand Down
12 changes: 9 additions & 3 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def apply_patches(self) -> None:

if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and self.cfg.flash_attention
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing
):
if "auto_map" in self.model_config:
Expand Down Expand Up @@ -707,7 +707,13 @@ def set_attention_config(self) -> None:
"""
sample packing uses custom FA2 patch
"""
if self.cfg.flash_attention:

if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
)
elif self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
Expand Down Expand Up @@ -1113,7 +1119,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
should_convert = (
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp)
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
)

Expand Down
Loading