From fc75d9f59341581249b79d7daa168a9e080a5e6b Mon Sep 17 00:00:00 2001 From: Michael Engel Date: Tue, 25 Feb 2025 08:14:43 +0100 Subject: [PATCH] Moved pruning protocol from model to factory By moving the pruning of the protocol from the model input to the model_factory and encapsulating it in a dedicated function, unit tests can be written more easily. Signed-off-by: Michael Engel --- ramalama/cli.py | 11 +++---- ramalama/common.py | 9 ++++++ ramalama/huggingface.py | 5 ++- ramalama/model.py | 9 ------ ramalama/model_factory.py | 52 ++++++++++++++++++++++++++----- ramalama/oci.py | 8 ++--- ramalama/ollama.py | 5 ++- ramalama/url.py | 9 +++--- test/unit/test_common.py | 18 +++++++++++ test/unit/test_model_factory.py | 55 ++++++++++++++++++++++++++++++--- 10 files changed, 137 insertions(+), 44 deletions(-) create mode 100644 test/unit/test_common.py diff --git a/ramalama/cli.py b/ramalama/cli.py index d6ffe8ec..28833304 100644 --- a/ramalama/cli.py +++ b/ramalama/cli.py @@ -13,7 +13,6 @@ from ramalama.gpu_detector import GPUDetector from ramalama.model import MODEL_TYPES from ramalama.model_factory import ModelFactory -from ramalama.oci import OCI from ramalama.shortnames import Shortnames from ramalama.toml_parser import TOMLParser from ramalama.version import print_version, version @@ -699,7 +698,7 @@ def convert_cli(args): if not tgt: tgt = target - model = OCI(tgt, args.engine) + model = ModelFactory(tgt, engine=args.engine).create_oci() model.convert(source, args) @@ -775,7 +774,7 @@ def push_cli(args): raise e try: # attempt to push as a container image - m = OCI(tgt, config.get('engine', container_manager())) + m = ModelFactory(tgt, engine=config.get('engine', container_manager())).create_oci() m.push(source, args) except Exception: raise e @@ -845,7 +844,7 @@ def run_cli(args): except KeyError as e: try: args.quiet = True - model = OCI(args.MODEL, args.engine, ignore_stderr=True) + model = ModelFactory(args.MODEL, engine=args.engine, ignore_stderr=True).create_oci() model.run(args) except Exception: raise e @@ -879,7 +878,7 @@ def serve_cli(args): except KeyError as e: try: args.quiet = True - model = OCI(args.MODEL, args.engine, ignore_stderr=True) + model = ModelFactory(args.MODEL, engine=args.engine, ignore_stderr=True).create_oci() model.serve(args) except Exception: raise e @@ -987,7 +986,7 @@ def _rm_model(models, args): raise e try: # attempt to remove as a container image - m = OCI(model, args.engine, ignore_stderr=True) + m = ModelFactory(model, engine=args.engine, ignore_stderr=True).create_oci() m.remove(args) return except Exception: diff --git a/ramalama/common.py b/ramalama/common.py index 6c9ce878..a13c9be4 100644 --- a/ramalama/common.py +++ b/ramalama/common.py @@ -311,3 +311,12 @@ def get_env_vars(): # env_vars[gpu_type] = str(gpu_num) return env_vars + + +def rm_until_substring(input, substring): + pos = input.find(substring) + if pos == -1: + return input + + # Create a new string starting after the found substring + return ''.join(input[i] for i in range(pos + len(substring), len(input))) diff --git a/ramalama/huggingface.py b/ramalama/huggingface.py index 4bc74396..9fa34388 100644 --- a/ramalama/huggingface.py +++ b/ramalama/huggingface.py @@ -3,7 +3,7 @@ import urllib.request from ramalama.common import available, download_file, exec_cmd, perror, run_cmd, verify_checksum -from ramalama.model import Model, rm_until_substring +from ramalama.model import Model missing_huggingface = """ Optional: Huggingface models require the huggingface-cli module. @@ -34,9 +34,8 @@ def fetch_checksum_from_api(url): class Huggingface(Model): def __init__(self, model): - model = rm_until_substring(model, "hf.co/") - model = rm_until_substring(model, "://") super().__init__(model) + self.type = "huggingface" self.hf_cli_available = is_huggingface_cli_available() diff --git a/ramalama/model.py b/ramalama/model.py index 205b0a20..3567de6d 100644 --- a/ramalama/model.py +++ b/ramalama/model.py @@ -584,12 +584,3 @@ def distinfo_volume(): return "" return f"-v{path}:/usr/share/ramalama/{dist_info}:ro" - - -def rm_until_substring(model, substring): - pos = model.find(substring) - if pos == -1: - return model - - # Create a new string starting after the found substring - return ''.join(model[i] for i in range(pos + len(substring), len(model))) diff --git a/ramalama/model_factory.py b/ramalama/model_factory.py index c4c95363..7ba650a7 100644 --- a/ramalama/model_factory.py +++ b/ramalama/model_factory.py @@ -1,6 +1,9 @@ from typing import Union +from urllib.parse import urlparse +from ramalama.common import rm_until_substring from ramalama.huggingface import Huggingface +from ramalama.model import MODEL_TYPES from ramalama.oci import OCI from ramalama.ollama import Ollama from ramalama.url import URL @@ -8,26 +11,59 @@ class ModelFactory: - def __init__(self, model: str, transport: str = "ollama", engine: str = "podman"): + def __init__(self, model: str, transport: str = "ollama", engine: str = "podman", ignore_stderr: bool = False): self.model = model self.transport = transport self.engine = engine + self.ignore_stderr = ignore_stderr + + def prune_model_input(self, cls: type[Union[Huggingface, Ollama, OCI, URL]]) -> str: + # remove protocol from model input + pruned_model_input = rm_until_substring(self.model, "://") + + if cls == Huggingface: + pruned_model_input = rm_until_substring(pruned_model_input, "hf.co/") + elif cls == Ollama: + pruned_model_input = rm_until_substring(pruned_model_input, "ollama.com/library/") + + return pruned_model_input + + def validate_oci_model_input(self): + if self.model.startswith("oci://") or self.model.startswith("docker://"): + return + + for t in MODEL_TYPES: + if self.model.startswith(t + "://"): + raise ValueError(f"{self.model} invalid: Only OCI Model types supported") + + def create_huggingface(self) -> Huggingface: + return Huggingface(self.prune_model_input(Huggingface)) + + def create_ollama(self) -> Ollama: + return Ollama(self.prune_model_input(Ollama)) + + def create_oci(self) -> OCI: + self.validate_oci_model_input() + return OCI(self.prune_model_input(OCI), self.engine, self.ignore_stderr) + + def create_url(self) -> URL: + return URL(self.prune_model_input(URL), urlparse(self.model).scheme) def create(self) -> Union[Huggingface, Ollama, OCI, URL]: if self.model.startswith("huggingface://") or self.model.startswith("hf://") or self.model.startswith("hf.co/"): - return Huggingface(self.model) + return self.create_huggingface() if self.model.startswith("ollama://") or "ollama.com/library/" in self.model: - return Ollama(self.model) + return self.create_ollama() if self.model.startswith("oci://") or self.model.startswith("docker://"): - return OCI(self.model, self.engine) + return self.create_oci() if self.model.startswith("http://") or self.model.startswith("https://") or self.model.startswith("file://"): - return URL(self.model) + return self.create_url() if self.transport == "huggingface": - return Huggingface(self.model) + return self.create_huggingface() if self.transport == "ollama": - return Ollama(self.model) + return self.create_ollama() if self.transport == "oci": - return OCI(self.model, self.engine) + return self.create_oci() raise KeyError(f'transport "{self.transport}" not supported. Must be oci, huggingface, or ollama.') diff --git a/ramalama/oci.py b/ramalama/oci.py index 136f06e8..920ef5b1 100644 --- a/ramalama/oci.py +++ b/ramalama/oci.py @@ -6,7 +6,7 @@ import ramalama.annotations as annotations from ramalama.common import MNT_FILE, engine_version, exec_cmd, perror, run_cmd -from ramalama.model import MODEL_TYPES, Model +from ramalama.model import Model prefix = "oci://" @@ -123,10 +123,8 @@ def list_models(args): class OCI(Model): def __init__(self, model, conman, ignore_stderr=False): - super().__init__(model.removeprefix(prefix).removeprefix("docker://")) - for t in MODEL_TYPES: - if self.model.startswith(t + "://"): - raise ValueError(f"{model} invalid: Only OCI Model types supported") + super().__init__(model) + self.type = "OCI" self.conman = conman self.ignore_stderr = ignore_stderr diff --git a/ramalama/ollama.py b/ramalama/ollama.py index 25103381..2dc2604e 100644 --- a/ramalama/ollama.py +++ b/ramalama/ollama.py @@ -3,7 +3,7 @@ import urllib.request from ramalama.common import available, download_file, run_cmd, verify_checksum -from ramalama.model import Model, rm_until_substring +from ramalama.model import Model def fetch_manifest_data(registry_head, model_tag, accept): @@ -87,9 +87,8 @@ def in_existing_cache(model_name, model_tag): class Ollama(Model): def __init__(self, model): - model = rm_until_substring(model, "ollama.com/library/") - model = rm_until_substring(model, "://") super().__init__(model) + self.type = "Ollama" def _local(self, args): diff --git a/ramalama/url.py b/ramalama/url.py index 036c9065..af0eb369 100644 --- a/ramalama/url.py +++ b/ramalama/url.py @@ -1,15 +1,14 @@ import os -from urllib.parse import urlparse from ramalama.common import download_file -from ramalama.model import Model, rm_until_substring +from ramalama.model import Model class URL(Model): - def __init__(self, model): - self.type = urlparse(model).scheme - model = rm_until_substring(model, "://") + def __init__(self, model, scheme): super().__init__(model) + + self.type = scheme split = self.model.rsplit("/", 1) self.directory = split[0].removeprefix("/") if len(split) > 1 else "" diff --git a/test/unit/test_common.py b/test/unit/test_common.py new file mode 100644 index 00000000..932ceaca --- /dev/null +++ b/test/unit/test_common.py @@ -0,0 +1,18 @@ +import pytest + +from ramalama.common import rm_until_substring + +@pytest.mark.parametrize( + "input,rm_until,expected", + [ + ("", "", ""), + ("huggingface://granite-code", "://", "granite-code"), + ("hf://granite-code", "://", "granite-code"), + ("hf.co/granite-code", "hf.co/", "granite-code"), + ("http://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", ".co/", "ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf"), + ("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", "file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf"), + ] +) +def test_rm_until_substring(input: str, rm_until: str, expected: str): + actual = rm_until_substring(input, rm_until) + assert actual == expected diff --git a/test/unit/test_model_factory.py b/test/unit/test_model_factory.py index e3fbbb08..cc447cfb 100644 --- a/test/unit/test_model_factory.py +++ b/test/unit/test_model_factory.py @@ -6,6 +6,7 @@ from ramalama.url import URL from dataclasses import dataclass +from typing import Union @dataclass class Input(): @@ -27,19 +28,63 @@ class Input(): (Input("http://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None), (Input("https://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None), (Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None), - (Input("granite-code", "huggingface", ""), Huggingface, None), (Input("granite-code", "ollama", ""), Ollama, None), (Input("granite-code", "oci", ""), OCI, None), - - - ] ) -def test_model_factory_create(input: Input, expected, error): +def test_model_factory_create(input: Input, expected: type[Union[Huggingface, Ollama, URL, OCI]], error): if error is not None: with pytest.raises(error): ModelFactory(input.Model, input.Transport, input.Engine).create() else: model = ModelFactory(input.Model, input.Transport, input.Engine).create() assert isinstance(model, expected) + +@pytest.mark.parametrize( + "input,error", + [ + (Input("", "", ""), None), + (Input("oci://granite-code", "", "podman"), None), + (Input("docker://granite-code", "", "podman"), None), + (Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), ValueError), + (Input("huggingface://granite-code", "", ""), ValueError), + (Input("hf://granite-code", "", ""), ValueError), + (Input("hf.co/granite-code", "", ""), None), + (Input("ollama://granite-code", "", ""), ValueError), + (Input("ollama.com/library/granite-code", "", ""), None), + (Input("granite-code", "", ""), None), + ] +) +def test_validate_oci_model_input(input: Input, error): + if error is not None: + with pytest.raises(error): + ModelFactory(input.Model, input.Transport, input.Engine).validate_oci_model_input() + return + + ModelFactory(input.Model, input.Transport, input.Engine).validate_oci_model_input() + + +@pytest.mark.parametrize( + "input,cls,expected", + [ + (Input("huggingface://granite-code", "", ""), Huggingface, "granite-code"), + (Input("huggingface://ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""), Huggingface, "ibm-granite/granite-3b-code-base-2k-GGUF/granite-code"), + (Input("hf://granite-code", "", ""), Huggingface, "granite-code"), + (Input("hf.co/granite-code", "", ""), Huggingface, "granite-code"), + (Input("ollama://granite-code", "", ""), Ollama, "granite-code"), + (Input("ollama.com/library/granite-code", "", ""), Ollama, "granite-code"), + (Input("ollama.com/library/ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""), Ollama, "ibm-granite/granite-3b-code-base-2k-GGUF/granite-code"), + (Input("oci://granite-code", "", "podman"), OCI, "granite-code"), + (Input("docker://granite-code", "", "podman"), OCI, "granite-code"), + (Input("http://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, "huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf"), + (Input("https://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, "huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf"), + (Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, "/tmp/models/granite-3b-code-base.Q4_K_M.gguf"), + (Input("granite-code", "huggingface", ""), Huggingface, "granite-code"), + (Input("granite-code", "ollama", ""), Ollama, "granite-code"), + (Input("granite-code", "oci", ""), OCI, "granite-code"), + ] +) +def test_prune_model_input(input: Input, cls: type[Union[Huggingface, Ollama, URL, OCI]], expected: str): + pruned_model_input = ModelFactory(input.Model, input.Transport, input.Engine).prune_model_input(cls) + assert pruned_model_input == expected