Skip to content

Commit

Permalink
Moved pruning protocol from model to factory
Browse files Browse the repository at this point in the history
By moving the pruning of the protocol from the model input to
the model_factory and encapsulating it in a dedicated function,
unit tests can be written more easily.

Signed-off-by: Michael Engel <mengel@redhat.com>
  • Loading branch information
engelmi committed Feb 25, 2025
1 parent 00839ee commit e5d7c3b
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 35 deletions.
9 changes: 9 additions & 0 deletions ramalama/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,12 @@ def get_env_vars():
# env_vars[gpu_type] = str(gpu_num)

return env_vars


def rm_until_substring(input, substring):
pos = input.find(substring)
if pos == -1:
return input

# Create a new string starting after the found substring
return ''.join(input[i] for i in range(pos + len(substring), len(input)))
5 changes: 2 additions & 3 deletions ramalama/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import urllib.request

from ramalama.common import available, download_file, exec_cmd, perror, run_cmd, verify_checksum
from ramalama.model import Model, rm_until_substring
from ramalama.model import Model

missing_huggingface = """
Optional: Huggingface models require the huggingface-cli module.
Expand Down Expand Up @@ -34,9 +34,8 @@ def fetch_checksum_from_api(url):

class Huggingface(Model):
def __init__(self, model):
model = rm_until_substring(model, "hf.co/")
model = rm_until_substring(model, "://")
super().__init__(model)

self.type = "huggingface"
self.hf_cli_available = is_huggingface_cli_available()

Expand Down
9 changes: 0 additions & 9 deletions ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,12 +587,3 @@ def distinfo_volume():
return ""

return f"-v{path}:/usr/share/ramalama/{dist_info}:ro"


def rm_until_substring(model, substring):
pos = model.find(substring)
if pos == -1:
return model

# Create a new string starting after the found substring
return ''.join(model[i] for i in range(pos + len(substring), len(model)))
37 changes: 30 additions & 7 deletions ramalama/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Union

from ramalama.common import rm_until_substring
from ramalama.huggingface import Huggingface
from ramalama.model import MODEL_TYPES
from ramalama.oci import OCI
from ramalama.ollama import Ollama
from ramalama.url import URL
Expand All @@ -13,21 +15,42 @@ def __init__(self, model: str, transport: str = "ollama", engine: str = "podman"
self.transport = transport
self.engine = engine

def prune_model_input(self, cls: type[Union[Huggingface, Ollama, OCI, URL]]) -> str:
# remove protocol from model input
pruned_model_input = rm_until_substring(self.model, "://")

if cls == Huggingface:
pruned_model_input = rm_until_substring(pruned_model_input, "hf.co/")
elif cls == Ollama:
pruned_model_input = rm_until_substring(pruned_model_input, "ollama.com/library/")

return pruned_model_input

def validate_oci_model_input(self):
if self.model.startswith("oci://") or self.model.startswith("docker://"):
return

for t in MODEL_TYPES:
if self.model.startswith(t + "://"):
raise ValueError(f"{self.model} invalid: Only OCI Model types supported")

def create(self) -> Union[Huggingface, Ollama, OCI, URL]:
if self.model.startswith("huggingface://") or self.model.startswith("hf://") or self.model.startswith("hf.co/"):
return Huggingface(self.model)
return Huggingface(self.prune_model_input(Huggingface))
if self.model.startswith("ollama://") or "ollama.com/library/" in self.model:
return Ollama(self.model)
return Ollama(self.prune_model_input(Ollama))
if self.model.startswith("oci://") or self.model.startswith("docker://"):
return OCI(self.model, self.engine)
self.validate_oci_model_input()
return OCI(self.prune_model_input(OCI), self.engine)
if self.model.startswith("http://") or self.model.startswith("https://") or self.model.startswith("file://"):
return URL(self.model)
return URL(self.prune_model_input(URL))

if self.transport == "huggingface":
return Huggingface(self.model)
return Huggingface(self.prune_model_input(Huggingface))
if self.transport == "ollama":
return Ollama(self.model)
return Ollama(self.prune_model_input(Ollama))
if self.transport == "oci":
return OCI(self.model, self.engine)
self.validate_oci_model_input()
return OCI(self.prune_model_input(OCI), self.engine)

raise KeyError(f'transport "{self.transport}" not supported. Must be oci, huggingface, or ollama.')
8 changes: 3 additions & 5 deletions ramalama/oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import ramalama.annotations as annotations
from ramalama.common import MNT_FILE, engine_version, exec_cmd, perror, run_cmd
from ramalama.model import MODEL_TYPES, Model
from ramalama.model import Model

prefix = "oci://"

Expand Down Expand Up @@ -123,10 +123,8 @@ def list_models(args):

class OCI(Model):
def __init__(self, model, conman, ignore_stderr=False):
super().__init__(model.removeprefix(prefix).removeprefix("docker://"))
for t in MODEL_TYPES:
if self.model.startswith(t + "://"):
raise ValueError(f"{model} invalid: Only OCI Model types supported")
super().__init__(model)

self.type = "OCI"
self.conman = conman
self.ignore_stderr = ignore_stderr
Expand Down
5 changes: 2 additions & 3 deletions ramalama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import urllib.request

from ramalama.common import available, download_file, run_cmd, verify_checksum
from ramalama.model import Model, rm_until_substring
from ramalama.model import Model


def fetch_manifest_data(registry_head, model_tag, accept):
Expand Down Expand Up @@ -87,9 +87,8 @@ def in_existing_cache(model_name, model_tag):

class Ollama(Model):
def __init__(self, model):
model = rm_until_substring(model, "ollama.com/library/")
model = rm_until_substring(model, "://")
super().__init__(model)

self.type = "Ollama"

def _local(self, args):
Expand Down
6 changes: 3 additions & 3 deletions ramalama/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from urllib.parse import urlparse

from ramalama.common import download_file
from ramalama.model import Model, rm_until_substring
from ramalama.model import Model


class URL(Model):
def __init__(self, model):
self.type = urlparse(model).scheme
model = rm_until_substring(model, "://")
super().__init__(model)

self.type = urlparse(model).scheme
split = self.model.rsplit("/", 1)
self.directory = split[0].removeprefix("/") if len(split) > 1 else ""

Expand Down
18 changes: 18 additions & 0 deletions test/unit/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from ramalama.common import rm_until_substring

@pytest.mark.parametrize(
"input,rm_until,expected",
[
("", "", ""),
("huggingface://granite-code", "://", "granite-code"),
("hf://granite-code", "://", "granite-code"),
("hf.co/granite-code", "hf.co/", "granite-code"),
("http://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", ".co/", "ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf"),
("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", "file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf"),
]
)
def test_rm_until_substring(input: str, rm_until: str, expected: str):
actual = rm_until_substring(input, rm_until)
assert actual == expected
55 changes: 50 additions & 5 deletions test/unit/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ramalama.url import URL

from dataclasses import dataclass
from typing import Union

@dataclass
class Input():
Expand All @@ -27,19 +28,63 @@ class Input():
(Input("http://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None),
(Input("https://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None),
(Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None),
(Input("granite-code", "huggingface", ""), Huggingface, None),
(Input("granite-code", "ollama", ""), Ollama, None),
(Input("granite-code", "oci", ""), OCI, None),
]
)
def test_model_factory_create(input: Input, expected, error):
def test_model_factory_create(input: Input, expected: type[Union[Huggingface, Ollama, URL, OCI]], error):
if error is not None:
with pytest.raises(error):
ModelFactory(input.Model, input.Transport, input.Engine).create()
else:
model = ModelFactory(input.Model, input.Transport, input.Engine).create()
assert isinstance(model, expected)

@pytest.mark.parametrize(
"input,error",
[
(Input("", "", ""), None),
(Input("oci://granite-code", "", "podman"), None),
(Input("docker://granite-code", "", "podman"), None),
(Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), ValueError),
(Input("huggingface://granite-code", "", ""), ValueError),
(Input("hf://granite-code", "", ""), ValueError),
(Input("hf.co/granite-code", "", ""), None),
(Input("ollama://granite-code", "", ""), ValueError),
(Input("ollama.com/library/granite-code", "", ""), None),
(Input("granite-code", "", ""), None),
]
)
def test_validate_oci_model_input(input: Input, error):
if error is not None:
with pytest.raises(error):
ModelFactory(input.Model, input.Transport, input.Engine).validate_oci_model_input()
return

ModelFactory(input.Model, input.Transport, input.Engine).validate_oci_model_input()


@pytest.mark.parametrize(
"input,cls,expected",
[
(Input("huggingface://granite-code", "", ""), Huggingface, "granite-code"),
(Input("huggingface://ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""), Huggingface, "ibm-granite/granite-3b-code-base-2k-GGUF/granite-code"),
(Input("hf://granite-code", "", ""), Huggingface, "granite-code"),
(Input("hf.co/granite-code", "", ""), Huggingface, "granite-code"),
(Input("ollama://granite-code", "", ""), Ollama, "granite-code"),
(Input("ollama.com/library/granite-code", "", ""), Ollama, "granite-code"),
(Input("ollama.com/library/ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""), Ollama, "ibm-granite/granite-3b-code-base-2k-GGUF/granite-code"),
(Input("oci://granite-code", "", "podman"), OCI, "granite-code"),
(Input("docker://granite-code", "", "podman"), OCI, "granite-code"),
(Input("http://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, "huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf"),
(Input("https://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, "huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf"),
(Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, "/tmp/models/granite-3b-code-base.Q4_K_M.gguf"),
(Input("granite-code", "huggingface", ""), Huggingface, "granite-code"),
(Input("granite-code", "ollama", ""), Ollama, "granite-code"),
(Input("granite-code", "oci", ""), OCI, "granite-code"),
]
)
def test_prune_model_input(input: Input, cls: type[Union[Huggingface, Ollama, URL, OCI]], expected: str):
pruned_model_input = ModelFactory(input.Model, input.Transport, input.Engine).prune_model_input(cls)
assert pruned_model_input == expected

0 comments on commit e5d7c3b

Please sign in to comment.