43
43
from vllm .sequence import IntermediateTensors
44
44
45
45
from .interfaces import SupportsPP , SupportsQuant , SupportsV0Only
46
- from .utils import (is_pp_missing_parameter ,
46
+ from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
47
47
make_empty_intermediate_tensors_factory , make_layers ,
48
48
maybe_prefix )
49
49
@@ -229,6 +229,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
229
229
config = vllm_config .model_config .hf_config
230
230
cache_config = vllm_config .cache_config
231
231
quant_config = vllm_config .quant_config
232
+ self .config = config
232
233
233
234
self .embed_dim = config .hidden_size
234
235
@@ -278,6 +279,37 @@ def forward(
278
279
hidden_states = self .ln_f (hidden_states )
279
280
return hidden_states
280
281
282
+ def load_weights (self , weights : Iterable [tuple [str ,
283
+ torch .Tensor ]]) -> set [str ]:
284
+ params_dict = dict (self .named_parameters (remove_duplicate = False ))
285
+ loaded_params : set [str ] = set ()
286
+ for name , loaded_weight in weights :
287
+ if is_pp_missing_parameter (name , self ):
288
+ continue
289
+ param = params_dict [name ]
290
+
291
+ if "query_key_value" in name :
292
+ # NOTE: BLOOM's fused QKV's output_dim has the shape of
293
+ # (num_heads * 3 * head_size), while the
294
+ # required shape is (3 * num_heads * head_size).
295
+ # Thus, we need weight conversion.
296
+ output_dim = getattr (param , "output_dim" , None )
297
+ num_heads = self .config .num_attention_heads
298
+ if output_dim is not None :
299
+ loaded_weight_shape = loaded_weight .shape
300
+ loaded_weight = loaded_weight .view (
301
+ loaded_weight_shape [:output_dim ] + (num_heads , 3 , - 1 ) +
302
+ loaded_weight_shape [output_dim + 1 :])
303
+ loaded_weight = loaded_weight .transpose (
304
+ output_dim , output_dim + 1 )
305
+ loaded_weight = loaded_weight .reshape (loaded_weight_shape )
306
+
307
+ weight_loader = getattr (param , "weight_loader" ,
308
+ default_weight_loader )
309
+ weight_loader (param , loaded_weight )
310
+ loaded_params .add (name )
311
+ return loaded_params
312
+
281
313
282
314
class BloomForCausalLM (nn .Module , SupportsPP , SupportsV0Only , SupportsQuant ):
283
315
@@ -325,35 +357,13 @@ def compute_logits(
325
357
326
358
def load_weights (self , weights : Iterable [tuple [str ,
327
359
torch .Tensor ]]) -> set [str ]:
328
- params_dict = dict (self .named_parameters (remove_duplicate = False ))
329
- loaded_params : set [str ] = set ()
330
- for name , loaded_weight in weights :
331
- if name == "lm_head.weight" :
332
- continue
333
- if not name .startswith ("transformer." ):
334
- name = "transformer." + name
335
- if is_pp_missing_parameter (name , self ):
336
- continue
337
- param = params_dict [name ]
338
-
339
- if "query_key_value" in name :
340
- # NOTE: BLOOM's fused QKV's output_dim has the shape of
341
- # (num_heads * 3 * head_size), while the
342
- # required shape is (3 * num_heads * head_size).
343
- # Thus, we need weight conversion.
344
- output_dim = getattr (param , "output_dim" , None )
345
- num_heads = self .config .num_attention_heads
346
- if output_dim is not None :
347
- loaded_weight_shape = loaded_weight .shape
348
- loaded_weight = loaded_weight .view (
349
- loaded_weight_shape [:output_dim ] + (num_heads , 3 , - 1 ) +
350
- loaded_weight_shape [output_dim + 1 :])
351
- loaded_weight = loaded_weight .transpose (
352
- output_dim , output_dim + 1 )
353
- loaded_weight = loaded_weight .reshape (loaded_weight_shape )
354
-
355
- weight_loader = getattr (param , "weight_loader" ,
356
- default_weight_loader )
357
- weight_loader (param , loaded_weight )
358
- loaded_params .add (name )
359
- return loaded_params
360
+ loader = AutoWeightsLoader (self , skip_prefixes = ["lm_head.weight" ])
361
+ weights = _add_transformer_prefix (weights )
362
+ return loader .load_weights (weights )
363
+
364
+ def _add_transformer_prefix (weights : Iterable [tuple [str ,
365
+ torch .Tensor ]]) -> Iterable [tuple [str , torch .Tensor ]]:
366
+ for name , tensor in weights :
367
+ if not name .startswith ('transformer.' ):
368
+ name = 'transformer.' + name
369
+ yield name , tensor
0 commit comments