diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index bcaa8acbe3..d4c434e218 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -51,6 +51,7 @@ shape_tuple, ) from aesara.tensor.type import ( + FixedShapeTensorType, TensorType, discrete_dtypes, float_dtypes, @@ -1633,11 +1634,11 @@ def make_node(self, *inputs): if inputs: dtype = inputs[0].type.dtype + otype = FixedShapeTensorType(dtype, (len(inputs),)) else: dtype = self.dtype - # bcastable = (len(inputs) == 1) - bcastable = False - otype = TensorType(broadcastable=(bcastable,), dtype=dtype) + otype = FixedShapeTensorType(dtype, 0) + return Apply(self, inputs, [otype()]) def perform(self, node, inputs, out_): @@ -2131,6 +2132,13 @@ def addbroadcast(x, *axes): A aesara tensor, which is broadcastable along the specified dimensions. """ + x = as_tensor_variable(x) + + if isinstance(x.type, FixedShapeTensorType): + if not set(i for i, b in enumerate(x.broadcastable) if b).issuperset(axes): + raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}") + return x + rval = Rebroadcast(*[(axis, True) for axis in axes])(x) return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval) @@ -2161,6 +2169,13 @@ def unbroadcast(x, *axes): A aesara tensor, which is unbroadcastable along the specified dimensions. """ + x = as_tensor_variable(x) + + if isinstance(x.type, FixedShapeTensorType): + if not set(i for i, b in enumerate(x.broadcastable) if not b).issuperset(axes): + raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}") + return x + rval = Rebroadcast(*[(axis, False) for axis in axes])(x) return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval) diff --git a/aesara/tensor/random/utils.py b/aesara/tensor/random/utils.py index ecd856d818..1ac7d7968d 100644 --- a/aesara/tensor/random/utils.py +++ b/aesara/tensor/random/utils.py @@ -11,7 +11,7 @@ from aesara.tensor.extra_ops import broadcast_to from aesara.tensor.math import maximum from aesara.tensor.shape import specify_shape -from aesara.tensor.type import int_dtypes +from aesara.tensor.type import FixedShapeTensorType, int_dtypes def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True): @@ -130,6 +130,7 @@ def normalize_size_param(size): # `Scan` performs) size = specify_shape(size, (get_vector_length(size),)) + assert isinstance(size.type, FixedShapeTensorType) assert size.dtype in int_dtypes return size diff --git a/aesara/tensor/shape.py b/aesara/tensor/shape.py index 3b7d413cb4..82046228b8 100644 --- a/aesara/tensor/shape.py +++ b/aesara/tensor/shape.py @@ -1,5 +1,5 @@ import warnings -from typing import Dict +from typing import Dict, List, Tuple, Union import numpy as np @@ -13,7 +13,7 @@ from aesara.tensor import _get_vector_length from aesara.tensor import basic as aet from aesara.tensor.exceptions import NotScalarConstantError -from aesara.tensor.type import TensorType, int_dtypes, tensor +from aesara.tensor.type import FixedShapeTensorType, TensorType, int_dtypes, tensor from aesara.tensor.var import TensorConstant @@ -61,7 +61,8 @@ def make_node(self, x): # This will fail at execution time. if not isinstance(x, Variable): x = aet.as_tensor_variable(x) - return Apply(self, [x], [aesara.tensor.type.lvector()]) + otype = FixedShapeTensorType("int64", (x.ndim,)) + return Apply(self, [x], [otype()]) def perform(self, node, inp, out_): (x,) = inp @@ -126,8 +127,16 @@ def c_code_cache_version(self): return tuple(version) -shape = Shape() -_shape = shape # was used in the past, now use shape directly. +_shape = Shape() + + +def shape(x: Variable) -> Variable: + """Return the shape of `x`.""" + x_type = getattr(x, "type", None) + if x_type and isinstance(x_type, FixedShapeTensorType): + return aet.as_tensor_variable(x_type.shape) + else: + return _shape(x) @_get_vector_length.register(Shape) @@ -135,7 +144,7 @@ def _get_vector_length_Shape(op, var): return var.owner.inputs[0].type.ndim -def shape_tuple(x): +def shape_tuple(x: Variable) -> Tuple[Variable]: """Get a tuple of symbolic shape values. This will return a `ScalarConstant` with the value ``1`` wherever @@ -495,7 +504,36 @@ def c_code_cache_version(self): return tuple(version) -specify_shape = SpecifyShape() +_specify_shape = SpecifyShape() + + +def specify_shape( + x: Variable, + shape: Union[ + int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable + ], +): + """Specify a fixed shape for a `Variable`.""" + + x = aet.as_tensor_variable(x) + + if isinstance(x.type, FixedShapeTensorType): + # The `Variable` is already a fixed shape `Type`, so + # there's no need to put a useless `Op` in the graph + return x + + if isinstance(shape, int) or ( + isinstance(shape, (list, tuple)) and all(isinstance(s, int) for s in shape) + ): + x_type = FixedShapeTensorType(x.type.dtype, shape, name=x.type.name) + x_new = x_type.convert_variable(x) + + if x_new is None: + raise ValueError(f"{x} cannot be assigned the shape {shape}") + + return x_new + + return _specify_shape(x, shape) @_get_vector_length.register(SpecifyShape) diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 37c5b9d8ea..5c93a4d7aa 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -5,7 +5,7 @@ import aesara from aesara import scalar as aes from aesara.configdefaults import config -from aesara.graph.basic import Constant, Variable +from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.type import CType from aesara.misc.safe_asarray import _asarray from aesara.utils import apply_across_args @@ -603,6 +603,76 @@ def get_size(self, shape_info): return np.dtype(self.dtype).itemsize +class FixedShapeTensorType(TensorType): + """A symbolic type representing a tensor/array with a fixed shape.""" + + __props__ = ("shape",) + + def __init__(self, dtype, shape, name=None): + if not (isinstance(shape, int) or all(isinstance(s, int) for s in shape)): + raise TypeError(f"{shape} is not a valid shape") + + if isinstance(shape, (list, tuple)): + self.shape = tuple(shape) + bcast = tuple(s == 1 for s in shape) + else: + self.shape = int(shape) + bcast = () + + super().__init__(dtype, bcast, name=name) + + def __str__(self): + return repr(self) + + def __repr__(self): + return f"FixedShapeTensorType({self.dtype}, {self.shape})" + + def value_zeros(self): + """Create an numpy ndarray full of 0 values.""" + return np.zeros(self.shape, dtype=self.dtype) + + def __eq__(self, other): + return ( + type(self) == type(other) + and other.dtype == self.dtype + and other.shape == self.shape + ) + + def is_compatible(self, otype): + # Fixed-shape types can only be replaced by equal fixed-shape types + return self == otype + + def __hash__(self): + return hash((type(self), self.dtype, self.shape)) + + def clone(self, dtype=None, shape=None): + if dtype is None: + dtype = self.dtype + if shape is None: + shape = self.shape + return self.__class__(dtype, shape, name=self.name) + + def convert_variable(self, var): + if self == var.type: + return var + + if isinstance(var.type, TensorType) and var.type.is_compatible(self): + new_type = FixedShapeTensorType( + var.type.dtype, self.shape, name=var.type.name + ) + new_var = new_type() + new_var.name = var.name + + if var.owner: + outputs_new = list(var.owner.outputs) + outputs_new[var.index] = new_var + # Construct a new graph in which the fixed-shape tensor is the + # output + _ = Apply(var.owner.op, var.owner.inputs, outputs_new) + + return new_var + + def values_eq_approx( a, b, allow_remove_inf=False, allow_remove_nan=False, rtol=None, atol=None ): diff --git a/aesara/tensor/var.py b/aesara/tensor/var.py index a4589c0b4f..3b2a029b01 100644 --- a/aesara/tensor/var.py +++ b/aesara/tensor/var.py @@ -11,8 +11,9 @@ from aesara.configdefaults import config from aesara.graph.basic import Constant, Variable from aesara.scalar import ComplexError, IntegerDivisionError +from aesara.tensor import _get_vector_length from aesara.tensor.exceptions import AdvancedIndexingError -from aesara.tensor.type import TensorType +from aesara.tensor.type import FixedShapeTensorType, TensorType from aesara.tensor.utils import hash_from_ndarray @@ -856,7 +857,19 @@ def __init__(self, type, owner=None, index=None, name=None): pdb.set_trace() +class FixedShapeTensorVariable(TensorVariable): + @property + def shape(self): + return self.type.shape + + +@_get_vector_length.register(FixedShapeTensorVariable) +def _get_vector_length_FixedShapeTensorVariable(op_or_var, var): + return var.size + + TensorType.Variable = TensorVariable +FixedShapeTensorType.Variable = FixedShapeTensorVariable class TensorConstantSignature(tuple): @@ -973,14 +986,20 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]: return None -class TensorConstant(TensorVariable, Constant): - """Subclass to add the tensor operators to the basic `Constant` class. +class TensorConstant(FixedShapeTensorVariable, Constant): + """Subclass to add the tensor operators to the basic `Constant` class.""" - To create a TensorConstant, use the `constant` function in this module. + def __init__(self, type, data, name=None): + data_shape = np.shape(data) - """ + if not isinstance(type, FixedShapeTensorType): + type = FixedShapeTensorType(type.dtype, data_shape, name=type.name) + + if np.shape(data) != type.shape: + raise ValueError( + f"Shape of data ({data_shape}) does not match shape of type ({type.shape})" + ) - def __init__(self, type, data, name=None): Constant.__init__(self, type, data, name) def __str__(self): diff --git a/tests/test_raise_op.py b/tests/test_raise_op.py index ba18fe1bff..80bf8f2144 100644 --- a/tests/test_raise_op.py +++ b/tests/test_raise_op.py @@ -79,10 +79,10 @@ def test_CheckAndRaise_basic_c(linker): with pytest.raises(CustomException, match=exc_msg): y_fn(0) - y = check_and_raise(at.as_tensor(1), conds) - y_fn = aesara.function([conds], y.shape, mode=Mode(linker, OPT_FAST_RUN)) + y = check_and_raise(x, conds) + y_fn = aesara.function([conds, x], y.shape, mode=Mode(linker, OPT_FAST_RUN)) - assert np.array_equal(y_fn(0), []) + assert np.array_equal(y_fn(0, 0), []) y = check_and_raise(x, at.as_tensor(0)) y_grad = aesara.grad(y, [x])