Skip to content

Commit 9d9ced1

Browse files
Isotr0pyLeiWang1999
authored andcommitted
[Core] Refactor GGUF parameters packing and forwarding (vllm-project#8859)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent f2b9bb0 commit 9d9ced1

File tree

4 files changed

+64
-62
lines changed

4 files changed

+64
-62
lines changed

tests/models/decoder_only/language/test_gguf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
# FIXME: Move this to confest
2121
MODELS = [
22-
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
23-
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
24-
filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")),
25-
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
26-
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF",
27-
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")),
22+
("meta-llama/Llama-3.2-1B-Instruct",
23+
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
24+
filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")),
25+
("meta-llama/Llama-3.2-1B-Instruct",
26+
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
27+
filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")),
2828
("Qwen/Qwen2-1.5B-Instruct",
2929
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
3030
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),

vllm/model_executor/layers/linear.py

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -444,17 +444,23 @@ def weight_loader(self,
444444
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
445445
return
446446

447-
if is_gguf_weight and isinstance(param, UninitializedParameter):
448-
from gguf.constants import GGML_QUANT_SIZES
447+
if is_gguf_weight:
448+
tp_size = get_tensor_model_parallel_world_size()
449+
tp_rank = get_tensor_model_parallel_rank()
450+
451+
output_dim = getattr(param, "output_dim", None)
452+
shard_size = loaded_weight.size(output_dim) // tp_size
453+
start_idx = tp_rank * shard_size
454+
455+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
456+
shard_size)
449457

450-
ori_shape = param.tensor_shape
451-
weight_types = self.qweight_type.shard_weight_type.values()
452-
row_size = []
453-
for weight_type in weight_types:
454-
block_size, type_size = GGML_QUANT_SIZES[weight_type]
455-
row_size.append(ori_shape[1] // block_size * type_size)
456-
q_shape = (ori_shape[0], max(row_size))
457-
param.materialize(q_shape, dtype=loaded_weight.dtype)
458+
param.shard_id.append(loaded_shard_id)
459+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
460+
param.data_container.append(loaded_weight)
461+
if len(param.data_container) == 2:
462+
self.qweight = param.materialize_nested()
463+
return
458464

459465
param_data = param.data
460466
output_dim = getattr(param, "output_dim", None)
@@ -522,18 +528,6 @@ def weight_loader(self,
522528
shard_offset = loaded_weight.shape[output_dim] * \
523529
loaded_shard_id
524530

525-
if is_gguf_weight:
526-
tp_size = get_tensor_model_parallel_world_size()
527-
output_dim = getattr(param, "output_dim", None)
528-
shard_shape = list(loaded_weight.shape)
529-
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
530-
param.shard_id.append(loaded_shard_id)
531-
param.shard_size[loaded_shard_id] = shard_shape
532-
533-
input_dim = getattr(param, "input_dim", None)
534-
input_size = loaded_weight.shape[input_dim]
535-
param_data = param_data.narrow(input_dim, 0, input_size)
536-
537531
param_data = param_data.narrow(output_dim, shard_offset,
538532
shard_size)
539533
start_idx = tp_rank * shard_size
@@ -790,17 +784,23 @@ def weight_loader(self,
790784
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
791785
return
792786

793-
if is_gguf_weight and isinstance(param, UninitializedParameter):
794-
from gguf.constants import GGML_QUANT_SIZES
787+
if is_gguf_weight:
788+
tp_size = get_tensor_model_parallel_world_size()
789+
tp_rank = get_tensor_model_parallel_rank()
795790

796-
ori_shape = param.tensor_shape
797-
weight_types = self.qweight_type.shard_weight_type.values()
798-
row_size = []
799-
for weight_type in weight_types:
800-
block_size, type_size = GGML_QUANT_SIZES[weight_type]
801-
row_size.append(ori_shape[1] // block_size * type_size)
802-
q_shape = (ori_shape[0], max(row_size))
803-
param.materialize(q_shape, dtype=loaded_weight.dtype)
791+
output_dim = getattr(param, "output_dim", None)
792+
shard_size = loaded_weight.size(output_dim) // tp_size
793+
start_idx = tp_rank * shard_size
794+
795+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
796+
shard_size)
797+
798+
param.shard_id.append(loaded_shard_id)
799+
param.shard_id_map[loaded_shard_id] = len(param.data_container)
800+
param.data_container.append(loaded_weight)
801+
if len(param.data_container) == 3:
802+
self.qweight = param.materialize_nested()
803+
return
804804

805805
param_data = param.data
806806
output_dim = getattr(param, "output_dim", None)
@@ -891,18 +891,6 @@ def weight_loader(self,
891891
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
892892
param, orig_qkv_offsets, loaded_shard_id)
893893

894-
if is_gguf_weight:
895-
tp_size = get_tensor_model_parallel_world_size()
896-
output_dim = getattr(param, "output_dim", None)
897-
shard_shape = list(loaded_weight.shape)
898-
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
899-
param.shard_id.append(loaded_shard_id)
900-
param.shard_size[loaded_shard_id] = shard_shape
901-
902-
input_dim = getattr(param, "input_dim", None)
903-
input_size = loaded_weight.shape[input_dim]
904-
param_data = param_data.narrow(input_dim, 0, input_size)
905-
906894
param_data = param_data.narrow(output_dim, shard_offset,
907895
shard_size)
908896
if loaded_shard_id == "q":

vllm/model_executor/layers/quantization/gguf.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,16 @@ def create_weights(self, layer: torch.nn.Module,
8686
output_size_per_partition = sum(output_partition_sizes)
8787

8888
tensor_shape = (output_size_per_partition, input_size_per_partition)
89-
qweight = UninitializedParameter(requires_grad=False)
89+
qweight = GGUFUninitializedParameter(requires_grad=False)
9090
set_weight_attrs(
9191
qweight, {
9292
"input_dim": 1,
9393
"output_dim": 0,
9494
"tensor_shape": tensor_shape,
9595
"is_gguf_weight": True,
96-
"shard_size": {},
96+
"data_container": [],
9797
"shard_id": [],
98+
"shard_id_map": {},
9899
})
99100
set_weight_attrs(qweight, extra_weight_attrs)
100101
layer.register_parameter("qweight", qweight)
@@ -116,21 +117,17 @@ def apply(self,
116117
layer: torch.nn.Module,
117118
x: torch.Tensor,
118119
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
119-
shard_size = getattr(layer.qweight, "shard_size", None)
120120
shard_id = getattr(layer.qweight, "shard_id", None)
121121

122-
if shard_id and shard_size:
123-
result = []
124-
offset = 0
122+
if shard_id:
125123
# dequantize shard weights respectively
126124
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
125+
qweight = layer.qweight.unbind(0)
126+
result = []
127127
for id in shard_id:
128-
shard_weight = layer.qweight[
129-
offset:offset +
130-
shard_size[id][0], :shard_size[id][1]].contiguous()
128+
q_idx = layer.qweight.shard_id_map[id]
131129
qweight_type = layer.qweight_type.shard_weight_type[id]
132-
result.append(_fuse_mul_mat(x, shard_weight, qweight_type))
133-
offset += shard_size[id][0]
130+
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
134131
out = torch.cat(result, axis=1)
135132
else:
136133
qweight = layer.qweight
@@ -162,3 +159,20 @@ def embedding(self, layer: torch.nn.Module,
162159
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
163160
x_flat.shape[0])
164161
return dequant.view(*x.shape, hidden_size)
162+
163+
164+
class GGUFUninitializedParameter(UninitializedParameter):
165+
cls_to_become = Parameter
166+
data_container: List[torch.Tensor]
167+
168+
def materialize_nested(self) -> Parameter:
169+
nested_data = torch.nested.nested_tensor(self.data_container,
170+
device=self.device,
171+
dtype=torch.uint8)
172+
self.data_container.clear()
173+
param = torch.Tensor._make_subclass(self.cls_to_become,
174+
nested_data,
175+
require_grad=False)
176+
for k, v in self.__dict__.items():
177+
setattr(param, k, v)
178+
return param

vllm/model_executor/models/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def __init__(
512512
quant_config=quant_config,
513513
)
514514
if config.tie_word_embeddings:
515-
self.lm_head.weight = self.model.embed_tokens.weight
515+
self.lm_head = self.model.embed_tokens
516516

517517
logit_scale = getattr(config, "logit_scale", 1.0)
518518
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,

0 commit comments

Comments
 (0)