Skip to content

chore: small refactor for /models/pull #2144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/static/openapi/cortex.json
Original file line number Diff line number Diff line change
Expand Up @@ -5913,10 +5913,11 @@
"properties": {
"model": {
"type": "string",
"description": "The identifier or URL of the model to use. It can be a model ID on Cortexso (https://huggingface.co/cortexso) or a HuggingFace URL pointing to the model file. For example: 'gpt2' or 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf'",
"description": "The identifier or URL of the model to use. It can be a model ID on Cortexso (https://huggingface.co/cortexso) or a HuggingFace URL pointing to the model GGUF file. For example: 'tinyllama:1b', 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf', or 'TheBloke:Mistral-7B-Instruct-v0.1-GGUF:mistral-7b-instruct-v0.1.Q2_K.gguf'",
"examples": [
"tinyllama:1b",
"https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf"
"https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf",
"TheBloke:Mistral-7B-Instruct-v0.1-GGUF:mistral-7b-instruct-v0.1.Q2_K.gguf"
]
},
"id": {
Expand All @@ -5926,7 +5927,7 @@
},
"name": {
"type": "string",
"description": "The name which will be used to overwrite the model name.",
"description": "The name which will be used to overwrite the model name. This only affects single GGUF file models.",
"examples": "my-custom-model-name"
}
}
Expand Down
31 changes: 1 addition & 30 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,7 @@ void Models::PullModel(const HttpRequestPtr& req,
desired_model_name = name_value;
}

auto handle_model_input =
[&, model_handle]() -> cpp::result<DownloadTask, std::string> {
CTL_INF("Handle model input, model handle: " + model_handle);
if (string_utils::StartsWith(model_handle, "https")) {
return model_service_->HandleDownloadUrlAsync(
model_handle, desired_model_id, desired_model_name);
} else if (model_handle.find(":") != std::string::npos) {
auto model_and_branch = string_utils::SplitBy(model_handle, ":");
if (model_and_branch.size() == 3) {
auto mh = url_parser::Url{
.protocol = "https",
.host = kHuggingFaceHost,
.pathParams = {
model_and_branch[0],
model_and_branch[1],
"resolve",
"main",
model_and_branch[2],
}}.ToFullPath();
return model_service_->HandleDownloadUrlAsync(mh, desired_model_id,
desired_model_name);
}
return model_service_->DownloadModelFromCortexsoAsync(
model_and_branch[0], model_and_branch[1], desired_model_id);
}

return cpp::fail("Invalid model handle or not supported!");
};

auto result = handle_model_input();
auto result = model_service_->PullModel(model_handle, desired_model_id, desired_model_name);
if (result.has_error()) {
Json::Value ret;
ret["message"] = result.error();
Expand Down
63 changes: 26 additions & 37 deletions engine/e2e-test/api/model/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class TestApiModel:
@pytest.fixture(autouse=True)
def setup_and_teardown(self):
# Setup
# Setup
success = start_server()
if not success:
raise Exception("Failed to start server")
Expand All @@ -20,42 +20,31 @@ def setup_and_teardown(self):

# Teardown
stop_server()
# Pull with direct url

# Pull with direct url
@pytest.mark.asyncio
async def test_model_pull_with_direct_url_should_be_success(self):
run(
"Delete model",
[
"models",
"delete",
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
],
)

myobj = {
"model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf"
}
@pytest.mark.parametrize(
"request_model",
[
"https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf",
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
]
)
async def test_model_pull_with_direct_url_should_be_success(self, request_model):
model_id = "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
run("Delete model", ["models", "delete", model_id])

myobj = {"model": request_model}
response = requests.post("http://localhost:3928/v1/models/pull", json=myobj)
assert response.status_code == 200
await wait_for_websocket_download_success_event(timeout=None)
get_model_response = requests.get(
"http://127.0.0.1:3928/v1/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
f"http://127.0.0.1:3928/v1/models/{model_id}"
)
assert get_model_response.status_code == 200
assert (
get_model_response.json()["model"]
== "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf"
)

run(
"Delete model",
[
"models",
"delete",
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
],
)
assert get_model_response.json()["model"] == model_id

run("Delete model", ["models", "delete", model_id])

@pytest.mark.asyncio
async def test_model_pull_with_direct_url_should_have_desired_name(self):
Expand All @@ -75,7 +64,7 @@ async def test_model_pull_with_direct_url_should_have_desired_name(self):
get_model_response.json()["name"]
== "smol_llama_100m"
)

run(
"Delete model",
[
Expand All @@ -84,7 +73,7 @@ async def test_model_pull_with_direct_url_should_have_desired_name(self):
"afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf",
],
)

@pytest.mark.asyncio
async def test_models_start_stop_should_be_successful(self):
print("Install engine")
Expand All @@ -99,12 +88,12 @@ async def test_models_start_stop_should_be_successful(self):
response = requests.post("http://localhost:3928/v1/models/pull", json=json_body)
assert response.status_code == 200, f"Failed to pull model: tinyllama:1b"
await wait_for_websocket_download_success_event(timeout=None)

# get API
print("Get model")
response = requests.get("http://localhost:3928/v1/models/tinyllama:1b")
assert response.status_code == 200

# list API
print("List model")
response = requests.get("http://localhost:3928/v1/models")
Expand All @@ -120,7 +109,7 @@ async def test_models_start_stop_should_be_successful(self):
print("Stop model")
response = requests.post("http://localhost:3928/v1/models/stop", json=json_body)
assert response.status_code == 200, f"status_code: {response.status_code}"

# update API
print("Update model")
body_json = {'model': 'tinyllama:1b'}
Expand All @@ -131,14 +120,14 @@ async def test_models_start_stop_should_be_successful(self):
print("Delete model")
response = requests.delete("http://localhost:3928/v1/models/tinyllama:1b")
assert response.status_code == 200

def test_models_sources_api(self):
json_body = {"source": "https://huggingface.co/cortexso/tinyllama"}
response = requests.post(
"http://localhost:3928/v1/models/sources", json=json_body
)
assert response.status_code == 200, f"status_code: {response.status_code}"

json_body = {"source": "https://huggingface.co/cortexso/tinyllama"}
response = requests.delete(
"http://localhost:3928/v1/models/sources", json=json_body
Expand Down
35 changes: 35 additions & 0 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,41 @@ cpp::result<bool, std::string> ModelService::GetModelStatus(
}
}

cpp::result<DownloadTask, std::string> ModelService::PullModel(
const std::string& model_handle,
const std::optional<std::string>& desired_model_id,
const std::optional<std::string>& desired_model_name) {
CTL_INF("Handle model input, model handle: " + model_handle);

if (string_utils::StartsWith(model_handle, "https"))
return HandleDownloadUrlAsync(model_handle, desired_model_id,
desired_model_name);

if (model_handle.find(":") == std::string::npos)
return cpp::fail("Invalid model handle or not supported!");

auto model_and_branch = string_utils::SplitBy(model_handle, ":");

// cortexso format - model:branch
// NOTE: desired_model_name is not used by cortexso downloader
if (model_and_branch.size() == 2)
return DownloadModelFromCortexsoAsync(
model_and_branch[0], model_and_branch[1], desired_model_id);

// single file GGUF format - author_id:model_id:GGUF_filename
if (model_and_branch.size() == 3) {
url_parser::Url url;
url.protocol = "https";
url.host = kHuggingFaceHost;
url.pathParams = {model_and_branch[0], model_and_branch[1], "resolve",
"main", model_and_branch[2]};
return HandleDownloadUrlAsync(url.ToFullPath(), desired_model_id,
desired_model_name);
}

return cpp::fail("Invalid model handle or not supported!");
}

cpp::result<ModelPullInfo, std::string> ModelService::GetModelPullInfo(
const std::string& input) {
if (input.empty()) {
Expand Down
21 changes: 13 additions & 8 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ class ModelService {
std::shared_ptr<EngineServiceI> engine_svc,
cortex::TaskQueue& task_queue);

cpp::result<DownloadTask, std::string> PullModel(
const std::string& model_handle,
const std::optional<std::string>& desired_model_id,
const std::optional<std::string>& desired_model_name);

cpp::result<std::string, std::string> AbortDownloadModel(
const std::string& task_id);

cpp::result<DownloadTask, std::string> DownloadModelFromCortexsoAsync(
const std::string& name, const std::string& branch = "main",
std::optional<std::string> temp_model_id = std::nullopt);

std::optional<config::ModelConfig> GetDownloadedModel(
const std::string& modelId) const;

Expand All @@ -67,10 +68,6 @@ class ModelService {
cpp::result<ModelPullInfo, std::string> GetModelPullInfo(
const std::string& model_handle);

cpp::result<DownloadTask, std::string> HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name);

bool HasModel(const std::string& id) const;

std::optional<hardware::Estimation> GetEstimation(
Expand All @@ -89,6 +86,14 @@ class ModelService {
std::string GetEngineByModelId(const std::string& model_id) const;

private:
cpp::result<DownloadTask, std::string> DownloadModelFromCortexsoAsync(
const std::string& name, const std::string& branch = "main",
std::optional<std::string> temp_model_id = std::nullopt);

cpp::result<DownloadTask, std::string> HandleDownloadUrlAsync(
const std::string& url, std::optional<std::string> temp_model_id,
std::optional<std::string> temp_name);

cpp::result<std::optional<std::string>, std::string> MayFallbackToCpu(
const std::string& model_path, int ngl, int ctx_len, int n_batch = 2048,
int n_ubatch = 2048, const std::string& kv_cache_type = "f16");
Expand Down
Loading