From 864c2ad10a54681aace9840209f053196cb85aac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20de=20la=20R=C3=BAa=20Mart=C3=ADnez?= Date: Tue, 10 Sep 2024 08:37:38 +0200 Subject: [PATCH 1/3] [HWORKS-1221] Bump protobuf==^4.25.4 --- python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 6cd64077e..07e75ef8b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "opensearch-py>=1.1.0,<=2.4.2", "tqdm", "grpcio>=1.49.1,<2.0.0", # ^1.49.1 - "protobuf>=3.19.0,<4.0.0", # ^3.19.0 + "protobuf>=4.25.4,<5.0.0", # ^4.25.4 ] [project.optional-dependencies] From 3a87c9d13ad9115e3080df0f3ecd2f8f5f27c36a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20de=20la=20R=C3=BAa=20Mart=C3=ADnez?= Date: Fri, 13 Sep 2024 17:29:42 +0200 Subject: [PATCH 2/3] [HWORKS-1224] Add llm signature and openai endpoint --- python/hopsworks_common/constants.py | 2 + python/hsml/core/serving_api.py | 3 ++ python/hsml/engine/serving_engine.py | 5 +- python/hsml/llm/__init__.py | 15 ++++++ python/hsml/llm/model.py | 75 +++++++++++++++++++++++++++ python/hsml/llm/predictor.py | 28 ++++++++++ python/hsml/llm/signature.py | 77 ++++++++++++++++++++++++++++ python/hsml/model_registry.py | 9 ++++ python/hsml/util.py | 7 +++ 9 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 python/hsml/llm/__init__.py create mode 100644 python/hsml/llm/model.py create mode 100644 python/hsml/llm/predictor.py create mode 100644 python/hsml/llm/signature.py diff --git a/python/hopsworks_common/constants.py b/python/hopsworks_common/constants.py index 72672dae8..b9f90f22b 100644 --- a/python/hopsworks_common/constants.py +++ b/python/hopsworks_common/constants.py @@ -158,6 +158,7 @@ class MODEL: FRAMEWORK_TORCH = "TORCH" FRAMEWORK_PYTHON = "PYTHON" FRAMEWORK_SKLEARN = "SKLEARN" + FRAMEWORK_LLM = "LLM" class MODEL_REGISTRY: @@ -210,6 +211,7 @@ class PREDICTOR: # model server MODEL_SERVER_PYTHON = "PYTHON" MODEL_SERVER_TF_SERVING = "TENSORFLOW_SERVING" + MODEL_SERVER_VLLM = "VLLM" # serving tool SERVING_TOOL_DEFAULT = "DEFAULT" SERVING_TOOL_KSERVE = "KSERVE" diff --git a/python/hsml/core/serving_api.py b/python/hsml/core/serving_api.py index c17eba65c..1ac975fe7 100644 --- a/python/hsml/core/serving_api.py +++ b/python/hsml/core/serving_api.py @@ -417,4 +417,7 @@ def _get_hopsworks_inference_path(self, project_id: int, deployment_instance): ] def _get_istio_inference_path(self, deployment_instance): + if deployment_instance.model_server == "VLLM": + return ["openai", "v1", "completions"] + return ["v1", "models", deployment_instance.name + ":predict"] diff --git a/python/hsml/engine/serving_engine.py b/python/hsml/engine/serving_engine.py index 15e2b3fa6..12f311d17 100644 --- a/python/hsml/engine/serving_engine.py +++ b/python/hsml/engine/serving_engine.py @@ -493,7 +493,10 @@ def predict( inputs: Union[Dict, List[Dict]], ): # validate user-provided payload - self._validate_inference_payload(deployment_instance.api_protocol, data, inputs) + if deployment_instance.model_server != "VLLM": + self._validate_inference_payload( + deployment_instance.api_protocol, data, inputs + ) # build inference payload based on API protocol payload = self._build_inference_payload( diff --git a/python/hsml/llm/__init__.py b/python/hsml/llm/__init__.py new file mode 100644 index 000000000..ff8055b9b --- /dev/null +++ b/python/hsml/llm/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/hsml/llm/model.py b/python/hsml/llm/model.py new file mode 100644 index 000000000..b52cf6398 --- /dev/null +++ b/python/hsml/llm/model.py @@ -0,0 +1,75 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import humps +from hsml.constants import MODEL +from hsml.model import Model + + +class Model(Model): + """Metadata object representing a LLM model in the Model Registry.""" + + def __init__( + self, + id, + name, + version=None, + created=None, + creator=None, + environment=None, + description=None, + project_name=None, + metrics=None, + program=None, + user_full_name=None, + model_schema=None, + training_dataset=None, + input_example=None, + model_registry_id=None, + tags=None, + href=None, + feature_view=None, + training_dataset_version=None, + **kwargs, + ): + super().__init__( + id, + name, + version=version, + created=created, + creator=creator, + environment=environment, + description=description, + project_name=project_name, + metrics=metrics, + program=program, + user_full_name=user_full_name, + model_schema=model_schema, + training_dataset=training_dataset, + input_example=input_example, + framework=MODEL.FRAMEWORK_LLM, + model_registry_id=model_registry_id, + feature_view=feature_view, + training_dataset_version=training_dataset_version, + ) + + def update_from_response_json(self, json_dict): + json_decamelized = humps.decamelize(json_dict) + json_decamelized.pop("framework") + if "type" in json_decamelized: # backwards compatibility + _ = json_decamelized.pop("type") + self.__init__(**json_decamelized) + return self diff --git a/python/hsml/llm/predictor.py b/python/hsml/llm/predictor.py new file mode 100644 index 000000000..814edc522 --- /dev/null +++ b/python/hsml/llm/predictor.py @@ -0,0 +1,28 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from hsml.constants import MODEL, PREDICTOR +from hsml.predictor import Predictor + + +class Predictor(Predictor): + """Configuration for a predictor running with the vLLM backend""" + + def __init__(self, **kwargs): + kwargs["model_framework"] = MODEL.FRAMEWORK_LLM + kwargs["model_server"] = PREDICTOR.MODEL_SERVER_VLLM + + super().__init__(**kwargs) diff --git a/python/hsml/llm/signature.py b/python/hsml/llm/signature.py new file mode 100644 index 000000000..742022a52 --- /dev/null +++ b/python/hsml/llm/signature.py @@ -0,0 +1,77 @@ +# +# Copyright 2024 Hopsworks AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Union + +import numpy +import pandas +from hopsworks_common import usage +from hsml.model_schema import ModelSchema +from hsml.llm.model import Model + + +_mr = None + + +@usage.method_logger +def create_model( + name: str, + version: Optional[int] = None, + metrics: Optional[dict] = None, + description: Optional[str] = None, + input_example: Optional[ + Union[pandas.DataFrame, pandas.Series, numpy.ndarray, list] + ] = None, + model_schema: Optional[ModelSchema] = None, + feature_view=None, + training_dataset_version: Optional[int] = None, +): + """Create an LLM model metadata object. + + !!! note "Lazy" + This method is lazy and does not persist any metadata or uploads model artifacts in the + model registry on its own. To save the model object and the model artifacts, call the `save()` method with a + local file path to the directory containing the model artifacts. + + # Arguments + name: Name of the model to create. + version: Optionally version of the model to create, defaults to `None` and + will create the model with incremented version from the last + version in the model registry. + metrics: Optionally a dictionary with model evaluation metrics (e.g., accuracy, MAE) + description: Optionally a string describing the model, defaults to empty string + `""`. + input_example: Optionally an input example that represents a single input for the model, defaults to `None`. + model_schema: Optionally a model schema for the model inputs and/or outputs. + + # Returns + `Model`. The model metadata object. + """ + model = Model( + id=None, + name=name, + version=version, + description=description, + metrics=metrics, + input_example=input_example, + model_schema=model_schema, + feature_view=feature_view, + training_dataset_version=training_dataset_version, + ) + model._shared_registry_project_name = _mr.shared_registry_project_name + model._model_registry_id = _mr.model_registry_id + + return model diff --git a/python/hsml/model_registry.py b/python/hsml/model_registry.py index 8968e6d16..761c7c496 100644 --- a/python/hsml/model_registry.py +++ b/python/hsml/model_registry.py @@ -24,6 +24,7 @@ from hsml.sklearn import signature as sklearn_signature # noqa: F401 from hsml.tensorflow import signature as tensorflow_signature # noqa: F401 from hsml.torch import signature as torch_signature # noqa: F401 +from hsml.llm import signature as llm_signature # noqa: F401 class ModelRegistry: @@ -49,11 +50,13 @@ def __init__( self._python = python_signature self._sklearn = sklearn_signature self._torch = torch_signature + self._llm = llm_signature tensorflow_signature._mr = self python_signature._mr = self sklearn_signature._mr = self torch_signature._mr = self + llm_signature._mr = self @classmethod def from_response_json(cls, json_dict): @@ -191,6 +194,12 @@ def python(self): return python_signature + @property + def llm(self): + """Module for exporting a Large Language Model.""" + + return llm_signature + def __repr__(self): project_name = ( self._shared_registry_project_name diff --git a/python/hsml/util.py b/python/hsml/util.py index 3fb243566..130f91b6d 100644 --- a/python/hsml/util.py +++ b/python/hsml/util.py @@ -100,6 +100,7 @@ def set_model_class(model): from hsml.sklearn.model import Model as SkLearnModel from hsml.tensorflow.model import Model as TFModel from hsml.torch.model import Model as TorchModel + from hsml.llm.model import Model as LLMModel if "href" in model: _ = model.pop("href") @@ -120,6 +121,8 @@ def set_model_class(model): return SkLearnModel(**model) elif framework == MODEL.FRAMEWORK_PYTHON: return PyModel(**model) + elif framework == MODEL.FRAMEWORK_LLM: + return LLMModel(**model) else: raise ValueError( "framework {} is not a supported framework".format(str(framework)) @@ -242,6 +245,8 @@ def get_predictor_for_model(model, **kwargs): from hsml.tensorflow.predictor import Predictor as TFPredictor from hsml.torch.model import Model as TorchModel from hsml.torch.predictor import Predictor as TorchPredictor + from hsml.llm.model import Model as LLMModel + from hsml.llm.predictor import Predictor as vLLMPredictor if not isinstance(model, BaseModel): raise ValueError( @@ -258,6 +263,8 @@ def get_predictor_for_model(model, **kwargs): return SkLearnPredictor(**kwargs) if type(model) is PyModel: return PyPredictor(**kwargs) + if type(model) is LLMModel: + return vLLMPredictor(**kwargs) if type(model) is BaseModel: return BasePredictor( # python as default framework and model server model_framework=MODEL.FRAMEWORK_PYTHON, From be95ada0b02082f69c702c99f53eb9127d46d11b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20de=20la=20R=C3=BAa=20Mart=C3=ADnez?= Date: Mon, 7 Oct 2024 17:16:30 +0200 Subject: [PATCH 3/3] [HWORKS-1224][Append] Fix imports and add pytests --- python/hsml/llm/signature.py | 2 +- python/hsml/model_registry.py | 2 +- python/hsml/predictor.py | 18 +++++---- python/hsml/util.py | 6 +-- python/tests/fixtures/model_fixtures.json | 27 +++++++++++++ python/tests/fixtures/model_fixtures.py | 7 ++++ python/tests/test_constants.py | 2 + python/tests/test_model.py | 13 ++++++ python/tests/test_predictor.py | 19 +++++++++ python/tests/test_util.py | 48 +++++++++++++++++++++++ 10 files changed, 132 insertions(+), 12 deletions(-) diff --git a/python/hsml/llm/signature.py b/python/hsml/llm/signature.py index 742022a52..9ac7db9ff 100644 --- a/python/hsml/llm/signature.py +++ b/python/hsml/llm/signature.py @@ -19,8 +19,8 @@ import numpy import pandas from hopsworks_common import usage -from hsml.model_schema import ModelSchema from hsml.llm.model import Model +from hsml.model_schema import ModelSchema _mr = None diff --git a/python/hsml/model_registry.py b/python/hsml/model_registry.py index 761c7c496..70b90b989 100644 --- a/python/hsml/model_registry.py +++ b/python/hsml/model_registry.py @@ -20,11 +20,11 @@ from hopsworks_common import usage from hsml import util from hsml.core import model_api +from hsml.llm import signature as llm_signature # noqa: F401 from hsml.python import signature as python_signature # noqa: F401 from hsml.sklearn import signature as sklearn_signature # noqa: F401 from hsml.tensorflow import signature as tensorflow_signature # noqa: F401 from hsml.torch import signature as torch_signature # noqa: F401 -from hsml.llm import signature as llm_signature # noqa: F401 class ModelRegistry: diff --git a/python/hsml/predictor.py b/python/hsml/predictor.py index 87f00c9aa..f1d458a3f 100644 --- a/python/hsml/predictor.py +++ b/python/hsml/predictor.py @@ -167,18 +167,22 @@ def _validate_serving_tool(cls, serving_tool): @classmethod def _validate_script_file(cls, model_framework, script_file): - if model_framework == MODEL.FRAMEWORK_PYTHON and script_file is None: + if script_file is None and ( + model_framework == MODEL.FRAMEWORK_PYTHON + or model_framework == MODEL.FRAMEWORK_LLM + ): raise ValueError( - "Predictor scripts are required in deployments for custom Python models" + "Predictor scripts are required in deployments for custom Python models and LLMs." ) @classmethod def _infer_model_server(cls, model_framework): - return ( - PREDICTOR.MODEL_SERVER_TF_SERVING - if model_framework == MODEL.FRAMEWORK_TENSORFLOW - else PREDICTOR.MODEL_SERVER_PYTHON - ) + if model_framework == MODEL.FRAMEWORK_TENSORFLOW: + return PREDICTOR.MODEL_SERVER_TF_SERVING + elif model_framework == MODEL.FRAMEWORK_LLM: + return PREDICTOR.MODEL_SERVER_VLLM + else: + return PREDICTOR.MODEL_SERVER_PYTHON @classmethod def _get_default_serving_tool(cls): diff --git a/python/hsml/util.py b/python/hsml/util.py index 130f91b6d..461793ebf 100644 --- a/python/hsml/util.py +++ b/python/hsml/util.py @@ -95,12 +95,12 @@ def default(self, obj): # pylint: disable=E0202 def set_model_class(model): + from hsml.llm.model import Model as LLMModel from hsml.model import Model as BaseModel from hsml.python.model import Model as PyModel from hsml.sklearn.model import Model as SkLearnModel from hsml.tensorflow.model import Model as TFModel from hsml.torch.model import Model as TorchModel - from hsml.llm.model import Model as LLMModel if "href" in model: _ = model.pop("href") @@ -235,6 +235,8 @@ def validate_metrics(metrics): def get_predictor_for_model(model, **kwargs): + from hsml.llm.model import Model as LLMModel + from hsml.llm.predictor import Predictor as vLLMPredictor from hsml.model import Model as BaseModel from hsml.predictor import Predictor as BasePredictor from hsml.python.model import Model as PyModel @@ -245,8 +247,6 @@ def get_predictor_for_model(model, **kwargs): from hsml.tensorflow.predictor import Predictor as TFPredictor from hsml.torch.model import Model as TorchModel from hsml.torch.predictor import Predictor as TorchPredictor - from hsml.llm.model import Model as LLMModel - from hsml.llm.predictor import Predictor as vLLMPredictor if not isinstance(model, BaseModel): raise ValueError( diff --git a/python/tests/fixtures/model_fixtures.json b/python/tests/fixtures/model_fixtures.json index 40c0b8002..cf44c3111 100644 --- a/python/tests/fixtures/model_fixtures.json +++ b/python/tests/fixtures/model_fixtures.json @@ -133,6 +133,33 @@ ] } }, + "get_llm": { + "response": { + "count": 1, + "items": [ + { + "id": "5", + "name": "llmmodel", + "version": 0, + "created": "created", + "creator": "creator", + "environment": "environment.yml", + "description": "description", + "project_name": "myproject", + "metrics": { "acc": 0.7 }, + "program": "program", + "user_full_name": "Full Name", + "model_schema": "model_schema.json", + "training_dataset": "training_dataset", + "input_example": "input_example.json", + "model_registry_id": 1, + "tags": [], + "framework": "LLM", + "href": "test_href" + } + ] + } + }, "get_list": { "response": { "count": 2, diff --git a/python/tests/fixtures/model_fixtures.py b/python/tests/fixtures/model_fixtures.py index 32fe396de..9b3796d05 100644 --- a/python/tests/fixtures/model_fixtures.py +++ b/python/tests/fixtures/model_fixtures.py @@ -17,6 +17,7 @@ import numpy as np import pandas as pd import pytest +from hsml.llm.model import Model as LLMModel from hsml.model import Model as BaseModel from hsml.python.model import Model as PythonModel from hsml.sklearn.model import Model as SklearnModel @@ -29,12 +30,14 @@ MODEL_SKLEARN_ID = 2 MODEL_TENSORFLOW_ID = 3 MODEL_TORCH_ID = 4 +MODEL_LLM_ID = 5 MODEL_BASE_NAME = "basemodel" MODEL_PYTHON_NAME = "pythonmodel" MODEL_SKLEARN_NAME = "sklearnmodel" MODEL_TENSORFLOW_NAME = "tensorflowmodel" MODEL_TORCH_NAME = "torchmodel" +MODEL_LLM_NAME = "llmmodel" # models @@ -63,6 +66,10 @@ def model_tensorflow(): def model_torch(): return TorchModel(MODEL_TORCH_ID, MODEL_TORCH_NAME) +@pytest.fixture +def model_llm(): + return LLMModel(MODEL_LLM_ID, MODEL_LLM_NAME) + # input example diff --git a/python/tests/test_constants.py b/python/tests/test_constants.py index 7a923d8d8..783770b14 100644 --- a/python/tests/test_constants.py +++ b/python/tests/test_constants.py @@ -38,6 +38,7 @@ def test_model_framework_constants(self): "FRAMEWORK_TORCH": "TORCH", "FRAMEWORK_PYTHON": "PYTHON", "FRAMEWORK_SKLEARN": "SKLEARN", + "FRAMEWORK_LLM": "LLM", } # Assert @@ -193,6 +194,7 @@ def test_predictor_model_server_constants(self): model_servers = { "MODEL_SERVER_PYTHON": "PYTHON", "MODEL_SERVER_TF_SERVING": "TENSORFLOW_SERVING", + "MODEL_SERVER_VLLM": "VLLM" } # Assert diff --git a/python/tests/test_model.py b/python/tests/test_model.py index 1f706a845..2442ac7fb 100644 --- a/python/tests/test_model.py +++ b/python/tests/test_model.py @@ -138,6 +138,19 @@ def test_constructor_torch(self, mocker, backend_fixtures): # Assert self.assert_model(mocker, m, json, MODEL.FRAMEWORK_TORCH) + def test_constructor_llm(self, mocker, backend_fixtures): + # Arrange + json = backend_fixtures["model"]["get_llm"]["response"]["items"][0] + m_json = copy.deepcopy(json) + id = m_json.pop("id") + name = m_json.pop("name") + + # Act + m = model.Model(id=id, name=name, **m_json) + + # Assert + self.assert_model(mocker, m, json, MODEL.FRAMEWORK_LLM) + # save def test_save(self, mocker, backend_fixtures): diff --git a/python/tests/test_predictor.py b/python/tests/test_predictor.py index e2e5485fc..166666baf 100644 --- a/python/tests/test_predictor.py +++ b/python/tests/test_predictor.py @@ -340,6 +340,14 @@ def test_validate_script_file_py_none(self): # Assert assert "Predictor scripts are required" in str(e_info.value) + def test_validate_script_file_llm_none(self): + # Act + with pytest.raises(ValueError) as e_info: + _ = predictor.Predictor._validate_script_file(MODEL.FRAMEWORK_LLM, None) + + # Assert + assert "Predictor scripts are required" in str(e_info.value) + def test_validate_script_file_tf_script_file(self): # Act predictor.Predictor._validate_script_file( @@ -360,6 +368,10 @@ def test_validate_script_file_py_script_file(self): # Act predictor.Predictor._validate_script_file(MODEL.FRAMEWORK_PYTHON, "script_file") + def test_validate_script_file_llm_script_file(self): + # Act + predictor.Predictor._validate_script_file(MODEL.FRAMEWORK_LLM, "script_file") + # infer model server def test_infer_model_server_tf(self): @@ -390,6 +402,13 @@ def test_infer_model_server_py(self): # Assert assert ms == PREDICTOR.MODEL_SERVER_PYTHON + def test_infer_model_server_llm(self): + # Act + ms = predictor.Predictor._infer_model_server(MODEL.FRAMEWORK_LLM) + + # Assert + assert ms == PREDICTOR.MODEL_SERVER_VLLM + # default serving tool def test_get_default_serving_tool_kserve_installed(self, mocker): diff --git a/python/tests/test_util.py b/python/tests/test_util.py index 21b411a71..076b2aea7 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -28,6 +28,8 @@ from hsfs.feature import Feature from hsml import util from hsml.constants import MODEL +from hsml.llm.model import Model as LLMModel +from hsml.llm.predictor import Predictor as LLMPredictor from hsml.model import Model as BaseModel from hsml.predictor import Predictor as BasePredictor from hsml.python.model import Model as PythonModel @@ -105,6 +107,17 @@ def test_set_model_class_torch(self, backend_fixtures): assert isinstance(model, TorchModel) assert model.framework == MODEL.FRAMEWORK_TORCH + def test_set_model_class_llm(self, backend_fixtures): + # Arrange + json = backend_fixtures["model"]["get_llm"]["response"]["items"][0] + + # Act + model = util.set_model_class(json) + + # Assert + assert isinstance(model, LLMModel) + assert model.framework == MODEL.FRAMEWORK_LLM + def test_set_model_class_unsupported(self, backend_fixtures): # Arrange json = backend_fixtures["model"]["get_base"]["response"]["items"][0] @@ -361,6 +374,7 @@ def pred_base_spec(model_framework, model_server): pred_sklearn = mocker.patch("hsml.sklearn.predictor.Predictor.__init__") pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__") pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__") + pred_llm = mocker.patch("hsml.llm.predictor.Predictor.__init__") # Act predictor = util.get_predictor_for_model(model_base) @@ -374,6 +388,7 @@ def pred_base_spec(model_framework, model_server): pred_sklearn.assert_not_called() pred_tensorflow.assert_not_called() pred_torch.assert_not_called() + pred_llm.assert_not_called() def test_get_predictor_for_model_python(self, mocker, model_python): # Arrange @@ -384,6 +399,7 @@ def test_get_predictor_for_model_python(self, mocker, model_python): pred_sklearn = mocker.patch("hsml.sklearn.predictor.Predictor.__init__") pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__") pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__") + pred_llm = mocker.patch("hsml.llm.predictor.Predictor.__init__") # Act predictor = util.get_predictor_for_model(model_python) @@ -395,6 +411,7 @@ def test_get_predictor_for_model_python(self, mocker, model_python): pred_sklearn.assert_not_called() pred_tensorflow.assert_not_called() pred_torch.assert_not_called() + pred_llm.assert_not_called() def test_get_predictor_for_model_sklearn(self, mocker, model_sklearn): # Arrange @@ -405,6 +422,7 @@ def test_get_predictor_for_model_sklearn(self, mocker, model_sklearn): ) pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__") pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__") + pred_llm = mocker.patch("hsml.llm.predictor.Predictor.__init__") # Act predictor = util.get_predictor_for_model(model_sklearn) @@ -416,6 +434,7 @@ def test_get_predictor_for_model_sklearn(self, mocker, model_sklearn): pred_sklearn.assert_called_once() pred_tensorflow.assert_not_called() pred_torch.assert_not_called() + pred_llm.assert_not_called() def test_get_predictor_for_model_tensorflow(self, mocker, model_tensorflow): # Arrange @@ -426,6 +445,7 @@ def test_get_predictor_for_model_tensorflow(self, mocker, model_tensorflow): "hsml.tensorflow.predictor.Predictor.__init__", return_value=None ) pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__") + pred_llm = mocker.patch("hsml.llm.predictor.Predictor.__init__") # Act predictor = util.get_predictor_for_model(model_tensorflow) @@ -437,6 +457,7 @@ def test_get_predictor_for_model_tensorflow(self, mocker, model_tensorflow): pred_sklearn.assert_not_called() pred_tensorflow.assert_called_once() pred_torch.assert_not_called() + pred_llm.assert_not_called() def test_get_predictor_for_model_torch(self, mocker, model_torch): # Arrange @@ -447,6 +468,7 @@ def test_get_predictor_for_model_torch(self, mocker, model_torch): pred_torch = mocker.patch( "hsml.torch.predictor.Predictor.__init__", return_value=None ) + pred_llm = mocker.patch("hsml.llm.predictor.Predictor.__init__") # Act predictor = util.get_predictor_for_model(model_torch) @@ -458,6 +480,30 @@ def test_get_predictor_for_model_torch(self, mocker, model_torch): pred_sklearn.assert_not_called() pred_tensorflow.assert_not_called() pred_torch.assert_called_once() + pred_llm.assert_not_called() + + def test_get_predictor_for_model_llm(self, mocker, model_llm): + # Arrange + pred_base = mocker.patch("hsml.predictor.Predictor.__init__") + pred_python = mocker.patch("hsml.python.predictor.Predictor.__init__") + pred_sklearn = mocker.patch("hsml.sklearn.predictor.Predictor.__init__") + pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__") + pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__") + pred_llm = mocker.patch( + "hsml.llm.predictor.Predictor.__init__", return_value=None + ) + + # Act + predictor = util.get_predictor_for_model(model_llm) + + # Assert + assert isinstance(predictor, LLMPredictor) + pred_base.assert_not_called() + pred_python.assert_not_called() + pred_sklearn.assert_not_called() + pred_tensorflow.assert_not_called() + pred_torch.assert_not_called() + pred_llm.assert_called_once() def test_get_predictor_for_model_non_base(self, mocker): # Arrange @@ -466,6 +512,7 @@ def test_get_predictor_for_model_non_base(self, mocker): pred_sklearn = mocker.patch("hsml.sklearn.predictor.Predictor.__init__") pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__") pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__") + pred_llm = mocker.patch("hsml.llm.predictor.Predictor.__init__") class NonBaseModel: pass @@ -482,6 +529,7 @@ class NonBaseModel: pred_sklearn.assert_not_called() pred_tensorflow.assert_not_called() pred_torch.assert_not_called() + pred_llm.assert_not_called() def test_get_hostname_replaced_url(self, mocker): # Arrange