From fe036033d0a608f762d34dc8b3b9a328bc32fd83 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 29 Dec 2021 23:01:17 -0600 Subject: [PATCH] Add FixedShapeTensorTypes and refactor basic shape logic to use them --- aesara/compile/ops.py | 16 +++- aesara/tensor/basic.py | 31 +++++-- aesara/tensor/random/utils.py | 3 +- aesara/tensor/shape.py | 120 ++++++++++++++++++++------- aesara/tensor/type.py | 92 +++++++++++++++++++-- aesara/tensor/var.py | 31 +++++-- tests/tensor/test_shape.py | 151 ++++++++++++---------------------- tests/tensor/test_type.py | 82 +++++++++++++++++- tests/tensor/test_var.py | 18 +++- tests/test_raise_op.py | 6 +- 10 files changed, 388 insertions(+), 162 deletions(-) diff --git a/aesara/compile/ops.py b/aesara/compile/ops.py index b780c5a832..0338f390c5 100644 --- a/aesara/compile/ops.py +++ b/aesara/compile/ops.py @@ -15,7 +15,7 @@ from aesara.graph.type import CType -def register_view_op_c_code(type, code, version=()): +def register_view_op_c_code(types, code, version=()): """ Tell ViewOp how to generate C code for an Aesara Type. @@ -30,7 +30,11 @@ def register_view_op_c_code(type, code, version=()): A number indicating the version of the code, for cache. """ - ViewOp.c_code_and_version[type] = (code, version) + if not isinstance(types, list): + types = [types] + + for typ in types: + ViewOp.c_code_and_version[typ] = (code, version) class ViewOp(COp): @@ -127,7 +131,7 @@ class OutputGuard(ViewOp): _output_guard = OutputGuard() -def register_deep_copy_op_c_code(typ, code, version=()): +def register_deep_copy_op_c_code(types, code, version=()): """ Tell DeepCopyOp how to generate C code for an Aesara Type. @@ -142,7 +146,11 @@ def register_deep_copy_op_c_code(typ, code, version=()): A number indicating the version of the code, for cache. """ - DeepCopyOp.c_code_and_version[typ] = (code, version) + if not isinstance(types, list): + types = [types] + + for typ in types: + DeepCopyOp.c_code_and_version[typ] = (code, version) class DeepCopyOp(COp): diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index e5ae1cf180..928a65cfe4 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -51,6 +51,7 @@ shape_tuple, ) from aesara.tensor.type import ( + FixedShapeTensorType, TensorType, discrete_dtypes, float_dtypes, @@ -782,7 +783,7 @@ def c_code_cache_version(self): return tuple(version) -def register_rebroadcast_c_code(typ, code, version=()): +def register_rebroadcast_c_code(types, code, version=()): """ Tell Rebroadcast how to generate C code for an Aesara Type. @@ -797,11 +798,15 @@ def register_rebroadcast_c_code(typ, code, version=()): A number indicating the version of the code, for cache. """ - Rebroadcast.c_code_and_version[typ] = (code, version) + if not isinstance(types, list): + types = [types] + + for typ in types: + Rebroadcast.c_code_and_version[typ] = (code, version) register_rebroadcast_c_code( - TensorType, + [FixedShapeTensorType, TensorType], """ if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){ PyErr_Format(PyExc_ValueError, @@ -1631,11 +1636,11 @@ def make_node(self, *inputs): if inputs: dtype = inputs[0].type.dtype + otype = FixedShapeTensorType(dtype, (len(inputs),)) else: dtype = self.dtype - # bcastable = (len(inputs) == 1) - bcastable = False - otype = TensorType(broadcastable=(bcastable,), dtype=dtype) + otype = FixedShapeTensorType(dtype, 0) + return Apply(self, inputs, [otype()]) def perform(self, node, inputs, out_): @@ -2122,6 +2127,13 @@ def addbroadcast(x, *axes): A aesara tensor, which is broadcastable along the specified dimensions. """ + x = as_tensor_variable(x) + + if isinstance(x.type, FixedShapeTensorType): + if not set(i for i, b in enumerate(x.broadcastable) if b).issuperset(axes): + raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}") + return x + rval = Rebroadcast(*[(axis, True) for axis in axes])(x) return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval) @@ -2152,6 +2164,13 @@ def unbroadcast(x, *axes): A aesara tensor, which is unbroadcastable along the specified dimensions. """ + x = as_tensor_variable(x) + + if isinstance(x.type, FixedShapeTensorType): + if not set(i for i, b in enumerate(x.broadcastable) if not b).issuperset(axes): + raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}") + return x + rval = Rebroadcast(*[(axis, False) for axis in axes])(x) return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval) diff --git a/aesara/tensor/random/utils.py b/aesara/tensor/random/utils.py index ecd856d818..1ac7d7968d 100644 --- a/aesara/tensor/random/utils.py +++ b/aesara/tensor/random/utils.py @@ -11,7 +11,7 @@ from aesara.tensor.extra_ops import broadcast_to from aesara.tensor.math import maximum from aesara.tensor.shape import specify_shape -from aesara.tensor.type import int_dtypes +from aesara.tensor.type import FixedShapeTensorType, int_dtypes def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True): @@ -130,6 +130,7 @@ def normalize_size_param(size): # `Scan` performs) size = specify_shape(size, (get_vector_length(size),)) + assert isinstance(size.type, FixedShapeTensorType) assert size.dtype in int_dtypes return size diff --git a/aesara/tensor/shape.py b/aesara/tensor/shape.py index 3b7d413cb4..57418dd73a 100644 --- a/aesara/tensor/shape.py +++ b/aesara/tensor/shape.py @@ -1,23 +1,25 @@ import warnings -from typing import Dict +from numbers import Number +from typing import Dict, List, Tuple, Union import numpy as np import aesara from aesara.gradient import DisconnectedType -from aesara.graph.basic import Apply, Variable +from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import COp from aesara.graph.params_type import ParamsType from aesara.misc.safe_asarray import _asarray from aesara.scalar import int32 from aesara.tensor import _get_vector_length from aesara.tensor import basic as aet +from aesara.tensor import get_vector_length from aesara.tensor.exceptions import NotScalarConstantError -from aesara.tensor.type import TensorType, int_dtypes, tensor +from aesara.tensor.type import FixedShapeTensorType, TensorType, int_dtypes, tensor from aesara.tensor.var import TensorConstant -def register_shape_c_code(type, code, version=()): +def register_shape_c_code(types, code, version=()): """ Tell Shape Op how to generate C code for an Aesara Type. @@ -33,7 +35,11 @@ def register_shape_c_code(type, code, version=()): A number indicating the version of the code, for cache. """ - Shape.c_code_and_version[type] = (code, version) + if not isinstance(types, list): + types = [types] + + for typ in types: + Shape.c_code_and_version[typ] = (code, version) class Shape(COp): @@ -61,7 +67,8 @@ def make_node(self, x): # This will fail at execution time. if not isinstance(x, Variable): x = aet.as_tensor_variable(x) - return Apply(self, [x], [aesara.tensor.type.lvector()]) + otype = FixedShapeTensorType("int64", (x.ndim,)) + return Apply(self, [x], [otype()]) def perform(self, node, inp, out_): (x,) = inp @@ -126,8 +133,21 @@ def c_code_cache_version(self): return tuple(version) -shape = Shape() -_shape = shape # was used in the past, now use shape directly. +_shape = Shape() + + +def shape(x: Union[np.ndarray, Number, Variable]) -> Variable: + """Return the shape of `x`.""" + x = aet.as_tensor_variable(x) + x_type = x.type + + if isinstance(x_type, FixedShapeTensorType): + res = aet.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64) + else: + res = _shape(x) + + assert isinstance(res.type, FixedShapeTensorType) + return res @_get_vector_length.register(Shape) @@ -135,7 +155,7 @@ def _get_vector_length_Shape(op, var): return var.owner.inputs[0].type.ndim -def shape_tuple(x): +def shape_tuple(x: Variable) -> Tuple[Variable]: """Get a tuple of symbolic shape values. This will return a `ScalarConstant` with the value ``1`` wherever @@ -319,7 +339,7 @@ def shape_i_op(i): shape_i_op.cache = {} -def register_shape_i_c_code(typ, code, check_input, version=()): +def register_shape_i_c_code(types, code, check_input, version=()): """ Tell Shape_i how to generate C code for an Aesara Type. @@ -335,10 +355,14 @@ def register_shape_i_c_code(typ, code, check_input, version=()): A number indicating the version of the code, for cache. """ - Shape_i.c_code_and_version[typ] = (code, check_input, version) + if not isinstance(types, list): + types = [types] + + for typ in types: + Shape_i.c_code_and_version[typ] = (code, check_input, version) -def register_specify_shape_c_code(typ, code, version=(), c_support_code_apply=None): +def register_specify_shape_c_code(types, code, version=(), c_support_code_apply=None): """ Tell SpecifyShape how to generate C code for an Aesara Type. @@ -357,7 +381,11 @@ def register_specify_shape_c_code(typ, code, version=(), c_support_code_apply=No Extra code. """ - SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply) + if not isinstance(types, list): + types = [types] + + for typ in types: + SpecifyShape.c_code_and_version[typ] = (code, version, c_support_code_apply) class SpecifyShape(COp): @@ -388,22 +416,29 @@ class SpecifyShape(COp): _f16_ok = True def make_node(self, x, shape): - if not isinstance(x, Variable): - x = aet.as_tensor_variable(x) - if shape == () or shape == []: - tshape = aet.constant([], dtype="int64") + x = aet.as_tensor_variable(x) + shape = aet.as_tensor_variable(shape, ndim=1) + + if isinstance(shape, Constant): + shape = tuple(shape.data) else: - tshape = aet.as_tensor_variable(shape, ndim=1) - if tshape.dtype not in aesara.tensor.type.integer_dtypes: - raise AssertionError( - f"The `shape` must be an integer type. Got {tshape.dtype} instead." - ) - if isinstance(tshape, TensorConstant) and tshape.data.size != x.ndim: - ndim = len(tshape.data) - raise AssertionError( - f"Input `x` is {x.ndim}-dimensional and will never match a {ndim}-dimensional shape." + shape = tuple(aet.as_tensor_variable(s, ndim=0) for s in shape) + + if any(s.dtype not in aesara.tensor.type.integer_dtypes for s in shape): + raise TypeError("Shape values must be integer types") + + if len(shape) != x.ndim: + raise ValueError( + f"Input `x` is {x.ndim}-dimensional and will never match a shape of length {len(shape)}." ) - return Apply(self, [x, tshape], [x.type()]) + + if all(isinstance(s, Number) for s in shape): + out_var = FixedShapeTensorType(x.type.dtype, shape)() + else: + out_var = x.type() + + in_shape = aet.as_tensor_variable(shape, ndim=1) + return Apply(self, [x, in_shape], [out_var]) def perform(self, node, inp, out_): x, shape = inp @@ -495,7 +530,28 @@ def c_code_cache_version(self): return tuple(version) -specify_shape = SpecifyShape() +_specify_shape = SpecifyShape() + + +def specify_shape( + x: Union[np.ndarray, Number, Variable], + shape: Union[ + int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable + ], +): + """Specify a fixed shape for a `Variable`.""" + + x = aet.as_tensor_variable(x) + + try: + _ = get_vector_length(shape) + except ValueError: + raise ValueError("Shape must have fixed dimensions") + + if isinstance(shape, Constant): + shape = tuple(shape.data) + + return _specify_shape(x, shape) @_get_vector_length.register(SpecifyShape) @@ -708,7 +764,7 @@ def reshape(x, newshape, ndim=None): f" scalar. Got {newshape} after conversion to a vector." ) try: - ndim = aet.get_vector_length(newshape) + ndim = get_vector_length(newshape) except ValueError: raise ValueError( f"The length of the provided shape ({newshape}) cannot " @@ -791,7 +847,7 @@ def shape_padaxis(t, axis): register_shape_c_code( - TensorType, + [FixedShapeTensorType, TensorType], """ npy_intp shape[] = {PyArray_NDIM(%(iname)s)}; if(%(oname)s == NULL || (PyArray_DIMS(%(oname)s)[0] != shape[0])) @@ -809,7 +865,7 @@ def shape_padaxis(t, axis): register_shape_i_c_code( - TensorType, + [FixedShapeTensorType, TensorType], """ if(!%(oname)s) %(oname)s=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0); @@ -827,7 +883,7 @@ def shape_padaxis(t, axis): register_specify_shape_c_code( - TensorType, + [FixedShapeTensorType, TensorType], """ if (PyArray_NDIM(%(iname)s) != PyArray_DIMS(%(shape)s)[0]) { PyErr_Format(PyExc_AssertionError, diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 75a7bc8650..367484065d 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -267,11 +267,10 @@ def __eq__(self, other): Compare True iff other is the same kind of TensorType. """ - return ( - type(self) == type(other) - and other.dtype == self.dtype - and other.broadcastable == self.broadcastable - ) + if type(self) != type(other): + return NotImplemented + + return other.dtype == self.dtype and other.broadcastable == self.broadcastable def is_compatible(self, otype): if ( @@ -596,6 +595,85 @@ def get_size(self, shape_info): return np.dtype(self.dtype).itemsize +class FixedShapeTensorType(TensorType): + """A symbolic type representing a tensor/array with a fixed shape.""" + + __props__ = ("shape",) + + def __init__(self, dtype, shape, name=None): + if not ( + isinstance(shape, (int, np.integer)) + or all(isinstance(s, (int, np.integer)) for s in shape) + ): + raise TypeError(f"{shape} is not a valid shape") + + if isinstance(shape, (list, tuple)): + self.shape = tuple(int(s) for s in shape) + bcast = tuple(s == 1 for s in shape) + else: + self.shape = int(shape) + bcast = () + + super().__init__(dtype, bcast, name=name) + + def __str__(self): + return repr(self) + + def __repr__(self): + return f"FixedShapeTensorType({self.dtype}, {self.shape})" + + def value_zeros(self): + """Create an numpy ndarray full of 0 values.""" + return np.zeros(self.shape, dtype=self.dtype) + + def __eq__(self, other): + if type(other) == TensorType: + # This handles `TensorType` case in which all dimensions are + # broadcastable (i.e. the shape is known) + return ( + other.dtype == self.dtype + and len(other.broadcastable) == len(self.shape) + and all(other.broadcastable) + and all(s == 1 for s in self.shape) + ) + + if type(self) != type(other): + return NotImplemented + + return other.dtype == self.dtype and other.shape == self.shape + + def is_compatible(self, otype): + return self == otype + + def __hash__(self): + return hash((type(self), self.dtype, self.shape)) + + def clone(self, dtype=None, broadcastable=None, shape=None): + if dtype is None: + dtype = self.dtype + + if shape is None: + shape = self.shape + + if broadcastable is not None: + # We accept this argument for compatibility with the base type, + # `TensorType`, and we use it as an assert/check on `shape` + assert len(broadcastable) == len(shape) + assert all( + s == 1 if b else (s > 1 or s == 0) for b, s in zip(broadcastable, shape) + ) + + return self.__class__(dtype, shape, name=self.name) + + def convert_variable(self, var): + if self.is_compatible(var.type): + return var + elif var.type.is_compatible(self): + from aesara.tensor.shape import specify_shape + + return specify_shape(var, self.shape) + + def values_eq_approx( a, b, allow_remove_inf=False, allow_remove_nan=False, rtol=None, atol=None ): @@ -690,7 +768,7 @@ def values_eq_approx_always_true(a, b): aesara.compile.register_view_op_c_code( - TensorType, + [FixedShapeTensorType, TensorType], """ Py_XDECREF(%(oname)s); %(oname)s = %(iname)s; @@ -701,7 +779,7 @@ def values_eq_approx_always_true(a, b): aesara.compile.register_deep_copy_op_c_code( - TensorType, + [FixedShapeTensorType, TensorType], """ int alloc = %(oname)s == NULL; for(int i=0; !alloc && i Optional[Number]: return None -class TensorConstant(TensorVariable, Constant): - """Subclass to add the tensor operators to the basic `Constant` class. +class TensorConstant(FixedShapeTensorVariable, Constant): + """Subclass to add the tensor operators to the basic `Constant` class.""" - To create a TensorConstant, use the `constant` function in this module. + def __init__(self, type, data, name=None): + data_shape = np.shape(data) - """ + if not isinstance(type, FixedShapeTensorType): + type = FixedShapeTensorType(type.dtype, data_shape, name=type.name) + + if np.shape(data) != type.shape: + raise ValueError( + f"Shape of data ({data_shape}) does not match shape of type ({type.shape})" + ) - def __init__(self, type, data, name=None): Constant.__init__(self, type, data, name) def __str__(self): diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index acd33be04e..d3f11a0f3b 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -7,14 +7,15 @@ from aesara.configdefaults import config from aesara.graph.fg import FunctionGraph from aesara.misc.safe_asarray import _asarray -from aesara.tensor import get_vector_length -from aesara.tensor.basic import MakeVector, as_tensor_variable, constant +from aesara.tensor import as_tensor_variable, get_vector_length +from aesara.tensor.basic import MakeVector, constant from aesara.tensor.basic_opt import ShapeFeature from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.shape import ( Reshape, Shape_i, SpecifyShape, + _specify_shape, reshape, shape, shape_i, @@ -22,11 +23,13 @@ ) from aesara.tensor.subtensor import Subtensor from aesara.tensor.type import ( + FixedShapeTensorType, TensorType, dmatrix, dtensor4, dvector, fvector, + iscalar, ivector, matrix, scalar, @@ -34,6 +37,7 @@ vector, ) from aesara.tensor.type_other import NoneConst +from aesara.tensor.var import FixedShapeTensorVariable from aesara.typed_list import make_list from tests import unittest_tools as utt from tests.tensor.utils import eval_outputs, random @@ -314,27 +318,26 @@ class TestSpecifyShape(utt.InferShapeTester): mode = None input_type = TensorType - def shortDescription(self): - return None - def test_check_inputs(self): - with pytest.raises(AssertionError, match="must be an integer type"): + with pytest.raises(TypeError, match="must be integer types"): specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3)) - specify_shape([[1, 2, 3], [4, 5, 6]], (2, 3)) - # Incompatible dimensionality is detected right away - with pytest.raises(AssertionError, match="will never match"): - specify_shape( - matrix(), - [ - 4, - ], - ) + with pytest.raises(TypeError, match="must be integer types"): + _specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3)) + + with pytest.raises(ValueError, match="will never match"): + specify_shape(matrix(), [4]) + + with pytest.raises(ValueError, match="will never match"): + _specify_shape(matrix(), [4]) + + with pytest.raises(ValueError, match="must have fixed dimensions"): + specify_shape(matrix(), vector(dtype="int32")) def test_scalar_shapes(self): - with pytest.raises(AssertionError, match="will never match"): + with pytest.raises(ValueError, match="will never match"): specify_shape(vector(), shape=()) - with pytest.raises(AssertionError, match="will never match"): + with pytest.raises(ValueError, match="will never match"): specify_shape(matrix(), shape=[]) x = scalar() @@ -342,41 +345,43 @@ def test_scalar_shapes(self): f = aesara.function([x], y, mode=self.mode) assert f(15) == 15 + def test_fixed_shapes(self): + x = vector() + shape = as_tensor_variable([2]) + y = specify_shape(x, shape) + assert isinstance(y.type, FixedShapeTensorType) + assert y.shape.equals(shape) + def test_python_perform(self): + """Test the Python `Op.perform` implementation.""" x = scalar() - s = vector(dtype="int32") + s = as_tensor_variable([], dtype=np.int32) y = specify_shape(x, s) - f = aesara.function([x, s], y, mode=Mode("py")) - assert f(12, ()) == 12 - with pytest.raises( - AssertionError, - match=r"Got 0 dimensions \(shape \(\)\), expected 1 dimensions with shape \(2,\).", - ): - f(12, (2,)) + f = aesara.function([x], y, mode=Mode("py")) + assert f(12) == 12 - x = matrix() - s = vector(dtype="int32") - y = specify_shape(x, s) - f = aesara.function([x, s], y, mode=Mode("py")) - f(np.ones((2, 3)).astype(config.floatX), (2, 3)) - with pytest.raises( - AssertionError, match=r"Got shape \(3, 4\), expected \(2, 3\)." - ): - f(np.ones((3, 4)).astype(config.floatX), (2, 3)) + x = vector() + s1 = iscalar() + shape = as_tensor_variable([s1]) + y = specify_shape(x, shape) + f = aesara.function([x, shape], y, mode=Mode("py")) + assert f([1], (1,)) == [1] + + with pytest.raises(AssertionError, match="SpecifyShape:.*"): + assert f([1], (2,)) == [1] def test_bad_shape(self): - # Test that at run time we raise an exception when the shape - # is not the one specified + """Test that at run-time we raise an exception when the shape is not the one specified.""" specify_shape = SpecifyShape() x = vector() xval = np.random.random((2)).astype(config.floatX) f = aesara.function([x], specify_shape(x, [2]), mode=self.mode) - f(xval) + + assert np.array_equal(f(xval), xval) + xval = np.random.random((3)).astype(config.floatX) - expected = r"(Got shape \(3,\), expected \(2,\))" - expected += r"|(dim 0 of input has shape 3, expected 2.)" - with pytest.raises(AssertionError, match=expected): + with pytest.raises(AssertionError, match="SpecifyShape:.*"): f(xval) assert isinstance( @@ -395,77 +400,23 @@ def test_bad_shape(self): .type, self.input_type, ) - f(xval) + + assert np.array_equal(f(xval), xval) + for shape_ in [(4, 3), (2, 8)]: xval = np.random.random(shape_).astype(config.floatX) - s_exp = str(shape_).replace("(", r"\(").replace(")", r"\)") - expected = rf"(Got shape {s_exp}, expected \(2, 3\).)" - expected += r"|(dim 0 of input has shape 4, expected 2)" - expected += r"|(dim 1 of input has shape 8, expected 3)" - with pytest.raises(AssertionError, match=expected): + with pytest.raises(AssertionError, match="SpecifyShape:.*"): f(xval) - def test_bad_number_of_shape(self): - # Test that the number of dimensions provided is good - specify_shape = SpecifyShape() - - x = vector() - shape_vec = ivector() - xval = np.random.random((2)).astype(config.floatX) - with pytest.raises(AssertionError, match="will never match"): - specify_shape(x, []) - with pytest.raises(AssertionError, match="will never match"): - specify_shape(x, [2, 2]) - - f = aesara.function([x, shape_vec], specify_shape(x, shape_vec), mode=self.mode) - assert isinstance( - [n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0] - .inputs[0] - .type, - self.input_type, - ) - expected = r"(Got 1 dimensions \(shape \(2,\)\), expected 0 dimensions with shape \(\).)" - expected += r"|(Got 1 dimensions, expected 0 dimensions.)" - with pytest.raises(AssertionError, match=expected): - f(xval, []) - expected = r"(Got 1 dimensions \(shape \(2,\)\), expected 2 dimensions with shape \(2, 2\).)" - expected += r"|(SpecifyShape: Got 1 dimensions, expected 2 dimensions.)" - with pytest.raises(AssertionError, match=expected): - f(xval, [2, 2]) - - x = matrix() - xval = np.random.random((2, 3)).astype(config.floatX) - for shape_ in [(), (1,), (2, 3, 4)]: - with pytest.raises(AssertionError, match="will never match"): - specify_shape(x, shape_) - f = aesara.function( - [x, shape_vec], specify_shape(x, shape_vec), mode=self.mode - ) - assert isinstance( - [ - n - for n in f.maker.fgraph.toposort() - if isinstance(n.op, SpecifyShape) - ][0] - .inputs[0] - .type, - self.input_type, - ) - s_exp = str(shape_).replace("(", r"\(").replace(")", r"\)") - expected = rf"(Got 2 dimensions \(shape \(2, 3\)\), expected {len(shape_)} dimensions with shape {s_exp}.)" - expected += rf"|(SpecifyShape: Got 2 dimensions, expected {len(shape_)} dimensions.)" - with pytest.raises(AssertionError, match=expected): - f(xval, shape_) - def test_infer_shape(self): rng = np.random.default_rng(3453) adtens4 = dtensor4() - aivec = ivector() + aivec = FixedShapeTensorVariable(FixedShapeTensorType("int64", (4,))) aivec_val = [3, 4, 2, 5] adtens4_val = rng.random(aivec_val) self._compile_and_check( [adtens4, aivec], - [SpecifyShape()(adtens4, aivec)], + [specify_shape(adtens4, aivec)], [adtens4_val, aivec_val], SpecifyShape, ) diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py index 2452612227..0eddcdb271 100644 --- a/tests/tensor/test_type.py +++ b/tests/tensor/test_type.py @@ -6,7 +6,8 @@ import aesara.tensor as at from aesara.configdefaults import config -from aesara.tensor.type import TensorType +from aesara.tensor.shape import SpecifyShape +from aesara.tensor.type import FixedShapeTensorType, TensorType def test_numpy_dtype(): @@ -168,3 +169,82 @@ def test_tensor_values_eq_approx(): b = np.asarray([-np.inf, -1, 0, 1, np.inf, 6]) with pytest.warns(RuntimeWarning): assert not TensorType.values_eq_approx(a, b, allow_remove_nan=False) + + +def test_FixedShapeTensorType_basic(): + t1 = FixedShapeTensorType("float64", (1, 1)) + assert t1.broadcastable == (True, True) + + t1 = FixedShapeTensorType("float64", (2, 3)) + assert t1.broadcastable == (False, False) + assert t1.value_zeros().shape == t1.shape + + assert str(t1) == "FixedShapeTensorType(float64, (2, 3))" + + t1 = FixedShapeTensorType("float64", (1,)) + assert t1.broadcastable == (True,) + + t2 = t1.clone() + assert t1 is not t2 + assert t1 == t2 + + t2 = t1.clone(dtype="float32", shape=(2, 4)) + assert t2.dtype == "float32" + assert t2.shape == (2, 4) + + t2 = t1.clone(dtype="float32", shape=(2, 4), broadcastable=(False, False)) + assert t2.shape == (2, 4) + + t2 = t1.clone(dtype="float32", shape=(2, 1), broadcastable=(False, True)) + assert t2.shape == (2, 1) + + with pytest.raises(AssertionError): + t1.clone(dtype="float32", shape=(2, 4), broadcastable=(True, False)) + + with pytest.raises(AssertionError): + t1.clone(dtype="float32", shape=(2, 1), broadcastable=(False, False)) + + with pytest.raises(AssertionError): + t1.clone(dtype="float32", shape=(2, 4), broadcastable=(False,)) + + +def test_FixedShapeTensorType_comparisons(): + t1 = TensorType("float64", (True, True)) + t2 = FixedShapeTensorType("float64", (1, 1)) + assert t1 == t2 + + assert t1.is_compatible(t2) + assert t2.is_compatible(t1) + + # TODO FIXME: This should be true. + # assert hash(t1) == hash(t2) + + t3 = TensorType("float64", (True, False)) + t4 = FixedShapeTensorType("float64", (1, 2)) + assert t3 != t4 + + t1 = TensorType("float64", (True, True)) + t2 = FixedShapeTensorType("float64", ()) + assert t1 != t2 + + +def test_FixedShapeTensorType_convert_variable(): + # These are equivalent types + t1 = TensorType("float64", (True, True)) + t2 = FixedShapeTensorType("float64", (1, 1)) + + t2_var = t2() + res = t2.convert_variable(t2_var) + assert res is t2_var + + res = t1.convert_variable(t2_var) + assert res is t2_var + + t1_var = t1() + res = t2.convert_variable(t1_var) + assert res is t1_var + + t3 = TensorType("float64", (False, True)) + t3_var = t3() + res = t2.convert_variable(t3_var) + assert isinstance(res.owner.op, SpecifyShape) diff --git a/tests/tensor/test_var.py b/tests/tensor/test_var.py index 36a930599f..5d5658ce94 100644 --- a/tests/tensor/test_var.py +++ b/tests/tensor/test_var.py @@ -4,11 +4,13 @@ import aesara import tests.unittest_tools as utt -from aesara.graph.basic import equal_computations +from aesara.graph.basic import Constant, equal_computations +from aesara.tensor import get_vector_length from aesara.tensor.elemwise import DimShuffle from aesara.tensor.math import dot from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor from aesara.tensor.type import ( + FixedShapeTensorType, TensorType, cscalar, dmatrix, @@ -20,7 +22,7 @@ tensor3, ) from aesara.tensor.type_other import MakeSlice -from aesara.tensor.var import TensorConstant +from aesara.tensor.var import FixedShapeTensorVariable, TensorConstant @pytest.mark.parametrize( @@ -217,3 +219,15 @@ def test__getitem__newaxis(x, indices, new_order): assert isinstance(res.owner.op, DimShuffle) assert res.broadcastable == tuple(i == "x" for i in new_order) assert res.owner.op.new_order == new_order + + +def test_FixedShapeTensorVariable_basic(): + x = FixedShapeTensorVariable(FixedShapeTensorType("int64", (4,))) + assert isinstance(x.shape, Constant) + assert np.array_equal(x.shape.data, (4,)) + + +def test_FixedShapeTensorVariable_get_vector_length(): + x = FixedShapeTensorVariable(FixedShapeTensorType("int64", (4,))) + res = get_vector_length(x) + assert res == 4 diff --git a/tests/test_raise_op.py b/tests/test_raise_op.py index ba18fe1bff..80bf8f2144 100644 --- a/tests/test_raise_op.py +++ b/tests/test_raise_op.py @@ -79,10 +79,10 @@ def test_CheckAndRaise_basic_c(linker): with pytest.raises(CustomException, match=exc_msg): y_fn(0) - y = check_and_raise(at.as_tensor(1), conds) - y_fn = aesara.function([conds], y.shape, mode=Mode(linker, OPT_FAST_RUN)) + y = check_and_raise(x, conds) + y_fn = aesara.function([conds, x], y.shape, mode=Mode(linker, OPT_FAST_RUN)) - assert np.array_equal(y_fn(0), []) + assert np.array_equal(y_fn(0, 0), []) y = check_and_raise(x, at.as_tensor(0)) y_grad = aesara.grad(y, [x])