Skip to content

Commit 3eda4ec

Browse files
authored
support ignore patterns in model loader (#6673)
1 parent 22fa2e3 commit 3eda4ec

File tree

4 files changed

+51
-10
lines changed

4 files changed

+51
-10
lines changed

vllm/config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,12 +599,16 @@ class LoadConfig:
599599
mainly for profiling.
600600
"tensorizer" will use CoreWeave's tensorizer library for
601601
fast weight loading.
602+
ignore_patterns: The list of patterns to ignore when loading the model.
603+
Default to "original/**/*" to avoid repeated loading of llama's
604+
checkpoints.
602605
"""
603606

604607
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
605608
download_dir: Optional[str] = None
606609
model_loader_extra_config: Optional[Union[str, dict]] = field(
607610
default_factory=dict)
611+
ignore_patterns: Optional[Union[List[str], str]] = None
608612

609613
def __post_init__(self):
610614
model_loader_extra_config = self.model_loader_extra_config or {}
@@ -613,6 +617,13 @@ def __post_init__(self):
613617
model_loader_extra_config)
614618
self._verify_load_format()
615619

620+
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
621+
logger.info(
622+
"Ignoring the following patterns when downloading weights: %s",
623+
self.ignore_patterns)
624+
else:
625+
self.ignore_patterns = ["original/**/*"]
626+
616627
def _verify_load_format(self) -> None:
617628
if not isinstance(self.load_format, str):
618629
return
@@ -801,7 +812,9 @@ def __init__(self,
801812
# for higher throughput.
802813
self.max_num_batched_tokens = max(max_model_len, 2048)
803814
if enable_chunked_prefill:
804-
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
815+
logger.info(
816+
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
817+
max_num_batched_tokens)
805818

806819
self.max_num_seqs = max_num_seqs
807820
self.max_model_len = max_model_len

vllm/engine/arg_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class EngineArgs:
9595
num_gpu_blocks_override: Optional[int] = None
9696
num_lookahead_slots: int = 0
9797
model_loader_extra_config: Optional[dict] = None
98+
ignore_patterns: Optional[Union[str, List[str]]] = None
9899
preemption_mode: Optional[str] = None
99100

100101
scheduler_delay_factor: float = 0.0
@@ -619,6 +620,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
619620
'corresponding to the chosen load_format. '
620621
'This should be a JSON string that will be '
621622
'parsed into a dictionary.')
623+
parser.add_argument(
624+
'--ignore-patterns',
625+
action="append",
626+
type=str,
627+
default=[],
628+
help="The pattern(s) to ignore when loading the model."
629+
"Default to 'original/**/*' to avoid repeated loading of llama's "
630+
"checkpoints.")
622631
parser.add_argument(
623632
'--preemption-mode',
624633
type=str,
@@ -824,6 +833,7 @@ def create_engine_config(self, ) -> EngineConfig:
824833
load_format=self.load_format,
825834
download_dir=self.download_dir,
826835
model_loader_extra_config=self.model_loader_extra_config,
836+
ignore_patterns=self.ignore_patterns,
827837
)
828838

829839
prompt_adapter_config = PromptAdapterConfig(

vllm/model_executor/model_loader/loader.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def _maybe_download_from_modelscope(
161161
cache_dir=self.load_config.download_dir,
162162
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
163163
revision=revision,
164+
ignore_patterns=self.load_config.ignore_patterns,
164165
)
165166
else:
166167
model_path = model
@@ -196,9 +197,13 @@ def _prepare_weights(self, model_name_or_path: str,
196197
allow_patterns += ["*.pt"]
197198

198199
if not is_local:
199-
hf_folder = download_weights_from_hf(model_name_or_path,
200-
self.load_config.download_dir,
201-
allow_patterns, revision)
200+
hf_folder = download_weights_from_hf(
201+
model_name_or_path,
202+
self.load_config.download_dir,
203+
allow_patterns,
204+
revision,
205+
ignore_patterns=self.load_config.ignore_patterns,
206+
)
202207
else:
203208
hf_folder = model_name_or_path
204209

@@ -489,9 +494,13 @@ def _prepare_weights(self, model_name_or_path: str,
489494
return model_name_or_path
490495
else:
491496
allow_patterns = ["*.safetensors"]
492-
return download_weights_from_hf(model_name_or_path,
493-
self.load_config.download_dir,
494-
allow_patterns, revision)
497+
return download_weights_from_hf(
498+
model_name_or_path,
499+
self.load_config.download_dir,
500+
allow_patterns,
501+
revision,
502+
ignore_patterns=self.load_config.ignore_patterns,
503+
)
495504

496505
def load_model(self, *, model_config: ModelConfig,
497506
device_config: DeviceConfig,
@@ -663,8 +672,12 @@ def _get_weight_files(
663672
matching_files = fnmatch.filter(repo_files, pattern)
664673
if matching_files:
665674
hf_folder = download_weights_from_hf(
666-
model_name_or_path, self.load_config.download_dir,
667-
[pattern], revision)
675+
model_name_or_path,
676+
self.load_config.download_dir,
677+
[pattern],
678+
revision,
679+
ignore_patterns=self.load_config.ignore_patterns,
680+
)
668681
return glob.glob(os.path.join(hf_folder, pattern)), pattern
669682

670683
raise RuntimeError(

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import tempfile
88
from collections import defaultdict
9-
from typing import Any, Generator, Iterable, List, Optional, Tuple
9+
from typing import Any, Generator, Iterable, List, Optional, Tuple, Union
1010

1111
import filelock
1212
import huggingface_hub.constants
@@ -189,6 +189,7 @@ def download_weights_from_hf(
189189
cache_dir: Optional[str],
190190
allow_patterns: List[str],
191191
revision: Optional[str] = None,
192+
ignore_patterns: Optional[Union[str, List[str]]] = None,
192193
) -> str:
193194
"""Download model weights from Hugging Face Hub.
194195
@@ -200,6 +201,9 @@ def download_weights_from_hf(
200201
weight files. Files matched by any of the patterns will be
201202
downloaded.
202203
revision (Optional[str]): The revision of the model.
204+
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
205+
filter out the weight files. Files matched by any of the patterns
206+
will be ignored.
203207
204208
Returns:
205209
str: The path to the downloaded model weights.
@@ -223,6 +227,7 @@ def download_weights_from_hf(
223227
hf_folder = snapshot_download(
224228
model_name_or_path,
225229
allow_patterns=allow_patterns,
230+
ignore_patterns=ignore_patterns,
226231
cache_dir=cache_dir,
227232
tqdm_class=DisabledTqdm,
228233
revision=revision,

0 commit comments

Comments
 (0)