Skip to content

Commit a0bd57c

Browse files
authored
add kv override and cuda backend bugfix
1 parent 4951c95 commit a0bd57c

File tree

3 files changed

+191
-26
lines changed

3 files changed

+191
-26
lines changed

src/AutoGGUF.py

Lines changed: 156 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66
import psutil
7+
import shutil
78
import subprocess
89
import time
910
import signal
@@ -12,17 +13,18 @@
1213
import requests
1314
import zipfile
1415
from datetime import datetime
15-
from imports_and_globals import ensure_directory
16-
from DownloadThread import *
17-
from ModelInfoDialog import *
18-
from TaskListItem import *
19-
from QuantizationThread import *
16+
from imports_and_globals import ensure_directory, open_file_safe
17+
from DownloadThread import DownloadThread
18+
from ModelInfoDialog import ModelInfoDialog
19+
from TaskListItem import TaskListItem
20+
from QuantizationThread import QuantizationThread
21+
from KVOverrideEntry import KVOverrideEntry
2022

2123
class AutoGGUF(QMainWindow):
2224
def __init__(self):
2325
super().__init__()
2426
self.setWindowTitle("AutoGGUF (automated GGUF model quantizer)")
25-
self.setGeometry(100, 100, 1200, 1000)
27+
self.setGeometry(100, 100, 1300, 1100)
2628

2729
main_layout = QHBoxLayout()
2830
left_layout = QVBoxLayout()
@@ -45,7 +47,7 @@ def __init__(self):
4547
backend_layout.addWidget(self.refresh_backends_button)
4648
left_layout.addLayout(backend_layout)
4749

48-
# Add Download llama.cpp section
50+
# Modify the Download llama.cpp section
4951
download_group = QGroupBox("Download llama.cpp")
5052
download_layout = QFormLayout()
5153

@@ -58,8 +60,19 @@ def __init__(self):
5860
download_layout.addRow("Select Release:", release_layout)
5961

6062
self.asset_combo = QComboBox()
63+
self.asset_combo.currentIndexChanged.connect(self.update_cuda_option)
6164
download_layout.addRow("Select Asset:", self.asset_combo)
6265

66+
self.cuda_extract_checkbox = QCheckBox("Extract CUDA files")
67+
self.cuda_extract_checkbox.setVisible(False)
68+
download_layout.addRow(self.cuda_extract_checkbox)
69+
70+
self.cuda_backend_label = QLabel("Select CUDA Backend:")
71+
self.cuda_backend_label.setVisible(False)
72+
self.backend_combo_cuda = QComboBox()
73+
self.backend_combo_cuda.setVisible(False)
74+
download_layout.addRow(self.cuda_backend_label, self.backend_combo_cuda)
75+
6376
self.download_progress = QProgressBar()
6477
self.download_button = QPushButton("Download")
6578
self.download_button.clicked.connect(self.download_llama_cpp)
@@ -144,20 +157,20 @@ def __init__(self):
144157
quant_options_layout.addRow(self.create_label("Exclude Weights:", "Don't use importance matrix for these tensors"), self.exclude_weights)
145158

146159
self.use_output_tensor_type = QCheckBox("Use Output Tensor Type")
147-
self.use_output_tensor_type.stateChanged.connect(self.toggle_output_tensor_type)
148160
self.output_tensor_type = QComboBox()
149161
self.output_tensor_type.addItems(["F32", "F16", "Q4_0", "Q4_1", "Q5_0", "Q5_1", "Q8_0"])
150162
self.output_tensor_type.setEnabled(False)
163+
self.use_output_tensor_type.toggled.connect(lambda checked: self.output_tensor_type.setEnabled(checked))
151164
output_tensor_layout = QHBoxLayout()
152165
output_tensor_layout.addWidget(self.use_output_tensor_type)
153166
output_tensor_layout.addWidget(self.output_tensor_type)
154167
quant_options_layout.addRow(self.create_label("Output Tensor Type:", "Use this type for the output.weight tensor"), output_tensor_layout)
155168

156169
self.use_token_embedding_type = QCheckBox("Use Token Embedding Type")
157-
self.use_token_embedding_type.stateChanged.connect(self.toggle_token_embedding_type)
158170
self.token_embedding_type = QComboBox()
159171
self.token_embedding_type.addItems(["F32", "F16", "Q4_0", "Q4_1", "Q5_0", "Q5_1", "Q8_0"])
160172
self.token_embedding_type.setEnabled(False)
173+
self.use_token_embedding_type.toggled.connect(lambda checked: self.token_embedding_type.setEnabled(checked))
161174
token_embedding_layout = QHBoxLayout()
162175
token_embedding_layout.addWidget(self.use_token_embedding_type)
163176
token_embedding_layout.addWidget(self.token_embedding_type)
@@ -166,7 +179,24 @@ def __init__(self):
166179
self.keep_split = QCheckBox("Keep Split")
167180
self.override_kv = QLineEdit()
168181
quant_options_layout.addRow(self.create_label("", "Will generate quantized model in the same shards as input"), self.keep_split)
169-
quant_options_layout.addRow(self.create_label("Override KV:", "Override model metadata by key in the quantized model"), self.override_kv)
182+
# KV Override section
183+
self.kv_override_widget = QWidget()
184+
self.kv_override_layout = QVBoxLayout(self.kv_override_widget)
185+
self.kv_override_entries = []
186+
187+
add_override_button = QPushButton("Add new override")
188+
add_override_button.clicked.connect(self.add_kv_override)
189+
190+
kv_override_scroll = QScrollArea()
191+
kv_override_scroll.setWidgetResizable(True)
192+
kv_override_scroll.setWidget(self.kv_override_widget)
193+
kv_override_scroll.setMinimumHeight(200)
194+
195+
kv_override_main_layout = QVBoxLayout()
196+
kv_override_main_layout.addWidget(kv_override_scroll)
197+
kv_override_main_layout.addWidget(add_override_button)
198+
199+
quant_options_layout.addRow(self.create_label("KV Overrides:", "Override model metadata"), kv_override_main_layout)
170200

171201
quant_options_widget.setLayout(quant_options_layout)
172202
quant_options_scroll.setWidget(quant_options_widget)
@@ -219,13 +249,17 @@ def __init__(self):
219249
# GPU Offload for IMatrix
220250
gpu_offload_layout = QHBoxLayout()
221251
self.gpu_offload_slider = QSlider(Qt.Orientation.Horizontal)
222-
self.gpu_offload_slider.setRange(0, 100)
252+
self.gpu_offload_slider.setRange(0, 200)
223253
self.gpu_offload_slider.valueChanged.connect(self.update_gpu_offload_spinbox)
254+
224255
self.gpu_offload_spinbox = QSpinBox()
225-
self.gpu_offload_spinbox.setRange(0, 100)
256+
self.gpu_offload_spinbox.setRange(0, 1000)
226257
self.gpu_offload_spinbox.valueChanged.connect(self.update_gpu_offload_slider)
258+
self.gpu_offload_spinbox.setMinimumWidth(75) # Set the minimum width to 75 pixels
259+
227260
self.gpu_offload_auto = QCheckBox("Auto")
228261
self.gpu_offload_auto.stateChanged.connect(self.toggle_gpu_offload_auto)
262+
229263
gpu_offload_layout.addWidget(self.gpu_offload_slider)
230264
gpu_offload_layout.addWidget(self.gpu_offload_spinbox)
231265
gpu_offload_layout.addWidget(self.gpu_offload_auto)
@@ -260,12 +294,44 @@ def refresh_backends(self):
260294
llama_bin = os.path.abspath("llama_bin")
261295
if not os.path.exists(llama_bin):
262296
os.makedirs(llama_bin)
263-
297+
264298
self.backend_combo.clear()
299+
valid_backends = []
265300
for item in os.listdir(llama_bin):
266301
item_path = os.path.join(llama_bin, item)
267-
if os.path.isdir(item_path):
268-
self.backend_combo.addItem(item, userData=item_path)
302+
if os.path.isdir(item_path) and "cudart-llama" not in item.lower():
303+
valid_backends.append((item, item_path))
304+
305+
if valid_backends:
306+
for name, path in valid_backends:
307+
self.backend_combo.addItem(name, userData=path)
308+
self.backend_combo.setEnabled(True) # Enable the combo box if there are valid backends
309+
else:
310+
self.backend_combo.addItem("No backends available")
311+
self.backend_combo.setEnabled(False)
312+
313+
def download_finished(self, extract_dir):
314+
self.download_button.setEnabled(True)
315+
self.download_progress.setValue(100)
316+
317+
if self.cuda_extract_checkbox.isChecked() and self.cuda_extract_checkbox.isVisible():
318+
cuda_backend = self.backend_combo_cuda.currentData()
319+
if cuda_backend and cuda_backend != "No suitable CUDA backends found":
320+
self.extract_cuda_files(extract_dir, cuda_backend)
321+
QMessageBox.information(self, "Download Complete", f"llama.cpp binary downloaded and extracted to {extract_dir}\nCUDA files extracted to {cuda_backend}")
322+
else:
323+
QMessageBox.warning(self, "CUDA Extraction Failed", "No suitable CUDA backend found for extraction")
324+
else:
325+
QMessageBox.information(self, "Download Complete", f"llama.cpp binary downloaded and extracted to {extract_dir}")
326+
327+
self.refresh_backends() # Refresh the backends after successful download
328+
self.update_cuda_option() # Update CUDA options in case a CUDA-capable backend was downloaded
329+
330+
# Select the newly downloaded backend
331+
new_backend_name = os.path.basename(extract_dir)
332+
index = self.backend_combo.findText(new_backend_name)
333+
if index >= 0:
334+
self.backend_combo.setCurrentIndex(index)
269335

270336
def refresh_releases(self):
271337
try:
@@ -285,6 +351,16 @@ def update_assets(self):
285351
if release:
286352
for asset in release['assets']:
287353
self.asset_combo.addItem(asset['name'], userData=asset)
354+
self.update_cuda_option()
355+
356+
def update_cuda_option(self):
357+
asset = self.asset_combo.currentData()
358+
is_cuda = asset and "cudart" in asset['name'].lower()
359+
self.cuda_extract_checkbox.setVisible(is_cuda)
360+
self.cuda_backend_label.setVisible(is_cuda)
361+
self.backend_combo_cuda.setVisible(is_cuda)
362+
if is_cuda:
363+
self.update_cuda_backends()
288364

289365
def download_llama_cpp(self):
290366
asset = self.asset_combo.currentData()
@@ -301,21 +377,68 @@ def download_llama_cpp(self):
301377
self.download_thread = DownloadThread(asset['browser_download_url'], save_path)
302378
self.download_thread.progress_signal.connect(self.update_download_progress)
303379
self.download_thread.finished_signal.connect(self.download_finished)
304-
self.download_thread.error_signal.connect(self.show_error)
380+
self.download_thread.error_signal.connect(self.download_error)
305381
self.download_thread.start()
306382

307383
self.download_button.setEnabled(False)
308384
self.download_progress.setValue(0)
309385

386+
def update_cuda_backends(self):
387+
self.backend_combo_cuda.clear()
388+
llama_bin = os.path.abspath("llama_bin")
389+
if os.path.exists(llama_bin):
390+
for item in os.listdir(llama_bin):
391+
item_path = os.path.join(llama_bin, item)
392+
if os.path.isdir(item_path) and "cudart-llama" not in item.lower():
393+
if "cu1" in item.lower(): # Only include CUDA-capable backends
394+
self.backend_combo_cuda.addItem(item, userData=item_path)
395+
396+
if self.backend_combo_cuda.count() == 0:
397+
self.backend_combo_cuda.addItem("No suitable CUDA backends found")
398+
self.backend_combo_cuda.setEnabled(False)
399+
else:
400+
self.backend_combo_cuda.setEnabled(True)
401+
310402
def update_download_progress(self, progress):
311403
self.download_progress.setValue(progress)
312404

313405
def download_finished(self, extract_dir):
314406
self.download_button.setEnabled(True)
315407
self.download_progress.setValue(100)
316-
QMessageBox.information(self, "Download Complete", f"llama.cpp binary downloaded and extracted to {extract_dir}")
317-
self.refresh_backends()
318408

409+
if self.cuda_extract_checkbox.isChecked() and self.cuda_extract_checkbox.isVisible():
410+
cuda_backend = self.backend_combo_cuda.currentData()
411+
if cuda_backend:
412+
self.extract_cuda_files(extract_dir, cuda_backend)
413+
QMessageBox.information(self, "Download Complete", f"llama.cpp binary downloaded and extracted to {extract_dir}\nCUDA files extracted to {cuda_backend}")
414+
else:
415+
QMessageBox.warning(self, "CUDA Extraction Failed", "No CUDA backend selected for extraction")
416+
else:
417+
QMessageBox.information(self, "Download Complete", f"llama.cpp binary downloaded and extracted to {extract_dir}")
418+
419+
self.refresh_backends()
420+
421+
def extract_cuda_files(self, extract_dir, destination):
422+
for root, dirs, files in os.walk(extract_dir):
423+
for file in files:
424+
if file.lower().endswith('.dll'):
425+
source_path = os.path.join(root, file)
426+
dest_path = os.path.join(destination, file)
427+
shutil.copy2(source_path, dest_path)
428+
429+
430+
def download_error(self, error_message):
431+
self.download_button.setEnabled(True)
432+
self.download_progress.setValue(0)
433+
self.show_error(f"Download failed: {error_message}")
434+
435+
# Clean up any partially downloaded files
436+
asset = self.asset_combo.currentData()
437+
if asset:
438+
partial_file = os.path.join(os.path.abspath("llama_bin"), asset['name'])
439+
if os.path.exists(partial_file):
440+
os.remove(partial_file)
441+
319442
def show_task_context_menu(self, position):
320443
item = self.task_list.itemAt(position)
321444
if item is not None:
@@ -422,12 +545,6 @@ def update_system_info(self):
422545
self.ram_bar.setValue(int(ram.percent))
423546
self.ram_bar.setFormat(f"{ram.percent:.1f}% ({ram.used // 1024 // 1024} MB / {ram.total // 1024 // 1024} MB)")
424547
self.cpu_label.setText(f"CPU Usage: {cpu:.1f}%")
425-
426-
def toggle_output_tensor_type(self, state):
427-
self.output_tensor_type.setEnabled(state == Qt.CheckState.Checked)
428-
429-
def toggle_token_embedding_type(self, state):
430-
self.token_embedding_type.setEnabled(state == Qt.CheckState.Checked)
431548

432549
def validate_quantization_inputs(self):
433550
if not self.backend_combo.currentData():
@@ -439,6 +556,17 @@ def validate_quantization_inputs(self):
439556
if not self.logs_input.text():
440557
raise ValueError("Logs path is required")
441558

559+
def add_kv_override(self):
560+
entry = KVOverrideEntry()
561+
entry.deleted.connect(self.remove_kv_override)
562+
self.kv_override_layout.addWidget(entry)
563+
self.kv_override_entries.append(entry)
564+
565+
def remove_kv_override(self, entry):
566+
self.kv_override_layout.removeWidget(entry)
567+
self.kv_override_entries.remove(entry)
568+
entry.deleteLater()
569+
442570
def quantize_model(self):
443571
try:
444572
self.validate_quantization_inputs()
@@ -480,7 +608,10 @@ def quantize_model(self):
480608
if self.keep_split.isChecked():
481609
command.append("--keep-split")
482610
if self.override_kv.text():
483-
command.extend(["--override-kv", self.override_kv.text()])
611+
for entry in self.kv_override_entries:
612+
override_string = entry.get_override_string()
613+
if override_string:
614+
command.extend(["--override-kv", override_string])
484615

485616
command.extend([input_path, output_path, quant_type])
486617

src/DownloadThread.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,5 @@ def run(self):
5050
self.finished_signal.emit(extract_dir)
5151
except Exception as e:
5252
self.error_signal.emit(str(e))
53-
53+
if os.path.exists(self.save_path):
54+
os.remove(self.save_path)

src/KVOverrideEntry.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from PyQt6.QtWidgets import QWidget, QHBoxLayout, QLineEdit, QComboBox, QPushButton
2+
from PyQt6.QtCore import pyqtSignal
3+
4+
class KVOverrideEntry(QWidget):
5+
deleted = pyqtSignal(QWidget)
6+
7+
def __init__(self, parent=None):
8+
super().__init__(parent)
9+
layout = QHBoxLayout(self)
10+
layout.setContentsMargins(0, 0, 0, 0)
11+
12+
self.key_input = QLineEdit()
13+
self.key_input.setPlaceholderText("Key")
14+
layout.addWidget(self.key_input)
15+
16+
self.type_combo = QComboBox()
17+
self.type_combo.addItems(["int", "str", "float"])
18+
layout.addWidget(self.type_combo)
19+
20+
self.value_input = QLineEdit()
21+
self.value_input.setPlaceholderText("Value")
22+
layout.addWidget(self.value_input)
23+
24+
delete_button = QPushButton("X")
25+
delete_button.setFixedSize(30, 30)
26+
delete_button.clicked.connect(self.delete_clicked)
27+
layout.addWidget(delete_button)
28+
29+
def delete_clicked(self):
30+
self.deleted.emit(self)
31+
32+
def get_override_string(self):
33+
return f"{self.key_input.text()}={self.type_combo.currentText()}:{self.value_input.text()}"

0 commit comments

Comments
 (0)