Skip to content

Commit c5df56f

Browse files
authored
Add support for a rope extension method (#6553)
1 parent 1689219 commit c5df56f

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

vllm/config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ def __init__(
151151
self.hf_text_config = get_hf_text_config(self.hf_config)
152152
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
153153

154+
if (getattr(self.hf_config, "max_position_embeddings", 0) == 131072
155+
and getattr(self.hf_config, "rope_scaling", None) is None):
156+
# Note(simon): this is a special case for a model that doesn't
157+
# supply rope_scaling. We should remove this once the model is
158+
# updated.
159+
self.hf_config.update({"rope_scaling": {
160+
"type": "extended",
161+
}})
162+
154163
if (not self.disable_sliding_window
155164
and self.hf_text_config.model_type == "gemma2"
156165
and self.hf_text_config.sliding_window is not None):
@@ -1442,8 +1451,9 @@ def _get_and_verify_max_len(
14421451
rope_scaling = getattr(hf_config, "rope_scaling", None)
14431452
# The correct one should be "longrope", kept "su" here
14441453
# to be backward compatible
1445-
if rope_scaling is not None and rope_scaling["type"] != "su" \
1446-
and rope_scaling["type"] != "longrope":
1454+
if rope_scaling is not None and rope_scaling["type"] not in {
1455+
"su", "longrope", "extended"
1456+
}:
14471457
if disable_sliding_window:
14481458
# TODO(robertgshaw): Find a model that supports rope_scaling
14491459
# with sliding window to see if this case should be allowed.

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,36 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
733733
return inv_freq
734734

735735

736+
class ExtendedRotaryEmbedding(RotaryEmbedding):
737+
738+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
739+
inv_freqs = super()._compute_inv_freq(base)
740+
return self.apply_scaling(inv_freqs)
741+
742+
def apply_scaling(self, freqs: torch.Tensor):
743+
scale_factor = 8
744+
low_freq_factor = 1
745+
high_freq_factor = 4
746+
old_context_len = 8192
747+
748+
low_freq_wavelen = old_context_len / low_freq_factor
749+
high_freq_wavelen = old_context_len / high_freq_factor
750+
new_freqs = []
751+
for freq in freqs:
752+
wavelen = 2 * math.pi / freq
753+
if wavelen < high_freq_wavelen:
754+
new_freqs.append(freq)
755+
elif wavelen > low_freq_wavelen:
756+
new_freqs.append(freq / scale_factor)
757+
else:
758+
assert low_freq_wavelen != high_freq_wavelen
759+
smooth = (old_context_len / wavelen - low_freq_factor) / (
760+
high_freq_factor - low_freq_factor)
761+
new_freqs.append((1 - smooth) * freq / scale_factor +
762+
smooth * freq)
763+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
764+
765+
736766
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
737767

738768

@@ -767,9 +797,13 @@ def get_rope(
767797
scaling_type = rope_scaling["type"]
768798
# The correct one should be "longrope" but keep "su" here
769799
# for backward compatible
770-
if scaling_type != "su" and scaling_type != "longrope":
800+
if scaling_type not in {"su", "longrope", "extended"}:
771801
scaling_factor = rope_scaling["factor"]
772-
if scaling_type == "linear":
802+
if scaling_type == "extended":
803+
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
804+
max_position, base,
805+
is_neox_style, dtype)
806+
elif scaling_type == "linear":
773807
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
774808
max_position, base,
775809
is_neox_style,

0 commit comments

Comments
 (0)