Skip to content

Commit 0890497

Browse files
jerryzh168Yuqi Zhang
authored andcommitted
Add pt_load_map_location to allow loading to cuda (vllm-project#16869)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 649a8c3 commit 0890497

File tree

6 files changed

+74
-3
lines changed

6 files changed

+74
-3
lines changed

tests/quantization/test_torchao.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import importlib.util
44

55
import pytest
6+
import torch
67

78
DTYPE = ["bfloat16"]
89

@@ -21,5 +22,30 @@ def test_pre_quantized_model(vllm_runner):
2122
print(output)
2223

2324

25+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
26+
@pytest.mark.parametrize(
27+
"pt_load_map_location",
28+
[
29+
"cuda:0",
30+
# {"": "cuda"},
31+
])
32+
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
33+
pt_load_map_location):
34+
"""
35+
Test loading roberta-base model with no lm_head.
36+
"""
37+
torch._dynamo.reset()
38+
model_name = "jerryzh168/opt-125m-int4wo"
39+
with vllm_runner(model_name=model_name,
40+
quantization="torchao",
41+
dtype="bfloat16",
42+
pt_load_map_location=pt_load_map_location) as llm:
43+
output = llm.generate_greedy(["The capital of France is"],
44+
max_tokens=32)
45+
46+
assert output
47+
print(output)
48+
49+
2450
if __name__ == "__main__":
2551
pytest.main([__file__])

tests/test_config.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import pytest
77

8-
from vllm.config import ModelConfig, PoolerConfig, config, get_field
8+
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
9+
config, get_field)
910
from vllm.model_executor.layers.pooler import PoolingType
1011
from vllm.platforms import current_platform
1112

@@ -410,3 +411,16 @@ def test_generation_config_loading():
410411
override_generation_config=override_generation_config)
411412

412413
assert model_config.get_diff_sampling_param() == override_generation_config
414+
415+
416+
@pytest.mark.parametrize("pt_load_map_location", [
417+
"cuda",
418+
{
419+
"": "cuda"
420+
},
421+
])
422+
def test_load_config_pt_load_map_location(pt_load_map_location):
423+
load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
424+
config = VllmConfig(load_config=load_config)
425+
426+
assert config.load_config.pt_load_map_location == pt_load_map_location

vllm/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,16 @@ class LoadConfig:
15641564
use_tqdm_on_load: bool = True
15651565
"""Whether to enable tqdm for showing progress bar when loading model
15661566
weights."""
1567+
pt_load_map_location: Union[str, dict[str, str]] = "cpu"
1568+
"""
1569+
pt_load_map_location: the map location for loading pytorch checkpoint, to
1570+
support loading checkpoints can only be loaded on certain devices like
1571+
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
1572+
mapping from different devices like from GPU 1 to GPU 0:
1573+
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
1574+
in dictionary needs to be double quoted for json parsing. For more details,
1575+
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
1576+
"""
15671577

15681578
def compute_hash(self) -> str:
15691579
"""

vllm/engine/arg_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ def _optional_type(val: str) -> Optional[T]:
6464
return _optional_type
6565

6666

67+
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
68+
if not re.match("^{.*}$", val):
69+
return str(val)
70+
else:
71+
return optional_type(json.loads)(val)
72+
73+
6774
@deprecated(
6875
"Passing a JSON argument as a string containing comma separated key=value "
6976
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
@@ -187,6 +194,10 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
187194
kwargs[name]["type"] = human_readable_int
188195
elif contains_type(type_hints, float):
189196
kwargs[name]["type"] = float
197+
elif contains_type(type_hints,
198+
dict) and (contains_type(type_hints, str) or any(
199+
is_not_builtin(th) for th in type_hints)):
200+
kwargs[name]["type"] = union_dict_and_str
190201
elif contains_type(type_hints, dict):
191202
# Dict arguments will always be optional
192203
kwargs[name]["type"] = optional_type(json.loads)
@@ -371,6 +382,7 @@ class EngineArgs:
371382
reasoning_parser: str = DecodingConfig.reasoning_backend
372383

373384
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
385+
pt_load_map_location: str = LoadConfig.pt_load_map_location
374386

375387
def __post_init__(self):
376388
# support `EngineArgs(compilation_config={...})`
@@ -491,6 +503,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
491503
type=str,
492504
default=None,
493505
help='Name or path of the QLoRA adapter.')
506+
load_group.add_argument('--pt-load-map-location',
507+
**load_kwargs["pt_load_map_location"])
494508

495509
# Guided decoding arguments
496510
guided_decoding_kwargs = get_kwargs(DecodingConfig)
@@ -883,12 +897,14 @@ def create_load_config(self) -> LoadConfig:
883897

884898
if self.quantization == "bitsandbytes":
885899
self.load_format = "bitsandbytes"
900+
886901
return LoadConfig(
887902
load_format=self.load_format,
888903
download_dir=self.download_dir,
889904
model_loader_extra_config=self.model_loader_extra_config,
890905
ignore_patterns=self.ignore_patterns,
891906
use_tqdm_on_load=self.use_tqdm_on_load,
907+
pt_load_map_location=self.pt_load_map_location,
892908
)
893909

894910
def create_speculative_config(
@@ -1513,7 +1529,7 @@ def _warn_or_fallback(feature_name: str) -> bool:
15131529
def human_readable_int(value):
15141530
"""Parse human-readable integers like '1k', '2M', etc.
15151531
Including decimal values with decimal multipliers.
1516-
1532+
15171533
Examples:
15181534
- '1k' -> 1,000
15191535
- '1K' -> 1,024

vllm/model_executor/model_loader/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def _get_weights_iterator(
384384
weights_iterator = pt_weights_iterator(
385385
hf_weights_files,
386386
self.load_config.use_tqdm_on_load,
387+
self.load_config.pt_load_map_location,
387388
)
388389

389390
if current_platform.is_tpu():
@@ -890,6 +891,7 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
890891
iterator = pt_weights_iterator(
891892
hf_weights_files,
892893
self.load_config.use_tqdm_on_load,
894+
self.load_config.pt_load_map_location,
893895
)
894896
for org_name, param in iterator:
895897
# mapping weight names from transformers to vllm while preserving

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ def fastsafetensors_weights_iterator(
502502
def pt_weights_iterator(
503503
hf_weights_files: List[str],
504504
use_tqdm_on_load: bool,
505+
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
505506
) -> Generator[Tuple[str, torch.Tensor], None, None]:
506507
"""Iterate over the weights in the model bin/pt files."""
507508
for bin_file in tqdm(
@@ -510,7 +511,9 @@ def pt_weights_iterator(
510511
disable=not enable_tqdm(use_tqdm_on_load),
511512
bar_format=_BAR_FORMAT,
512513
):
513-
state = torch.load(bin_file, map_location="cpu", weights_only=True)
514+
state = torch.load(bin_file,
515+
map_location=pt_load_map_location,
516+
weights_only=True)
514517
yield from state.items()
515518
del state
516519

0 commit comments

Comments
 (0)