Skip to content

Commit dd1a208

Browse files
gshtrassimon-moWoosukKwon
authored
Llama3.1 (#129)
* Add support for a rope extension method (vllm-project#6553) * [BugFix] Fix RoPE error in Llama 3.1 (vllm-project#6693) --------- Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent f49dff3 commit dd1a208

File tree

2 files changed

+202
-18
lines changed

2 files changed

+202
-18
lines changed

vllm/config.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1212
from vllm.model_executor.models import ModelRegistry
1313
from vllm.transformers_utils.config import get_config, get_hf_text_config
14-
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
14+
from vllm.utils import (get_cpu_memory, is_cpu, is_hip, is_neuron,
15+
print_warning_once)
1516

1617
if TYPE_CHECKING:
1718
from ray.util.placement_group import PlacementGroup
@@ -133,6 +134,17 @@ def __init__(
133134
code_revision, rope_scaling)
134135
self.hf_text_config = get_hf_text_config(self.hf_config)
135136
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
137+
138+
if (not self.disable_sliding_window
139+
and self.hf_text_config.model_type == "gemma2"
140+
and self.hf_text_config.sliding_window is not None):
141+
print_warning_once(
142+
"Gemma 2 uses sliding window attention for every odd layer, "
143+
"which is currently not supported by vLLM. Disabling sliding "
144+
"window and capping the max length to the sliding window size "
145+
f"({self.hf_text_config.sliding_window}).")
146+
self.disable_sliding_window = True
147+
136148
self.max_model_len = _get_and_verify_max_len(
137149
hf_config=self.hf_text_config,
138150
max_model_len=max_model_len,
@@ -1225,20 +1237,32 @@ def _get_and_verify_max_len(
12251237
derived_max_model_len = default_max_len
12261238

12271239
rope_scaling = getattr(hf_config, "rope_scaling", None)
1228-
if rope_scaling is not None and rope_scaling["type"] != "su":
1229-
if disable_sliding_window:
1230-
# TODO(robertgshaw): Find a model that supports rope_scaling
1231-
# with sliding window to see if this case should be allowed.
1232-
raise NotImplementedError(
1233-
"Disabling sliding window is not supported for models "
1234-
"with rope_scaling. Please raise an issue so we can "
1235-
"investigate.")
1236-
assert "factor" in rope_scaling
1237-
scaling_factor = rope_scaling["factor"]
1238-
if rope_scaling["type"] == "yarn":
1239-
derived_max_model_len = rope_scaling[
1240-
"original_max_position_embeddings"]
1241-
derived_max_model_len *= scaling_factor
1240+
if rope_scaling is not None:
1241+
if "type" in rope_scaling:
1242+
rope_type = rope_scaling["type"]
1243+
elif "rope_type" in rope_scaling:
1244+
rope_type = rope_scaling["rope_type"]
1245+
else:
1246+
raise ValueError(
1247+
"rope_scaling must have a 'type' or 'rope_type' key.")
1248+
1249+
# The correct one should be "longrope", kept "su" here
1250+
# to be backward compatible
1251+
if rope_type not in ("su", "longrope", "llama3"):
1252+
if disable_sliding_window:
1253+
# TODO(robertgshaw): Find a model that supports rope_scaling
1254+
# with sliding window to see if this case should be allowed.
1255+
raise NotImplementedError(
1256+
"Disabling sliding window is not supported for models "
1257+
"with rope_scaling. Please raise an issue so we can "
1258+
"investigate.")
1259+
1260+
assert "factor" in rope_scaling
1261+
scaling_factor = rope_scaling["factor"]
1262+
if rope_type == "yarn":
1263+
derived_max_model_len = rope_scaling[
1264+
"original_max_position_embeddings"]
1265+
derived_max_model_len *= scaling_factor
12421266

12431267
# If the user specified a max length, make sure it is smaller than the
12441268
# derived length from the HF model config.

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 163 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,159 @@ def forward(
503503
return query.flatten(-2), key.flatten(-2)
504504

505505

506+
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
507+
if scale <= 1:
508+
return 1.0
509+
return 0.1 * mscale * math.log(scale) + 1.0
510+
511+
512+
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
513+
"""RotaryEmbedding extended with YaRN method.
514+
515+
Credits to Peng et al. github.com/jquesnelle/yarn
516+
"""
517+
518+
def __init__(
519+
self,
520+
head_size: int,
521+
rotary_dim: int,
522+
max_position_embeddings: int,
523+
base: int,
524+
is_neox_style: bool,
525+
scaling_factor: float,
526+
dtype: torch.dtype,
527+
*,
528+
extrapolation_factor: float = 1,
529+
attn_factor: float = 1,
530+
beta_fast: int = 32,
531+
beta_slow: int = 1,
532+
mscale: float = 1,
533+
mscale_all_dim: float = 0,
534+
) -> None:
535+
self.scaling_factor = scaling_factor
536+
self.extrapolation_factor = extrapolation_factor
537+
self.attn_factor = attn_factor
538+
self.beta_fast = beta_fast
539+
self.beta_slow = beta_slow
540+
# Get n-d magnitude scaling corrected for interpolation.
541+
self.mscale = float(
542+
yarn_get_mscale(self.scaling_factor, float(mscale)) /
543+
yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
544+
attn_factor)
545+
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
546+
is_neox_style, dtype)
547+
548+
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
549+
pos_freqs = self.base**(torch.arange(
550+
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
551+
self.rotary_dim)
552+
inv_freq_extrapolation = 1.0 / pos_freqs
553+
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
554+
555+
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
556+
self.rotary_dim, self.base,
557+
self.max_position_embeddings)
558+
# Get n-d rotational scaling corrected for extrapolation
559+
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
560+
low, high, self.rotary_dim // 2,
561+
dtype=torch.float)) * self.extrapolation_factor
562+
inv_freq = inv_freq_interpolation * (
563+
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
564+
return inv_freq
565+
566+
def _compute_cos_sin_cache(self) -> torch.Tensor:
567+
inv_freq = self._compute_inv_freq(self.scaling_factor)
568+
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
569+
device="cuda",
570+
dtype=torch.float32)
571+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
572+
cos = (freqs.cos() * self.mscale)
573+
sin = (freqs.sin() * self.mscale)
574+
cache = torch.cat((cos, sin), dim=-1)
575+
print("Cache shape", cache.shape)
576+
return cache
577+
578+
def forward(
579+
self,
580+
positions: torch.Tensor,
581+
query: torch.Tensor,
582+
key: torch.Tensor,
583+
offsets: Optional[torch.Tensor] = None,
584+
) -> Tuple[torch.Tensor, torch.Tensor]:
585+
"""PyTorch-native implementation equivalent to forward()."""
586+
query_rot = query[..., :self.rotary_dim]
587+
key_rot = key[..., :self.rotary_dim]
588+
if self.rotary_dim < self.head_size:
589+
query_pass = query[..., self.rotary_dim:]
590+
key_pass = key[..., self.rotary_dim:]
591+
592+
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
593+
positions.device)
594+
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
595+
if offsets is not None else positions]
596+
cos, sin = cos_sin.chunk(2, dim=-1)
597+
if self.is_neox_style:
598+
# NOTE(woosuk): Here we assume that the positions tensor has the
599+
# shape [batch_size, seq_len].
600+
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
601+
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
602+
else:
603+
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
604+
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
605+
606+
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
607+
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
608+
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
609+
610+
if self.rotary_dim < self.head_size:
611+
query = torch.cat((query_rot, query_pass), dim=-1)
612+
key = torch.cat((key_rot, key_pass), dim=-1)
613+
else:
614+
query = query_rot
615+
key = key_rot
616+
return query, key
617+
618+
619+
class GemmaRotaryEmbedding(RotaryEmbedding):
620+
621+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
622+
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
623+
inv_freq = 1.0 / (base**(
624+
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
625+
self.rotary_dim))
626+
return inv_freq
627+
628+
629+
class ExtendedRotaryEmbedding(RotaryEmbedding):
630+
631+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
632+
inv_freqs = super()._compute_inv_freq(base)
633+
return self.apply_scaling(inv_freqs)
634+
635+
def apply_scaling(self, freqs: torch.Tensor):
636+
scale_factor = 8
637+
low_freq_factor = 1
638+
high_freq_factor = 4
639+
old_context_len = 8192
640+
641+
low_freq_wavelen = old_context_len / low_freq_factor
642+
high_freq_wavelen = old_context_len / high_freq_factor
643+
new_freqs = []
644+
for freq in freqs:
645+
wavelen = 2 * math.pi / freq
646+
if wavelen < high_freq_wavelen:
647+
new_freqs.append(freq)
648+
elif wavelen > low_freq_wavelen:
649+
new_freqs.append(freq / scale_factor)
650+
else:
651+
assert low_freq_wavelen != high_freq_wavelen
652+
smooth = (old_context_len / wavelen - low_freq_factor) / (
653+
high_freq_factor - low_freq_factor)
654+
new_freqs.append((1 - smooth) * freq / scale_factor +
655+
smooth * freq)
656+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
657+
658+
506659
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
507660

508661

@@ -534,10 +687,17 @@ def get_rope(
534687
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
535688
is_neox_style, dtype)
536689
else:
537-
scaling_type = rope_scaling["type"]
538-
if scaling_type != "su":
690+
scaling_type = rope_scaling[
691+
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
692+
# The correct one should be "longrope" but keep "su" here
693+
# for backward compatible
694+
if scaling_type not in {"su", "longrope", "llama3"}:
539695
scaling_factor = rope_scaling["factor"]
540-
if scaling_type == "linear":
696+
if scaling_type == "llama3":
697+
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
698+
max_position, base,
699+
is_neox_style, dtype)
700+
elif scaling_type == "linear":
541701
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
542702
max_position, base,
543703
is_neox_style,

0 commit comments

Comments
 (0)