Skip to content

Commit 7ced25f

Browse files
committed
fix model test
fix model init for TeleChat2ForCausalLM and llama4 V0
1 parent 4e45bfc commit 7ced25f

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

vllm/attention/backends/flash_attn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,10 +617,15 @@ def __init__(
617617
blocksparse_params: Optional[Dict[str, Any]] = None,
618618
logits_soft_cap: Optional[float] = None,
619619
attn_type: str = AttentionType.DECODER,
620+
use_irope: bool = False,
620621
) -> None:
621622
if blocksparse_params is not None:
622623
raise ValueError(
623624
"FlashAttention does not support block-sparse attention.")
625+
if use_irope:
626+
logger.warning(
627+
"Using irope in V0 is not supported yet, it will fall back to global attention for long context, which could impact accuracy"
628+
)
624629
self.num_heads = num_heads
625630
self.head_size = head_size
626631
self.scale = float(scale)

vllm/model_executor/models/telechat2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
22-
from typing import Iterable, Set, Tuple
22+
from typing import Iterable, Set, Tuple, Type
2323

2424
import torch
2525

2626
from vllm.config import VllmConfig
2727
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2828
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel
2929

30+
from .llama import LlamaDecoderLayer
3031
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
3132
is_pp_missing_parameter)
3233

@@ -120,7 +121,10 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
120121
},
121122
)
122123

123-
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
124+
def _init_model(self,
125+
vllm_config: VllmConfig,
126+
prefix: str = "",
127+
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
124128
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
125129

126130
def load_weights(self, weights: Iterable[Tuple[str,

0 commit comments

Comments
 (0)