Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved pruning protocol from model to factory #882

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions ramalama/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ramalama.gpu_detector import GPUDetector
from ramalama.model import MODEL_TYPES
from ramalama.model_factory import ModelFactory
from ramalama.oci import OCI
from ramalama.shortnames import Shortnames
from ramalama.toml_parser import TOMLParser
from ramalama.version import print_version, version
Expand Down Expand Up @@ -699,7 +698,7 @@ def convert_cli(args):
if not tgt:
tgt = target

model = OCI(tgt, args.engine)
model = ModelFactory(tgt, engine=args.engine).create_oci()
model.convert(source, args)


Expand Down Expand Up @@ -775,7 +774,7 @@ def push_cli(args):
raise e
try:
# attempt to push as a container image
m = OCI(tgt, config.get('engine', container_manager()))
m = ModelFactory(tgt, engine=config.get('engine', container_manager())).create_oci()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably not for this PR, but it would be better if we resolved all environment variables, configuration files, CLI args, once early in execution of ramalama. I started that here at one point, but it drifted:

load_and_merge_config

we tend to do things like, use the "engine" value or container_manager() over and over again in this codebase which isn't great.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree and good to know. I'll keep that in mind and will try to refactor occurrences like this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config here is already the result of load_and_merge_config where we do exactly the same thing. Maybe a dataclass or something similar might help. Although its less flexible, the fields are stated explicitly and don't encourage specifying defaults when the field isn't there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should probably be doing something like:

config['engine']

here. The code is misleading, it suggests we might not have resolved this value correctly up to this point.

m.push(source, args)
except Exception:
raise e
Expand Down Expand Up @@ -845,7 +844,7 @@ def run_cli(args):
except KeyError as e:
try:
args.quiet = True
model = OCI(args.MODEL, args.engine, ignore_stderr=True)
model = ModelFactory(args.MODEL, engine=args.engine, ignore_stderr=True).create_oci()
model.run(args)
except Exception:
raise e
Expand Down Expand Up @@ -879,7 +878,7 @@ def serve_cli(args):
except KeyError as e:
try:
args.quiet = True
model = OCI(args.MODEL, args.engine, ignore_stderr=True)
model = ModelFactory(args.MODEL, engine=args.engine, ignore_stderr=True).create_oci()
model.serve(args)
except Exception:
raise e
Expand Down Expand Up @@ -987,7 +986,7 @@ def _rm_model(models, args):
raise e
try:
# attempt to remove as a container image
m = OCI(model, args.engine, ignore_stderr=True)
m = ModelFactory(model, engine=args.engine, ignore_stderr=True).create_oci()
m.remove(args)
return
except Exception:
Expand Down
9 changes: 9 additions & 0 deletions ramalama/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,12 @@ def get_env_vars():
# env_vars[gpu_type] = str(gpu_num)

return env_vars


def rm_until_substring(input, substring):
pos = input.find(substring)
if pos == -1:
return input

# Create a new string starting after the found substring
return ''.join(input[i] for i in range(pos + len(substring), len(input)))
5 changes: 2 additions & 3 deletions ramalama/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import urllib.request

from ramalama.common import available, download_file, exec_cmd, perror, run_cmd, verify_checksum
from ramalama.model import Model, rm_until_substring
from ramalama.model import Model

missing_huggingface = """
Optional: Huggingface models require the huggingface-cli module.
Expand Down Expand Up @@ -34,9 +34,8 @@ def fetch_checksum_from_api(url):

class Huggingface(Model):
def __init__(self, model):
model = rm_until_substring(model, "hf.co/")
model = rm_until_substring(model, "://")
super().__init__(model)

self.type = "huggingface"
self.hf_cli_available = is_huggingface_cli_available()

Expand Down
9 changes: 0 additions & 9 deletions ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,3 @@ def distinfo_volume():
return ""

return f"-v{path}:/usr/share/ramalama/{dist_info}:ro"


def rm_until_substring(model, substring):
pos = model.find(substring)
if pos == -1:
return model

# Create a new string starting after the found substring
return ''.join(model[i] for i in range(pos + len(substring), len(model)))
52 changes: 44 additions & 8 deletions ramalama/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,69 @@
from typing import Union
from urllib.parse import urlparse

from ramalama.common import rm_until_substring
from ramalama.huggingface import Huggingface
from ramalama.model import MODEL_TYPES
from ramalama.oci import OCI
from ramalama.ollama import Ollama
from ramalama.url import URL


class ModelFactory:

def __init__(self, model: str, transport: str = "ollama", engine: str = "podman"):
def __init__(self, model: str, transport: str = "ollama", engine: str = "podman", ignore_stderr: bool = False):
self.model = model
self.transport = transport
self.engine = engine
self.ignore_stderr = ignore_stderr

def prune_model_input(self, cls: type[Union[Huggingface, Ollama, OCI, URL]]) -> str:
# remove protocol from model input
pruned_model_input = rm_until_substring(self.model, "://")

if cls == Huggingface:
pruned_model_input = rm_until_substring(pruned_model_input, "hf.co/")
elif cls == Ollama:
pruned_model_input = rm_until_substring(pruned_model_input, "ollama.com/library/")

return pruned_model_input

def validate_oci_model_input(self):
if self.model.startswith("oci://") or self.model.startswith("docker://"):
return

for t in MODEL_TYPES:
if self.model.startswith(t + "://"):
raise ValueError(f"{self.model} invalid: Only OCI Model types supported")

def create_huggingface(self) -> Huggingface:
return Huggingface(self.prune_model_input(Huggingface))

def create_ollama(self) -> Ollama:
return Ollama(self.prune_model_input(Ollama))

def create_oci(self) -> OCI:
self.validate_oci_model_input()
return OCI(self.prune_model_input(OCI), self.engine, self.ignore_stderr)

def create_url(self) -> URL:
return URL(self.prune_model_input(URL), urlparse(self.model).scheme)

def create(self) -> Union[Huggingface, Ollama, OCI, URL]:
if self.model.startswith("huggingface://") or self.model.startswith("hf://") or self.model.startswith("hf.co/"):
return Huggingface(self.model)
return self.create_huggingface()
if self.model.startswith("ollama://") or "ollama.com/library/" in self.model:
return Ollama(self.model)
return self.create_ollama()
if self.model.startswith("oci://") or self.model.startswith("docker://"):
return OCI(self.model, self.engine)
return self.create_oci()
if self.model.startswith("http://") or self.model.startswith("https://") or self.model.startswith("file://"):
return URL(self.model)
return self.create_url()

if self.transport == "huggingface":
return Huggingface(self.model)
return self.create_huggingface()
if self.transport == "ollama":
return Ollama(self.model)
return self.create_ollama()
if self.transport == "oci":
return OCI(self.model, self.engine)
return self.create_oci()

raise KeyError(f'transport "{self.transport}" not supported. Must be oci, huggingface, or ollama.')
8 changes: 3 additions & 5 deletions ramalama/oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import ramalama.annotations as annotations
from ramalama.common import MNT_FILE, engine_version, exec_cmd, perror, run_cmd
from ramalama.model import MODEL_TYPES, Model
from ramalama.model import Model

prefix = "oci://"

Expand Down Expand Up @@ -123,10 +123,8 @@ def list_models(args):

class OCI(Model):
def __init__(self, model, conman, ignore_stderr=False):
super().__init__(model.removeprefix(prefix).removeprefix("docker://"))
for t in MODEL_TYPES:
if self.model.startswith(t + "://"):
raise ValueError(f"{model} invalid: Only OCI Model types supported")
super().__init__(model)

self.type = "OCI"
self.conman = conman
self.ignore_stderr = ignore_stderr
Expand Down
5 changes: 2 additions & 3 deletions ramalama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import urllib.request

from ramalama.common import available, download_file, run_cmd, verify_checksum
from ramalama.model import Model, rm_until_substring
from ramalama.model import Model


def fetch_manifest_data(registry_head, model_tag, accept):
Expand Down Expand Up @@ -87,9 +87,8 @@ def in_existing_cache(model_name, model_tag):

class Ollama(Model):
def __init__(self, model):
model = rm_until_substring(model, "ollama.com/library/")
model = rm_until_substring(model, "://")
super().__init__(model)

self.type = "Ollama"

def _local(self, args):
Expand Down
9 changes: 4 additions & 5 deletions ramalama/url.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os
from urllib.parse import urlparse

from ramalama.common import download_file
from ramalama.model import Model, rm_until_substring
from ramalama.model import Model


class URL(Model):
def __init__(self, model):
self.type = urlparse(model).scheme
model = rm_until_substring(model, "://")
def __init__(self, model, scheme):
super().__init__(model)

self.type = scheme
split = self.model.rsplit("/", 1)
self.directory = split[0].removeprefix("/") if len(split) > 1 else ""

Expand Down
18 changes: 18 additions & 0 deletions test/unit/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from ramalama.common import rm_until_substring

@pytest.mark.parametrize(
"input,rm_until,expected",
[
("", "", ""),
("huggingface://granite-code", "://", "granite-code"),
("hf://granite-code", "://", "granite-code"),
("hf.co/granite-code", "hf.co/", "granite-code"),
("http://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", ".co/", "ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf"),
("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", "file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf"),
]
)
def test_rm_until_substring(input: str, rm_until: str, expected: str):
actual = rm_until_substring(input, rm_until)
assert actual == expected
55 changes: 50 additions & 5 deletions test/unit/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ramalama.url import URL

from dataclasses import dataclass
from typing import Union

@dataclass
class Input():
Expand All @@ -27,19 +28,63 @@ class Input():
(Input("http://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None),
(Input("https://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None),
(Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, None),

(Input("granite-code", "huggingface", ""), Huggingface, None),
(Input("granite-code", "ollama", ""), Ollama, None),
(Input("granite-code", "oci", ""), OCI, None),



]
)
def test_model_factory_create(input: Input, expected, error):
def test_model_factory_create(input: Input, expected: type[Union[Huggingface, Ollama, URL, OCI]], error):
if error is not None:
with pytest.raises(error):
ModelFactory(input.Model, input.Transport, input.Engine).create()
else:
model = ModelFactory(input.Model, input.Transport, input.Engine).create()
assert isinstance(model, expected)

@pytest.mark.parametrize(
"input,error",
[
(Input("", "", ""), None),
(Input("oci://granite-code", "", "podman"), None),
(Input("docker://granite-code", "", "podman"), None),
(Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), ValueError),
(Input("huggingface://granite-code", "", ""), ValueError),
(Input("hf://granite-code", "", ""), ValueError),
(Input("hf.co/granite-code", "", ""), None),
(Input("ollama://granite-code", "", ""), ValueError),
(Input("ollama.com/library/granite-code", "", ""), None),
(Input("granite-code", "", ""), None),
]
)
def test_validate_oci_model_input(input: Input, error):
if error is not None:
with pytest.raises(error):
ModelFactory(input.Model, input.Transport, input.Engine).validate_oci_model_input()
return

ModelFactory(input.Model, input.Transport, input.Engine).validate_oci_model_input()


@pytest.mark.parametrize(
"input,cls,expected",
[
(Input("huggingface://granite-code", "", ""), Huggingface, "granite-code"),
(Input("huggingface://ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""), Huggingface, "ibm-granite/granite-3b-code-base-2k-GGUF/granite-code"),
(Input("hf://granite-code", "", ""), Huggingface, "granite-code"),
(Input("hf.co/granite-code", "", ""), Huggingface, "granite-code"),
(Input("ollama://granite-code", "", ""), Ollama, "granite-code"),
(Input("ollama.com/library/granite-code", "", ""), Ollama, "granite-code"),
(Input("ollama.com/library/ibm-granite/granite-3b-code-base-2k-GGUF/granite-code", "", ""), Ollama, "ibm-granite/granite-3b-code-base-2k-GGUF/granite-code"),
(Input("oci://granite-code", "", "podman"), OCI, "granite-code"),
(Input("docker://granite-code", "", "podman"), OCI, "granite-code"),
(Input("http://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, "huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf"),
(Input("https://huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, "huggingface.co/ibm-granite/granite-3b-code-base-2k-GGUF/blob/main/granite-3b-code-base.Q4_K_M.gguf"),
(Input("file:///tmp/models/granite-3b-code-base.Q4_K_M.gguf", "", ""), URL, "/tmp/models/granite-3b-code-base.Q4_K_M.gguf"),
(Input("granite-code", "huggingface", ""), Huggingface, "granite-code"),
(Input("granite-code", "ollama", ""), Ollama, "granite-code"),
(Input("granite-code", "oci", ""), OCI, "granite-code"),
]
)
def test_prune_model_input(input: Input, cls: type[Union[Huggingface, Ollama, URL, OCI]], expected: str):
pruned_model_input = ModelFactory(input.Model, input.Transport, input.Engine).prune_model_input(cls)
assert pruned_model_input == expected
Loading