@@ -733,6 +733,36 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
733
733
return inv_freq
734
734
735
735
736
+ class ExtendedRotaryEmbedding (RotaryEmbedding ):
737
+
738
+ def _compute_inv_freq (self , base : Union [int , float ]) -> torch .Tensor :
739
+ inv_freqs = super ()._compute_inv_freq (base )
740
+ return self .apply_scaling (inv_freqs )
741
+
742
+ def apply_scaling (self , freqs : torch .Tensor ):
743
+ scale_factor = 8
744
+ low_freq_factor = 1
745
+ high_freq_factor = 4
746
+ old_context_len = 8192
747
+
748
+ low_freq_wavelen = old_context_len / low_freq_factor
749
+ high_freq_wavelen = old_context_len / high_freq_factor
750
+ new_freqs = []
751
+ for freq in freqs :
752
+ wavelen = 2 * math .pi / freq
753
+ if wavelen < high_freq_wavelen :
754
+ new_freqs .append (freq )
755
+ elif wavelen > low_freq_wavelen :
756
+ new_freqs .append (freq / scale_factor )
757
+ else :
758
+ assert low_freq_wavelen != high_freq_wavelen
759
+ smooth = (old_context_len / wavelen - low_freq_factor ) / (
760
+ high_freq_factor - low_freq_factor )
761
+ new_freqs .append ((1 - smooth ) * freq / scale_factor +
762
+ smooth * freq )
763
+ return torch .tensor (new_freqs , dtype = freqs .dtype , device = freqs .device )
764
+
765
+
736
766
_ROPE_DICT : Dict [Tuple , RotaryEmbedding ] = {}
737
767
738
768
@@ -767,9 +797,13 @@ def get_rope(
767
797
scaling_type = rope_scaling ["type" ]
768
798
# The correct one should be "longrope" but keep "su" here
769
799
# for backward compatible
770
- if scaling_type != "su" and scaling_type != "longrope" :
800
+ if scaling_type not in { "su" , "longrope" , "extended" } :
771
801
scaling_factor = rope_scaling ["factor" ]
772
- if scaling_type == "linear" :
802
+ if scaling_type == "extended" :
803
+ rotary_emb = ExtendedRotaryEmbedding (head_size , rotary_dim ,
804
+ max_position , base ,
805
+ is_neox_style , dtype )
806
+ elif scaling_type == "linear" :
773
807
rotary_emb = LinearScalingRotaryEmbedding (head_size , rotary_dim ,
774
808
max_position , base ,
775
809
is_neox_style ,
0 commit comments