Skip to content

Commit a91f804

Browse files
committed
fix: use proper status in TaskListItem
- use proper status in TaskListItem - make sure to pass quant_threads and Logger to TaskListItem - remove unnecessary logging in quantize_to_fp8_dynamic.py and optimize imports
1 parent a7f2dec commit a91f804

File tree

3 files changed

+51
-19
lines changed

3 files changed

+51
-19
lines changed

src/AutoGGUF.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def __init__(self, args: List[str]) -> None:
9696
self.delete_task = partial(TaskListItem.delete_task, self)
9797
self.show_task_context_menu = partial(TaskListItem.show_task_context_menu, self)
9898
self.show_task_properties = partial(TaskListItem.show_task_properties, self)
99-
self.cancel_task_by_item = partial(TaskListItem.cancel_task_by_item, self)
10099
self.toggle_gpu_offload_auto = partial(ui_update.toggle_gpu_offload_auto, self)
101100
self.update_threads_spinbox = partial(ui_update.update_threads_spinbox, self)
102101
self.update_threads_slider = partial(ui_update.update_threads_slider, self)
@@ -1036,7 +1035,13 @@ def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
10361035
self.quant_threads.append(thread)
10371036

10381037
task_name = f"Quantizing {os.path.basename(model_dir)} with AutoFP8"
1039-
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
1038+
task_item = TaskListItem(
1039+
task_name,
1040+
log_file,
1041+
show_progress_bar=False,
1042+
logger=self.logger,
1043+
quant_threads=self.quant_threads,
1044+
)
10401045
list_item = QListWidgetItem(self.task_list)
10411046
list_item.setSizeHint(task_item.sizeHint())
10421047
self.task_list.addItem(list_item)
@@ -1152,7 +1157,13 @@ def convert_hf_to_gguf(self) -> None:
11521157
self.quant_threads.append(thread)
11531158

11541159
task_name = CONVERTING_TO_GGUF.format(os.path.basename(model_dir))
1155-
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
1160+
task_item = TaskListItem(
1161+
task_name,
1162+
log_file,
1163+
show_progress_bar=False,
1164+
logger=self.logger,
1165+
quant_threads=self.quant_threads,
1166+
)
11561167
list_item = QListWidgetItem(self.task_list)
11571168
list_item.setSizeHint(task_item.sizeHint())
11581169
self.task_list.addItem(list_item)
@@ -1516,7 +1527,10 @@ def quantize_model(self) -> None:
15161527
self.quant_threads.append(thread)
15171528

15181529
task_item = TaskListItem(
1519-
QUANTIZING_MODEL_TO.format(model_name, quant_type), log_file
1530+
QUANTIZING_MODEL_TO.format(model_name, quant_type),
1531+
log_file,
1532+
show_properties=True,
1533+
logger=self.logger,
15201534
)
15211535
list_item = QListWidgetItem(self.task_list)
15221536
list_item.setSizeHint(task_item.sizeHint())
@@ -1687,7 +1701,13 @@ def generate_imatrix(self) -> None:
16871701
task_name = GENERATING_IMATRIX_FOR.format(
16881702
os.path.basename(self.imatrix_model.text())
16891703
)
1690-
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
1704+
task_item = TaskListItem(
1705+
task_name,
1706+
log_file,
1707+
show_progress_bar=False,
1708+
logger=self.logger,
1709+
quant_threads=self.quant_threads,
1710+
)
16911711
list_item = QListWidgetItem(self.task_list)
16921712
list_item.setSizeHint(task_item.sizeHint())
16931713
self.task_list.addItem(list_item)

src/TaskListItem.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List
2+
13
from PySide6.QtCore import *
24
from PySide6.QtGui import QAction
35
from PySide6.QtWidgets import *
@@ -15,19 +17,34 @@
1517
SHOWING_PROPERTIES_FOR_TASK,
1618
DELETE,
1719
RESTART,
20+
IN_PROGRESS,
21+
ERROR,
1822
)
1923
from ModelInfoDialog import ModelInfoDialog
24+
from QuantizationThread import QuantizationThread
25+
from Logger import Logger
2026

2127

2228
class TaskListItem(QWidget):
2329
def __init__(
24-
self, task_name, log_file, show_progress_bar=True, parent=None
30+
self,
31+
task_name,
32+
log_file,
33+
show_progress_bar=True,
34+
parent=None,
35+
show_properties=False,
36+
logger=Logger,
37+
quant_threads=List[QuantizationThread],
2538
) -> None:
2639
super().__init__(parent)
40+
self.quant_threads = quant_threads
2741
self.task_name = task_name
2842
self.log_file = log_file
43+
self.logger = logger
44+
self.show_properties = show_properties
2945
self.status = "Pending"
3046
layout = QHBoxLayout(self)
47+
3148
self.task_label = QLabel(task_name)
3249
self.progress_bar = QProgressBar()
3350
self.progress_bar.setRange(0, 100)
@@ -84,7 +101,8 @@ def show_task_properties(self, item) -> None:
84101
model_info_dialog.exec()
85102
break
86103

87-
def cancel_task_by_item(self, item) -> None:
104+
def cancel_task(self, item) -> None:
105+
self.logger.info(CANCELLING_TASK.format(item.text()))
88106
task_item = self.task_list.itemWidget(item)
89107
for thread in self.quant_threads:
90108
if thread.log_file == task_item.log_file:
@@ -93,15 +111,11 @@ def cancel_task_by_item(self, item) -> None:
93111
self.quant_threads.remove(thread)
94112
break
95113

96-
def cancel_task(self, item) -> None:
97-
self.logger.info(CANCELLING_TASK.format(item.text()))
98-
self.cancel_task_by_item(item)
99-
100114
def delete_task(self, item) -> None:
101115
self.logger.info(DELETING_TASK.format(item.text()))
102116

103117
# Cancel the task first
104-
self.cancel_task_by_item(item)
118+
self.cancel_task(item)
105119

106120
reply = QMessageBox.question(
107121
self,
@@ -121,21 +135,21 @@ def delete_task(self, item) -> None:
121135
def update_status(self, status) -> None:
122136
self.status = status
123137
self.status_label.setText(status)
124-
if status == "In Progress":
138+
if status == IN_PROGRESS:
125139
# Only start timer if showing percentage progress
126140
if self.progress_bar.isVisible():
127141
self.progress_bar.setRange(0, 100)
128142
self.progress_timer.start(100)
129-
elif status == "Completed":
143+
elif status == COMPLETED:
130144
self.progress_timer.stop()
131145
self.progress_bar.setValue(100)
132-
elif status == "Canceled":
146+
elif status == CANCELED:
133147
self.progress_timer.stop()
134148
self.progress_bar.setValue(0)
135149

136150
def set_error(self) -> None:
137-
self.status = "Error"
138-
self.status_label.setText("Error")
151+
self.status = ERROR
152+
self.status_label.setText(ERROR)
139153
self.status_label.setStyleSheet("color: red;")
140154
self.progress_bar.setRange(0, 100)
141155
self.progress_timer.stop()

src/quantize_to_fp8_dynamic.py

-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import tqdm
1010
from transformers import AutoModelForCausalLM, AutoTokenizer
1111

12-
from Logger import Logger
1312

1413
# https://github.com/neuralmagic/AutoFP8
1514

@@ -544,7 +543,6 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List
544543

545544

546545
def quantize_to_fp8_dynamic(input_model_dir: str, output_model_dir: str) -> None:
547-
print("Starting fp8 dynamic quantization")
548546
# Define quantization config with static activation scales
549547
quantize_config = BaseQuantizeConfig(
550548
quant_method="fp8", activation_scheme="dynamic"

0 commit comments

Comments
 (0)