Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add embedding model loading #5

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 129 additions & 4 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import aiohttp
import gradio as gr
import requests
from requests import HTTPError

conn_url = None
conn_key = None
Expand All @@ -17,6 +18,7 @@
loras = []
templates = []
overrides = []
embedding_models = []

model_load_task = None
model_load_state = False
Expand Down Expand Up @@ -178,6 +180,7 @@ def connect(api_url, admin_key, silent=False):
global loras
global templates
global overrides
global embedding_models

try:
a = requests.get(
Expand Down Expand Up @@ -212,6 +215,10 @@ def connect(api_url, admin_key, silent=False):
url=api_url + "/v1/sampling/override/list", headers={"X-api-key": admin_key}
)
so.raise_for_status()
em = requests.get(
url=api_url + "/v1/model/embedding/list", headers={"X-api-key": admin_key}
)
em.raise_for_status()
except Exception as e:
raise gr.Error(e)

Expand Down Expand Up @@ -243,6 +250,11 @@ def connect(api_url, admin_key, silent=False):
overrides.append(override)
overrides.sort(key=str.lower)

embedding_models = []
for model in em.json().get("data"):
embedding_models.append(model.get("id"))
embedding_models.sort(key=str.lower)

if not silent:
gr.Info("TabbyAPI connected.")
return (
Expand All @@ -256,6 +268,9 @@ def connect(api_url, admin_key, silent=False):
get_override_list(),
get_current_model(),
get_current_loras(),
get_embedding_model_list(),
gr.Textbox(value=", ".join(embedding_models), visible=True),
get_current_embedding_model(),
)


Expand Down Expand Up @@ -287,11 +302,11 @@ def get_current_model():
return gr.Textbox(value=None)
params = model_card.get("parameters")
draft_model_card = params.get("draft")
model = f'{model_card.get("id")} (context: {params.get("max_seq_len")}, cache size: {params.get("cache_size")}, rope scale: {params.get("rope_scale")}, rope alpha: {params.get("rope_alpha")})'
model = f"{model_card.get('id')} (context: {params.get('max_seq_len')}, cache size: {params.get('cache_size')}, rope scale: {params.get('rope_scale')}, rope alpha: {params.get('rope_alpha')})"

if draft_model_card:
draft_params = draft_model_card.get("parameters")
model += f' | {draft_model_card.get("id")} (rope scale: {draft_params.get("rope_scale")}, rope alpha: {draft_params.get("rope_alpha")})'
model += f" | {draft_model_card.get('id')} (rope scale: {draft_params.get('rope_scale')}, rope alpha: {draft_params.get('rope_alpha')})"
return gr.Textbox(value=model)


Expand All @@ -302,10 +317,22 @@ def get_current_loras():
lora_list = lo.get("data")
loras = []
for lora in lora_list:
loras.append(f'{lora.get("id")} (scaling: {lora.get("scaling")})')
loras.append(f"{lora.get('id')} (scaling: {lora.get('scaling')})")
return gr.Textbox(value=", ".join(loras))


def get_current_embedding_model():
try:
model_card = requests.get(
url=conn_url + "/v1/model/embedding", headers={"X-api-key": conn_key}
).json()
if not model_card.get("id"):
return gr.Textbox(value=None)
return gr.Textbox(value=model_card.get("id"))
except Exception:
return gr.Textbox(value=None)


def update_loras_table(loras):
array = []
for lora in loras:
Expand Down Expand Up @@ -591,7 +618,7 @@ async def download(repo_id, revision, repo_type, folder_name, token, include, ex
r.raise_for_status()
content = await r.json()
gr.Info(
f'{repo_type} {repo_id} downloaded to folder: {content.get("download_path")}.'
f"{repo_type} {repo_id} downloaded to folder: {content.get('download_path')}."
)
except asyncio.CancelledError:
gr.Info("Download canceled.")
Expand All @@ -608,6 +635,62 @@ def cancel_download():
download_task.cancel()


def get_embedding_model_list():
try:
r = requests.get(
url=conn_url + "/v1/model/embedding/list", headers={"X-api-key": conn_key}
)
r.raise_for_status()
embedding_models = []
for model in r.json().get("data"):
embedding_models.append(model.get("id"))
embedding_models.sort(key=str.lower)
return gr.Dropdown(choices=[""] + embedding_models, value=None)
except Exception as e:
raise gr.Error(e)


async def load_embedding_model(embedding_model_name, device):
if not embedding_model_name:
raise gr.Error("Specify an embedding model to load!")

request = {
"embedding_model_name": embedding_model_name,
"embeddings_device": device,
}

try:
requests.post(
url=conn_url + "/v1/model/embedding/unload",
headers={"X-admin-key": conn_key},
)
r = requests.post(
url=conn_url + "/v1/model/embedding/load",
headers={"X-admin-key": conn_key},
json=request,
)
r.raise_for_status()
gr.Info("Embedding model successfully loaded.")
return get_current_embedding_model()
except Exception as e:
raise gr.Error(e)


def unload_embedding_model():
try:
r = requests.post(
url=conn_url + "/v1/model/embedding/unload",
headers={"X-admin-key": conn_key},
)
r.raise_for_status()
gr.Info("Embedding model unloaded.")
return get_current_embedding_model()
except Exception as e:
if type(e) is HTTPError:
raise gr.Error(e.response.content.decode())
raise gr.Error(e)


# Auto-attempt connection if admin key is provided
init_model_text = None
init_lora_text = None
Expand All @@ -628,6 +711,7 @@ def cancel_download():
)
current_model = gr.Textbox(value=init_model_text, label="Current Model:")
current_loras = gr.Textbox(value=init_lora_text, label="Current Loras:")
current_embedding_model = gr.Textbox(value=None, label="Current Embedding Model:")

with gr.Tab("Connect to API"):
connect_btn = gr.Button(value="Connect", variant="primary")
Expand All @@ -648,6 +732,11 @@ def cancel_download():
lora_list = gr.Textbox(
value=", ".join(loras), label="Available Loras:", visible=bool(conn_key)
)
embedding_model_list = gr.Textbox(
value=", ".join(embedding_models),
label="Available Embedding Models:",
visible=bool(conn_key),
)

with gr.Tab("Load Model"):
with gr.Row():
Expand Down Expand Up @@ -855,6 +944,29 @@ def cancel_download():
interactive=True,
)

with gr.Tab("Load Embedding Model"):
with gr.Row():
load_embedding_btn = gr.Button(
value="Load Embedding Model", variant="primary"
)
unload_embedding_btn = gr.Button(
value="Unload Embedding Model", variant="stop"
)

with gr.Row():
embedding_models_drop = gr.Dropdown(
choices=[""],
label="Select Embedding Model:",
interactive=True,
)
embeddings_device = gr.Radio(
choices=["auto", "cpu", "cuda"],
value="cuda",
label="Device:",
interactive=True,
info="Device to load the embedding model on.",
)

with gr.Tab("HF Downloader"):
with gr.Row():
download_btn = gr.Button(value="Download", variant="primary")
Expand Down Expand Up @@ -924,6 +1036,9 @@ def cancel_download():
sampler_override,
current_model,
current_loras,
embedding_models_drop,
embedding_model_list,
current_embedding_model,
],
)

Expand Down Expand Up @@ -1049,6 +1164,16 @@ def cancel_download():
)
cancel_download_btn.click(fn=cancel_download)

# Embeddings
load_embedding_btn.click(
fn=load_embedding_model,
inputs=[embedding_models_drop, embeddings_device],
outputs=current_embedding_model,
)
unload_embedding_btn.click(
fn=unload_embedding_model, outputs=current_embedding_model
)

webui.launch(
inbrowser=args.autolaunch,
show_api=False,
Expand Down