diff --git a/projects/fal/src/fal/api.py b/projects/fal/src/fal/api.py index 9dba7c61..5005bd4b 100644 --- a/projects/fal/src/fal/api.py +++ b/projects/fal/src/fal/api.py @@ -142,6 +142,8 @@ def parse_key(cls, key: str, value: Any) -> tuple[Any, Any]: # Conda environment definition should be parsed before sending to serverless with open(value) as f: return "env_dict", yaml.safe_load(f) + elif key == "image" and isinstance(value, ContainerImage): + return "image", value.to_dict() else: return key, value diff --git a/projects/fal/src/fal/container.py b/projects/fal/src/fal/container.py index 2f98f6c1..d97c8255 100644 --- a/projects/fal/src/fal/container.py +++ b/projects/fal/src/fal/container.py @@ -1,19 +1,49 @@ +from dataclasses import dataclass, field +from typing import Dict, Literal + +Builder = Literal["depot", "service", "worker"] +BUILDERS = {"depot", "service", "worker"} +DEFAULT_BUILDER: Builder = "depot" + + +@dataclass class ContainerImage: """ContainerImage represents a Docker image that can be built from a Dockerfile. """ - _known_keys = {"dockerfile_str", "build_args", "registries", "builder"} + dockerfile_str: str + build_args: Dict[str, str] = field(default_factory=dict) + registries: Dict[str, Dict[str, str]] = field(default_factory=dict) + builder: Builder = field(default=DEFAULT_BUILDER) + + def __post_init__(self) -> None: + if self.registries: + for registry in self.registries.values(): + keys = registry.keys() + if "username" not in keys or "password" not in keys: + raise ValueError( + "Username and password are required for each registry" + ) + + if self.builder not in BUILDERS: + raise ValueError( + f"Invalid builder: {self.builder}, must be one of {BUILDERS}" + ) @classmethod - def from_dockerfile_str(cls, text: str, **kwargs): - # Check for unknown keys and return them as a dict. - return dict( - dockerfile_str=text, - **{k: v for k, v in kwargs.items() if k in cls._known_keys}, - ) + def from_dockerfile_str(cls, text: str, **kwargs) -> "ContainerImage": + return cls(dockerfile_str=text, **kwargs) @classmethod - def from_dockerfile(cls, path: str, **kwargs): + def from_dockerfile(cls, path: str, **kwargs) -> "ContainerImage": with open(path) as fobj: return cls.from_dockerfile_str(fobj.read(), **kwargs) + + def to_dict(self) -> dict: + return { + "dockerfile_str": self.dockerfile_str, + "build_args": self.build_args, + "registries": self.registries, + "builder": self.builder, + }