diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 12346b8a27..0371963c94 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -831,7 +831,9 @@ def build_collator( if "max_length" in kwargs: kwargs.pop("max_length") elif use_batch_sampler_collator: - if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: + if self.cfg.flex_attention is True: + collator = V2BatchSamplerDataCollatorForSeq2Seq + elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: collator = V2BatchSamplerDataCollatorForSeq2Seq elif ( self.cfg.model_config_type in ["llama"] diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index c2772b4715..b6bb159a1b 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -95,6 +95,103 @@ def get_cu_seqlens(attn_mask): return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) +def get_packed_mask_from_pos_ids(position_ids): + if len(position_ids.shape) == 1: + position_ids = position_ids.unsqueeze(0) + + device = position_ids.device + results = [] + + for i, row in enumerate(position_ids): + # Count the number of consecutive zeros from the right side + padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() + + # Adjust the row to exclude padding + adjusted_row = row[:-padding_length] if padding_length else row.clone() + + # Find where the position resets to 0 (indicating a new sequence) + seq_starts = torch.cat( + [ + torch.tensor([True], dtype=torch.bool, device=device), + adjusted_row[1:] == 0, + ] + ) + # Get the indices where the sequence starts + start_indices = torch.cat( + [ + torch.nonzero(seq_starts).unbind(dim=1)[0], + torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), + ] + ) + # Calculate the sequence lengths + seq_lengths = start_indices[1:] - start_indices[:-1] + # Append the padding length to the sequence lengths + doc_mask = torch.ones(len(row), dtype=torch.int32, device=device) + for i, seq_len in enumerate(seq_lengths): + start_id = start_indices[i] + doc_mask[start_id : start_id + seq_len] = ( + (i+1) * doc_mask[start_id : start_id + seq_len] + ) + if padding_length: + doc_mask[len(adjusted_row) :] = 0 * doc_mask[len(adjusted_row) :] + + results.append(doc_mask) + + return torch.stack(results) + + +def get_seqlens_from_pos_ids(position_ids): + """generate a sequence length set using pos ids for doc mask creation in flex attention""" + if len(position_ids.shape) == 1: + position_ids = position_ids.unsqueeze(0) + max_seq_len = position_ids.shape[1] + + device = position_ids.device + results = [] + totalseqlens = [] + + for row in position_ids: + # Count the number of consecutive zeros from the right side + padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() + + # Adjust the row to exclude padding + adjusted_row = row[:-padding_length] if padding_length else row.clone() + + # Find where the position resets to 0 (indicating a new sequence) + seq_starts = torch.cat( + [ + torch.tensor([True], dtype=torch.bool, device=device), + adjusted_row[1:] == 0, + ] + ) + # Get the indices where the sequence starts + start_indices = torch.cat( + [ + torch.nonzero(seq_starts).unbind(dim=1)[0], + torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), + ] + ) + # Calculate the sequence lengths + seq_lengths = start_indices[1:] - start_indices[:-1] + # Append the padding length to the sequence lengths + if padding_length: + seq_lengths = torch.cat( + [ + seq_lengths, + torch.tensor( + [len(row) - torch.sum(seq_lengths)], + dtype=torch.int32, + device=device, + ), + ] + ) + + results.append(seq_lengths) + totalseqlens.append(len(adjusted_row)) + + return results, torch.tensor(totalseqlens, dtype=torch.int32, device=device) + + def get_cu_seqlens_from_pos_ids(position_ids): """generate a cumulative sequence length mask for flash attention using pos ids""" if len(position_ids.shape) == 1: @@ -176,7 +273,10 @@ def mask_2d_to_4d( when they attend to each other within that sequence. This expansion transforms the mask to lower triangular form to prevent future peeking. """ - bsz, src_len = mask.size() + + if len(mask.size()) == 4: + return mask + bsz, src_len = int(mask.size()[0]), int(mask.size()[1]) tgt_len = tgt_len if tgt_len is not None else src_len mask = mask.unsqueeze(1).unsqueeze(2) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1810413bee..475f79b33f 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -823,6 +823,7 @@ class AxolotlInputConfig( xformers_attention: Optional[bool] = None sdp_attention: Optional[bool] = None s2_attention: Optional[bool] = None + flex_attention: Optional[bool] = None flash_attention: Optional[bool] = None flash_attn_cross_entropy: Optional[bool] = None flash_attn_rms_norm: Optional[bool] = None @@ -1789,6 +1790,26 @@ def check_adopt_torch_version(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_flex_torch_version(cls, data): + if (data.get("flex_attention") is not None) and ( + data.get("flex_attention") is True + ): + env_capabilities = data.get("env_capabilities", {}) + torch_version = env_capabilities.get("torch_version") + + if torch_version is None: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + + if version.parse(torch_version) < version.parse("2.5.1"): + raise ValueError( + "Flex attention is not supported on torch version < 2.5.1" + ) + return data + @model_validator(mode="before") @classmethod def check_torch_compile_auto(cls, data): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c4c07dd33d..6d2dbe8d89 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -403,7 +403,7 @@ def apply_patches(self) -> None: if ( self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and self.cfg.flash_attention + and (self.cfg.flash_attention or self.cfg.flex_attention) and self.cfg.sample_packing ): if "auto_map" in self.model_config: @@ -707,7 +707,13 @@ def set_attention_config(self) -> None: """ sample packing uses custom FA2 patch """ - if self.cfg.flash_attention: + + if self.cfg.flex_attention: + self.model_kwargs["attn_implementation"] = "flex_attention" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flex_attention" + ) + elif self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass self.model_kwargs["attn_implementation"] = "flash_attention_2" @@ -1113,7 +1119,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: should_convert = ( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - ((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp) + ((needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp) or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass )