From 7037e3c836960b315b469ab5659c1695b2fe0583 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Tue, 27 Aug 2024 20:52:40 -0700 Subject: [PATCH] deepseekv2 liger support (#1878) * deepseekv2 liger support * add comment * add missing impl --- src/axolotl/integrations/liger/__init__.py | 26 ++++ .../integrations/liger/models/deepseekv2.py | 127 ++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 src/axolotl/integrations/liger/models/deepseekv2.py diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index f78083300d..2a3e95163b 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -19,6 +19,7 @@ It is designed to be performant, correct, and light-weight. """ import logging +import sys from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP @@ -119,3 +120,28 @@ def pre_model_load(self, cfg): modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward + + elif cfg.model_config_type == "deepseek_v2": + from accelerate import init_empty_weights + from transformers import AutoModelForCausalLM + + with init_empty_weights(): + model = AutoModelForCausalLM.from_pretrained( + cfg.base_model, trust_remote_code=cfg.trust_remote_code or False + ) + modeling_mod = sys.modules[model.__class__.__module__] + + from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward + + if cfg.liger_rope: + # The DeepseekV2 version of RoPE is different than upstream LLaMA. + # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 + logging.warning("Fused liger_rope is not supported for DeepseekV2.") + if cfg.liger_rms_norm: + modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward + if cfg.liger_cross_entropy: + modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward diff --git a/src/axolotl/integrations/liger/models/deepseekv2.py b/src/axolotl/integrations/liger/models/deepseekv2.py new file mode 100644 index 0000000000..79fb274360 --- /dev/null +++ b/src/axolotl/integrations/liger/models/deepseekv2.py @@ -0,0 +1,127 @@ +""" +DeepseekV2 model with LigerFusedLinearCrossEntropyLoss +""" +# pylint: disable=duplicate-code + +from typing import List, Optional, Tuple, Union + +import torch +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import CausalLMOutputWithPast + + +# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) +# @replace_return_docstrings( +# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +# ) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )