Skip to content

fix model testing for TeleChat2ForCausalLM and V0 llama4 #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: init_pr
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,10 +617,15 @@ def __init__(
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
if use_irope:
logger.warning(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/telechat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Set, Tuple
from typing import Iterable, Set, Tuple, Type

import torch

from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel

from .llama import LlamaDecoderLayer
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter)

Expand Down Expand Up @@ -120,7 +121,10 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
},
)

def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)

def load_weights(self, weights: Iterable[Tuple[str,
Expand Down