@@ -60,7 +60,8 @@ def device(self) -> str:
60
60
def load_multi (self ,
61
61
key : str ,
62
62
keys : list [str ],
63
- measure : bool = False ) -> int | dict [str : torch .Tensor ]:
63
+ measure : bool = False ,
64
+ cpu : bool = False ) -> int | dict [str : torch .Tensor ]:
64
65
65
66
tensors = {}
66
67
submap = {}
@@ -85,13 +86,14 @@ def load_multi(self,
85
86
if measure :
86
87
size += stfile .measure (key + "." + k )
87
88
else :
88
- tensors [k ] = stfile .get_tensor (key + "." + k , device = self .device ())
89
+ tensors [k ] = stfile .get_tensor (key + "." + k , device = self .device () if not cpu else "cpu" )
89
90
90
91
return size if measure else tensors
91
92
92
93
93
94
def load_weight (self ,
94
- override_key : str | None = None ):
95
+ override_key : str | None = None ,
96
+ cpu : bool = False ):
95
97
96
98
if override_key is not None :
97
99
keys = [override_key ]
@@ -105,14 +107,14 @@ def load_weight(self,
105
107
# EXL2
106
108
107
109
if key + ".q_weight" in self .model .config .tensor_file_map :
108
- qtensors = self .load_multi (key , ["q_weight" , "q_invperm" , "q_scale" , "q_scale_max" , "q_groups" , "q_perm" , "bias" ])
110
+ qtensors = self .load_multi (key , ["q_weight" , "q_invperm" , "q_scale" , "q_scale_max" , "q_groups" , "q_perm" , "bias" ], cpu = cpu )
109
111
qtensors ["q_perm" ] = torch .argsort (qtensors ["q_invperm" ]).to (torch .int )
110
112
return qtensors
111
113
112
114
# GPTQ
113
115
114
116
if key + ".qweight" in self .model .config .tensor_file_map :
115
- qtensors = self .load_multi (key , ["qweight" , "qzeros" , "scales" , "g_idx" , "bias" ])
117
+ qtensors = self .load_multi (key , ["qweight" , "qzeros" , "scales" , "g_idx" , "bias" ], cpu = cpu )
116
118
if "bias" in qtensors and torch .all (qtensors ["bias" ].eq (0 )):
117
119
del qtensors ["bias" ]
118
120
qtensors ["scales" ] = qtensors ["scales" ].half ()
@@ -122,14 +124,14 @@ def load_weight(self,
122
124
123
125
if key + ".weight" in self .model .config .tensor_file_map :
124
126
if key + ".bias" in self .model .config .tensor_file_map :
125
- tensors = self .load_multi (key , ["weight" , "bias" ])
127
+ tensors = self .load_multi (key , ["weight" , "bias" ], cpu = cpu )
126
128
tensor = tensors ["weight" ].half ()
127
129
bias = tensors ["bias" ].half ()
128
130
if self .model .config .arch .orig_weights_transposed and len (tensor .shape ) == 2 :
129
131
tensor = tensor .T
130
132
return nn .Parameter (tensor , requires_grad = False ), nn .Parameter (bias , requires_grad = False )
131
133
else :
132
- tensors = self .load_multi (key , ["weight" ])
134
+ tensors = self .load_multi (key , ["weight" ], cpu = cpu )
133
135
tensor = tensors ["weight" ].half ()
134
136
# if self.model.config.arch.orig_weights_transposed:
135
137
# tensor = tensor.T
0 commit comments