@@ -444,17 +444,23 @@ def weight_loader(self,
444
444
param .shard_weight_type [loaded_shard_id ] = loaded_weight .item ()
445
445
return
446
446
447
- if is_gguf_weight and isinstance (param , UninitializedParameter ):
448
- from gguf .constants import GGML_QUANT_SIZES
447
+ if is_gguf_weight :
448
+ tp_size = get_tensor_model_parallel_world_size ()
449
+ tp_rank = get_tensor_model_parallel_rank ()
450
+
451
+ output_dim = getattr (param , "output_dim" , None )
452
+ shard_size = loaded_weight .size (output_dim ) // tp_size
453
+ start_idx = tp_rank * shard_size
454
+
455
+ loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
456
+ shard_size )
449
457
450
- ori_shape = param .tensor_shape
451
- weight_types = self .qweight_type .shard_weight_type .values ()
452
- row_size = []
453
- for weight_type in weight_types :
454
- block_size , type_size = GGML_QUANT_SIZES [weight_type ]
455
- row_size .append (ori_shape [1 ] // block_size * type_size )
456
- q_shape = (ori_shape [0 ], max (row_size ))
457
- param .materialize (q_shape , dtype = loaded_weight .dtype )
458
+ param .shard_id .append (loaded_shard_id )
459
+ param .shard_id_map [loaded_shard_id ] = len (param .data_container )
460
+ param .data_container .append (loaded_weight )
461
+ if len (param .data_container ) == 2 :
462
+ self .qweight = param .materialize_nested ()
463
+ return
458
464
459
465
param_data = param .data
460
466
output_dim = getattr (param , "output_dim" , None )
@@ -522,18 +528,6 @@ def weight_loader(self,
522
528
shard_offset = loaded_weight .shape [output_dim ] * \
523
529
loaded_shard_id
524
530
525
- if is_gguf_weight :
526
- tp_size = get_tensor_model_parallel_world_size ()
527
- output_dim = getattr (param , "output_dim" , None )
528
- shard_shape = list (loaded_weight .shape )
529
- shard_shape [output_dim ] = shard_shape [output_dim ] // tp_size
530
- param .shard_id .append (loaded_shard_id )
531
- param .shard_size [loaded_shard_id ] = shard_shape
532
-
533
- input_dim = getattr (param , "input_dim" , None )
534
- input_size = loaded_weight .shape [input_dim ]
535
- param_data = param_data .narrow (input_dim , 0 , input_size )
536
-
537
531
param_data = param_data .narrow (output_dim , shard_offset ,
538
532
shard_size )
539
533
start_idx = tp_rank * shard_size
@@ -790,17 +784,23 @@ def weight_loader(self,
790
784
param .shard_weight_type [loaded_shard_id ] = loaded_weight .item ()
791
785
return
792
786
793
- if is_gguf_weight and isinstance (param , UninitializedParameter ):
794
- from gguf .constants import GGML_QUANT_SIZES
787
+ if is_gguf_weight :
788
+ tp_size = get_tensor_model_parallel_world_size ()
789
+ tp_rank = get_tensor_model_parallel_rank ()
795
790
796
- ori_shape = param .tensor_shape
797
- weight_types = self .qweight_type .shard_weight_type .values ()
798
- row_size = []
799
- for weight_type in weight_types :
800
- block_size , type_size = GGML_QUANT_SIZES [weight_type ]
801
- row_size .append (ori_shape [1 ] // block_size * type_size )
802
- q_shape = (ori_shape [0 ], max (row_size ))
803
- param .materialize (q_shape , dtype = loaded_weight .dtype )
791
+ output_dim = getattr (param , "output_dim" , None )
792
+ shard_size = loaded_weight .size (output_dim ) // tp_size
793
+ start_idx = tp_rank * shard_size
794
+
795
+ loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
796
+ shard_size )
797
+
798
+ param .shard_id .append (loaded_shard_id )
799
+ param .shard_id_map [loaded_shard_id ] = len (param .data_container )
800
+ param .data_container .append (loaded_weight )
801
+ if len (param .data_container ) == 3 :
802
+ self .qweight = param .materialize_nested ()
803
+ return
804
804
805
805
param_data = param .data
806
806
output_dim = getattr (param , "output_dim" , None )
@@ -891,18 +891,6 @@ def weight_loader(self,
891
891
shard_size , shard_offset = adjust_bitsandbytes_4bit_shard (
892
892
param , orig_qkv_offsets , loaded_shard_id )
893
893
894
- if is_gguf_weight :
895
- tp_size = get_tensor_model_parallel_world_size ()
896
- output_dim = getattr (param , "output_dim" , None )
897
- shard_shape = list (loaded_weight .shape )
898
- shard_shape [output_dim ] = shard_shape [output_dim ] // tp_size
899
- param .shard_id .append (loaded_shard_id )
900
- param .shard_size [loaded_shard_id ] = shard_shape
901
-
902
- input_dim = getattr (param , "input_dim" , None )
903
- input_size = loaded_weight .shape [input_dim ]
904
- param_data = param_data .narrow (input_dim , 0 , input_size )
905
-
906
894
param_data = param_data .narrow (output_dim , shard_offset ,
907
895
shard_size )
908
896
if loaded_shard_id == "q" :
0 commit comments