Skip to content

Commit f2c77aa

Browse files
committedSep 3, 2024
Merge remote-tracking branch 'turboderp/master'
2 parents 03255d4 + 40e37f4 commit f2c77aa

15 files changed

+131
-36
lines changed
 

‎exllamav2/exllamav2_ext/cpp/safetensors.cpp

+62-1
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,65 @@ void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tenso
453453
remaining -= chunk;
454454
}
455455
}
456-
}
456+
}
457+
458+
void tensor_remap
459+
(
460+
torch::Tensor tensor,
461+
torch::Tensor index
462+
)
463+
{
464+
TORCH_CHECK_SHAPES(tensor, 1, index, 0, 1);
465+
TORCH_CHECK_DTYPE(tensor, kInt);
466+
TORCH_CHECK_DTYPE(index, kInt);
467+
468+
int rows = tensor.size(0);
469+
int cols = tensor.size(1);
470+
uint32_t* temp = (uint32_t*) calloc(cols, sizeof(int));
471+
uint32_t* a = (uint32_t*) tensor.data_ptr();
472+
uint32_t* idx = (uint32_t*) index.data_ptr();
473+
474+
for (int r = 0; r < rows; ++r)
475+
{
476+
memcpy(temp, a, sizeof(uint32_t) * cols);
477+
for (int c = 0; c < cols; ++c)
478+
{
479+
*a++ = temp[idx[c]];
480+
}
481+
}
482+
free(temp);
483+
}
484+
485+
void tensor_remap_4bit
486+
(
487+
torch::Tensor tensor,
488+
torch::Tensor index
489+
)
490+
{
491+
TORCH_CHECK_SHAPES(index, 0, tensor, 1, 8);
492+
TORCH_CHECK_DTYPE(tensor, kInt);
493+
TORCH_CHECK_DTYPE(index, kInt);
494+
495+
int rows = tensor.size(0);
496+
int cols = index.size(0);
497+
uint32_t* temp = (uint32_t*) calloc(cols / 8, sizeof(int));
498+
uint32_t* a = (uint32_t*) tensor.data_ptr();
499+
uint32_t* idx = (uint32_t*) index.data_ptr();
500+
501+
for (int r = 0; r < rows; ++r)
502+
{
503+
memcpy(temp, a, sizeof(uint32_t) * cols / 8);
504+
for (int c = 0; c < cols;)
505+
{
506+
uint32_t rv = 0;
507+
for (int b = 0; b < 8; ++b, ++c)
508+
{
509+
uint32_t i = idx[c];
510+
uint32_t v = (temp[i / 8] >> ((i & 7) * 4) & 0x0f);
511+
rv |= v << (b * 4);
512+
}
513+
*a++ = rv;
514+
}
515+
}
516+
free(temp);
517+
}

‎exllamav2/exllamav2_ext/cpp/safetensors.h

+13
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,17 @@ uintptr_t safetensors_open_fb(const char* filename);
4747
void safetensors_close_fb(uintptr_t handle);
4848
void safetensors_read_fb(uintptr_t handle, size_t beg, size_t size, torch::Tensor target);
4949

50+
void tensor_remap
51+
(
52+
torch::Tensor tensor,
53+
torch::Tensor index
54+
);
55+
56+
void tensor_remap_4bit
57+
(
58+
torch::Tensor tensor,
59+
torch::Tensor index
60+
);
61+
62+
5063
#endif

‎exllamav2/exllamav2_ext/cuda/graph.cu

+12-9
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void Graph::attach_label(cudaStream_t stream, int label, int sublabel)
133133
}
134134

135135
template <typename T>
136-
void Graph::update_param(int label, int sublabel, int param, T value)
136+
void Graph::update_param(int label, int sublabel, int param, T value, bool debug)
137137
{
138138
for (int i = 0; i < node_labels.size(); ++i)
139139
{
@@ -145,19 +145,22 @@ void Graph::update_param(int label, int sublabel, int param, T value)
145145

146146
node_needs_update[i] = true;
147147

148-
// printf("-----------------------------------------------------\n");
149-
// printf("UPDATED:\n");
150-
// DBGI(i);
151-
// inspect_graph();
148+
if (debug)
149+
{
150+
printf("-----------------------------------------------------\n");
151+
printf("UPDATED: ");
152+
DBGI(i);
153+
inspect_graph();
154+
}
152155
}
153156
}
154157

155-
void Graph::update_param_ptr(int label, int sublabel, int param, void* value)
158+
void Graph::update_param_ptr(int label, int sublabel, int param, void* value, bool debug)
156159
{
157-
update_param<void*>(label, sublabel, param, value);
160+
update_param<void*>(label, sublabel, param, value, debug);
158161
}
159162

160-
void Graph::update_param_int(int label, int sublabel, int param, int value)
163+
void Graph::update_param_int(int label, int sublabel, int param, int value, bool debug)
161164
{
162-
update_param<int>(label, sublabel, param, value);
165+
update_param<int>(label, sublabel, param, value, debug);
163166
}

‎exllamav2/exllamav2_ext/cuda/graph.cuh

+3-3
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ public:
4646
void attach_label(cudaStream_t stream, int label, int sublabel);
4747

4848
template <typename T>
49-
void update_param(int label, int sublabel, int param, T value);
49+
void update_param(int label, int sublabel, int param, T value, bool debug);
5050

51-
void update_param_ptr(int label, int sublabel, int param, void* value);
52-
void update_param_int(int label, int sublabel, int param, int value);
51+
void update_param_ptr(int label, int sublabel, int param, void* value, bool debug = false);
52+
void update_param_int(int label, int sublabel, int param, int value, bool debug = false);
5353
};
5454

5555

‎exllamav2/exllamav2_ext/cuda/q_mlp.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void QMLP::forward_
109109
if (graph->count())
110110
{
111111
graph->begin_capture(stream);
112-
forward_run_(stream, cublas_handle, (half*) x, rows, columns, loras, lora_temp, graph);
112+
forward_run_(stream, cublas_handle, (void*) x, rows, columns, loras, lora_temp, graph);
113113
graph->end_capture(stream);
114114
// printf("**** record ****\n");
115115
// DBGI2(rows, columns);
@@ -225,7 +225,7 @@ void QMLP::forward_run_
225225

226226
else
227227
{
228-
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, graph, 0);
228+
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, down, temp_state, rows, columns, intermediate_size, true, temp_dq, false, NULL, 0, false, graph, 0);
229229
if (layernorm_is_rms)
230230
rms_norm_cuda(stream, temp_state, post_layernorm, x, norm_epsilon, rows, columns, true, false, residual_fp32, graph, KernelLabels::POST_NORM);
231231
else

‎exllamav2/exllamav2_ext/ext_bindings.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
5555
m.def("safetensors_pinned_buffer", &safetensors_pinned_buffer, "safetensors_pinned_buffer");
5656
m.def("safetensors_free_pinned_buffer", &safetensors_free_pinned_buffer, "safetensors_free_pinned_buffer");
5757
m.def("safetensors_read_fb", &safetensors_read_fb, "safetensors_read_fb");
58+
m.def("tensor_remap", &tensor_remap, "tensor_remap");
59+
m.def("tensor_remap_4bit", &tensor_remap_4bit, "tensor_remap_4bit");
5860

5961
// qmatrix
6062

‎exllamav2/ext.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ def find_msvc():
173173
# gcc / cl.exe flags
174174

175175
if windows:
176-
extra_cflags = ["/Ox", "/openmp"]
176+
extra_cflags = ["/Ox"]
177177
else:
178-
extra_cflags = ["-Ofast", "-fopenmp"]
178+
extra_cflags = ["-Ofast"]
179179

180180
if ext_debug:
181181
extra_cflags += ["-ftime-report", "-DTORCH_USE_CUDA_DSA"]

‎exllamav2/fasttensors.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def get_tensor(self,
189189
out_dtype = None) -> torch.Tensor:
190190
global global_tensorcache
191191

192+
torch.cuda.synchronize()
193+
192194
if self.tensor_remap and (not_fast or not self.fast):
193195
key = self.tensor_remap[key]
194196

@@ -211,6 +213,8 @@ def get_tensor(self,
211213
size = end - beg
212214
numel = size // esize
213215
shape = h["shape"]
216+
if device != "cpu":
217+
torch.cuda.set_stream(torch.cuda.default_stream(device))
214218
tensor = torch.zeros(shape, dtype = dtype, device = device)
215219
assert tensor.is_contiguous, "Non-contiguous tensor"
216220
ext_c.safetensors_read_fb(self.handle_fb, beg + self.header_size, size, tensor)
@@ -224,7 +228,8 @@ def get_tensor(self,
224228
offset = data_offsets[0] + self.header_size
225229
length = data_offsets[1] - data_offsets[0]
226230
assert np.prod(sh) * dts == length, f"Tensor shape doesn't match storage size: {key}"
227-
231+
if device != "cpu":
232+
torch.cuda.set_stream(torch.cuda.default_stream(device))
228233
tensor = torch.empty(sh, device = device, dtype = dtt)
229234
ext_c.safetensors_load(self.handle, tensor, offset, length)
230235

@@ -236,4 +241,6 @@ def get_tensor(self,
236241
global_tensorcache = global_tensorcache[1:]
237242
global_tensorcache.append((cachekey, tensor))
238243

244+
torch.cuda.synchronize()
245+
239246
return tensor

‎exllamav2/linear.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from exllamav2.compat import safe_move_tensor
99
from exllamav2.tensor_p import BROADCAST_VC
1010
from exllamav2.util import unpack_4bit, pack_4bit
11+
import gc
1112

1213
from typing import TYPE_CHECKING
1314

@@ -118,7 +119,7 @@ def load(self,
118119
cfg = self.model.config
119120

120121
if self.f_key: w = self.load_weight_fused(self.f_key, self.f_beg, self.f_end, self.in_features, self.out_features, self.altpack_qkv)
121-
if w is None: w = self.load_weight()
122+
if w is None: w = self.load_weight(cpu = output_map is not None)
122123

123124
# Load quantized linear layer from dictionary
124125

@@ -137,7 +138,7 @@ def load(self,
137138
self.q_tensors = w
138139

139140
if unmap and "q_perm" in w:
140-
perm = w["q_perm"]
141+
perm = w["q_perm"].cpu()
141142
del w["q_perm"]
142143
del w["q_invperm"]
143144
# w["q_perm"] = torch.arange(0, w["q_perm"].shape[-1], dtype = w["q_perm"].dtype, device = w["q_perm"].device)
@@ -146,8 +147,10 @@ def load(self,
146147
perm = None
147148

148149
if output_map is not None:
149-
w["q_weight"] = w["q_weight"][:, output_map]
150-
w["q_scale"] = pack_4bit(unpack_4bit(w["q_scale"])[:, output_map])
150+
ext_c.tensor_remap(w["q_weight"], output_map)
151+
ext_c.tensor_remap_4bit(w["q_scale"], output_map)
152+
for k in w.keys():
153+
w[k] = safe_move_tensor(w[k], self.device())
151154

152155
self.q_handle = ext.make_q_matrix(w,
153156
self.temp_dq,

‎exllamav2/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,10 @@ def forward_chunk(self,
989989
if self.tp_context:
990990
self.tp_context.wait_streams()
991991

992+
if x is not None and x.is_cuda:
993+
context = self.get_device_context(x.device.index)
994+
torch.cuda.set_stream(context.stream)
995+
992996
# Apply logit scale
993997

994998
# if x is not None and self.config.logit_scale != 1:

‎exllamav2/model_init.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def print_options(args):
3434

3535
print_opts = []
3636
if args.gpu_split is not None: print_opts += [f"gpu_split: {args.gpu_split}"]
37-
if args.tensor_parallel is not None: print_opts += ["tensor_parallel"]
37+
if args.tensor_parallel: print_opts += ["tensor_parallel"]
3838
if args.length is not None: print_opts += [f"length: {args.length}"]
3939
if args.rope_scale is not None: print_opts += [f"rope_scale: {args.rope_scale}"]
4040
if args.rope_alpha is not None: print_opts += [f"rope_alpha: {args.rope_alpha}"]

‎exllamav2/module.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def device(self) -> str:
6060
def load_multi(self,
6161
key: str,
6262
keys: list[str],
63-
measure: bool = False) -> int | dict[str: torch.Tensor]:
63+
measure: bool = False,
64+
cpu: bool = False) -> int | dict[str: torch.Tensor]:
6465

6566
tensors = {}
6667
submap = {}
@@ -85,13 +86,14 @@ def load_multi(self,
8586
if measure:
8687
size += stfile.measure(key + "." + k)
8788
else:
88-
tensors[k] = stfile.get_tensor(key + "." + k, device = self.device())
89+
tensors[k] = stfile.get_tensor(key + "." + k, device = self.device() if not cpu else "cpu")
8990

9091
return size if measure else tensors
9192

9293

9394
def load_weight(self,
94-
override_key: str | None = None):
95+
override_key: str | None = None,
96+
cpu: bool = False):
9597

9698
if override_key is not None:
9799
keys = [override_key]
@@ -105,14 +107,14 @@ def load_weight(self,
105107
# EXL2
106108

107109
if key + ".q_weight" in self.model.config.tensor_file_map:
108-
qtensors = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"])
110+
qtensors = self.load_multi(key, ["q_weight", "q_invperm", "q_scale", "q_scale_max", "q_groups", "q_perm", "bias"], cpu = cpu)
109111
qtensors["q_perm"] = torch.argsort(qtensors["q_invperm"]).to(torch.int)
110112
return qtensors
111113

112114
# GPTQ
113115

114116
if key + ".qweight" in self.model.config.tensor_file_map:
115-
qtensors = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx", "bias"])
117+
qtensors = self.load_multi(key, ["qweight", "qzeros", "scales", "g_idx", "bias"], cpu = cpu)
116118
if "bias" in qtensors and torch.all(qtensors["bias"].eq(0)):
117119
del qtensors["bias"]
118120
qtensors["scales"] = qtensors["scales"].half()
@@ -122,14 +124,14 @@ def load_weight(self,
122124

123125
if key + ".weight" in self.model.config.tensor_file_map:
124126
if key + ".bias" in self.model.config.tensor_file_map:
125-
tensors = self.load_multi(key, ["weight", "bias"])
127+
tensors = self.load_multi(key, ["weight", "bias"], cpu = cpu)
126128
tensor = tensors["weight"].half()
127129
bias = tensors["bias"].half()
128130
if self.model.config.arch.orig_weights_transposed and len(tensor.shape) == 2:
129131
tensor = tensor.T
130132
return nn.Parameter(tensor, requires_grad = False), nn.Parameter(bias, requires_grad = False)
131133
else:
132-
tensors = self.load_multi(key, ["weight"])
134+
tensors = self.load_multi(key, ["weight"], cpu = cpu)
133135
tensor = tensors["weight"].half()
134136
# if self.model.config.arch.orig_weights_transposed:
135137
# tensor = tensor.T

‎exllamav2/tensor_p.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def define_split(
155155

156156
# Vocab split
157157

158-
vc_split = [s * 32 for s in integer_split(cfg.vocab_size // 32, gpu_split, 16)]
158+
vc_split = [s * 32 for s in integer_split((cfg.vocab_size + 31) // 32, gpu_split, 16)]
159159

160160
def set_split(raw_split):
161161
b = 0

‎exllamav2/util.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -291,19 +291,19 @@ def get_all_gpu_memory():
291291
try:
292292
nvidia_memory = get_nvidia_gpu_memory(visible_devices)
293293
gpu_memory.update(nvidia_memory)
294-
except FileNotFoundError:
294+
except:
295295
pass
296296
# print("nvidia-smi not found. Skipping NVIDIA GPU check.")
297297

298298
try:
299299
amd_memory = get_amd_gpu_memory()
300300
gpu_memory.update(amd_memory)
301-
except FileNotFoundError:
301+
except:
302302
pass
303-
# print("rocm-smi not found. Skipping AMD GPU check.") # TODO: remove warning on NVidia, test on AMD
303+
# print("rocm-smi not found. Skipping AMD GPU check.") # TODO: test on AMD
304304

305305
assert gpu_memory, \
306-
"Unable to read available VRAM from nvidia-smi or rocm-smi"
306+
"Unable to read available VRAM from either nvidia-smi or rocm-smi"
307307

308308
return gpu_memory
309309

‎exllamav2/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.9"
1+
__version__ = "0.2.0"

0 commit comments

Comments
 (0)
Failed to load comments.