diff --git a/docs/static/openapi/cortex.json b/docs/static/openapi/cortex.json index 23970ef51..78deab912 100644 --- a/docs/static/openapi/cortex.json +++ b/docs/static/openapi/cortex.json @@ -5913,10 +5913,11 @@ "properties": { "model": { "type": "string", - "description": "The identifier or URL of the model to use. It can be a model ID on Cortexso (https://huggingface.co/cortexso) or a HuggingFace URL pointing to the model file. For example: 'gpt2' or 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf'", + "description": "The identifier or URL of the model to use. It can be a model ID on Cortexso (https://huggingface.co/cortexso) or a HuggingFace URL pointing to the model GGUF file. For example: 'tinyllama:1b', 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf', or 'TheBloke:Mistral-7B-Instruct-v0.1-GGUF:mistral-7b-instruct-v0.1.Q2_K.gguf'", "examples": [ "tinyllama:1b", - "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf" + "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf", + "TheBloke:Mistral-7B-Instruct-v0.1-GGUF:mistral-7b-instruct-v0.1.Q2_K.gguf" ] }, "id": { @@ -5926,7 +5927,7 @@ }, "name": { "type": "string", - "description": "The name which will be used to overwrite the model name.", + "description": "The name which will be used to overwrite the model name. This only affects single GGUF file models.", "examples": "my-custom-model-name" } } diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 3215da753..8d300ef25 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -50,36 +50,7 @@ void Models::PullModel(const HttpRequestPtr& req, desired_model_name = name_value; } - auto handle_model_input = - [&, model_handle]() -> cpp::result { - CTL_INF("Handle model input, model handle: " + model_handle); - if (string_utils::StartsWith(model_handle, "https")) { - return model_service_->HandleDownloadUrlAsync( - model_handle, desired_model_id, desired_model_name); - } else if (model_handle.find(":") != std::string::npos) { - auto model_and_branch = string_utils::SplitBy(model_handle, ":"); - if (model_and_branch.size() == 3) { - auto mh = url_parser::Url{ - .protocol = "https", - .host = kHuggingFaceHost, - .pathParams = { - model_and_branch[0], - model_and_branch[1], - "resolve", - "main", - model_and_branch[2], - }}.ToFullPath(); - return model_service_->HandleDownloadUrlAsync(mh, desired_model_id, - desired_model_name); - } - return model_service_->DownloadModelFromCortexsoAsync( - model_and_branch[0], model_and_branch[1], desired_model_id); - } - - return cpp::fail("Invalid model handle or not supported!"); - }; - - auto result = handle_model_input(); + auto result = model_service_->PullModel(model_handle, desired_model_id, desired_model_name); if (result.has_error()) { Json::Value ret; ret["message"] = result.error(); diff --git a/engine/e2e-test/api/model/test_api_model.py b/engine/e2e-test/api/model/test_api_model.py index bacf7e1b0..b65f4f7ac 100644 --- a/engine/e2e-test/api/model/test_api_model.py +++ b/engine/e2e-test/api/model/test_api_model.py @@ -11,7 +11,7 @@ class TestApiModel: @pytest.fixture(autouse=True) def setup_and_teardown(self): - # Setup + # Setup success = start_server() if not success: raise Exception("Failed to start server") @@ -20,42 +20,31 @@ def setup_and_teardown(self): # Teardown stop_server() - - # Pull with direct url + + # Pull with direct url @pytest.mark.asyncio - async def test_model_pull_with_direct_url_should_be_success(self): - run( - "Delete model", - [ - "models", - "delete", - "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", - ], - ) - - myobj = { - "model": "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf" - } + @pytest.mark.parametrize( + "request_model", + [ + "https://huggingface.co/afrideva/zephyr-smol_llama-100m-sft-full-GGUF/blob/main/zephyr-smol_llama-100m-sft-full.q2_k.gguf", + "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", + ] + ) + async def test_model_pull_with_direct_url_should_be_success(self, request_model): + model_id = "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf" + run("Delete model", ["models", "delete", model_id]) + + myobj = {"model": request_model} response = requests.post("http://localhost:3928/v1/models/pull", json=myobj) assert response.status_code == 200 await wait_for_websocket_download_success_event(timeout=None) get_model_response = requests.get( - "http://127.0.0.1:3928/v1/models/afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf" + f"http://127.0.0.1:3928/v1/models/{model_id}" ) assert get_model_response.status_code == 200 - assert ( - get_model_response.json()["model"] - == "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf" - ) - - run( - "Delete model", - [ - "models", - "delete", - "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", - ], - ) + assert get_model_response.json()["model"] == model_id + + run("Delete model", ["models", "delete", model_id]) @pytest.mark.asyncio async def test_model_pull_with_direct_url_should_have_desired_name(self): @@ -75,7 +64,7 @@ async def test_model_pull_with_direct_url_should_have_desired_name(self): get_model_response.json()["name"] == "smol_llama_100m" ) - + run( "Delete model", [ @@ -84,7 +73,7 @@ async def test_model_pull_with_direct_url_should_have_desired_name(self): "afrideva:zephyr-smol_llama-100m-sft-full-GGUF:zephyr-smol_llama-100m-sft-full.q2_k.gguf", ], ) - + @pytest.mark.asyncio async def test_models_start_stop_should_be_successful(self): print("Install engine") @@ -99,12 +88,12 @@ async def test_models_start_stop_should_be_successful(self): response = requests.post("http://localhost:3928/v1/models/pull", json=json_body) assert response.status_code == 200, f"Failed to pull model: tinyllama:1b" await wait_for_websocket_download_success_event(timeout=None) - + # get API print("Get model") response = requests.get("http://localhost:3928/v1/models/tinyllama:1b") assert response.status_code == 200 - + # list API print("List model") response = requests.get("http://localhost:3928/v1/models") @@ -120,7 +109,7 @@ async def test_models_start_stop_should_be_successful(self): print("Stop model") response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) assert response.status_code == 200, f"status_code: {response.status_code}" - + # update API print("Update model") body_json = {'model': 'tinyllama:1b'} @@ -131,14 +120,14 @@ async def test_models_start_stop_should_be_successful(self): print("Delete model") response = requests.delete("http://localhost:3928/v1/models/tinyllama:1b") assert response.status_code == 200 - + def test_models_sources_api(self): json_body = {"source": "https://huggingface.co/cortexso/tinyllama"} response = requests.post( "http://localhost:3928/v1/models/sources", json=json_body ) assert response.status_code == 200, f"status_code: {response.status_code}" - + json_body = {"source": "https://huggingface.co/cortexso/tinyllama"} response = requests.delete( "http://localhost:3928/v1/models/sources", json=json_body diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index b0a692eb5..7ff2ef359 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -790,6 +790,41 @@ cpp::result ModelService::GetModelStatus( } } +cpp::result ModelService::PullModel( + const std::string& model_handle, + const std::optional& desired_model_id, + const std::optional& desired_model_name) { + CTL_INF("Handle model input, model handle: " + model_handle); + + if (string_utils::StartsWith(model_handle, "https")) + return HandleDownloadUrlAsync(model_handle, desired_model_id, + desired_model_name); + + if (model_handle.find(":") == std::string::npos) + return cpp::fail("Invalid model handle or not supported!"); + + auto model_and_branch = string_utils::SplitBy(model_handle, ":"); + + // cortexso format - model:branch + // NOTE: desired_model_name is not used by cortexso downloader + if (model_and_branch.size() == 2) + return DownloadModelFromCortexsoAsync( + model_and_branch[0], model_and_branch[1], desired_model_id); + + // single file GGUF format - author_id:model_id:GGUF_filename + if (model_and_branch.size() == 3) { + url_parser::Url url; + url.protocol = "https"; + url.host = kHuggingFaceHost; + url.pathParams = {model_and_branch[0], model_and_branch[1], "resolve", + "main", model_and_branch[2]}; + return HandleDownloadUrlAsync(url.ToFullPath(), desired_model_id, + desired_model_name); + } + + return cpp::fail("Invalid model handle or not supported!"); +} + cpp::result ModelService::GetModelPullInfo( const std::string& input) { if (input.empty()) { diff --git a/engine/services/model_service.h b/engine/services/model_service.h index beba91f8c..b5ef24be2 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -39,13 +39,14 @@ class ModelService { std::shared_ptr engine_svc, cortex::TaskQueue& task_queue); + cpp::result PullModel( + const std::string& model_handle, + const std::optional& desired_model_id, + const std::optional& desired_model_name); + cpp::result AbortDownloadModel( const std::string& task_id); - cpp::result DownloadModelFromCortexsoAsync( - const std::string& name, const std::string& branch = "main", - std::optional temp_model_id = std::nullopt); - std::optional GetDownloadedModel( const std::string& modelId) const; @@ -67,10 +68,6 @@ class ModelService { cpp::result GetModelPullInfo( const std::string& model_handle); - cpp::result HandleDownloadUrlAsync( - const std::string& url, std::optional temp_model_id, - std::optional temp_name); - bool HasModel(const std::string& id) const; std::optional GetEstimation( @@ -89,6 +86,14 @@ class ModelService { std::string GetEngineByModelId(const std::string& model_id) const; private: + cpp::result DownloadModelFromCortexsoAsync( + const std::string& name, const std::string& branch = "main", + std::optional temp_model_id = std::nullopt); + + cpp::result HandleDownloadUrlAsync( + const std::string& url, std::optional temp_model_id, + std::optional temp_name); + cpp::result, std::string> MayFallbackToCpu( const std::string& model_path, int ngl, int ctx_len, int n_batch = 2048, int n_ubatch = 2048, const std::string& kv_cache_type = "f16");