Skip to content

Commit 22bd74b

Browse files
committed
feat(server): replace Flask with FastAPI and Uvicorn
- replace Flask with FastAPI and Uvicorn - fix web page not found error - port is now defaulted to 7001 - bind to localhost (127.0.0.1) instead of 0.0.0.0 - improve performance by using Uvicorn - add OpenAPI docs for endpoints
1 parent db1733b commit 22bd74b

File tree

2 files changed

+158
-59
lines changed

2 files changed

+158
-59
lines changed

Diff for: requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ sentencepiece~=0.2.0
66
PyYAML~=6.0.2
77
pynvml~=11.5.3
88
PySide6~=6.7.2
9-
flask~=3.0.3
109
python-dotenv~=1.0.1
1110
safetensors~=0.4.4
1211
setuptools~=68.2.0
1312
huggingface-hub~=0.24.6
1413
transformers~=4.44.2
14+
fastapi~=0.112.2
15+
uvicorn~=0.30.6

Diff for: src/main.py

+156-58
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,178 @@
11
import os
22
import sys
33
import threading
4+
from enum import Enum
5+
from typing import List, Optional
46

57
from PySide6.QtCore import QTimer
68
from PySide6.QtWidgets import QApplication
9+
from fastapi import FastAPI, Query
10+
from pydantic import BaseModel, Field
11+
from uvicorn import Config, Server
12+
713
from AutoGGUF import AutoGGUF
8-
from flask import Flask, Response, jsonify
14+
from Localizations import AUTOGGUF_VERSION
915

10-
server = Flask(__name__)
16+
app = FastAPI(
17+
title="AutoGGUF",
18+
description="API for AutoGGUF - automatically quant GGUF models",
19+
version=AUTOGGUF_VERSION,
20+
license_info={
21+
"name": "Apache 2.0",
22+
"url": "https://raw.githubusercontent.com/leafspark/AutoGGUF/main/LICENSE",
23+
},
24+
)
1125

26+
# Global variable to hold the window reference
27+
window = None
1228

13-
def main() -> None:
14-
@server.route("/v1/models", methods=["GET"])
15-
def models() -> Response:
16-
if window:
17-
return jsonify({"models": window.get_models_data()})
18-
return jsonify({"models": []})
19-
20-
@server.route("/v1/tasks", methods=["GET"])
21-
def tasks() -> Response:
22-
if window:
23-
return jsonify({"tasks": window.get_tasks_data()})
24-
return jsonify({"tasks": []})
25-
26-
@server.route("/v1/health", methods=["GET"])
27-
def ping() -> Response:
28-
return jsonify({"status": "alive"})
29-
30-
@server.route("/v1/backends", methods=["GET"])
31-
def get_backends() -> Response:
32-
backends = []
29+
30+
class ModelType(str, Enum):
31+
single = "single"
32+
sharded = "sharded"
33+
34+
35+
class Model(BaseModel):
36+
name: str = Field(..., description="Name of the model")
37+
type: str = Field(..., description="Type of the model")
38+
path: str = Field(..., description="Path to the model file")
39+
size: Optional[int] = Field(None, description="Size of the model in bytes")
40+
41+
class Config:
42+
json_schema_extra = {
43+
"example": {
44+
"name": "Llama-3.1-8B-Instruct.fp16.gguf",
45+
"type": "single",
46+
"path": "Llama-3.1-8B-Instruct.fp16.gguf",
47+
"size": 13000000000,
48+
}
49+
}
50+
51+
52+
class Task(BaseModel):
53+
id: str = Field(..., description="Unique identifier for the task")
54+
status: str = Field(..., description="Current status of the task")
55+
progress: float = Field(..., description="Progress of the task as a percentage")
56+
57+
class Config:
58+
json_json_schema_extra = {
59+
"example": {"id": "task_123", "status": "running", "progress": 75.5}
60+
}
61+
62+
63+
class Backend(BaseModel):
64+
name: str = Field(..., description="Name of the backend")
65+
path: str = Field(..., description="Path to the backend executable")
66+
67+
68+
class Plugin(BaseModel):
69+
name: str = Field(..., description="Name of the plugin")
70+
version: str = Field(..., description="Version of the plugin")
71+
description: str = Field(..., description="Description of the plugin")
72+
author: str = Field(..., description="Author of the plugin")
73+
74+
75+
@app.get("/v1/models", response_model=List[Model], tags=["Models"])
76+
async def get_models(
77+
type: Optional[ModelType] = Query(None, description="Filter models by type")
78+
) -> List[Model]:
79+
"""
80+
Get a list of all available models.
81+
82+
- **type**: Optional filter for model type
83+
84+
Returns a list of Model objects containing name, type, path, and optional size.
85+
"""
86+
if window:
87+
models = window.get_models_data()
88+
if type:
89+
models = [m for m in models if m["type"] == type]
90+
91+
# Convert to Pydantic models, handling missing 'size' field
92+
return [Model(**m) for m in models]
93+
return []
94+
95+
96+
@app.get("/v1/tasks", response_model=List[Task], tags=["Tasks"])
97+
async def get_tasks() -> List[Task]:
98+
"""
99+
Get a list of all current tasks.
100+
101+
Returns a list of Task objects containing id, status, and progress.
102+
"""
103+
if window:
104+
return window.get_tasks_data()
105+
return []
106+
107+
108+
@app.get("/v1/health", tags=["System"])
109+
async def health_check() -> dict:
110+
"""
111+
Check the health status of the API.
112+
113+
Returns a simple status message indicating the API is alive.
114+
"""
115+
return {"status": "alive"}
116+
117+
118+
@app.get("/v1/backends", response_model=List[Backend], tags=["System"])
119+
async def get_backends() -> List[Backend]:
120+
"""
121+
Get a list of all available llama.cpp backends.
122+
123+
Returns a list of Backend objects containing name and path.
124+
"""
125+
backends = []
126+
if window:
33127
for i in range(window.backend_combo.count()):
34128
backends.append(
35-
{
36-
"name": window.backend_combo.itemText(i),
37-
"path": window.backend_combo.itemData(i),
38-
}
39-
)
40-
return jsonify({"backends": backends})
41-
42-
@server.route("/v1/plugins", methods=["GET"])
43-
def get_plugins() -> Response:
44-
if window:
45-
return jsonify(
46-
{
47-
"plugins": [
48-
{
49-
"name": plugin_data["data"]["name"],
50-
"version": plugin_data["data"]["version"],
51-
"description": plugin_data["data"]["description"],
52-
"author": plugin_data["data"]["author"],
53-
}
54-
for plugin_data in window.plugins.values()
55-
]
56-
}
57-
)
58-
return jsonify({"plugins": []})
59-
60-
def run_flask() -> None:
61-
if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled":
62-
server.run(
63-
host="0.0.0.0",
64-
port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 5000)),
65-
debug=False,
66-
use_reloader=False,
129+
Backend(
130+
name=window.backend_combo.itemText(i),
131+
path=window.backend_combo.itemData(i),
132+
)
67133
)
134+
return backends
68135

69-
app = QApplication(sys.argv)
136+
137+
@app.get("/v1/plugins", response_model=List[Plugin], tags=["System"])
138+
async def get_plugins() -> List[Plugin]:
139+
"""
140+
Get a list of all installed plugins.
141+
142+
Returns a list of Plugin objects containing name, version, description, and author.
143+
"""
144+
if window:
145+
return [
146+
Plugin(**plugin_data["data"]) for plugin_data in window.plugins.values()
147+
]
148+
return []
149+
150+
151+
def run_uvicorn() -> None:
152+
if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled":
153+
config = Config(
154+
app=app,
155+
host="127.0.0.1",
156+
port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 7001)),
157+
log_level="info",
158+
)
159+
server = Server(config)
160+
server.run()
161+
162+
163+
def main() -> None:
164+
global window
165+
qt_app = QApplication(sys.argv)
70166
window = AutoGGUF(sys.argv)
71167
window.show()
72-
# Start Flask in a separate thread after a short delay
168+
169+
# Start Uvicorn in a separate thread after a short delay
73170
timer = QTimer()
74171
timer.singleShot(
75-
100, lambda: threading.Thread(target=run_flask, daemon=True).start()
172+
100, lambda: threading.Thread(target=run_uvicorn, daemon=True).start()
76173
)
77-
sys.exit(app.exec())
174+
175+
sys.exit(qt_app.exec())
78176

79177

80178
if __name__ == "__main__":

0 commit comments

Comments
 (0)