Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes #275 #276

Merged
merged 6 commits into from
Apr 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 98 additions & 39 deletions scabha/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import click
from scabha.exceptions import SchemaError
from .cargo import Parameter, UNSET, _UNSET_DEFAULT, Cargo
from .cargo import Parameter, UNSET, _UNSET_DEFAULT, Cargo, ParameterPolicies
from typing import List, Union, Optional, Callable, Dict, DefaultDict, Any
from .basetypes import EmptyDictDefault, File, is_file_type
from dataclasses import dataclass, make_dataclass, field
Expand Down Expand Up @@ -125,46 +125,79 @@ def nested_schema_to_dataclass(nested: Dict[str, Dict], class_name: str, bases=(

_atomic_types = dict(bool=bool, str=str, int=int, float=float)

def _validate_list(text: str, element_type, schema):
def _validate_list(text: str, element_type, schema, sep=",", brackets=True):
if not text:
return schema.default
if text == "[]":
return []
if text[0] == "[" and text[-1] == "]":
if brackets:
if text == "[]":
return []
if text[0] != "[" or text[-1] != "]":
raise click.BadParameter(f"can't convert to '{schema.dtype}', missing '[]' brackets")
text = text[1:-1]
try:
return [element_type(x) for x in text.split(",")]
return list(element_type(x) for x in text.split(sep))
except ValueError:
raise click.BadParameter(f"can't convert to '{schema.dtype}'")

def _validate_tuple(text: str, element_types, schema, sep=",", brackets=True):
if not text:
return schema.default
if brackets:
if text == "[]":
return []
if text[0] != "[" or text[-1] != "]":
raise click.BadParameter(f"can't convert to '{schema.dtype}', missing '[]' brackets")
text = text[1:-1]
elems = text.split(sep)
if len(elems) != len(element_types):
raise click.BadParameter(f"can't convert to '{schema.dtype}', tuple length mismatch")
try:
return tuple(element_type(x) for x, element_type in zip(elems, element_types))
except ValueError:
raise click.BadParameter(f"can't convert to '{schema.dtype}'")

@dataclass
class Schema(object):
inputs: Dict[str, Parameter] = EmptyDictDefault()
outputs: Dict[str, Parameter] = EmptyDictDefault()
policies: Optional[Dict[str, Any]] = None


def clickify_parameters(schemas: Union[str, Dict[str, Any]]):
def clickify_parameters(schemas: Union[str, Dict[str, Any]],
default_policies: Dict[str, Any] = None):

if type(schemas) is str:
schemas = OmegaConf.merge(OmegaConf.structured(Schema),
OmegaConf.load(schemas))

# get default policies from argument or schemas
if default_policies:
default_policies = OmegaConf.merge(OmegaConf.structured(ParameterPolicies), default_policies)
elif getattr(schemas, 'policies', None):
default_policies = OmegaConf.merge(OmegaConf.structured(ParameterPolicies), schemas.policies)
else:
schemas = OmegaConf.merge(OmegaConf.structured(Schema),
dict(inputs=schemas.inputs, outputs=schemas.outputs))
default_policies = ParameterPolicies()

decorator_chain = None
inputs = Cargo.flatten_schemas(OrderedDict(), schemas.inputs, "inputs")
outputs = Cargo.flatten_schemas(OrderedDict(), schemas.outputs, "outputs")
inputs = Cargo.flatten_schemas(OrderedDict(), getattr(schemas, 'inputs', {}), "inputs")
outputs = Cargo.flatten_schemas(OrderedDict(), getattr(schemas, 'outputs', {}), "outputs")
for io in inputs, outputs:
for name, schema in io.items():
# skip outputs, unless they're named outputs
if io is schemas.outputs and not (schema.is_file_type and not schema.implicit):
if io is outputs and not (schema.is_file_type and not schema.implicit):
continue

policies = OmegaConf.merge(default_policies, schema.policies)
# impose default repeat policy of using a single argument for a list, i.e. X1,X2,X3
if policies.repeat is None:
policies.repeat = ","

name = name.replace("_", "-").replace(".", "-")
optname = f"--{name}"
dtype = schema.dtype
validator = None
multiple = False
nargs = 1

# sort out option type. Atomic type?
if dtype in _atomic_types:
Expand All @@ -175,15 +208,47 @@ def clickify_parameters(schemas: Union[str, Dict[str, Any]]):
elif dtype in ("MS", "File", "Directory"):
dtype = click.Path(exists=(io is schemas.inputs))
else:
match = re.fullmatch("List\[(.*)\]", dtype)
list_match = re.fullmatch("List\[(.*)\]", dtype)
tuple_match = re.fullmatch("Tuple\[(.*)\]", dtype)
# List[x] type? Add validation callback to convert elements
if match:
elem_type_name = match.group(1)
# convert "x" to type object -- unknown element types will get treated as a string
elem_type = _atomic_types.get(elem_type_name, str)
validator = lambda ctx, param, value, etype=elem_type, schema=schema: _validate_list(value, element_type=etype, schema=schema)
# anything else will be just a string
dtype = str
if list_match:
elem_type = _atomic_types.get(list_match.group(1).strip(), str)
if policies.repeat == 'list':
nargs = -1
dtype = elem_type
elif policies.repeat == 'repeat':
multiple = True
dtype = elem_type
elif policies.repeat == '[]': # else assume [X,Y] or X,Y syntax
dtype = str
validator = lambda ctx, param, value, etype=dtype, schema=schema: \
_validate_list(value, element_type=elem_type, schema=schema)
elif policies.repeat is not None: # assume XrepY syntax
dtype = str
validator = lambda ctx, param, value, etype=dtype, schema=schema: \
_validate_list(value, element_type=elem_type, schema=schema,
sep=policies.repeat, brackets=False)
else:
raise SchemaError(f"list-type parameter '{name}' does not have a repeat policy set")
elif tuple_match:
elem_types = tuple(_atomic_types.get(t.strip(), str) for t in tuple_match.group(1).split(","))
if policies.repeat == 'list' or policies.repeat == 'repeat':
nargs = len(elem_types)
dtype = elem_types
elif policies.repeat == '[]': # else assume [X,Y] or X,Y syntax
dtype = str
validator = lambda ctx, param, value, etype=dtype, schema=schema: \
_validate_tuple(value, element_types=elem_types, schema=schema)
elif policies.repeat is not None: # assume XrepY syntax
dtype = str
validator = lambda ctx, param, value, etype=dtype, schema=schema: \
_validate_tuple(value, element_types=elem_types, schema=schema,
sep=policies.repeat, brackets=False)
else:
raise SchemaError(f"tuple-type parameter '{name}' does not have a repeat policy set")
else:
# anything else will be just a string
dtype = str

# choices?
if schema.choices:
Expand All @@ -194,31 +259,25 @@ def clickify_parameters(schemas: Union[str, Dict[str, Any]]):
if schema.abbreviation:
optnames.append(f"-{schema.abbreviation}")

if schema.policies.positional:
if schema.default in (UNSET, _UNSET_DEFAULT) or schema.suppress_cli_default:
deco = click.argument(name, type=dtype, callback=validator,
required=schema.required,
metavar=schema.metavar)
else:
deco = click.argument(name, type=dtype, callback=validator,
default=schema.default, required=schema.required,
metavar=schema.metavar)
if policies.positional:
kwargs = dict(type=dtype, callback=validator, required=schema.required, nargs=nargs,
metavar=schema.metavar)
if not schema.default in (UNSET, _UNSET_DEFAULT) and not schema.suppress_cli_default:
kwargs['default'] = schema.default
deco = click.argument(name, **kwargs)
else:
if schema.default in (UNSET, _UNSET_DEFAULT) or schema.suppress_cli_default:
deco = click.option(*optnames, type=dtype, callback=validator,
required=schema.required,
metavar=schema.metavar, help=schema.info)
else:
deco = click.option(*optnames, type=dtype, callback=validator,
default=schema.default, required=schema.required,
metavar=schema.metavar, help=schema.info)

kwargs = dict(type=dtype, callback=validator,
required=schema.required, multiple=multiple,
metavar=schema.metavar, help=schema.info)
if not schema.default in (UNSET, _UNSET_DEFAULT) and not schema.suppress_cli_default:
kwargs['default'] = schema.default
deco = click.option(*optnames, **kwargs)
if decorator_chain is None:
decorator_chain = deco
else:
decorator_chain = lambda x,deco=deco,chain=decorator_chain: chain(deco(x))

return decorator_chain
return decorator_chain or (lambda x: x)

@dataclass
class SchemaSpec:
Expand Down
Loading