Skip to content

Commit

Permalink
Created abstract base class for models
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Engel <mengel@redhat.com>
  • Loading branch information
engelmi committed Feb 24, 2025
1 parent acd8209 commit 38da850
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,46 @@
$(error)s"""


class Model:
class ModelBase:

def __not_implemented_error(self, param):
return NotImplementedError(f"ramalama {param} for '{type(self).__name__}' not implemented")

def login(self, args):
raise self.__not_implemented_error("login")

def logout(self, args):
raise self.__not_implemented_error("logout")

def pull(self, args):
raise self.__not_implemented_error("pull")

def push(self, source, args):
raise self.__not_implemented_error("push")

def remove(self, args):
raise self.__not_implemented_error("rm")

def bench(self, args):
raise self.__not_implemented_error("bench")

def run(self, args):
raise self.__not_implemented_error("run")

def perplexity(self, args):
raise self.__not_implemented_error("perplexity")

def serve(self, args):
raise self.__not_implemented_error("serve")

def exists(self, args):
raise self.__not_implemented_error("exists")

def inspect(self, args):
raise self.__not_implemented_error("inspect")


class Model(ModelBase):
"""Model super class"""

model = ""
Expand All @@ -48,18 +87,6 @@ def __init__(self, model):
self.directory = split[0] if len(split) > 1 else ""
self.filename = split[1] if len(split) > 1 else split[0]

def login(self, args):
raise NotImplementedError(f"ramalama login for {self.type} not implemented")

def logout(self, args):
raise NotImplementedError(f"ramalama logout for {self.type} not implemented")

def pull(self, args):
raise NotImplementedError(f"ramalama pull for {self.type} not implemented")

def push(self, source, args):
raise NotImplementedError(f"ramalama push for {self.type} not implemented")

def is_symlink_to(self, file_path, target_path):
if os.path.islink(file_path):
symlink_target = os.readlink(file_path)
Expand Down Expand Up @@ -109,8 +136,6 @@ def attempt_to_use_versioned(self, conman, image, vers, args):
except Exception:
return False

return False

def _image(self, args):
if args.image != DEFAULT_IMAGE:
return args.image
Expand Down

0 comments on commit 38da850

Please sign in to comment.