Skip to content

Commit

Permalink
Add FixedShapeTensorTypes and refactor basic shape logic to use them
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 30, 2021
1 parent a670243 commit 7da3d5d
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 21 deletions.
21 changes: 18 additions & 3 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
shape_tuple,
)
from aesara.tensor.type import (
FixedShapeTensorType,
TensorType,
discrete_dtypes,
float_dtypes,
Expand Down Expand Up @@ -1633,11 +1634,11 @@ def make_node(self, *inputs):

if inputs:
dtype = inputs[0].type.dtype
otype = FixedShapeTensorType(dtype, (len(inputs),))
else:
dtype = self.dtype
# bcastable = (len(inputs) == 1)
bcastable = False
otype = TensorType(broadcastable=(bcastable,), dtype=dtype)
otype = FixedShapeTensorType(dtype, 0)

return Apply(self, inputs, [otype()])

def perform(self, node, inputs, out_):
Expand Down Expand Up @@ -2131,6 +2132,13 @@ def addbroadcast(x, *axes):
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)

if isinstance(x.type, FixedShapeTensorType):
if not set(i for i, b in enumerate(x.broadcastable) if b).issuperset(axes):
raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}")
return x

rval = Rebroadcast(*[(axis, True) for axis in axes])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)

Expand Down Expand Up @@ -2161,6 +2169,13 @@ def unbroadcast(x, *axes):
A aesara tensor, which is unbroadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)

if isinstance(x.type, FixedShapeTensorType):
if not set(i for i, b in enumerate(x.broadcastable) if not b).issuperset(axes):
raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}")
return x

rval = Rebroadcast(*[(axis, False) for axis in axes])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)

Expand Down
3 changes: 2 additions & 1 deletion aesara/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aesara.tensor.extra_ops import broadcast_to
from aesara.tensor.math import maximum
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import int_dtypes
from aesara.tensor.type import FixedShapeTensorType, int_dtypes


def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True):
Expand Down Expand Up @@ -130,6 +130,7 @@ def normalize_size_param(size):
# `Scan` performs)
size = specify_shape(size, (get_vector_length(size),))

assert isinstance(size.type, FixedShapeTensorType)
assert size.dtype in int_dtypes

return size
Expand Down
52 changes: 45 additions & 7 deletions aesara/tensor/shape.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Dict
from typing import Dict, List, Tuple, Union

import numpy as np

Expand All @@ -13,7 +13,7 @@
from aesara.tensor import _get_vector_length
from aesara.tensor import basic as aet
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.type import TensorType, int_dtypes, tensor
from aesara.tensor.type import FixedShapeTensorType, TensorType, int_dtypes, tensor
from aesara.tensor.var import TensorConstant


Expand Down Expand Up @@ -61,7 +61,8 @@ def make_node(self, x):
# This will fail at execution time.
if not isinstance(x, Variable):
x = aet.as_tensor_variable(x)
return Apply(self, [x], [aesara.tensor.type.lvector()])
otype = FixedShapeTensorType("int64", (x.ndim,))
return Apply(self, [x], [otype()])

def perform(self, node, inp, out_):
(x,) = inp
Expand Down Expand Up @@ -126,16 +127,24 @@ def c_code_cache_version(self):
return tuple(version)


shape = Shape()
_shape = shape # was used in the past, now use shape directly.
_shape = Shape()


def shape(x: Variable) -> Variable:
"""Return the shape of `x`."""
x_type = getattr(x, "type", None)
if x_type and isinstance(x_type, FixedShapeTensorType):
return aet.as_tensor_variable(x_type.shape)
else:
return _shape(x)


@_get_vector_length.register(Shape)
def _get_vector_length_Shape(op, var):
return var.owner.inputs[0].type.ndim


def shape_tuple(x):
def shape_tuple(x: Variable) -> Tuple[Variable]:
"""Get a tuple of symbolic shape values.
This will return a `ScalarConstant` with the value ``1`` wherever
Expand Down Expand Up @@ -495,7 +504,36 @@ def c_code_cache_version(self):
return tuple(version)


specify_shape = SpecifyShape()
_specify_shape = SpecifyShape()


def specify_shape(
x: Variable,
shape: Union[
int, List[Union[int, Variable]], Tuple[Union[int, Variable]], Variable
],
):
"""Specify a fixed shape for a `Variable`."""

x = aet.as_tensor_variable(x)

if isinstance(x.type, FixedShapeTensorType):
# The `Variable` is already a fixed shape `Type`, so
# there's no need to put a useless `Op` in the graph
return x

if isinstance(shape, int) or (
isinstance(shape, (list, tuple)) and all(isinstance(s, int) for s in shape)
):
x_type = FixedShapeTensorType(x.type.dtype, shape, name=x.type.name)
x_new = x_type.convert_variable(x)

if x_new is None:
raise ValueError(f"{x} cannot be assigned the shape {shape}")

return x_new

return _specify_shape(x, shape)


@_get_vector_length.register(SpecifyShape)
Expand Down
72 changes: 71 additions & 1 deletion aesara/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import aesara
from aesara import scalar as aes
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.type import CType
from aesara.misc.safe_asarray import _asarray
from aesara.utils import apply_across_args
Expand Down Expand Up @@ -603,6 +603,76 @@ def get_size(self, shape_info):
return np.dtype(self.dtype).itemsize


class FixedShapeTensorType(TensorType):
"""A symbolic type representing a tensor/array with a fixed shape."""

__props__ = ("shape",)

def __init__(self, dtype, shape, name=None):
if not (isinstance(shape, int) or all(isinstance(s, int) for s in shape)):
raise TypeError(f"{shape} is not a valid shape")

if isinstance(shape, (list, tuple)):
self.shape = tuple(shape)
bcast = tuple(s == 1 for s in shape)
else:
self.shape = int(shape)
bcast = ()

super().__init__(dtype, bcast, name=name)

def __str__(self):
return repr(self)

def __repr__(self):
return f"FixedShapeTensorType({self.dtype}, {self.shape})"

def value_zeros(self):
"""Create an numpy ndarray full of 0 values."""
return np.zeros(self.shape, dtype=self.dtype)

def __eq__(self, other):
return (
type(self) == type(other)
and other.dtype == self.dtype
and other.shape == self.shape
)

def is_compatible(self, otype):
# Fixed-shape types can only be replaced by equal fixed-shape types
return self == otype

def __hash__(self):
return hash((type(self), self.dtype, self.shape))

def clone(self, dtype=None, shape=None):
if dtype is None:
dtype = self.dtype
if shape is None:
shape = self.shape
return self.__class__(dtype, shape, name=self.name)

def convert_variable(self, var):
if self == var.type:
return var

if isinstance(var.type, TensorType) and var.type.is_compatible(self):
new_type = FixedShapeTensorType(
var.type.dtype, self.shape, name=var.type.name
)
new_var = new_type()
new_var.name = var.name

if var.owner:
outputs_new = list(var.owner.outputs)
outputs_new[var.index] = new_var
# Construct a new graph in which the fixed-shape tensor is the
# output
_ = Apply(var.owner.op, var.owner.inputs, outputs_new)

return new_var


def values_eq_approx(
a, b, allow_remove_inf=False, allow_remove_nan=False, rtol=None, atol=None
):
Expand Down
31 changes: 25 additions & 6 deletions aesara/tensor/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.scalar import ComplexError, IntegerDivisionError
from aesara.tensor import _get_vector_length
from aesara.tensor.exceptions import AdvancedIndexingError
from aesara.tensor.type import TensorType
from aesara.tensor.type import FixedShapeTensorType, TensorType
from aesara.tensor.utils import hash_from_ndarray


Expand Down Expand Up @@ -856,7 +857,19 @@ def __init__(self, type, owner=None, index=None, name=None):
pdb.set_trace()


class FixedShapeTensorVariable(TensorVariable):
@property
def shape(self):
return self.type.shape


@_get_vector_length.register(FixedShapeTensorVariable)
def _get_vector_length_FixedShapeTensorVariable(op_or_var, var):
return var.size


TensorType.Variable = TensorVariable
FixedShapeTensorType.Variable = FixedShapeTensorVariable


class TensorConstantSignature(tuple):
Expand Down Expand Up @@ -973,14 +986,20 @@ def get_unique_value(x: TensorVariable) -> Optional[Number]:
return None


class TensorConstant(TensorVariable, Constant):
"""Subclass to add the tensor operators to the basic `Constant` class.
class TensorConstant(FixedShapeTensorVariable, Constant):
"""Subclass to add the tensor operators to the basic `Constant` class."""

To create a TensorConstant, use the `constant` function in this module.
def __init__(self, type, data, name=None):
data_shape = np.shape(data)

"""
if not isinstance(type, FixedShapeTensorType):
type = FixedShapeTensorType(type.dtype, data_shape, name=type.name)

if np.shape(data) != type.shape:
raise ValueError(
f"Shape of data ({data_shape}) does not match shape of type ({type.shape})"
)

def __init__(self, type, data, name=None):
Constant.__init__(self, type, data, name)

def __str__(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def test_CheckAndRaise_basic_c(linker):
with pytest.raises(CustomException, match=exc_msg):
y_fn(0)

y = check_and_raise(at.as_tensor(1), conds)
y_fn = aesara.function([conds], y.shape, mode=Mode(linker, OPT_FAST_RUN))
y = check_and_raise(x, conds)
y_fn = aesara.function([conds, x], y.shape, mode=Mode(linker, OPT_FAST_RUN))

assert np.array_equal(y_fn(0), [])
assert np.array_equal(y_fn(0, 0), [])

y = check_and_raise(x, at.as_tensor(0))
y_grad = aesara.grad(y, [x])
Expand Down

0 comments on commit 7da3d5d

Please sign in to comment.