Skip to content

Commit

Permalink
feat: cleaner merge_with_sig, more annotations, and more signature co…
Browse files Browse the repository at this point in the history
…mparison tools
  • Loading branch information
thorwhalen committed Aug 14, 2024
1 parent d1d342a commit 68bc934
Showing 1 changed file with 156 additions and 79 deletions.
235 changes: 156 additions & 79 deletions i2/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@
Iterator,
TypeVar,
Mapping as MappingType,
Literal,
Optional,
get_args,
)
from typing import KT, VT, T
from types import FunctionType
Expand Down Expand Up @@ -154,6 +157,8 @@ def wrapper(*args, **kwargs):
var_param_kind_dflts_items = tuple({VP: (), VK: {}}.items())

DFLT_DEFAULT_CONFLICT_METHOD = 'strict'
SigMergeOptions = Literal[None, 'strict', 'take_first', 'fill_defaults_and_annotations']

param_attributes = {'name', 'kind', 'default', 'annotation'}


Expand Down Expand Up @@ -710,7 +715,10 @@ def extract_arguments(

if include_all_when_var_keywords_in_params:
if (
next((p.name for p in params if p.kind == Parameter.VAR_KEYWORD), None,)
next(
(p.name for p in params if p.kind == Parameter.VAR_KEYWORD),
None,
)
is not None
):
param_kwargs.update(remaining_kwargs)
Expand Down Expand Up @@ -1980,7 +1988,7 @@ def merge_with_sig(
sig: ParamsAble,
ch_to_all_pk: bool = False,
*,
default_conflict_method: str = DFLT_DEFAULT_CONFLICT_METHOD,
default_conflict_method: SigMergeOptions = DFLT_DEFAULT_CONFLICT_METHOD,
):
"""Return a signature obtained by merging self signature with another signature.
Insofar as it can, given the kind precedence rules, the arguments of self will
Expand Down Expand Up @@ -2048,14 +2056,8 @@ def merge_with_sig(
'them with a signature mapping that avoids the argument name clashing'
)

assert default_conflict_method in {
None,
'strict',
'take_first',
'fill_defaults_and_annotations',
}, (
'default_conflict_method should be in '
"{None, 'strict', 'take_first', 'fill_defaults_and_annotations'}"
assert default_conflict_method in get_args(SigMergeOptions), (
'default_conflict_method should be one of: ' f"{get_args(SigMergeOptions)}"
)

if default_conflict_method == 'take_first':
Expand Down Expand Up @@ -2930,7 +2932,9 @@ def extract_args_and_kwargs(
ignore_kind=_ignore_kind,
)
return self.mk_args_and_kwargs(
arguments, allow_partial=_allow_partial, args_limit=_args_limit,
arguments,
allow_partial=_allow_partial,
args_limit=_args_limit,
)

def source_arguments(
Expand Down Expand Up @@ -3095,7 +3099,9 @@ def source_args_and_kwargs(
**kwargs,
)
return self.mk_args_and_kwargs(
arguments, allow_partial=_allow_partial, args_limit=_args_limit,
arguments,
allow_partial=_allow_partial,
args_limit=_args_limit,
)


Expand Down Expand Up @@ -4232,32 +4238,23 @@ def zip(*iterables):
zip(*iterables) --> A zip object yielding tuples until an input is exhausted.
"""

def bool(x: Any, /) -> bool:
...
def bool(x: Any, /) -> bool: ...

def bytearray(iterable_of_ints: Iterable[int], /):
...
def bytearray(iterable_of_ints: Iterable[int], /): ...

def classmethod(function: Callable, /):
...
def classmethod(function: Callable, /): ...

def int(x, base=10, /):
...
def int(x, base=10, /): ...

def iter(callable: Callable, sentinel=None, /):
...
def iter(callable: Callable, sentinel=None, /): ...

def next(iterator: Iterator, default=None, /):
...
def next(iterator: Iterator, default=None, /): ...

def staticmethod(function: Callable, /):
...
def staticmethod(function: Callable, /): ...

def str(bytes_or_buffer, encoding=None, errors=None, /):
...
def str(bytes_or_buffer, encoding=None, errors=None, /): ...

def super(type_, obj=None, /):
...
def super(type_, obj=None, /): ...

# def type(name, bases=None, dict=None, /):
# ...
Expand Down Expand Up @@ -4427,14 +4424,11 @@ class sigs_for_type_name:
signatures (through ``inspect.signature``),
"""

def itemgetter(iterable: Iterable[VT], /) -> Union[VT, Tuple[VT]]:
...
def itemgetter(iterable: Iterable[VT], /) -> Union[VT, Tuple[VT]]: ...

def attrgetter(iterable: Iterable[VT], /) -> Union[VT, Tuple[VT]]:
...
def attrgetter(iterable: Iterable[VT], /) -> Union[VT, Tuple[VT]]: ...

def methodcaller(obj: Any) -> Any:
...
def methodcaller(obj: Any) -> Any: ...


############# Tools for testing #########################################################
Expand Down Expand Up @@ -4479,7 +4473,9 @@ def param_for_kind(
lower_kind = kind.lower()
setattr(param_for_kind, lower_kind, partial(param_for_kind, kind=kind))
setattr(
param_for_kind, 'with_default', partial(param_for_kind, with_default=True),
param_for_kind,
'with_default',
partial(param_for_kind, with_default=True),
)
setattr(
getattr(param_for_kind, lower_kind),
Expand All @@ -4493,7 +4489,7 @@ def param_for_kind(
)

########################################################################################
# Signature Compatibility #
# Signature Comparison and Compatibility #
########################################################################################

Compared = TypeVar('Compared')
Expand Down Expand Up @@ -4533,7 +4529,10 @@ def mk_func_comparator_based_on_signature_comparator(


def _keyed_comparator(
comparator: Comparator, key: KeyFunction, x: CT, y: CT,
comparator: Comparator,
key: KeyFunction,
x: CT,
y: CT,
) -> Comparison:
"""Apply a comparator after transforming inputs through a key function.
Expand All @@ -4547,7 +4546,10 @@ def _keyed_comparator(
return comparator(key(x), key(y))


def keyed_comparator(comparator: Comparator, key: KeyFunction,) -> Comparator:
def keyed_comparator(
comparator: Comparator,
key: KeyFunction,
) -> Comparator:
"""Create a key-function enabled binary operator.
In various places in python functionality is extended by allowing a key function.
Expand Down Expand Up @@ -4630,59 +4632,135 @@ def param_comparator(
)


param_comparator: ParamComparator
param_binary_func = param_comparator # back compatibility alias


# TODO: Implement annotation compatibility
def is_annotation_compatible_with(annot1, annot2):
def ignore_any_differences(x, y):
return True


def is_default_value_compatible_with(dflt1, dflt2):
return dflt1 is empty or dflt2 is not empty
permissive_param_comparator = partial(
param_comparator,
name=ignore_any_differences,
kind=ignore_any_differences,
default=ignore_any_differences,
annotation=ignore_any_differences,
)
permissive_param_comparator.__doc__ = """
Permissive version of param_comparator that ignores any differences of parameter
attributes.
It is meant to be used with partial, but with a permissive base, contrary to the
base param_comparator which requires strict equality (`eq`) for all attributes.
"""

def is_param_compatible_with(
p1: Parameter,
p2: Parameter,
annotation_comparator: Comparator = None,
default_value_comparator: Comparator = None,
):
"""Return True if ``p1`` is compatible with ``p2``. Meaning that any value valid
for ``p1`` is valid for ``p2``.

:param p1: The main parameter.
:param p2: The parameter to be compared with.
:param annotation_comparator: The function used to compare the annotations
:param default_value_comparator: The function used to compare the default values
def return_tuple(x, y):
return x, y

>>> is_param_compatible_with(
... Parameter('a', PO),
... Parameter('b', PO)
... )
True
>>> is_param_compatible_with(
... Parameter('a', PO),
... Parameter('b', PO, default=0)
... )
True
>>> is_param_compatible_with(
... Parameter('a', PO, default=0),
... Parameter('b', PO)
... )
False

param_attribute_dict: ComparisonAggreg


def param_attribute_dict(name_kind_default_annotation: Iterable[Comparison]) -> dict:
keys = ['name', 'kind', 'default', 'annotation']
return {key: value for key, value in zip(keys, name_kind_default_annotation)}


param_comparison_dict = partial(
param_comparator,
name=return_tuple,
kind=return_tuple,
default=return_tuple,
annotation=return_tuple,
aggreg=param_attribute_dict,
)

param_comparison_dict.__doc__ = """
A ParamComparator that returns a dictionary with pairs parameter attributes.
>>> param1 = Sig('(a: int = 1)')['a']
>>> param2 = Sig('(a: str = 2)')['a']
>>> param_comparison_dict(param1, param2) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
{'name': ('a', 'a'), 'kind': ..., 'default': (1, 2), 'annotation': (<class 'int'>, <class 'str'>)}
"""


def param_differences_dict(
param1: Parameter,
param2: Parameter,
*,
name: Comparator = eq,
kind: Comparator = eq,
default: Comparator = eq,
annotation: Comparator = eq,
):
"""Makes a dictionary exibiting the differences between two parameters.
>>> param1 = Sig('(a: int = 1)')['a']
>>> param2 = Sig('(a: str = 2)')['a']
>>> param_differences_dict(param1, param2)
{'default': (1, 2), 'annotation': (<class 'int'>, <class 'str'>)}
>>> param_differences_dict(param1, param2, default=lambda x, y: isinstance(x, type(y)))
{'annotation': (<class 'int'>, <class 'str'>)}
"""
# TODO: Consider using functions as defaults instead of None
annotation_comparator = annotation_comparator or is_annotation_compatible_with
default_value_comparator = (
default_value_comparator or is_default_value_compatible_with
equality_vector = param_comparator(
param1,
param2,
name=name,
kind=kind,
default=default,
annotation=annotation,
aggreg=tuple,
)
comparison_dict = param_comparison_dict(param1, param2)
return {
key: comparison_dict[key]
for key, equal in zip(comparison_dict, equality_vector)
if not equal
}


def defaults_are_the_same_when_not_empty(dflt1, dflt2):
"""
Check if two defaults are the same when they are not empty.
# >>> defaults_are_the_same_when_not_empty(1, 1)
# True
# >>> defaults_are_the_same_when_not_empty(1, 2)
# False
# >>> defaults_are_the_same_when_not_empty(1, None)
# False
# >>> defaults_are_the_same_when_not_empty(1, Parameter.empty)
# True
"""
return dflt1 is empty or dflt2 is empty or dflt1 == dflt2


return annotation_comparator(
p1.annotation, p2.annotation
) and default_value_comparator(p1.default, p2.default)
def dflt1_is_empty_or_dflt2_is_not(dflt1, dflt2):
"""
Why such a strange default comparison function?
This is to be used as a default in is_call_compatible_with.
Consider two functions func1 and func2 with a parameter p with default values
dflt1 and dflt2 respectively.
If dflt1 was not empty and dflt2 was, this would mean that func1 could be called
without specifying p, but func2 couldn't.
So to avoid this situation, we use dflt1_is_empty_or_dflt2_is_not as the default
"""
return dflt1 is empty or dflt2 is not empty


# TODO: It seems like param_comparator is really only used to compare parameters on defaults.
# This may be due to the fact that is_call_compatible_with was developed independently
# from the other general param_comparator functionality that was developed (see above)
# The code of is_call_compatible_with should be reviwed and refactored to use general
# tools.
def is_call_compatible_with(
sig1: Sig, sig2: Sig, *, param_comparator: ParamComparator = None
) -> bool:
Expand Down Expand Up @@ -4814,8 +4892,7 @@ def validate_param_compatibility():
return False
return True

# TODO: Consider putting is_param_compatible_with as default instead
param_comparator = param_comparator or is_param_compatible_with
param_comparator = param_comparator or dflt1_is_empty_or_dflt2_is_not

pos1, pks1, vp1, kos1, vk1 = sig1.detail_names_by_kind()
ps1 = pos1 + pks1
Expand Down

0 comments on commit 68bc934

Please sign in to comment.