File tree Expand file tree Collapse file tree 2 files changed +11
-2
lines changed Expand file tree Collapse file tree 2 files changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -617,10 +617,15 @@ def __init__(
617
617
blocksparse_params : Optional [Dict [str , Any ]] = None ,
618
618
logits_soft_cap : Optional [float ] = None ,
619
619
attn_type : str = AttentionType .DECODER ,
620
+ use_irope : bool = False ,
620
621
) -> None :
621
622
if blocksparse_params is not None :
622
623
raise ValueError (
623
624
"FlashAttention does not support block-sparse attention." )
625
+ if use_irope :
626
+ logger .warning (
627
+ "Using irope in V0 is not supported yet, it will fall back to global attention for long context, which could impact accuracy"
628
+ )
624
629
self .num_heads = num_heads
625
630
self .head_size = head_size
626
631
self .scale = float (scale )
Original file line number Diff line number Diff line change 19
19
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
20
# See the License for the specific language governing permissions and
21
21
# limitations under the License.
22
- from typing import Iterable , Set , Tuple
22
+ from typing import Iterable , Set , Tuple , Type
23
23
24
24
import torch
25
25
26
26
from vllm .config import VllmConfig
27
27
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
28
28
from vllm .model_executor .models .llama import LlamaForCausalLM , LlamaModel
29
29
30
+ from .llama import LlamaDecoderLayer
30
31
from .utils import (AutoWeightsLoader , PPMissingLayer , WeightsMapper ,
31
32
is_pp_missing_parameter )
32
33
@@ -120,7 +121,10 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
120
121
},
121
122
)
122
123
123
- def _init_model (self , vllm_config : VllmConfig , prefix : str = "" ):
124
+ def _init_model (self ,
125
+ vllm_config : VllmConfig ,
126
+ prefix : str = "" ,
127
+ layer_type : Type [LlamaDecoderLayer ] = LlamaDecoderLayer ):
124
128
return TeleChat2Model (vllm_config = vllm_config , prefix = prefix )
125
129
126
130
def load_weights (self , weights : Iterable [Tuple [str ,
You can’t perform that action at this time.
0 commit comments