Skip to content

Commit

Permalink
Add Type.is_compatible and update type conversion
Browse files Browse the repository at this point in the history
These changes enforce a strict narrowing-only conversion policy; i.e. `Type`s
can only be converted to equal or more specific types.
  • Loading branch information
brandonwillard committed Dec 31, 2021
1 parent de52629 commit dd6973d
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 87 deletions.
67 changes: 43 additions & 24 deletions aesara/graph/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ class Type(MetaObject):
The `Type` that will be created by a call to `Type.make_constant`.
"""

def is_compatible(self, otype: "Type") -> Optional[bool]:
"""Determine if another `Type` is compatible with this `Type`.
Compatibility is determined by the ability to replace one `Type` with
another. In this case, ``t1.is_compatible(t2) == True`` implies that
``t1`` can be replaced with ``t2``.
In general, a type can be replaced by a more specific type, so this
compatibility check is really a type-level ordering or
`issubclass`-like check.
"""
if self == otype:
return True

# Indeterminate
return None

@abstractmethod
def filter(
self, data: D, strict: bool = False, allow_downcast: Optional[bool] = None
Expand Down Expand Up @@ -101,14 +119,9 @@ def filter_inplace(
def filter_variable(
self, other: Union[Variable, D], allow_convert: bool = True
) -> Variable:
r"""Convert a symbolic variable into this `Type`, if compatible.
For the moment, the only `Type`\s compatible with one another are
`TensorType` and `GpuArrayType`, provided they have the same number of
dimensions, same broadcasting pattern, and same dtype.
If `Type`\s are not compatible, a ``TypeError`` should be raised.
r"""Convert a `other` into a `Variable` with a `Type` that's compatible with `self`.
If the involved `Type`\s are not compatible, a `TypeError` will be raised.
"""
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
Expand All @@ -122,30 +135,36 @@ def filter_variable(

if other.type != self:
raise TypeError(
"Cannot convert Type %(othertype)s "
"(of Variable %(other)s) into Type %(self)s. "
"You can try to manually convert %(other)s into a %(self)s."
% dict(othertype=other.type, other=other, self=self)
f"Cannot convert Type {other.type} "
f"(of Variable {other}) into Type {self}. "
f"You can try to manually convert {other} into a {self}."
)
return other

def convert_variable(self, var: Union[Variable, D]) -> Optional[Variable]:
"""Patch a variable so that its `Type` will match ``self``, if possible.
If the variable can't be converted, this should return None.
def convert_variable(self, var: Variable) -> Optional[Variable]:
"""Produce a `Variable` that's compatible with both `self` and `var.type`, if possible.
The conversion can only happen if the following implication is
true for all possible `val`.
A compatible `Variable` is a `Variable` with a `Type` that's the
"narrower" of `self` and `var.type`.
self.is_valid_value(val) => var.type.is_valid_value(val)
For the majority of types this means that you can only have
non-broadcastable dimensions become broadcastable and not the
inverse.
The default is to not convert anything which is always safe.
If a compatible `Type` cannot be found, this method will return
``None``.
"""
var_type = var.type

if self.is_compatible(var_type):
# `var.type` is at least as specific as `self`, so we return it
# as-is
return var
elif var_type.is_compatible(self):
# `var.type` is less specific than `self`, so we need to convert it
# to `self`'s `Type`.
#
# Note that, in this case, `var.type != self`, because equality is
# covered by the branch above.
raise NotImplementedError()

return None

def is_valid_value(self, data: D) -> bool:
Expand Down
45 changes: 29 additions & 16 deletions aesara/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,6 @@ def filter(self, data, strict=False, allow_downcast=None):
return data

def filter_variable(self, other, allow_convert=True):
"""
Convert a symbolic Variable into a TensorType, if compatible.
For the moment, only a TensorType and GpuArrayType will be
converted, provided they have the same number of dimensions
and dtype and have "compatible" broadcastable pattern.
"""
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
Expand All @@ -235,9 +227,8 @@ def filter_variable(self, other, allow_convert=True):
return other

if allow_convert:
# Attempt safe broadcast conversion.
other2 = self.convert_variable(other)
if other2 is not None and other2.type == self:
if other2 is not None:
return other2

raise TypeError(
Expand Down Expand Up @@ -282,17 +273,39 @@ def __eq__(self, other):
and other.broadcastable == self.broadcastable
)

def convert_variable(self, var):
def is_compatible(self, otype):
if (
isinstance(self, type(var.type))
and self.dtype == var.type.dtype # noqa
and self.ndim == var.type.ndim
isinstance(otype, type(self))
and otype.dtype == self.dtype
and self.ndim == otype.ndim
# `otype` is allowed to be as or more shape-specific than `self`,
# but not less
and all(
sb == ob or ob
for sb, ob in zip(self.broadcastable, var.type.broadcastable)
for sb, ob in zip(self.broadcastable, otype.broadcastable)
)
):
return aesara.tensor.patternbroadcast(var, self.broadcastable)
return True

return False

def convert_variable(self, var):
if self.is_compatible(var.type):
# `var.type` is at least as specific as `self`, so we return
# `var` as-is
return var
elif var.type.is_compatible(self):
# `var.type` is less specific than `self`, so we convert
# `var` to `self`'s `Type`.
# Note that, in this case, `var.type != self`, because that's
# covered by the branch above.

# Use the more specific broadcast/shape information of the two
# TODO: Why do we need/want `Rebroadcast`? It's basically just
# another `CheckAndRaise` `Op` that can be avoided entirely.
return aesara.tensor.basic.Rebroadcast(
*[(i, b) for i, b in enumerate(self.broadcastable)]
)(var)

@staticmethod
def may_share_memory(a, b):
Expand Down
135 changes: 88 additions & 47 deletions tests/graph/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,118 @@

import aesara
from aesara import scalar as aes
from aesara.graph.basic import Apply
from aesara.graph.basic import Apply, Variable
from aesara.graph.op import COp
from aesara.graph.type import CDataType, CEnumType, EnumList, EnumType
from aesara.graph.type import CDataType, CEnumType, EnumList, EnumType, Type
from aesara.tensor.type import TensorType, continuous_dtypes


# todo: test generic
class MyType(Type):
def __init__(self, thingy):
self.thingy = thingy

def filter(self, *args, **kwargs):
raise NotImplementedError()

class ProdOp(COp):
__props__ = ()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy

def make_node(self, i):
return Apply(self, [i], [CDataType("void *", "py_decref")()])
def __str__(self):
return f"R{self.thingy}"

def c_support_code(self, **kwargs):
return """
void py_decref(void *p) {
Py_XDECREF((PyObject *)p);
}
"""
def __repr__(self):
return f"R{self.thingy}"

def c_code(self, node, name, inps, outs, sub):
return """
Py_XDECREF(%(out)s);
%(out)s = (void *)%(inp)s;
Py_INCREF(%(inp)s);
""" % dict(
out=outs[0], inp=inps[0]
)

def c_code_cache_version(self):
return (0,)
class MyType2(MyType):
def is_compatible(self, other):
if self.thingy <= other.thingy:
return True

def perform(self, *args, **kwargs):
raise NotImplementedError()

def test_is_compatible():
t1 = MyType(1)
t2 = MyType(2)

class GetOp(COp):
__props__ = ()
assert t1.is_compatible(t2) is None

def make_node(self, c):
return Apply(self, [c], [TensorType("float32", (False,))()])
t1_2 = MyType(1)
assert t1.is_compatible(t1_2)

def c_support_code(self, **kwargs):
return """
void py_decref(void *p) {
Py_XDECREF((PyObject *)p);
}
"""

def c_code(self, node, name, inps, outs, sub):
return """
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)%(inp)s;
Py_INCREF(%(out)s);
""" % dict(
out=outs[0], inp=inps[0]
)
def test_convert_variable():
t1 = MyType(1)
v1 = Variable(MyType(1), None, None)
v2 = Variable(MyType(2), None, None)
v3 = Variable(MyType2(0), None, None)

def c_code_cache_version(self):
return (0,)
assert t1.convert_variable(v1) is v1
assert t1.convert_variable(v2) is None

def perform(self, *args, **kwargs):
raise NotImplementedError()
with pytest.raises(NotImplementedError):
t1.convert_variable(v3)


@pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_cdata():
class ProdOp(COp):
__props__ = ()

def make_node(self, i):
return Apply(self, [i], [CDataType("void *", "py_decref")()])

def c_support_code(self, **kwargs):
return """
void py_decref(void *p) {
Py_XDECREF((PyObject *)p);
}
"""

def c_code(self, node, name, inps, outs, sub):
return """
Py_XDECREF(%(out)s);
%(out)s = (void *)%(inp)s;
Py_INCREF(%(inp)s);
""" % dict(
out=outs[0], inp=inps[0]
)

def c_code_cache_version(self):
return (0,)

def perform(self, *args, **kwargs):
raise NotImplementedError()

class GetOp(COp):
__props__ = ()

def make_node(self, c):
return Apply(self, [c], [TensorType("float32", (False,))()])

def c_support_code(self, **kwargs):
return """
void py_decref(void *p) {
Py_XDECREF((PyObject *)p);
}
"""

def c_code(self, node, name, inps, outs, sub):
return """
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)%(inp)s;
Py_INCREF(%(out)s);
""" % dict(
out=outs[0], inp=inps[0]
)

def c_code_cache_version(self):
return (0,)

def perform(self, *args, **kwargs):
raise NotImplementedError()

i = TensorType("float32", (False,))()
c = ProdOp()(i)
i2 = GetOp()(c)
Expand Down
Loading

0 comments on commit dd6973d

Please sign in to comment.