|
1 | 1 | import os
|
2 | 2 | import sys
|
3 | 3 | import threading
|
| 4 | +from enum import Enum |
| 5 | +from typing import List, Optional |
4 | 6 |
|
5 | 7 | from PySide6.QtCore import QTimer
|
6 | 8 | from PySide6.QtWidgets import QApplication
|
| 9 | +from fastapi import FastAPI, Query |
| 10 | +from pydantic import BaseModel, Field |
| 11 | +from uvicorn import Config, Server |
| 12 | + |
7 | 13 | from AutoGGUF import AutoGGUF
|
8 |
| -from flask import Flask, Response, jsonify |
| 14 | +from Localizations import AUTOGGUF_VERSION |
9 | 15 |
|
10 |
| -server = Flask(__name__) |
| 16 | +app = FastAPI( |
| 17 | + title="AutoGGUF", |
| 18 | + description="API for AutoGGUF - automatically quant GGUF models", |
| 19 | + version=AUTOGGUF_VERSION, |
| 20 | + license_info={ |
| 21 | + "name": "Apache 2.0", |
| 22 | + "url": "https://raw.githubusercontent.com/leafspark/AutoGGUF/main/LICENSE", |
| 23 | + }, |
| 24 | +) |
11 | 25 |
|
| 26 | +# Global variable to hold the window reference |
| 27 | +window = None |
12 | 28 |
|
13 |
| -def main() -> None: |
14 |
| - @server.route("/v1/models", methods=["GET"]) |
15 |
| - def models() -> Response: |
16 |
| - if window: |
17 |
| - return jsonify({"models": window.get_models_data()}) |
18 |
| - return jsonify({"models": []}) |
19 |
| - |
20 |
| - @server.route("/v1/tasks", methods=["GET"]) |
21 |
| - def tasks() -> Response: |
22 |
| - if window: |
23 |
| - return jsonify({"tasks": window.get_tasks_data()}) |
24 |
| - return jsonify({"tasks": []}) |
25 |
| - |
26 |
| - @server.route("/v1/health", methods=["GET"]) |
27 |
| - def ping() -> Response: |
28 |
| - return jsonify({"status": "alive"}) |
29 |
| - |
30 |
| - @server.route("/v1/backends", methods=["GET"]) |
31 |
| - def get_backends() -> Response: |
32 |
| - backends = [] |
| 29 | + |
| 30 | +class ModelType(str, Enum): |
| 31 | + single = "single" |
| 32 | + sharded = "sharded" |
| 33 | + |
| 34 | + |
| 35 | +class Model(BaseModel): |
| 36 | + name: str = Field(..., description="Name of the model") |
| 37 | + type: str = Field(..., description="Type of the model") |
| 38 | + path: str = Field(..., description="Path to the model file") |
| 39 | + size: Optional[int] = Field(None, description="Size of the model in bytes") |
| 40 | + |
| 41 | + class Config: |
| 42 | + json_schema_extra = { |
| 43 | + "example": { |
| 44 | + "name": "Llama-3.1-8B-Instruct.fp16.gguf", |
| 45 | + "type": "single", |
| 46 | + "path": "Llama-3.1-8B-Instruct.fp16.gguf", |
| 47 | + "size": 13000000000, |
| 48 | + } |
| 49 | + } |
| 50 | + |
| 51 | + |
| 52 | +class Task(BaseModel): |
| 53 | + id: str = Field(..., description="Unique identifier for the task") |
| 54 | + status: str = Field(..., description="Current status of the task") |
| 55 | + progress: float = Field(..., description="Progress of the task as a percentage") |
| 56 | + |
| 57 | + class Config: |
| 58 | + json_json_schema_extra = { |
| 59 | + "example": {"id": "task_123", "status": "running", "progress": 75.5} |
| 60 | + } |
| 61 | + |
| 62 | + |
| 63 | +class Backend(BaseModel): |
| 64 | + name: str = Field(..., description="Name of the backend") |
| 65 | + path: str = Field(..., description="Path to the backend executable") |
| 66 | + |
| 67 | + |
| 68 | +class Plugin(BaseModel): |
| 69 | + name: str = Field(..., description="Name of the plugin") |
| 70 | + version: str = Field(..., description="Version of the plugin") |
| 71 | + description: str = Field(..., description="Description of the plugin") |
| 72 | + author: str = Field(..., description="Author of the plugin") |
| 73 | + |
| 74 | + |
| 75 | +@app.get("/v1/models", response_model=List[Model], tags=["Models"]) |
| 76 | +async def get_models( |
| 77 | + type: Optional[ModelType] = Query(None, description="Filter models by type") |
| 78 | +) -> List[Model]: |
| 79 | + """ |
| 80 | + Get a list of all available models. |
| 81 | +
|
| 82 | + - **type**: Optional filter for model type |
| 83 | +
|
| 84 | + Returns a list of Model objects containing name, type, path, and optional size. |
| 85 | + """ |
| 86 | + if window: |
| 87 | + models = window.get_models_data() |
| 88 | + if type: |
| 89 | + models = [m for m in models if m["type"] == type] |
| 90 | + |
| 91 | + # Convert to Pydantic models, handling missing 'size' field |
| 92 | + return [Model(**m) for m in models] |
| 93 | + return [] |
| 94 | + |
| 95 | + |
| 96 | +@app.get("/v1/tasks", response_model=List[Task], tags=["Tasks"]) |
| 97 | +async def get_tasks() -> List[Task]: |
| 98 | + """ |
| 99 | + Get a list of all current tasks. |
| 100 | +
|
| 101 | + Returns a list of Task objects containing id, status, and progress. |
| 102 | + """ |
| 103 | + if window: |
| 104 | + return window.get_tasks_data() |
| 105 | + return [] |
| 106 | + |
| 107 | + |
| 108 | +@app.get("/v1/health", tags=["System"]) |
| 109 | +async def health_check() -> dict: |
| 110 | + """ |
| 111 | + Check the health status of the API. |
| 112 | +
|
| 113 | + Returns a simple status message indicating the API is alive. |
| 114 | + """ |
| 115 | + return {"status": "alive"} |
| 116 | + |
| 117 | + |
| 118 | +@app.get("/v1/backends", response_model=List[Backend], tags=["System"]) |
| 119 | +async def get_backends() -> List[Backend]: |
| 120 | + """ |
| 121 | + Get a list of all available llama.cpp backends. |
| 122 | +
|
| 123 | + Returns a list of Backend objects containing name and path. |
| 124 | + """ |
| 125 | + backends = [] |
| 126 | + if window: |
33 | 127 | for i in range(window.backend_combo.count()):
|
34 | 128 | backends.append(
|
35 |
| - { |
36 |
| - "name": window.backend_combo.itemText(i), |
37 |
| - "path": window.backend_combo.itemData(i), |
38 |
| - } |
39 |
| - ) |
40 |
| - return jsonify({"backends": backends}) |
41 |
| - |
42 |
| - @server.route("/v1/plugins", methods=["GET"]) |
43 |
| - def get_plugins() -> Response: |
44 |
| - if window: |
45 |
| - return jsonify( |
46 |
| - { |
47 |
| - "plugins": [ |
48 |
| - { |
49 |
| - "name": plugin_data["data"]["name"], |
50 |
| - "version": plugin_data["data"]["version"], |
51 |
| - "description": plugin_data["data"]["description"], |
52 |
| - "author": plugin_data["data"]["author"], |
53 |
| - } |
54 |
| - for plugin_data in window.plugins.values() |
55 |
| - ] |
56 |
| - } |
57 |
| - ) |
58 |
| - return jsonify({"plugins": []}) |
59 |
| - |
60 |
| - def run_flask() -> None: |
61 |
| - if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled": |
62 |
| - server.run( |
63 |
| - host="0.0.0.0", |
64 |
| - port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 5000)), |
65 |
| - debug=False, |
66 |
| - use_reloader=False, |
| 129 | + Backend( |
| 130 | + name=window.backend_combo.itemText(i), |
| 131 | + path=window.backend_combo.itemData(i), |
| 132 | + ) |
67 | 133 | )
|
| 134 | + return backends |
68 | 135 |
|
69 |
| - app = QApplication(sys.argv) |
| 136 | + |
| 137 | +@app.get("/v1/plugins", response_model=List[Plugin], tags=["System"]) |
| 138 | +async def get_plugins() -> List[Plugin]: |
| 139 | + """ |
| 140 | + Get a list of all installed plugins. |
| 141 | +
|
| 142 | + Returns a list of Plugin objects containing name, version, description, and author. |
| 143 | + """ |
| 144 | + if window: |
| 145 | + return [ |
| 146 | + Plugin(**plugin_data["data"]) for plugin_data in window.plugins.values() |
| 147 | + ] |
| 148 | + return [] |
| 149 | + |
| 150 | + |
| 151 | +def run_uvicorn() -> None: |
| 152 | + if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled": |
| 153 | + config = Config( |
| 154 | + app=app, |
| 155 | + host="127.0.0.1", |
| 156 | + port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 7001)), |
| 157 | + log_level="info", |
| 158 | + ) |
| 159 | + server = Server(config) |
| 160 | + server.run() |
| 161 | + |
| 162 | + |
| 163 | +def main() -> None: |
| 164 | + global window |
| 165 | + qt_app = QApplication(sys.argv) |
70 | 166 | window = AutoGGUF(sys.argv)
|
71 | 167 | window.show()
|
72 |
| - # Start Flask in a separate thread after a short delay |
| 168 | + |
| 169 | + # Start Uvicorn in a separate thread after a short delay |
73 | 170 | timer = QTimer()
|
74 | 171 | timer.singleShot(
|
75 |
| - 100, lambda: threading.Thread(target=run_flask, daemon=True).start() |
| 172 | + 100, lambda: threading.Thread(target=run_uvicorn, daemon=True).start() |
76 | 173 | )
|
77 |
| - sys.exit(app.exec()) |
| 174 | + |
| 175 | + sys.exit(qt_app.exec()) |
78 | 176 |
|
79 | 177 |
|
80 | 178 | if __name__ == "__main__":
|
|
0 commit comments