@@ -503,6 +503,159 @@ def forward(
503
503
return query .flatten (- 2 ), key .flatten (- 2 )
504
504
505
505
506
+ def yarn_get_mscale (scale : float = 1 , mscale : float = 1 ) -> float :
507
+ if scale <= 1 :
508
+ return 1.0
509
+ return 0.1 * mscale * math .log (scale ) + 1.0
510
+
511
+
512
+ class DeepseekScalingRotaryEmbedding (RotaryEmbedding ):
513
+ """RotaryEmbedding extended with YaRN method.
514
+
515
+ Credits to Peng et al. github.com/jquesnelle/yarn
516
+ """
517
+
518
+ def __init__ (
519
+ self ,
520
+ head_size : int ,
521
+ rotary_dim : int ,
522
+ max_position_embeddings : int ,
523
+ base : int ,
524
+ is_neox_style : bool ,
525
+ scaling_factor : float ,
526
+ dtype : torch .dtype ,
527
+ * ,
528
+ extrapolation_factor : float = 1 ,
529
+ attn_factor : float = 1 ,
530
+ beta_fast : int = 32 ,
531
+ beta_slow : int = 1 ,
532
+ mscale : float = 1 ,
533
+ mscale_all_dim : float = 0 ,
534
+ ) -> None :
535
+ self .scaling_factor = scaling_factor
536
+ self .extrapolation_factor = extrapolation_factor
537
+ self .attn_factor = attn_factor
538
+ self .beta_fast = beta_fast
539
+ self .beta_slow = beta_slow
540
+ # Get n-d magnitude scaling corrected for interpolation.
541
+ self .mscale = float (
542
+ yarn_get_mscale (self .scaling_factor , float (mscale )) /
543
+ yarn_get_mscale (self .scaling_factor , float (mscale_all_dim )) *
544
+ attn_factor )
545
+ super ().__init__ (head_size , rotary_dim , max_position_embeddings , base ,
546
+ is_neox_style , dtype )
547
+
548
+ def _compute_inv_freq (self , scaling_factor : float ) -> torch .Tensor :
549
+ pos_freqs = self .base ** (torch .arange (
550
+ 0 , self .rotary_dim , 2 , dtype = torch .float , device = "cuda" ) /
551
+ self .rotary_dim )
552
+ inv_freq_extrapolation = 1.0 / pos_freqs
553
+ inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs )
554
+
555
+ low , high = _yarn_find_correction_range (self .beta_fast , self .beta_slow ,
556
+ self .rotary_dim , self .base ,
557
+ self .max_position_embeddings )
558
+ # Get n-d rotational scaling corrected for extrapolation
559
+ inv_freq_mask = (1 - _yarn_linear_ramp_mask (
560
+ low , high , self .rotary_dim // 2 ,
561
+ dtype = torch .float )) * self .extrapolation_factor
562
+ inv_freq = inv_freq_interpolation * (
563
+ 1 - inv_freq_mask ) + inv_freq_extrapolation * inv_freq_mask
564
+ return inv_freq
565
+
566
+ def _compute_cos_sin_cache (self ) -> torch .Tensor :
567
+ inv_freq = self ._compute_inv_freq (self .scaling_factor )
568
+ t = torch .arange (self .max_position_embeddings * self .scaling_factor ,
569
+ device = "cuda" ,
570
+ dtype = torch .float32 )
571
+ freqs = torch .einsum ("i,j -> ij" , t , inv_freq )
572
+ cos = (freqs .cos () * self .mscale )
573
+ sin = (freqs .sin () * self .mscale )
574
+ cache = torch .cat ((cos , sin ), dim = - 1 )
575
+ print ("Cache shape" , cache .shape )
576
+ return cache
577
+
578
+ def forward (
579
+ self ,
580
+ positions : torch .Tensor ,
581
+ query : torch .Tensor ,
582
+ key : torch .Tensor ,
583
+ offsets : Optional [torch .Tensor ] = None ,
584
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
585
+ """PyTorch-native implementation equivalent to forward()."""
586
+ query_rot = query [..., :self .rotary_dim ]
587
+ key_rot = key [..., :self .rotary_dim ]
588
+ if self .rotary_dim < self .head_size :
589
+ query_pass = query [..., self .rotary_dim :]
590
+ key_pass = key [..., self .rotary_dim :]
591
+
592
+ self .cos_sin_cache : torch .Tensor = self .cos_sin_cache .to (
593
+ positions .device )
594
+ cos_sin = self .cos_sin_cache [torch .add (positions , offsets )
595
+ if offsets is not None else positions ]
596
+ cos , sin = cos_sin .chunk (2 , dim = - 1 )
597
+ if self .is_neox_style :
598
+ # NOTE(woosuk): Here we assume that the positions tensor has the
599
+ # shape [batch_size, seq_len].
600
+ cos = cos .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
601
+ sin = sin .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
602
+ else :
603
+ cos = cos .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
604
+ sin = sin .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
605
+
606
+ rotate_fn = _rotate_neox if self .is_neox_style else _rotate_gptj
607
+ query_rot = query_rot * cos + rotate_fn (query_rot ) * sin
608
+ key_rot = key_rot * cos + rotate_fn (key_rot ) * sin
609
+
610
+ if self .rotary_dim < self .head_size :
611
+ query = torch .cat ((query_rot , query_pass ), dim = - 1 )
612
+ key = torch .cat ((key_rot , key_pass ), dim = - 1 )
613
+ else :
614
+ query = query_rot
615
+ key = key_rot
616
+ return query , key
617
+
618
+
619
+ class GemmaRotaryEmbedding (RotaryEmbedding ):
620
+
621
+ def _compute_inv_freq (self , base : Union [int , float ]) -> torch .Tensor :
622
+ # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
623
+ inv_freq = 1.0 / (base ** (
624
+ torch .arange (0 , self .rotary_dim , 2 , dtype = torch .int64 ).float () /
625
+ self .rotary_dim ))
626
+ return inv_freq
627
+
628
+
629
+ class ExtendedRotaryEmbedding (RotaryEmbedding ):
630
+
631
+ def _compute_inv_freq (self , base : Union [int , float ]) -> torch .Tensor :
632
+ inv_freqs = super ()._compute_inv_freq (base )
633
+ return self .apply_scaling (inv_freqs )
634
+
635
+ def apply_scaling (self , freqs : torch .Tensor ):
636
+ scale_factor = 8
637
+ low_freq_factor = 1
638
+ high_freq_factor = 4
639
+ old_context_len = 8192
640
+
641
+ low_freq_wavelen = old_context_len / low_freq_factor
642
+ high_freq_wavelen = old_context_len / high_freq_factor
643
+ new_freqs = []
644
+ for freq in freqs :
645
+ wavelen = 2 * math .pi / freq
646
+ if wavelen < high_freq_wavelen :
647
+ new_freqs .append (freq )
648
+ elif wavelen > low_freq_wavelen :
649
+ new_freqs .append (freq / scale_factor )
650
+ else :
651
+ assert low_freq_wavelen != high_freq_wavelen
652
+ smooth = (old_context_len / wavelen - low_freq_factor ) / (
653
+ high_freq_factor - low_freq_factor )
654
+ new_freqs .append ((1 - smooth ) * freq / scale_factor +
655
+ smooth * freq )
656
+ return torch .tensor (new_freqs , dtype = freqs .dtype , device = freqs .device )
657
+
658
+
506
659
_ROPE_DICT : Dict [Tuple , RotaryEmbedding ] = {}
507
660
508
661
@@ -534,10 +687,17 @@ def get_rope(
534
687
rotary_emb = RotaryEmbedding (head_size , rotary_dim , max_position , base ,
535
688
is_neox_style , dtype )
536
689
else :
537
- scaling_type = rope_scaling ["type" ]
538
- if scaling_type != "su" :
690
+ scaling_type = rope_scaling [
691
+ "type" ] if "type" in rope_scaling else rope_scaling ["rope_type" ]
692
+ # The correct one should be "longrope" but keep "su" here
693
+ # for backward compatible
694
+ if scaling_type not in {"su" , "longrope" , "llama3" }:
539
695
scaling_factor = rope_scaling ["factor" ]
540
- if scaling_type == "linear" :
696
+ if scaling_type == "llama3" :
697
+ rotary_emb = ExtendedRotaryEmbedding (head_size , rotary_dim ,
698
+ max_position , base ,
699
+ is_neox_style , dtype )
700
+ elif scaling_type == "linear" :
541
701
rotary_emb = LinearScalingRotaryEmbedding (head_size , rotary_dim ,
542
702
max_position , base ,
543
703
is_neox_style ,
0 commit comments