Skip to content

Commit

Permalink
Merge pull request #266 from caracal-pipeline/file_finder
Browse files Browse the repository at this point in the history
Add recursive finder for filelike objects
  • Loading branch information
o-smirnov authored Apr 5, 2024
2 parents 915f0b3 + f3ae235 commit 7a28ba5
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 13 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pydantic = "^1.10.2"
psutil = "^5.9.3"
rich = "^13.7.0"
dill = "^0.3.6"
typeguard = "^4.2.1"

[tool.poetry.scripts]
stimela = "stimela.main:cli"
Expand Down
84 changes: 80 additions & 4 deletions scabha/basetypes.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from __future__ import annotations
from dataclasses import field, dataclass
from collections import OrderedDict
from typing import List
from typing import List, Union, get_args, get_origin, Any
import os.path
import re
from .exceptions import UnsetError
from itertools import zip_longest
from typeguard import (
check_type, TypeCheckError, TypeCheckerCallable, TypeCheckMemo, checker_lookup_functions
)
from inspect import isclass


def EmptyDictDefault():
return field(default_factory=lambda:OrderedDict())
Expand Down Expand Up @@ -55,7 +62,7 @@ def __init__(self, value):
def parse(value: str, expand_user=True):
"""
Parses URI. If URI does not start with "protocol://", assumes "file://"
Returns tuple of (protocol, path, is_remote)
If expand_user is True, ~ in (file-protocol) paths will be expanded.
Expand All @@ -75,7 +82,7 @@ class File(URI):
@property
def NAME(self):
return File(os.path.basename(self))

@property
def PATH(self):
return File(os.path.abspath(self))
Expand All @@ -95,7 +102,7 @@ def BASENAME(self):
@property
def EXT(self):
return os.path.splitext(self)[1]

@property
def EXISTS(self):
return os.path.exists(self)
Expand All @@ -114,3 +121,72 @@ def is_file_type(dtype):
def is_file_list_type(dtype):
return any(dtype == List[t] for t in FILE_TYPES)


def check_filelike(value: Any, origin_type: Any, args: tuple[Any, ...], memo: TypeCheckMemo) -> None:
"""Custom checker for filelike objects. Currently checks for strings."""
if not isinstance(value, str):
raise TypeCheckError(f'{value} is not compatible with URI or its subclasses.')


def filelike_lookup(origin_type: Any, args: tuple[Any, ...], extras: tuple[Any, ...]) -> TypeCheckerCallable | None:
"""Lookup the custom checker for filelike objects."""
if isclass(origin_type) and issubclass(origin_type, URI):
return check_filelike

return None

checker_lookup_functions.append(filelike_lookup) # Register custom type checker.

def get_filelikes(dtype, value, filelikes=None):
"""Recursively recover all filelike elements from a composite dtype."""

filelikes = set() if filelikes is None else filelikes

origin = get_origin(dtype)
args = get_args(dtype)

if origin: # Implies composition.

if origin is dict:

# No further work required for empty collections.
if len(value) == 0:
return filelikes

k_dtype, v_dtype = args

for k, v in value.items():
filelikes = get_filelikes(k_dtype, k, filelikes)
filelikes = get_filelikes(v_dtype, v, filelikes)

elif origin in (tuple, list, set):

# No further work required for empty collections.
if len(value) == 0:
return filelikes

# This is a special case for tuples of arbitrary
# length i.e. list-like behaviour. We can simply
# strip out the Ellipsis.
args = tuple([arg for arg in args if arg != ...])

for dt, v in zip_longest(args, value, fillvalue=args[0]):
filelikes = get_filelikes(dt, v, filelikes)

elif origin is Union:

for dt in args:
try:
check_type(value, dt)
except TypeCheckError: # Value doesn't match dtype - incorrect branch.
continue
filelikes = get_filelikes(dt, value, filelikes)

else:
raise ValueError(f"Failed to traverse {dtype} dtype when looking for files.")

else:
if is_file_type(dtype):
filelikes.add(value)

return filelikes
15 changes: 12 additions & 3 deletions stimela/backends/singularity.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import subprocess
import os
import re
import logging
from stimela import utils
from enum import Enum
import stimela
from shutil import which
from dataclasses import dataclass
from omegaconf import OmegaConf
from typing import Dict, List, Any, Optional, Callable
from contextlib import ExitStack
from scabha.basetypes import EmptyListDefault
from scabha.basetypes import EmptyDictDefault
import datetime
from stimela.utils.xrun_asyncio import xrun

from stimela.exceptions import BackendError

from . import native

ReadWriteMode = Enum("ReadWriteMode", "ro rw", module=__name__)


@dataclass
class SingularityBackendOptions(object):
enable: bool = True
Expand All @@ -26,6 +28,8 @@ class SingularityBackendOptions(object):
executable: Optional[str] = None
remote_only: bool = False # if True, won't look for singularity on local system -- useful in combination with slurm wrapper

# optional extra bindings
bind_dirs: Dict[str, ReadWriteMode] = EmptyDictDefault()
# @dataclass
# class EmptyVolume(object):
# name: str
Expand Down Expand Up @@ -250,6 +254,11 @@ def run(cab: 'stimela.kitchen.cab.Cab', params: Dict[str, Any], fqname: str,

# initial set of mounts has cwd as read-write
mounts = {cwd: True}
# add extra binds
for path, rw in backend.singularity.bind_dirs.items():
path = os.path.expanduser(path)
mounts[path] = mounts.get(path, False) or (rw == ReadWriteMode.rw)

# get extra required filesystem bindings
resolve_required_mounts(mounts, params, cab.inputs, cab.outputs)

Expand Down
10 changes: 4 additions & 6 deletions stimela/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from stimela.kitchen.cab import Cab, Parameter
from scabha.exceptions import SchemaError
from stimela.exceptions import BackendError
from scabha.basetypes import File, Directory, MS, URI
from scabha.basetypes import File, Directory, MS, URI, get_filelikes

## commenting out for now -- will need to fix when we reactive the kube backend (and have tests for it)

Expand Down Expand Up @@ -34,11 +34,9 @@ def add_target(param_name, path, must_exist, readwrite):
if schema is None:
raise SchemaError(f"parameter {name} not in defined inputs or outputs for this cab. This should have been caught by validation earlier!")

if schema.is_file_type:
files = [value]
elif schema.is_file_list_type:
files = value
else:
files = get_filelikes(schema._dtype, value)

if not files:
continue

must_exist = schema.must_exist and name in inputs
Expand Down
42 changes: 42 additions & 0 deletions tests/scabha_tests/test_filelikes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from scabha.basetypes import get_filelikes, File, URI, Directory, MS
from typing import Dict, List, Set, Tuple, Union, Optional
import pytest


@pytest.fixture(scope="module", params=[File, URI, Directory, MS])
def templates(request):

ft = request.param

TEMPLATES = (
(Tuple, (), set()),
(Tuple[int, ...], [1, 2], set()),
(Tuple[ft, ...], ("foo", "bar"), {"foo", "bar"}),
(Tuple[ft, str], ("foo", "bar"), {"foo"}),
(Dict[str, int], {"a": 1, "b": 2}, set()),
(Dict[str, ft], {"a": "foo", "b": "bar"}, {"foo", "bar"}),
(Dict[ft, str], {"foo": "a", "bar": "b"}, {"foo", "bar"}),
(List[ft], [], set()),
(List[int], [1, 2], set()),
(List[ft], ["foo", "bar"], {"foo", "bar"}),
(Set[ft], set(), set()),
(Set[int], {1, 2}, set()),
(Set[ft], {"foo", "bar"}, {"foo", "bar"}),
(Union[str, List[ft]], "foo", set()),
(Union[str, List[ft]], ["foo"], {"foo"}),
(Union[str, Tuple[ft]], "foo", set()),
(Union[str, Tuple[ft]], ("foo",), {"foo"}),
(Optional[ft], None, set()),
(Optional[ft], "foo", {"foo"}),
(Optional[Union[ft, int]], 1, set()),
(Optional[Union[ft, int]], "foo", {"foo"}),
(Dict[str, Tuple[ft, str]], {"a": ("foo", "bar")}, {"foo"})
)

return TEMPLATES


def test_get_filelikes(templates):

for dt, v, res in templates:
assert get_filelikes(dt, v) == res, f"Failed for dtype {dt} and value {v}."

0 comments on commit 7a28ba5

Please sign in to comment.