From dd6973d6637610cabc12869f41a55a2833f44c4f Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 29 Dec 2021 22:31:13 -0600 Subject: [PATCH] Add Type.is_compatible and update type conversion These changes enforce a strict narrowing-only conversion policy; i.e. `Type`s can only be converted to equal or more specific types. --- aesara/graph/type.py | 67 ++++++++++++------- aesara/tensor/type.py | 45 ++++++++----- tests/graph/test_types.py | 135 +++++++++++++++++++++++++------------- tests/tensor/test_type.py | 48 ++++++++++++++ 4 files changed, 208 insertions(+), 87 deletions(-) diff --git a/aesara/graph/type.py b/aesara/graph/type.py index 3eedabe600..fb9eb38508 100644 --- a/aesara/graph/type.py +++ b/aesara/graph/type.py @@ -43,6 +43,24 @@ class Type(MetaObject): The `Type` that will be created by a call to `Type.make_constant`. """ + def is_compatible(self, otype: "Type") -> Optional[bool]: + """Determine if another `Type` is compatible with this `Type`. + + Compatibility is determined by the ability to replace one `Type` with + another. In this case, ``t1.is_compatible(t2) == True`` implies that + ``t1`` can be replaced with ``t2``. + + In general, a type can be replaced by a more specific type, so this + compatibility check is really a type-level ordering or + `issubclass`-like check. + + """ + if self == otype: + return True + + # Indeterminate + return None + @abstractmethod def filter( self, data: D, strict: bool = False, allow_downcast: Optional[bool] = None @@ -101,14 +119,9 @@ def filter_inplace( def filter_variable( self, other: Union[Variable, D], allow_convert: bool = True ) -> Variable: - r"""Convert a symbolic variable into this `Type`, if compatible. - - For the moment, the only `Type`\s compatible with one another are - `TensorType` and `GpuArrayType`, provided they have the same number of - dimensions, same broadcasting pattern, and same dtype. - - If `Type`\s are not compatible, a ``TypeError`` should be raised. + r"""Convert a `other` into a `Variable` with a `Type` that's compatible with `self`. + If the involved `Type`\s are not compatible, a `TypeError` will be raised. """ if not isinstance(other, Variable): # The value is not a Variable: we cast it into @@ -122,30 +135,36 @@ def filter_variable( if other.type != self: raise TypeError( - "Cannot convert Type %(othertype)s " - "(of Variable %(other)s) into Type %(self)s. " - "You can try to manually convert %(other)s into a %(self)s." - % dict(othertype=other.type, other=other, self=self) + f"Cannot convert Type {other.type} " + f"(of Variable {other}) into Type {self}. " + f"You can try to manually convert {other} into a {self}." ) return other - def convert_variable(self, var: Union[Variable, D]) -> Optional[Variable]: - """Patch a variable so that its `Type` will match ``self``, if possible. - - If the variable can't be converted, this should return None. + def convert_variable(self, var: Variable) -> Optional[Variable]: + """Produce a `Variable` that's compatible with both `self` and `var.type`, if possible. - The conversion can only happen if the following implication is - true for all possible `val`. + A compatible `Variable` is a `Variable` with a `Type` that's the + "narrower" of `self` and `var.type`. - self.is_valid_value(val) => var.type.is_valid_value(val) - - For the majority of types this means that you can only have - non-broadcastable dimensions become broadcastable and not the - inverse. - - The default is to not convert anything which is always safe. + If a compatible `Type` cannot be found, this method will return + ``None``. """ + var_type = var.type + + if self.is_compatible(var_type): + # `var.type` is at least as specific as `self`, so we return it + # as-is + return var + elif var_type.is_compatible(self): + # `var.type` is less specific than `self`, so we need to convert it + # to `self`'s `Type`. + # + # Note that, in this case, `var.type != self`, because equality is + # covered by the branch above. + raise NotImplementedError() + return None def is_valid_value(self, data: D) -> bool: diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index a0ccde2efa..75a7bc8650 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -218,14 +218,6 @@ def filter(self, data, strict=False, allow_downcast=None): return data def filter_variable(self, other, allow_convert=True): - """ - Convert a symbolic Variable into a TensorType, if compatible. - - For the moment, only a TensorType and GpuArrayType will be - converted, provided they have the same number of dimensions - and dtype and have "compatible" broadcastable pattern. - - """ if not isinstance(other, Variable): # The value is not a Variable: we cast it into # a Constant of the appropriate Type. @@ -235,9 +227,8 @@ def filter_variable(self, other, allow_convert=True): return other if allow_convert: - # Attempt safe broadcast conversion. other2 = self.convert_variable(other) - if other2 is not None and other2.type == self: + if other2 is not None: return other2 raise TypeError( @@ -282,17 +273,39 @@ def __eq__(self, other): and other.broadcastable == self.broadcastable ) - def convert_variable(self, var): + def is_compatible(self, otype): if ( - isinstance(self, type(var.type)) - and self.dtype == var.type.dtype # noqa - and self.ndim == var.type.ndim + isinstance(otype, type(self)) + and otype.dtype == self.dtype + and self.ndim == otype.ndim + # `otype` is allowed to be as or more shape-specific than `self`, + # but not less and all( sb == ob or ob - for sb, ob in zip(self.broadcastable, var.type.broadcastable) + for sb, ob in zip(self.broadcastable, otype.broadcastable) ) ): - return aesara.tensor.patternbroadcast(var, self.broadcastable) + return True + + return False + + def convert_variable(self, var): + if self.is_compatible(var.type): + # `var.type` is at least as specific as `self`, so we return + # `var` as-is + return var + elif var.type.is_compatible(self): + # `var.type` is less specific than `self`, so we convert + # `var` to `self`'s `Type`. + # Note that, in this case, `var.type != self`, because that's + # covered by the branch above. + + # Use the more specific broadcast/shape information of the two + # TODO: Why do we need/want `Rebroadcast`? It's basically just + # another `CheckAndRaise` `Op` that can be avoided entirely. + return aesara.tensor.basic.Rebroadcast( + *[(i, b) for i, b in enumerate(self.broadcastable)] + )(var) @staticmethod def may_share_memory(a, b): diff --git a/tests/graph/test_types.py b/tests/graph/test_types.py index edf5a8539c..b5ad8a8ba9 100644 --- a/tests/graph/test_types.py +++ b/tests/graph/test_types.py @@ -5,77 +5,118 @@ import aesara from aesara import scalar as aes -from aesara.graph.basic import Apply +from aesara.graph.basic import Apply, Variable from aesara.graph.op import COp -from aesara.graph.type import CDataType, CEnumType, EnumList, EnumType +from aesara.graph.type import CDataType, CEnumType, EnumList, EnumType, Type from aesara.tensor.type import TensorType, continuous_dtypes -# todo: test generic +class MyType(Type): + def __init__(self, thingy): + self.thingy = thingy + def filter(self, *args, **kwargs): + raise NotImplementedError() -class ProdOp(COp): - __props__ = () + def __eq__(self, other): + return isinstance(other, MyType) and other.thingy == self.thingy - def make_node(self, i): - return Apply(self, [i], [CDataType("void *", "py_decref")()]) + def __str__(self): + return f"R{self.thingy}" - def c_support_code(self, **kwargs): - return """ -void py_decref(void *p) { - Py_XDECREF((PyObject *)p); -} -""" + def __repr__(self): + return f"R{self.thingy}" - def c_code(self, node, name, inps, outs, sub): - return """ -Py_XDECREF(%(out)s); -%(out)s = (void *)%(inp)s; -Py_INCREF(%(inp)s); -""" % dict( - out=outs[0], inp=inps[0] - ) - def c_code_cache_version(self): - return (0,) +class MyType2(MyType): + def is_compatible(self, other): + if self.thingy <= other.thingy: + return True - def perform(self, *args, **kwargs): - raise NotImplementedError() +def test_is_compatible(): + t1 = MyType(1) + t2 = MyType(2) -class GetOp(COp): - __props__ = () + assert t1.is_compatible(t2) is None - def make_node(self, c): - return Apply(self, [c], [TensorType("float32", (False,))()]) + t1_2 = MyType(1) + assert t1.is_compatible(t1_2) - def c_support_code(self, **kwargs): - return """ -void py_decref(void *p) { - Py_XDECREF((PyObject *)p); -} -""" - def c_code(self, node, name, inps, outs, sub): - return """ -Py_XDECREF(%(out)s); -%(out)s = (PyArrayObject *)%(inp)s; -Py_INCREF(%(out)s); -""" % dict( - out=outs[0], inp=inps[0] - ) +def test_convert_variable(): + t1 = MyType(1) + v1 = Variable(MyType(1), None, None) + v2 = Variable(MyType(2), None, None) + v3 = Variable(MyType2(0), None, None) - def c_code_cache_version(self): - return (0,) + assert t1.convert_variable(v1) is v1 + assert t1.convert_variable(v2) is None - def perform(self, *args, **kwargs): - raise NotImplementedError() + with pytest.raises(NotImplementedError): + t1.convert_variable(v3) @pytest.mark.skipif( not aesara.config.cxx, reason="G++ not available, so we need to skip this test." ) def test_cdata(): + class ProdOp(COp): + __props__ = () + + def make_node(self, i): + return Apply(self, [i], [CDataType("void *", "py_decref")()]) + + def c_support_code(self, **kwargs): + return """ + void py_decref(void *p) { + Py_XDECREF((PyObject *)p); + } + """ + + def c_code(self, node, name, inps, outs, sub): + return """ + Py_XDECREF(%(out)s); + %(out)s = (void *)%(inp)s; + Py_INCREF(%(inp)s); + """ % dict( + out=outs[0], inp=inps[0] + ) + + def c_code_cache_version(self): + return (0,) + + def perform(self, *args, **kwargs): + raise NotImplementedError() + + class GetOp(COp): + __props__ = () + + def make_node(self, c): + return Apply(self, [c], [TensorType("float32", (False,))()]) + + def c_support_code(self, **kwargs): + return """ + void py_decref(void *p) { + Py_XDECREF((PyObject *)p); + } + """ + + def c_code(self, node, name, inps, outs, sub): + return """ + Py_XDECREF(%(out)s); + %(out)s = (PyArrayObject *)%(inp)s; + Py_INCREF(%(out)s); + """ % dict( + out=outs[0], inp=inps[0] + ) + + def c_code_cache_version(self): + return (0,) + + def perform(self, *args, **kwargs): + raise NotImplementedError() + i = TensorType("float32", (False,))() c = ProdOp()(i) i2 = GetOp()(c) diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py index f1e079d7bf..2452612227 100644 --- a/tests/tensor/test_type.py +++ b/tests/tensor/test_type.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import aesara.tensor as at from aesara.configdefaults import config from aesara.tensor.type import TensorType @@ -13,6 +14,42 @@ def test_numpy_dtype(): assert test_type.dtype == "int32" +def test_is_compatible(): + test_type = TensorType(config.floatX, [False, False]) + test_type2 = TensorType(config.floatX, [False, True]) + + assert test_type.is_compatible(test_type) + assert test_type.is_compatible(test_type2) + assert not test_type2.is_compatible(test_type) + + +def test_convert_variable(): + test_type = TensorType(config.floatX, [False, False]) + test_var = test_type() + + test_type2 = TensorType(config.floatX, [True, False]) + test_var2 = test_type2() + + res = test_type.convert_variable(test_var) + assert res is test_var + + res = test_type.convert_variable(test_var2) + assert res is test_var2 + + res = test_type2.convert_variable(test_var) + assert res.type == test_type2 + + test_type3 = TensorType(config.floatX, [True, False, True]) + test_var3 = test_type3() + + res = test_type2.convert_variable(test_var3) + assert res is None + + const_var = at.as_tensor([[1, 2], [3, 4]], dtype=config.floatX) + res = test_type.convert_variable(const_var) + assert res is const_var + + def test_filter_variable(): test_type = TensorType(config.floatX, []) @@ -33,6 +70,17 @@ def test_filter_variable(): test_type.filter_checks_isfinite = True test_type.filter(np.full((1, 2), np.inf, dtype=config.floatX)) + test_type2 = TensorType(config.floatX, [False, False]) + test_var = test_type() + test_var2 = test_type2() + + res = test_type.filter_variable(test_var, allow_convert=True) + assert res is test_var + + # Make sure it returns the more specific type + res = test_type.filter_variable(test_var2, allow_convert=True) + assert res.type == test_type + def test_filter_strict(): test_type = TensorType(config.floatX, [])